44 #ifndef TPETRA_DETAILS_DEFAULTGEMM_HPP 45 #define TPETRA_DETAILS_DEFAULTGEMM_HPP 55 #include "TpetraCore_config.h" 56 #include "Kokkos_ArithTraits.hpp" 57 #include "Kokkos_Complex.hpp" 58 #include <type_traits> 75 template<
class ViewType1,
78 class CoefficientType,
79 class IndexType =
int>
83 const CoefficientType& alpha,
86 const CoefficientType& beta,
90 static_assert (ViewType1::rank == 2,
"GEMM: A must have rank 2 (be a matrix).");
91 static_assert (ViewType2::rank == 2,
"GEMM: B must have rank 2 (be a matrix).");
92 static_assert (ViewType3::rank == 2,
"GEMM: C must have rank 2 (be a matrix).");
94 typedef typename ViewType3::non_const_value_type c_value_type;
95 typedef Kokkos::Details::ArithTraits<CoefficientType> STS;
96 const CoefficientType
ZERO = STS::zero ();
97 const CoefficientType ONE = STS::one ();
101 if (transA ==
'N' || transA ==
'n') {
102 m =
static_cast<IndexType
> (A.dimension_0 ());
103 n =
static_cast<IndexType
> (A.dimension_1 ());
106 m =
static_cast<IndexType
> (A.dimension_1 ());
107 n =
static_cast<IndexType
> (A.dimension_0 ());
109 k =
static_cast<IndexType
> (C.dimension_1 ());
112 if (alpha == ZERO && beta == ONE) {
119 for (IndexType i = 0; i < m; ++i) {
120 for (IndexType j = 0; j < k; ++j) {
126 for (IndexType i = 0; i < m; ++i) {
127 for (IndexType j = 0; j < k; ++j) {
128 C(i,j) = beta*C(i,j);
135 if (transB ==
'n' || transB ==
'N') {
136 if (transA ==
'n' || transA ==
'N') {
138 for (IndexType j = 0; j < n; ++j) {
140 for (IndexType i = 0; i < m; ++i) {
144 else if (beta != ONE) {
145 for (IndexType i = 0; i < m; ++i) {
146 C(i,j) = beta*C(i,j);
149 for (IndexType l = 0; l < k; ++l) {
152 auto temp = alpha*B(l,j);
153 for (IndexType i = 0; i < m; ++i) {
154 C(i,j) = C(i,j) + temp*A(i,l);
161 for (IndexType j = 0; j < n; ++j) {
162 for (IndexType i = 0; i < m; ++i) {
163 c_value_type temp =
ZERO;
164 for (IndexType l = 0; l < k; ++l) {
165 temp = temp + A(l,i)*B(l,j);
171 C(i,j) = alpha*temp + beta*C(i,j);
178 if (transA ==
'n' || transA ==
'N') {
180 for (IndexType j = 0; j < n; ++j) {
182 for (IndexType i = 0; i < m; ++i) {
186 else if (beta != ONE) {
187 for (IndexType i = 0; i < m; ++i) {
188 C(i,j) = beta*C(i,j);
191 for (IndexType l = 0; l < k; ++l) {
194 auto temp = alpha*B(j,l);
195 for (IndexType i = 0; i < m; ++i) {
196 C(i,j) = C(i,j) + temp*A(i,l);
203 for (IndexType j = 0; j < n; ++j) {
204 for (IndexType i = 0; i < m; ++i) {
205 c_value_type temp =
ZERO;
206 for (IndexType l = 0; l < k; ++l) {
207 temp = temp + A(l,i)*B(j,l);
213 C(i,j) = alpha*temp + beta*C(i,j);
226 #endif // TPETRA_DETAILS_DEFAULTGEMM_HPP Namespace Tpetra contains the class and methods constituting the Tpetra library.
Implementation details of Tpetra.
void gemm(const char transA, const char transB, const CoefficientType &alpha, const ViewType1 &A, const ViewType2 &B, const CoefficientType &beta, const ViewType3 &C)
Default implementation of dense matrix-matrix multiply on a single MPI process: C := alpha*A*B + beta...
Replace old values with zero.