Tpetra parallel linear algebra  Version of the Day
Tpetra_Details_mklGemm.hpp
Go to the documentation of this file.
1 /*
2 //@HEADER
3 // ************************************************************************
4 //
5 // Kokkos: Node API and Parallel Node Kernels
6 // Copyright (2008) Sandia Corporation
7 //
8 // Under the terms of Contract DE-AC04-94AL85000 with Sandia Corporation,
9 // the U.S. Government retains certain rights in this software.
10 //
11 // Redistribution and use in source and binary forms, with or without
12 // modification, are permitted provided that the following conditions are
13 // met:
14 //
15 // 1. Redistributions of source code must retain the above copyright
16 // notice, this list of conditions and the following disclaimer.
17 //
18 // 2. Redistributions in binary form must reproduce the above copyright
19 // notice, this list of conditions and the following disclaimer in the
20 // documentation and/or other materials provided with the distribution.
21 //
22 // 3. Neither the name of the Corporation nor the names of the
23 // contributors may be used to endorse or promote products derived from
24 // this software without specific prior written permission.
25 //
26 // THIS SOFTWARE IS PROVIDED BY SANDIA CORPORATION "AS IS" AND ANY
27 // EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
28 // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
29 // PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL SANDIA CORPORATION OR THE
30 // CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
31 // EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
32 // PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
33 // PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
34 // LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
35 // NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
36 // SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
37 //
38 // Questions? Contact Michael A. Heroux (maherou@sandia.gov)
39 //
40 // ************************************************************************
41 //@HEADER
42 */
43 
44 #ifndef TPETRA_DETAILS_MKLGEMM_HPP
45 #define TPETRA_DETAILS_MKLGEMM_HPP
46 
56 
57 #include "TpetraCore_config.h"
58 #include "Tpetra_Details_Blas.hpp"
60 
61 namespace Tpetra {
62 namespace Details {
63 namespace Blas {
64 namespace Mkl {
65 
68 template<class ViewType1,
69  class ViewType2,
70  class ViewType3,
71  class CoefficientType = typename ViewType1::non_const_value_type,
72  class IndexType = int>
73 struct GemmCanUseMkl {
74 #ifdef HAVE_KOKKOSKERNELS_MKL
75 # ifdef KOKKOS_HAVE_CUDA
76  static constexpr bool value =
77  ::Tpetra::Details::Blas::Lib::GemmCanUseLib<ViewType1, ViewType2, ViewType3,
78  CoefficientType, IndexType>::value &&
79  ! std::is_same<typename ViewType1::execution_space, ::Kokkos::Cuda>::value &&
80  ! std::is_same<typename ViewType2::execution_space, ::Kokkos::Cuda>::value &&
81  ! std::is_same<typename ViewType3::execution_space, ::Kokkos::Cuda>::value;
82 # else // NOT KOKKOS_HAVE_CUDA
83  static constexpr bool value =
84  ::Tpetra::Details::Blas::Lib::GemmCanUseLib<ViewType1, ViewType2, ViewType3,
85  CoefficientType, IndexType>::value;
86 # endif // KOKKOS_HAVE_CUDA
87 #else // NOT HAVE_KOKKOSKERNELS_MKL
88  static constexpr bool value = false;
89 #endif // NOT HAVE_KOKKOSKERNELS_MKL
90 };
91 
92 namespace Impl {
93 
97 void
98 cgemm (const char char_transA,
99  const char char_transB,
100  const int m,
101  const int n,
102  const int k,
103  const ::Kokkos::complex<float>& alpha,
104  const ::Kokkos::complex<float> A[],
105  const int lda,
106  const ::Kokkos::complex<float> B[],
107  const int ldb,
108  const ::Kokkos::complex<float>& beta,
109  ::Kokkos::complex<float> C[],
110  const int ldc);
111 
115 void
116 dgemm (const char char_transA,
117  const char char_transB,
118  const int m,
119  const int n,
120  const int k,
121  const double alpha,
122  const double A[],
123  const int lda,
124  const double B[],
125  const int ldb,
126  const double beta,
127  double C[],
128  const int ldc);
129 
133 void
134 sgemm (const char char_transA,
135  const char char_transB,
136  const int m,
137  const int n,
138  const int k,
139  const float alpha,
140  const float A[],
141  const int lda,
142  const float B[],
143  const int ldb,
144  const float beta,
145  float C[],
146  const int ldc);
147 
151 void
152 zgemm (const char char_transA,
153  const char char_transB,
154  const int m,
155  const int n,
156  const int k,
157  const ::Kokkos::complex<double>& alpha,
158  const ::Kokkos::complex<double> A[],
159  const int lda,
160  const ::Kokkos::complex<double> B[],
161  const int ldb,
162  const ::Kokkos::complex<double>& beta,
163  ::Kokkos::complex<double> C[],
164  const int ldc);
165 
168 template<class ScalarType> struct Gemm {};
169 
170 template<>
171 struct Gemm< ::Kokkos::complex<float> > {
172  typedef ::Kokkos::complex<float> scalar_type;
173 
174  static void
175  gemm (const char transA,
176  const char transB,
177  const int m,
178  const int n,
179  const int k,
180  const scalar_type& alpha,
181  const scalar_type A[],
182  const int lda,
183  const scalar_type B[],
184  const int ldb,
185  const scalar_type& beta,
186  scalar_type C[],
187  const int ldc)
188  {
189  return cgemm (transA, transB, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
190  }
191 };
192 
193 template<>
194 struct Gemm<double> {
195  typedef double scalar_type;
196 
197  static void
198  gemm (const char transA,
199  const char transB,
200  const int m,
201  const int n,
202  const int k,
203  const scalar_type& alpha,
204  const scalar_type A[],
205  const int lda,
206  const scalar_type B[],
207  const int ldb,
208  const scalar_type& beta,
209  scalar_type C[],
210  const int ldc)
211  {
212  return dgemm (transA, transB, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
213  }
214 };
215 
216 template<>
217 struct Gemm<float> {
218  typedef float scalar_type;
219 
220  static void
221  gemm (const char transA,
222  const char transB,
223  const int m,
224  const int n,
225  const int k,
226  const scalar_type& alpha,
227  const scalar_type A[],
228  const int lda,
229  const scalar_type B[],
230  const int ldb,
231  const scalar_type& beta,
232  scalar_type C[],
233  const int ldc)
234  {
235  return sgemm (transA, transB, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
236  }
237 };
238 
239 template<>
240 struct Gemm< ::Kokkos::complex<double> > {
241  typedef ::Kokkos::complex<double> scalar_type;
242 
243  static void
244  gemm (const char transA,
245  const char transB,
246  const int m,
247  const int n,
248  const int k,
249  const scalar_type& alpha,
250  const scalar_type A[],
251  const int lda,
252  const scalar_type B[],
253  const int ldb,
254  const scalar_type& beta,
255  scalar_type C[],
256  const int ldc)
257  {
258  return zgemm (transA, transB, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
259  }
260 };
261 
262 } // namespace Impl
263 
264 template<class ViewType1,
265  class ViewType2,
266  class ViewType3,
267  class CoefficientType>
268 static void
269 gemm (const char transA,
270  const char transB,
271  const CoefficientType& alpha,
272  const ViewType1& A,
273  const ViewType2& B,
274  const CoefficientType& beta,
275  const ViewType3& C)
276 {
277  typedef CoefficientType scalar_type;
278  typedef Impl::Gemm<scalar_type> impl_type;
279  typedef int index_type;
280 
281  const index_type lda = getStride2DView<ViewType1, index_type> (A);
282  const index_type ldb = getStride2DView<ViewType2, index_type> (B);
283  const index_type ldc = getStride2DView<ViewType3, index_type> (C);
284 
285  const index_type m = static_cast<index_type> (C.dimension_0 ());
286  const index_type n = static_cast<index_type> (C.dimension_1 ());
287  const bool noTransA = (transA == 'N' || transA == 'n');
288  const index_type k = static_cast<index_type> (noTransA ?
289  A.dimension_1 () :
290  A.dimension_0 ());
291  impl_type::gemm (transA, transB, m, n, k,
292  alpha, A.data (), lda,
293  B.data (), ldb,
294  beta, C.data (), ldc);
295 }
296 
297 } // namespace Mkl
298 } // namespace Blas
299 } // namespace Details
300 } // namespace Tpetra
301 
302 #endif // TPETRA_DETAILS_MKLGEMM_HPP
Namespace Tpetra contains the class and methods constituting the Tpetra library.
Type traits for Tpetra&#39;s BLAS wrappers; an implementation detail of Tpetra::MultiVector.
Wrappers for the BLAS library&#39;s implementation of _GEMM; implementation detail of Tpetra::MultiVector...
For this set of template parameters, can and should we implement Gemm (see below) using the MKL...
Implementation details of Tpetra.
For this set of template parameters, can we implement Gemm (see below) using any compliant BLAS libra...
Wrapper for the above wrappers, templated on scalar type (the type of each entry in the matrices)...