Tpetra parallel linear algebra  Version of the Day
Tpetra_Details_defaultGemm.hpp
Go to the documentation of this file.
1 /*
2 //@HEADER
3 // ************************************************************************
4 //
5 // Kokkos: Node API and Parallel Node Kernels
6 // Copyright (2008) Sandia Corporation
7 //
8 // Under the terms of Contract DE-AC04-94AL85000 with Sandia Corporation,
9 // the U.S. Government retains certain rights in this software.
10 //
11 // Redistribution and use in source and binary forms, with or without
12 // modification, are permitted provided that the following conditions are
13 // met:
14 //
15 // 1. Redistributions of source code must retain the above copyright
16 // notice, this list of conditions and the following disclaimer.
17 //
18 // 2. Redistributions in binary form must reproduce the above copyright
19 // notice, this list of conditions and the following disclaimer in the
20 // documentation and/or other materials provided with the distribution.
21 //
22 // 3. Neither the name of the Corporation nor the names of the
23 // contributors may be used to endorse or promote products derived from
24 // this software without specific prior written permission.
25 //
26 // THIS SOFTWARE IS PROVIDED BY SANDIA CORPORATION "AS IS" AND ANY
27 // EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
28 // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
29 // PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL SANDIA CORPORATION OR THE
30 // CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
31 // EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
32 // PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
33 // PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
34 // LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
35 // NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
36 // SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
37 //
38 // Questions? Contact Michael A. Heroux (maherou@sandia.gov)
39 //
40 // ************************************************************************
41 //@HEADER
42 */
43 
44 #ifndef TPETRA_DETAILS_DEFAULTGEMM_HPP
45 #define TPETRA_DETAILS_DEFAULTGEMM_HPP
46 
54 
55 #include "TpetraCore_config.h"
56 #include "Kokkos_ArithTraits.hpp"
57 #include "Kokkos_Complex.hpp"
58 #include <type_traits>
59 
60 namespace Tpetra {
61 namespace Details {
62 namespace Blas {
63 namespace Default {
64 
75 template<class ViewType1,
76  class ViewType2,
77  class ViewType3,
78  class CoefficientType,
79  class IndexType = int>
80 void
81 gemm (const char transA,
82  const char transB,
83  const CoefficientType& alpha,
84  const ViewType1& A,
85  const ViewType2& B,
86  const CoefficientType& beta,
87  const ViewType3& C)
88 {
89  // Assert that A, B, and C are in fact matrices
90  static_assert (ViewType1::rank == 2, "GEMM: A must have rank 2 (be a matrix).");
91  static_assert (ViewType2::rank == 2, "GEMM: B must have rank 2 (be a matrix).");
92  static_assert (ViewType3::rank == 2, "GEMM: C must have rank 2 (be a matrix).");
93 
94  typedef typename ViewType3::non_const_value_type c_value_type;
95  typedef Kokkos::Details::ArithTraits<CoefficientType> STS;
96  const CoefficientType ZERO = STS::zero ();
97  const CoefficientType ONE = STS::one ();
98 
99  // Get the dimensions
100  IndexType m, n, k;
101  if (transA == 'N' || transA == 'n') {
102  m = static_cast<IndexType> (A.dimension_0 ());
103  n = static_cast<IndexType> (A.dimension_1 ());
104  }
105  else {
106  m = static_cast<IndexType> (A.dimension_1 ());
107  n = static_cast<IndexType> (A.dimension_0 ());
108  }
109  k = static_cast<IndexType> (C.dimension_1 ());
110 
111  // quick return if possible
112  if (alpha == ZERO && beta == ONE) {
113  return;
114  }
115 
116  // And if alpha equals zero...
117  if (alpha == ZERO) {
118  if (beta == ZERO) {
119  for (IndexType i = 0; i < m; ++i) {
120  for (IndexType j = 0; j < k; ++j) {
121  C(i,j) = ZERO;
122  }
123  }
124  }
125  else {
126  for (IndexType i = 0; i < m; ++i) {
127  for (IndexType j = 0; j < k; ++j) {
128  C(i,j) = beta*C(i,j);
129  }
130  }
131  }
132  }
133 
134  // Start the operations
135  if (transB == 'n' || transB == 'N') {
136  if (transA == 'n' || transA == 'N') {
137  // Form C = alpha*A*B + beta*C
138  for (IndexType j = 0; j < n; ++j) {
139  if (beta == ZERO) {
140  for (IndexType i = 0; i < m; ++i) {
141  C(i,j) = ZERO;
142  }
143  }
144  else if (beta != ONE) {
145  for (IndexType i = 0; i < m; ++i) {
146  C(i,j) = beta*C(i,j);
147  }
148  }
149  for (IndexType l = 0; l < k; ++l) {
150  // Don't use c_value_type here, since it unnecessarily
151  // forces type conversion before we assign to C(i,j).
152  auto temp = alpha*B(l,j);
153  for (IndexType i = 0; i < m; ++i) {
154  C(i,j) = C(i,j) + temp*A(i,l);
155  }
156  }
157  }
158  }
159  else {
160  // Form C = alpha*A**T*B + beta*C
161  for (IndexType j = 0; j < n; ++j) {
162  for (IndexType i = 0; i < m; ++i) {
163  c_value_type temp = ZERO;
164  for (IndexType l = 0; l < k; ++l) {
165  temp = temp + A(l,i)*B(l,j);
166  }
167  if (beta == ZERO) {
168  C(i,j) = alpha*temp;
169  }
170  else {
171  C(i,j) = alpha*temp + beta*C(i,j);
172  }
173  }
174  }
175  }
176  }
177  else {
178  if (transA == 'n' || transA == 'N') {
179  // Form C = alpha*A*B**T + beta*C
180  for (IndexType j = 0; j < n; ++j) {
181  if (beta == ZERO) {
182  for (IndexType i = 0; i < m; ++i) {
183  C(i,j) = ZERO;
184  }
185  }
186  else if (beta != ONE) {
187  for (IndexType i = 0; i < m; ++i) {
188  C(i,j) = beta*C(i,j);
189  }
190  }
191  for (IndexType l = 0; l < k; ++l) {
192  // Don't use c_value_type here, since it unnecessarily
193  // forces type conversion before we assign to C(i,j).
194  auto temp = alpha*B(j,l);
195  for (IndexType i = 0; i < m; ++i) {
196  C(i,j) = C(i,j) + temp*A(i,l);
197  }
198  }
199  }
200  }
201  else {
202  // Form C = alpha*A**T*B**T + beta*C
203  for (IndexType j = 0; j < n; ++j) {
204  for (IndexType i = 0; i < m; ++i) {
205  c_value_type temp = ZERO;
206  for (IndexType l = 0; l < k; ++l) {
207  temp = temp + A(l,i)*B(j,l);
208  }
209  if (beta == ZERO) {
210  C(i,j) = alpha*temp;
211  }
212  else {
213  C(i,j) = alpha*temp + beta*C(i,j);
214  }
215  }
216  }
217  }
218  }
219 }
220 
221 } // namespace Default
222 } // namespace Blas
223 } // namespace Details
224 } // namespace Tpetra
225 
226 #endif // TPETRA_DETAILS_DEFAULTGEMM_HPP
Namespace Tpetra contains the class and methods constituting the Tpetra library.
Implementation details of Tpetra.
void gemm(const char transA, const char transB, const CoefficientType &alpha, const ViewType1 &A, const ViewType2 &B, const CoefficientType &beta, const ViewType3 &C)
Default implementation of dense matrix-matrix multiply on a single MPI process: C := alpha*A*B + beta...
Replace old values with zero.