gemm.hpp
Go to the documentation of this file.
1 //===========================================================================
2 /*!
3  *
4  *
5  * \brief -
6  *
7  * \author O. Krause
8  * \date 2010
9  *
10  *
11  * \par Copyright 1995-2015 Shark Development Team
12  *
13  * <BR><HR>
14  * This file is part of Shark.
15  * <http://image.diku.dk/shark/>
16  *
17  * Shark is free software: you can redistribute it and/or modify
18  * it under the terms of the GNU Lesser General Public License as published
19  * by the Free Software Foundation, either version 3 of the License, or
20  * (at your option) any later version.
21  *
22  * Shark is distributed in the hope that it will be useful,
23  * but WITHOUT ANY WARRANTY; without even the implied warranty of
24  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
25  * GNU Lesser General Public License for more details.
26  *
27  * You should have received a copy of the GNU Lesser General Public License
28  * along with Shark. If not, see <http://www.gnu.org/licenses/>.
29  *
30  */
31 //===========================================================================
32 #ifndef SHARK_LINALG_BLAS_KERNELS_CBLAS_GEMM_HPP
33 #define SHARK_LINALG_BLAS_KERNELS_CBLAS_GEMM_HPP
34 
35 #include "cblas_inc.hpp"
36 
37 namespace shark { namespace blas { namespace bindings {
38 
39 inline void gemm(
40  CBLAS_ORDER const Order, CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB,
41  int M, int N, int K,
42  float alpha, float const *A, int lda,
43  float const *B, int ldb,
44  float beta, float *C, int ldc
45 ){
46  cblas_sgemm(
47  Order, TransA, TransB,
48  M, N, K,
49  alpha, A, lda,
50  B, ldb,
51  beta, C, ldc
52  );
53 }
54 
55 inline void gemm(
56  CBLAS_ORDER const Order, CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB,
57  int M, int N, int K,
58  double alpha, double const *A, int lda,
59  double const *B, int ldb,
60  double beta, double *C, int ldc
61 ){
62  cblas_dgemm(
63  Order, TransA, TransB,
64  M, N, K,
65  alpha,
66  A, lda,
67  B, ldb,
68  beta,
69  C, ldc
70  );
71 }
72 
73 inline void gemm(
74  CBLAS_ORDER const Order, CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB,
75  int M, int N, int K,
76  float alpha,
77  std::complex<float> const *A, int lda,
78  std::complex<float> const *B, int ldb,
79  float beta,
80  std::complex<float>* C, int ldc
81 ) {
82  std::complex<float> alphaArg(alpha,0);
83  std::complex<float> betaArg(beta,0);
84  cblas_cgemm(
85  Order, TransA, TransB,
86  M, N, K,
87  reinterpret_cast<cblas_float_complex_type const *>(&alphaArg),
88  reinterpret_cast<cblas_float_complex_type const *>(A), lda,
89  reinterpret_cast<cblas_float_complex_type const *>(B), ldb,
90  reinterpret_cast<cblas_float_complex_type const *>(&betaArg),
91  reinterpret_cast<cblas_float_complex_type *>(C), ldc
92  );
93 }
94 
95 inline void gemm(
96  CBLAS_ORDER const Order, CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB,
97  int M, int N, int K,
98  double alpha,
99  std::complex<double> const *A, int lda,
100  std::complex<double> const *B, int ldb,
101  double beta,
102  std::complex<double>* C, int ldc
103 ) {
104  std::complex<double> alphaArg(alpha,0);
105  std::complex<double> betaArg(beta,0);
106  cblas_zgemm(
107  Order, TransA, TransB,
108  M, N, K,
109  reinterpret_cast<cblas_double_complex_type const *>(&alphaArg),
110  reinterpret_cast<cblas_double_complex_type const *>(A), lda,
111  reinterpret_cast<cblas_double_complex_type const *>(B), ldb,
112  reinterpret_cast<cblas_double_complex_type const *>(&betaArg),
113  reinterpret_cast<cblas_double_complex_type *>(C), ldc
114  );
115 }
116 
117 // C <- alpha * A * B + beta * C
118 template <typename MatrA, typename MatrB, typename MatrC>
119 void gemm(
120  matrix_expression<MatrA> const &matA,
121  matrix_expression<MatrB> const &matB,
122  matrix_expression<MatrC>& matC,
123  typename MatrC::value_type alpha,
124  boost::mpl::true_
125 ) {
126  SIZE_CHECK(matA().size1() == matC().size1());
127  SIZE_CHECK(matB().size2() == matC().size2());
128  SIZE_CHECK(matA().size2()== matB().size1());
129 
130  CBLAS_TRANSPOSE transA = traits::same_orientation(matA,matC)?CblasNoTrans:CblasTrans;
131  CBLAS_TRANSPOSE transB = traits::same_orientation(matB,matC)?CblasNoTrans:CblasTrans;
132  std::size_t m = matC().size1();
133  std::size_t n = matC().size2();
134  std::size_t k = matA().size2();
135  CBLAS_ORDER stor_ord = (CBLAS_ORDER) storage_order<typename MatrC::orientation >::value;
136 
137  gemm(stor_ord, transA, transB, (int)m, (int)n, (int)k, alpha,
138  traits::storage(matA()),
139  traits::leading_dimension(matA()),
140  traits::storage(matB()),
141  traits::leading_dimension(matB()),
142  typename MatrC::value_type(1),
143  traits::storage(matC()),
144  traits::leading_dimension(matC())
145  );
146 }
147 
148 
149 template<class Storage1, class Storage2, class Storage3, class T1, class T2, class T3>
150 struct optimized_gemm_detail{
151  typedef boost::mpl::false_ type;
152 };
153 template<>
154 struct optimized_gemm_detail<
155  dense_tag, dense_tag, dense_tag,
156  double, double, double
157 >{
158  typedef boost::mpl::true_ type;
159 };
160 template<>
161 struct optimized_gemm_detail<
162  dense_tag, dense_tag, dense_tag,
163  float, float, float
164 >{
165  typedef boost::mpl::true_ type;
166 };
167 
168 template<>
169 struct optimized_gemm_detail<
170  dense_tag, dense_tag, dense_tag,
171  std::complex<double>, std::complex<double>, std::complex<double>
172 >{
173  typedef boost::mpl::true_ type;
174 };
175 template<>
176 struct optimized_gemm_detail<
177  dense_tag, dense_tag, dense_tag,
178  std::complex<float>, std::complex<float>, std::complex<float>
179 >{
180  typedef boost::mpl::true_ type;
181 };
182 
183 template<class M1, class M2, class M3>
184 struct has_optimized_gemm
185 : public optimized_gemm_detail<
186  typename M1::storage_category,
187  typename M2::storage_category,
188  typename M3::storage_category,
189  typename M1::value_type,
190  typename M2::value_type,
191  typename M3::value_type
192 >{};
193 
194 }}}
195 
196 #endif