Tpetra parallel linear algebra  Version of the Day
Tpetra_Details_libGemm.hpp
Go to the documentation of this file.
1 /*
2 //@HEADER
3 // ************************************************************************
4 //
5 // Kokkos: Node API and Parallel Node Kernels
6 // Copyright (2008) Sandia Corporation
7 //
8 // Under the terms of Contract DE-AC04-94AL85000 with Sandia Corporation,
9 // the U.S. Government retains certain rights in this software.
10 //
11 // Redistribution and use in source and binary forms, with or without
12 // modification, are permitted provided that the following conditions are
13 // met:
14 //
15 // 1. Redistributions of source code must retain the above copyright
16 // notice, this list of conditions and the following disclaimer.
17 //
18 // 2. Redistributions in binary form must reproduce the above copyright
19 // notice, this list of conditions and the following disclaimer in the
20 // documentation and/or other materials provided with the distribution.
21 //
22 // 3. Neither the name of the Corporation nor the names of the
23 // contributors may be used to endorse or promote products derived from
24 // this software without specific prior written permission.
25 //
26 // THIS SOFTWARE IS PROVIDED BY SANDIA CORPORATION "AS IS" AND ANY
27 // EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
28 // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
29 // PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL SANDIA CORPORATION OR THE
30 // CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
31 // EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
32 // PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
33 // PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
34 // LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
35 // NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
36 // SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
37 //
38 // Questions? Contact Michael A. Heroux (maherou@sandia.gov)
39 //
40 // ************************************************************************
41 //@HEADER
42 */
43 
44 #ifndef TPETRA_DETAILS_LIBGEMM_HPP
45 #define TPETRA_DETAILS_LIBGEMM_HPP
46 
58 
59 #include "TpetraCore_config.h"
60 #include "Tpetra_Details_Blas.hpp"
61 
62 namespace Tpetra {
63 namespace Details {
64 namespace Blas {
65 namespace Lib {
66 
70 template<class ViewType1,
71  class ViewType2,
72  class ViewType3,
73  class CoefficientType,
74  class IndexType>
75 struct GemmCanUseLib {
76  typedef typename ViewType1::non_const_value_type scalar_type;
77  typedef typename std::decay<CoefficientType>::type coeff_type;
78  typedef typename std::decay<IndexType>::type index_type;
79 
80  // All three Views must have the same entry types, the coefficient
81  // type must be the same as that, all these four types must be one
82  // of the four types that the BLAS library supports, and IndexType
83  // must be int.
84  //
85  // Modify this if you later add a TPL that can support more types
86  // than this.
87  static constexpr bool value =
88  std::is_same<scalar_type, typename ViewType2::non_const_value_type>::value &&
89  std::is_same<scalar_type, typename ViewType3::non_const_value_type>::value &&
90  std::is_same<scalar_type, coeff_type>::value &&
95  std::is_same<index_type, int>::value;
96 };
97 
98 namespace Impl {
99 
101 void
102 cgemm (const char char_transA,
103  const char char_transB,
104  const int m,
105  const int n,
106  const int k,
107  const ::Kokkos::complex<float>& alpha,
108  const ::Kokkos::complex<float> A[],
109  const int lda,
110  const ::Kokkos::complex<float> B[],
111  const int ldb,
112  const ::Kokkos::complex<float>& beta,
113  ::Kokkos::complex<float> C[],
114  const int ldc);
115 
117 void
118 dgemm (const char char_transA,
119  const char char_transB,
120  const int m,
121  const int n,
122  const int k,
123  const double alpha,
124  const double A[],
125  const int lda,
126  const double B[],
127  const int ldb,
128  const double beta,
129  double C[],
130  const int ldc);
131 
133 void
134 sgemm (const char char_transA,
135  const char char_transB,
136  const int m,
137  const int n,
138  const int k,
139  const float alpha,
140  const float A[],
141  const int lda,
142  const float B[],
143  const int ldb,
144  const float beta,
145  float C[],
146  const int ldc);
147 
149 void
150 zgemm (const char char_transA,
151  const char char_transB,
152  const int m,
153  const int n,
154  const int k,
155  const ::Kokkos::complex<double>& alpha,
156  const ::Kokkos::complex<double> A[],
157  const int lda,
158  const ::Kokkos::complex<double> B[],
159  const int ldb,
160  const ::Kokkos::complex<double>& beta,
161  ::Kokkos::complex<double> C[],
162  const int ldc);
163 
166 template<class ScalarType> struct Gemm {};
167 
168 template<>
169 struct Gemm< ::Kokkos::complex<float> > {
170  typedef ::Kokkos::complex<float> scalar_type;
171 
172  static void
173  gemm (const char transA,
174  const char transB,
175  const int m,
176  const int n,
177  const int k,
178  const scalar_type& alpha,
179  const scalar_type A[],
180  const int lda,
181  const scalar_type B[],
182  const int ldb,
183  const scalar_type& beta,
184  scalar_type C[],
185  const int ldc)
186  {
187  return cgemm (transA, transB, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
188  }
189 };
190 
191 template<>
192 struct Gemm<double> {
193  typedef double scalar_type;
194 
195  static void
196  gemm (const char transA,
197  const char transB,
198  const int m,
199  const int n,
200  const int k,
201  const scalar_type& alpha,
202  const scalar_type A[],
203  const int lda,
204  const scalar_type B[],
205  const int ldb,
206  const scalar_type& beta,
207  scalar_type C[],
208  const int ldc)
209  {
210  return dgemm (transA, transB, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
211  }
212 };
213 
214 template<>
215 struct Gemm<float> {
216  typedef float scalar_type;
217 
218  static void
219  gemm (const char transA,
220  const char transB,
221  const int m,
222  const int n,
223  const int k,
224  const scalar_type& alpha,
225  const scalar_type A[],
226  const int lda,
227  const scalar_type B[],
228  const int ldb,
229  const scalar_type& beta,
230  scalar_type C[],
231  const int ldc)
232  {
233  return sgemm (transA, transB, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
234  }
235 };
236 
237 template<>
238 struct Gemm< ::Kokkos::complex<double> > {
239  typedef ::Kokkos::complex<double> scalar_type;
240 
241  static void
242  gemm (const char transA,
243  const char transB,
244  const int m,
245  const int n,
246  const int k,
247  const scalar_type& alpha,
248  const scalar_type A[],
249  const int lda,
250  const scalar_type B[],
251  const int ldb,
252  const scalar_type& beta,
253  scalar_type C[],
254  const int ldc)
255  {
256  return zgemm (transA, transB, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
257  }
258 };
259 
260 } // namespace Impl
261 
262 template<class ViewType1,
263  class ViewType2,
264  class ViewType3,
265  class CoefficientType>
266 static void
267 gemm (const char transA,
268  const char transB,
269  const CoefficientType& alpha,
270  const ViewType1& A,
271  const ViewType2& B,
272  const CoefficientType& beta,
273  const ViewType3& C)
274 {
275  typedef CoefficientType scalar_type;
276  typedef Impl::Gemm<scalar_type> impl_type;
277  typedef int index_type;
278 
279  const index_type lda = getStride2DView<ViewType1, index_type> (A);
280  const index_type ldb = getStride2DView<ViewType2, index_type> (B);
281  const index_type ldc = getStride2DView<ViewType3, index_type> (C);
282 
283  const index_type m = static_cast<index_type> (C.dimension_0 ());
284  const index_type n = static_cast<index_type> (C.dimension_1 ());
285  const bool noTransA = (transA == 'N' || transA == 'n');
286  const index_type k = static_cast<index_type> (noTransA ?
287  A.dimension_1 () :
288  A.dimension_0 ());
289  impl_type::gemm (transA, transB, m, n, k,
290  alpha, A.data (), lda,
291  B.data (), ldb,
292  beta, C.data (), ldc);
293 }
294 
295 } // namespace Lib
296 } // namespace Blas
297 } // namespace Details
298 } // namespace Tpetra
299 
300 #endif // TPETRA_DETAILS_LIBGEMM_HPP
Namespace Tpetra contains the class and methods constituting the Tpetra library.
Type traits for Tpetra&#39;s BLAS wrappers; an implementation detail of Tpetra::MultiVector.
Implementation details of Tpetra.
Do BLAS libraries (all that are compliant with the BLAS Standard) support the given "scalar" (matrix ...
For this set of template parameters, can we implement Gemm (see below) using any compliant BLAS libra...
Wrapper for the above wrappers, templated on scalar type (the type of each entry in the matrices)...
Do BLAS libraries (all that are compliant with the BLAS Standard) support the given Kokkos array layo...