44 #ifndef KOKKOS_MV_GEMM_HPP 45 #define KOKKOS_MV_GEMM_HPP 52 #include <Teuchos_BLAS.hpp> 53 #include <Kokkos_Blas2_MV.hpp> 66 class BLAS<int, ::Kokkos::complex<float> > {
68 typedef float mag_type;
69 typedef ::Kokkos::complex<float> val_type;
70 typedef std::complex<float> impl_type;
73 BLAS (
const BLAS<int, val_type>&) {}
87 GEMV (ETransp trans,
const int m,
const int n,
const val_type alpha,
88 const val_type* A,
const int lda,
const val_type* x,
const int incx,
89 const val_type beta, val_type* y,
const int incy)
const 91 BLAS<int, impl_type> blas;
92 blas.GEMV (trans, m, n, static_cast<impl_type> (alpha),
93 reinterpret_cast<const impl_type*> (A), lda,
94 reinterpret_cast<const impl_type*> (x), incx,
95 static_cast<impl_type> (beta),
96 reinterpret_cast<impl_type*> (y), incy);
103 GEMM (ETransp transa, ETransp transb,
const int m,
const int n,
const int k,
104 const val_type alpha,
const val_type* A,
const int lda,
105 const val_type* B,
const int ldb,
const val_type beta, val_type* C,
108 BLAS<int, impl_type> blas;
109 blas.GEMM (transa, transb, m, n, k,
110 static_cast<impl_type> (alpha),
111 reinterpret_cast<const impl_type*> (A), lda,
112 reinterpret_cast<const impl_type*> (B), ldb,
113 static_cast<impl_type> (beta),
114 reinterpret_cast<impl_type*> (C), ldc);
124 class BLAS<int, ::Kokkos::complex<double> > {
126 typedef double mag_type;
127 typedef ::Kokkos::complex<double> val_type;
128 typedef std::complex<double> impl_type;
131 BLAS (
const BLAS<int, val_type>&) {}
145 GEMV (ETransp trans,
const int m,
const int n,
const val_type alpha,
146 const val_type* A,
const int lda,
const val_type* x,
const int incx,
147 const val_type beta, val_type* y,
const int incy)
const 149 BLAS<int, impl_type> blas;
150 blas.GEMV (trans, m, n, static_cast<impl_type> (alpha),
151 reinterpret_cast<const impl_type*> (A), lda,
152 reinterpret_cast<const impl_type*> (x), incx,
153 static_cast<impl_type> (beta),
154 reinterpret_cast<impl_type*> (y), incy);
161 GEMM (ETransp transa, ETransp transb,
const int m,
const int n,
const int k,
162 const val_type alpha,
const val_type* A,
const int lda,
163 const val_type* B,
const int ldb,
const val_type beta, val_type* C,
166 BLAS<int, impl_type> blas;
167 blas.GEMM (transa, transb, m, n, k,
168 static_cast<impl_type> (alpha),
169 reinterpret_cast<const impl_type*> (A), lda,
170 reinterpret_cast<const impl_type*> (B), ldb,
171 static_cast<impl_type> (beta),
172 reinterpret_cast<impl_type*> (C), ldc);
187 template<
class ViewType>
188 size_t getStride2DView (ViewType A) {
191 return A.dimension_1 () > 1 ? stride[1] : A.dimension_0 ();
201 template <
typename Scalar,
typename DeviceType>
205 GEMM (
const Teuchos::ETransp transA,
206 const Teuchos::ETransp transB,
208 const View<const Scalar**, LayoutLeft, DeviceType>& A,
209 const View<const Scalar**, LayoutLeft, DeviceType>& B,
211 const View<Scalar**, LayoutLeft, DeviceType>& C)
213 const int n =
static_cast<int> (C.dimension_1 ());
217 if (n == 1 && transB == Teuchos::NO_TRANS) {
218 const int lda =
static_cast<int> (Impl::getStride2DView (A));
219 Teuchos::BLAS<int,Scalar> blas;
220 blas.GEMV (transA, A.dimension_0 (), A.dimension_1 (),
221 alpha, A.ptr_on_device (), lda,
222 B.ptr_on_device (),
static_cast<int> (1),
223 beta, C.ptr_on_device (),
static_cast<int> (1));
226 const char ctransA = (transA == Teuchos::CONJ_TRANS ?
'C' :
227 (transA == Teuchos::TRANS ?
'T' :
'N'));
228 const char ctransB = (transB == Teuchos::CONJ_TRANS ?
'C' :
229 (transB == Teuchos::TRANS ?
'T' :
'N'));
230 ::Tpetra::Details::Blas::gemm (ctransA, ctransB, alpha, A, B, beta, C);
237 #ifdef HAVE_KOKKOSKERNELS_MKL 238 template <
typename DeviceType>
242 GEMM (
const Teuchos::ETransp transA,
243 const Teuchos::ETransp transB,
245 const View<const double**, LayoutLeft, DeviceType>& A,
246 const View<const double**, LayoutLeft, DeviceType>& B,
248 const View<double**, LayoutLeft, DeviceType>& C)
250 const int n =
static_cast<int> (C.dimension_1 ());
254 if (n == 1 && transB == Teuchos::NO_TRANS) {
256 if (transA == Teuchos::TRANS) {
259 else if (transA == Teuchos::CONJ_TRANS) {
262 auto B_0 = Kokkos::subview (B, Kokkos::ALL (), 0);
263 auto C_0 = Kokkos::subview (C, Kokkos::ALL (), 0);
264 KokkosBlas::gemv (&trans, alpha, A, B_0, beta, C_0);
267 const char ctransA = (transA == Teuchos::CONJ_TRANS ?
'C' :
268 (transA == Teuchos::TRANS ?
'T' :
'N'));
269 const char ctransB = (transB == Teuchos::CONJ_TRANS ?
'C' :
270 (transB == Teuchos::TRANS ?
'T' :
'N'));
271 ::Tpetra::Details::Blas::gemm (ctransA, ctransB,
272 alpha, A, B, beta, C);
276 #endif // HAVE_KOKKOSKERNELS_MKL 278 #ifdef KOKKOS_HAVE_CUDA 279 template <
typename Scalar>
283 GEMM (
const Teuchos::ETransp transA,
284 const Teuchos::ETransp transB,
286 const View<const Scalar**, LayoutLeft, Cuda>& A,
287 const View<const Scalar**,LayoutLeft, Cuda>& B,
289 const View<Scalar**,LayoutLeft,Cuda>& C)
291 TEUCHOS_TEST_FOR_EXCEPTION
292 (
true, std::logic_error,
"DeviceGEMM: Kokkos::Cuda has no support " 293 "for GEMM operations over Scalar=" << Teuchos::typeName(alpha) <<
".");
301 GEMM (
const Teuchos::ETransp transA,
302 const Teuchos::ETransp transB,
304 const View<const float**,LayoutLeft,Cuda>& A,
305 const View<const float**,LayoutLeft,Cuda>& B,
307 const View<float**,LayoutLeft,Cuda>& C)
309 const char ctransA = (transA == Teuchos::NO_TRANS ?
'N' :
'T');
310 const char ctransB = (transB == Teuchos::NO_TRANS ?
'N' :
'T');
312 ::Tpetra::Details::Blas::gemm (ctransA, ctransB,
313 alpha, A, B, beta, C);
321 GEMM (
const Teuchos::ETransp transA,
322 const Teuchos::ETransp transB,
324 const View<const double**, LayoutLeft, Cuda>& A,
325 const View<const double**, LayoutLeft, Cuda>& B,
327 const View<double**, LayoutLeft, Cuda>& C)
329 const char ctransA = (transA == Teuchos::NO_TRANS ?
'N' :
'T');
330 const char ctransB = (transB == Teuchos::NO_TRANS ?
'N' :
'T');
332 ::Tpetra::Details::Blas::gemm (ctransA, ctransB,
333 alpha, A, B, beta, C);
336 #endif // KOKKOS_HAVE_CUDA 339 #endif // KOKKOS_MV_GEMM_HPP KOKKOS_INLINE_FUNCTION void GEMV(const CoeffType &alpha, const BlkType &A, const VecType1 &x, const VecType2 &y)
y := y + alpha * A * x (dense matrix-vector multiply)
KOKKOS_INLINE_FUNCTION void GEMM(const char transA[], const char transB[], const CoefficientType &alpha, const ViewType1 &A, const ViewType2 &B, const CoefficientType &beta, const ViewType3 &C)
Small dense matrix-matrix multiply: C := alpha*A*B + beta*C
Class that provides GEMM for a particular Kokkos Device.
Declaration and definition of Tpetra::Details::Blas::gemm, an implementation detail of Tpetra::MultiV...