44 #ifndef TPETRA_DETAILS_MKLGEMM_HPP 45 #define TPETRA_DETAILS_MKLGEMM_HPP 57 #include "TpetraCore_config.h" 68 template<
class ViewType1,
71 class CoefficientType =
typename ViewType1::non_const_value_type,
72 class IndexType =
int>
74 #ifdef HAVE_KOKKOSKERNELS_MKL 75 # ifdef KOKKOS_HAVE_CUDA 76 static constexpr
bool value =
78 CoefficientType, IndexType>::value &&
79 ! std::is_same<typename ViewType1::execution_space, ::Kokkos::Cuda>::value &&
80 ! std::is_same<typename ViewType2::execution_space, ::Kokkos::Cuda>::value &&
81 ! std::is_same<typename ViewType3::execution_space, ::Kokkos::Cuda>::value;
82 # else // NOT KOKKOS_HAVE_CUDA 83 static constexpr
bool value =
85 CoefficientType, IndexType>::value;
86 # endif // KOKKOS_HAVE_CUDA 87 #else // NOT HAVE_KOKKOSKERNELS_MKL 88 static constexpr
bool value =
false;
89 #endif // NOT HAVE_KOKKOSKERNELS_MKL 98 cgemm (
const char char_transA,
99 const char char_transB,
103 const ::Kokkos::complex<float>& alpha,
104 const ::Kokkos::complex<float> A[],
106 const ::Kokkos::complex<float> B[],
108 const ::Kokkos::complex<float>& beta,
109 ::Kokkos::complex<float> C[],
116 dgemm (
const char char_transA,
117 const char char_transB,
134 sgemm (
const char char_transA,
135 const char char_transB,
152 zgemm (
const char char_transA,
153 const char char_transB,
157 const ::Kokkos::complex<double>& alpha,
158 const ::Kokkos::complex<double> A[],
160 const ::Kokkos::complex<double> B[],
162 const ::Kokkos::complex<double>& beta,
163 ::Kokkos::complex<double> C[],
168 template<
class ScalarType>
struct Gemm {};
171 struct Gemm< ::Kokkos::complex<float> > {
172 typedef ::Kokkos::complex<float> scalar_type;
175 gemm (
const char transA,
180 const scalar_type& alpha,
181 const scalar_type A[],
183 const scalar_type B[],
185 const scalar_type& beta,
189 return cgemm (transA, transB, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
194 struct Gemm<double> {
195 typedef double scalar_type;
198 gemm (
const char transA,
203 const scalar_type& alpha,
204 const scalar_type A[],
206 const scalar_type B[],
208 const scalar_type& beta,
212 return dgemm (transA, transB, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
218 typedef float scalar_type;
221 gemm (
const char transA,
226 const scalar_type& alpha,
227 const scalar_type A[],
229 const scalar_type B[],
231 const scalar_type& beta,
235 return sgemm (transA, transB, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
240 struct Gemm< ::Kokkos::complex<double> > {
241 typedef ::Kokkos::complex<double> scalar_type;
244 gemm (
const char transA,
249 const scalar_type& alpha,
250 const scalar_type A[],
252 const scalar_type B[],
254 const scalar_type& beta,
258 return zgemm (transA, transB, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
264 template<
class ViewType1,
267 class CoefficientType>
269 gemm (
const char transA,
271 const CoefficientType& alpha,
274 const CoefficientType& beta,
277 typedef CoefficientType scalar_type;
279 typedef int index_type;
281 const index_type lda = getStride2DView<ViewType1, index_type> (A);
282 const index_type ldb = getStride2DView<ViewType2, index_type> (B);
283 const index_type ldc = getStride2DView<ViewType3, index_type> (C);
285 const index_type m =
static_cast<index_type
> (C.dimension_0 ());
286 const index_type n =
static_cast<index_type
> (C.dimension_1 ());
287 const bool noTransA = (transA ==
'N' || transA ==
'n');
288 const index_type k =
static_cast<index_type
> (noTransA ?
291 impl_type::gemm (transA, transB, m, n, k,
292 alpha, A.data (), lda,
294 beta, C.data (), ldc);
302 #endif // TPETRA_DETAILS_MKLGEMM_HPP Namespace Tpetra contains the class and methods constituting the Tpetra library.
Type traits for Tpetra's BLAS wrappers; an implementation detail of Tpetra::MultiVector.
Wrappers for the BLAS library's implementation of _GEMM; implementation detail of Tpetra::MultiVector...
For this set of template parameters, can and should we implement Gemm (see below) using the MKL...
Implementation details of Tpetra.
For this set of template parameters, can we implement Gemm (see below) using any compliant BLAS libra...
Wrapper for the above wrappers, templated on scalar type (the type of each entry in the matrices)...