44 #ifndef TPETRA_DETAILS_CUBLASGEMM_HPP 45 #define TPETRA_DETAILS_CUBLASGEMM_HPP 63 #include "TpetraCore_config.h" 74 template<
class ViewType1,
77 class CoefficientType,
80 #ifdef KOKKOS_HAVE_CUDA 81 static constexpr
bool value =
83 CoefficientType, IndexType>::value &&
84 std::is_same<typename ViewType1::execution_space, ::Kokkos::Cuda>::value &&
85 std::is_same<typename ViewType2::execution_space, ::Kokkos::Cuda>::value &&
86 std::is_same<typename ViewType3::execution_space, ::Kokkos::Cuda>::value;
87 #else // NOT KOKKOS_HAVE_CUDA 88 static constexpr
bool value =
false;
89 #endif // KOKKOS_HAVE_CUDA 98 cgemm (
const char char_transA,
99 const char char_transB,
103 const ::Kokkos::complex<float>& alpha,
104 const ::Kokkos::complex<float> A[],
106 const ::Kokkos::complex<float> B[],
108 const ::Kokkos::complex<float>& beta,
109 ::Kokkos::complex<float> C[],
116 dgemm (
const char char_transA,
117 const char char_transB,
134 sgemm (
const char char_transA,
135 const char char_transB,
152 zgemm (
const char char_transA,
153 const char char_transB,
157 const ::Kokkos::complex<double>& alpha,
158 const ::Kokkos::complex<double> A[],
160 const ::Kokkos::complex<double> B[],
162 const ::Kokkos::complex<double>& beta,
163 ::Kokkos::complex<double> C[],
168 template<
class ScalarType>
struct Gemm {};
171 struct Gemm< ::Kokkos::complex<float> > {
172 typedef ::Kokkos::complex<float> scalar_type;
175 gemm (
const char transA,
180 const scalar_type& alpha,
181 const scalar_type A[],
183 const scalar_type B[],
185 const scalar_type& beta,
189 return cgemm (transA, transB, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
194 struct Gemm<double> {
195 typedef double scalar_type;
198 gemm (
const char transA,
203 const scalar_type& alpha,
204 const scalar_type A[],
206 const scalar_type B[],
208 const scalar_type& beta,
212 return dgemm (transA, transB, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
218 typedef float scalar_type;
221 gemm (
const char transA,
226 const scalar_type& alpha,
227 const scalar_type A[],
229 const scalar_type B[],
231 const scalar_type& beta,
235 return sgemm (transA, transB, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
240 struct Gemm< ::Kokkos::complex<double> > {
241 typedef ::Kokkos::complex<double> scalar_type;
244 gemm (
const char transA,
249 const scalar_type& alpha,
250 const scalar_type A[],
252 const scalar_type B[],
254 const scalar_type& beta,
258 return zgemm (transA, transB, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
264 template<
class ViewType1,
267 class CoefficientType>
269 gemm (
const char transA,
271 const CoefficientType& alpha,
274 const CoefficientType& beta,
277 typedef CoefficientType scalar_type;
279 typedef int index_type;
281 const index_type lda = getStride2DView<ViewType1, index_type> (A);
282 const index_type ldb = getStride2DView<ViewType2, index_type> (B);
283 const index_type ldc = getStride2DView<ViewType3, index_type> (C);
285 const index_type m =
static_cast<index_type
> (C.dimension_0 ());
286 const index_type n =
static_cast<index_type
> (C.dimension_1 ());
287 const bool noTransA = (transA ==
'N' || transA ==
'n');
288 const index_type k =
static_cast<index_type
> (noTransA ?
291 impl_type::gemm (transA, transB, m, n, k,
292 alpha, A.data (), lda,
294 beta, C.data (), ldc);
302 #endif // TPETRA_DETAILS_CUBLASGEMM_HPP Namespace Tpetra contains the class and methods constituting the Tpetra library.
Type traits for Tpetra's BLAS wrappers; an implementation detail of Tpetra::MultiVector.
Wrappers for the BLAS library's implementation of _GEMM; implementation detail of Tpetra::MultiVector...
Implementation details of Tpetra.
Wrapper for the above wrappers, templated on scalar type (the type of each entry in the matrices)...
For this set of template parameters, can and should we implement Gemm (see below) using cuBLAS...
For this set of template parameters, can we implement Gemm (see below) using any compliant BLAS libra...