Tpetra parallel linear algebra  Version of the Day
Tpetra_Details_cublasGemm.cpp
1 /*
2 //@HEADER
3 // ************************************************************************
4 //
5 // Kokkos: Node API and Parallel Node Kernels
6 // Copyright (2008) Sandia Corporation
7 //
8 // Under the terms of Contract DE-AC04-94AL85000 with Sandia Corporation,
9 // the U.S. Government retains certain rights in this software.
10 //
11 // Redistribution and use in source and binary forms, with or without
12 // modification, are permitted provided that the following conditions are
13 // met:
14 //
15 // 1. Redistributions of source code must retain the above copyright
16 // notice, this list of conditions and the following disclaimer.
17 //
18 // 2. Redistributions in binary form must reproduce the above copyright
19 // notice, this list of conditions and the following disclaimer in the
20 // documentation and/or other materials provided with the distribution.
21 //
22 // 3. Neither the name of the Corporation nor the names of the
23 // contributors may be used to endorse or promote products derived from
24 // this software without specific prior written permission.
25 //
26 // THIS SOFTWARE IS PROVIDED BY SANDIA CORPORATION "AS IS" AND ANY
27 // EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
28 // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
29 // PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL SANDIA CORPORATION OR THE
30 // CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
31 // EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
32 // PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
33 // PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
34 // LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
35 // NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
36 // SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
37 //
38 // Questions? Contact Michael A. Heroux (maherou@sandia.gov)
39 //
40 // ************************************************************************
41 //@HEADER
42 */
43 
45 #include "Kokkos_Macros.hpp"
46 #ifdef KOKKOS_ENABLE_CUDA
47 # include <cublas.h>
48 #endif // KOKKOS_ENABLE_CUDA
49 #include <sstream>
50 #include <stdexcept>
51 
52 namespace Tpetra {
53 namespace Details {
54 namespace Blas {
55 namespace Cublas {
56 namespace Impl {
57 
58 void
59 cgemm (const char transA,
60  const char transB,
61  const int m,
62  const int n,
63  const int k,
64  const Kokkos::complex<float>& alpha,
65  const Kokkos::complex<float> A[],
66  const int lda,
67  const Kokkos::complex<float> B[],
68  const int ldb,
69  const Kokkos::complex<float>& beta,
70  Kokkos::complex<float> C[],
71  const int ldc)
72 {
73 #ifdef KOKKOS_ENABLE_CUDA
74  const cuComplex alpha_c = make_cuFloatComplex (alpha.real (), alpha.imag ());
75  const cuComplex beta_c = make_cuFloatComplex (beta.real (), beta.imag ());
76  cublasCgemm (transA, transB,
77  m, n, k,
78  alpha_c, reinterpret_cast<const cuComplex*> (A), lda,
79  reinterpret_cast<const cuComplex*> (B), ldb,
80  beta_c, reinterpret_cast<cuComplex*> (C), ldc);
81  cublasStatus info = cublasGetError ();
82  if (info != CUBLAS_STATUS_SUCCESS) {
83  std::ostringstream err;
84  err << "cublasCgemm failed with status " << info << ".";
85  throw std::runtime_error (err.str ());
86  }
87 #else // NOT KOKKOS_ENABLE_CUDA
88  throw std::runtime_error ("You must enable CUDA in your Trilinos build in "
89  "order to invoke cuBLAS functions in Tpetra.");
90 #endif // KOKKOS_ENABLE_CUDA
91 }
92 
93 void
94 dgemm (const char char_transA,
95  const char char_transB,
96  const int m,
97  const int n,
98  const int k,
99  const double alpha,
100  const double A[],
101  const int lda,
102  const double B[],
103  const int ldb,
104  const double beta,
105  double C[],
106  const int ldc)
107 {
108 #ifdef KOKKOS_ENABLE_CUDA
109  ::cublasDgemm (char_transA, char_transB, m, n, k,
110  alpha, A, lda, B, ldb, beta, C, ldc);
111  cublasStatus info = cublasGetError ();
112  if (info != CUBLAS_STATUS_SUCCESS) {
113  std::ostringstream err;
114  err << "cublasDgemm failed with status " << info << ".";
115  throw std::runtime_error (err.str ());
116  }
117 #else // NOT KOKKOS_ENABLE_CUDA
118  throw std::runtime_error ("You must enable CUDA in your Trilinos build in "
119  "order to invoke cuBLAS functions in Tpetra.");
120 #endif // KOKKOS_ENABLE_CUDA
121 }
122 
123 void
124 sgemm (const char char_transA,
125  const char char_transB,
126  const int m,
127  const int n,
128  const int k,
129  const float alpha,
130  const float A[],
131  const int lda,
132  const float B[],
133  const int ldb,
134  const float beta,
135  float C[],
136  const int ldc)
137 {
138 #ifdef KOKKOS_ENABLE_CUDA
139  ::cublasSgemm (char_transA, char_transB, m, n, k,
140  alpha, A, lda, B, ldb, beta, C, ldc);
141  cublasStatus info = cublasGetError ();
142  if (info != CUBLAS_STATUS_SUCCESS) {
143  std::ostringstream err;
144  err << "cublasSgemm failed with status " << info << ".";
145  throw std::runtime_error (err.str ());
146  }
147 #else // NOT KOKKOS_ENABLE_CUDA
148  throw std::runtime_error ("You must enable CUDA in your Trilinos build in "
149  "order to invoke cuBLAS functions in Tpetra.");
150 #endif // KOKKOS_ENABLE_CUDA
151 }
152 
153 void
154 zgemm (const char transA,
155  const char transB,
156  const int m,
157  const int n,
158  const int k,
159  const Kokkos::complex<double>& alpha,
160  const Kokkos::complex<double> A[],
161  const int lda,
162  const Kokkos::complex<double> B[],
163  const int ldb,
164  const Kokkos::complex<double>& beta,
165  Kokkos::complex<double> C[],
166  const int ldc)
167 {
168 #ifdef KOKKOS_ENABLE_CUDA
169  const cuDoubleComplex alpha_c =
170  make_cuDoubleComplex (alpha.real (), alpha.imag ());
171  const cuDoubleComplex beta_c =
172  make_cuDoubleComplex (beta.real (), beta.imag ());
173  cublasZgemm (transA, transB,
174  m, n, k,
175  alpha_c, reinterpret_cast<const cuDoubleComplex*> (A), lda,
176  reinterpret_cast<const cuDoubleComplex*> (B), ldb,
177  beta_c, reinterpret_cast<cuDoubleComplex*> (C), ldc);
178  cublasStatus info = cublasGetError ();
179  if (info != CUBLAS_STATUS_SUCCESS) {
180  std::ostringstream err;
181  err << "cublasCgemm failed with status " << info << ".";
182  throw std::runtime_error (err.str ());
183  }
184 #else // NOT KOKKOS_ENABLE_CUDA
185  throw std::runtime_error ("You must enable CUDA in your Trilinos build in "
186  "order to invoke cuBLAS functions in Tpetra.");
187 #endif // KOKKOS_ENABLE_CUDA
188 }
189 
190 } // namespace Impl
191 } // namespace Cublas
192 } // namespace Blas
193 } // namespace Details
194 } // namespace Tpetra
Namespace Tpetra contains the class and methods constituting the Tpetra library.
Implementation details of Tpetra.
Implementation detail of Tpetra::MultiVector.