Tpetra parallel linear algebra  Version of the Day
Tpetra_Details_gemm.hpp
Go to the documentation of this file.
1 /*
2 //@HEADER
3 // ************************************************************************
4 //
5 // Kokkos: Node API and Parallel Node Kernels
6 // Copyright (2008) Sandia Corporation
7 //
8 // Under the terms of Contract DE-AC04-94AL85000 with Sandia Corporation,
9 // the U.S. Government retains certain rights in this software.
10 //
11 // Redistribution and use in source and binary forms, with or without
12 // modification, are permitted provided that the following conditions are
13 // met:
14 //
15 // 1. Redistributions of source code must retain the above copyright
16 // notice, this list of conditions and the following disclaimer.
17 //
18 // 2. Redistributions in binary form must reproduce the above copyright
19 // notice, this list of conditions and the following disclaimer in the
20 // documentation and/or other materials provided with the distribution.
21 //
22 // 3. Neither the name of the Corporation nor the names of the
23 // contributors may be used to endorse or promote products derived from
24 // this software without specific prior written permission.
25 //
26 // THIS SOFTWARE IS PROVIDED BY SANDIA CORPORATION "AS IS" AND ANY
27 // EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
28 // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
29 // PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL SANDIA CORPORATION OR THE
30 // CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
31 // EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
32 // PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
33 // PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
34 // LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
35 // NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
36 // SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
37 //
38 // Questions? Contact Michael A. Heroux (maherou@sandia.gov)
39 //
40 // ************************************************************************
41 //@HEADER
42 */
43 
44 #ifndef TPETRA_DETAILS_GEMM_HPP
45 #define TPETRA_DETAILS_GEMM_HPP
46 
58 
59 #include "Tpetra_Details_Blas.hpp"
64 #include <type_traits>
65 
66 namespace Tpetra {
67 namespace Details {
68 namespace Blas {
69 namespace Impl {
70 
72 template<class ViewType1,
73  class ViewType2,
74  class ViewType3,
75  class CoefficientType,
76  class IndexType,
77  const bool canUseBlasLibrary =
78  Lib::GemmCanUseLib<ViewType1, ViewType2, ViewType3, CoefficientType, IndexType>::value,
79  const bool canUseCublas =
80  Cublas::GemmCanUseCublas<ViewType1, ViewType2, ViewType3, CoefficientType, IndexType>::value,
81  const bool canUseMkl =
82  Mkl::GemmCanUseMkl<ViewType1, ViewType2, ViewType3, CoefficientType, IndexType>::value>
83 struct Gemm {
84  static void
85  gemm (const char transA,
86  const char transB,
87  const CoefficientType& alpha,
88  const ViewType1& A,
89  const ViewType2& B,
90  const CoefficientType& beta,
91  const ViewType3& C)
92  {
93  ::Tpetra::Details::Blas::Default::gemm (transA, transB, alpha, A, B, beta, C);
94  }
95 };
96 
98 template<class ViewType1,
99  class ViewType2,
100  class ViewType3,
101  class CoefficientType,
102  class IndexType>
103 struct Gemm<ViewType1, ViewType2, ViewType3, CoefficientType, IndexType,
104  true, true, false> {
105  static void
106  gemm (const char transA,
107  const char transB,
108  const CoefficientType& alpha,
109  const ViewType1& A,
110  const ViewType2& B,
111  const CoefficientType& beta,
112  const ViewType3& C)
113  {
114  ::Tpetra::Details::Blas::Cublas::gemm (transA, transB, alpha, A, B, beta, C);
115  }
116 };
117 
119 template<class ViewType1,
120  class ViewType2,
121  class ViewType3,
122  class CoefficientType,
123  class IndexType>
124 struct Gemm<ViewType1, ViewType2, ViewType3, CoefficientType, IndexType,
125  true, false, true> {
126  static void
127  gemm (const char transA,
128  const char transB,
129  const CoefficientType& alpha,
130  const ViewType1& A,
131  const ViewType2& B,
132  const CoefficientType& beta,
133  const ViewType3& C)
134  {
135  ::Tpetra::Details::Blas::Mkl::gemm (transA, transB, alpha, A, B, beta, C);
136  }
137 };
138 
140 template<class ViewType1,
141  class ViewType2,
142  class ViewType3,
143  class CoefficientType,
144  class IndexType>
145 struct Gemm<ViewType1, ViewType2, ViewType3, CoefficientType, IndexType,
146  true, false, false> {
147  static void
148  gemm (const char transA,
149  const char transB,
150  const CoefficientType& alpha,
151  const ViewType1& A,
152  const ViewType2& B,
153  const CoefficientType& beta,
154  const ViewType3& C)
155  {
156  ::Tpetra::Details::Blas::Lib::gemm (transA, transB, alpha, A, B, beta, C);
157  }
158 };
159 
160 } // namespace Impl
161 
162 //
163 // SKIP TO HERE FOR THE ACTUAL INTERFACE
164 //
165 
174 template<class ViewType1,
175  class ViewType2,
176  class ViewType3,
177  class CoefficientType,
178  class IndexType = int>
179 void
180 gemm (const char transA,
181  const char transB,
182  const CoefficientType& alpha,
183  const ViewType1& A,
184  const ViewType2& B,
185  const CoefficientType& beta,
186  const ViewType3& C)
187 {
188  typedef Impl::Gemm<ViewType1, ViewType2, ViewType3,
189  CoefficientType, IndexType> impl_type;
190  impl_type::gemm (transA, transB, alpha, A, B, beta, C);
191 }
192 
193 } // namespace Blas
194 } // namespace Details
195 } // namespace Tpetra
196 
197 #endif // TPETRA_DETAILS_GEMM_HPP
Namespace Tpetra contains the class and methods constituting the Tpetra library.
Type traits for Tpetra&#39;s BLAS wrappers; an implementation detail of Tpetra::MultiVector.
Wrappers for the BLAS library&#39;s implementation of _GEMM; implementation detail of Tpetra::MultiVector...
Implementation details of Tpetra.
Implementation detail of Tpetra::MultiVector.
Default implementation of local (but process-global) GEMM (dense matrix-matrix multiply), for Tpetra::MultiVector.
Implementation of ::Tpetra::Details::Blas::gemm.
Implementation detail of Tpetra::MultiVector.