32 #ifndef SHARK_LINALG_BLAS_KERNELS_CBLAS_GEMM_HPP 33 #define SHARK_LINALG_BLAS_KERNELS_CBLAS_GEMM_HPP 37 namespace shark {
namespace blas {
namespace bindings {
40 CBLAS_ORDER
const Order, CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB,
42 float alpha,
float const *A,
int lda,
43 float const *B,
int ldb,
44 float beta,
float *C,
int ldc
47 Order, TransA, TransB,
56 CBLAS_ORDER
const Order, CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB,
58 double alpha,
double const *A,
int lda,
59 double const *B,
int ldb,
60 double beta,
double *C,
int ldc
63 Order, TransA, TransB,
74 CBLAS_ORDER
const Order, CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB,
77 std::complex<float>
const *A,
int lda,
78 std::complex<float>
const *B,
int ldb,
80 std::complex<float>* C,
int ldc
82 std::complex<float> alphaArg(alpha,0);
83 std::complex<float> betaArg(beta,0);
85 Order, TransA, TransB,
87 reinterpret_cast<cblas_float_complex_type const *>(&alphaArg),
88 reinterpret_cast<cblas_float_complex_type const *>(A), lda,
89 reinterpret_cast<cblas_float_complex_type const *>(B), ldb,
90 reinterpret_cast<cblas_float_complex_type const *>(&betaArg),
91 reinterpret_cast<cblas_float_complex_type *>(C), ldc
96 CBLAS_ORDER
const Order, CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB,
99 std::complex<double>
const *A,
int lda,
100 std::complex<double>
const *B,
int ldb,
102 std::complex<double>* C,
int ldc
104 std::complex<double> alphaArg(alpha,0);
105 std::complex<double> betaArg(beta,0);
107 Order, TransA, TransB,
109 reinterpret_cast<cblas_double_complex_type const *>(&alphaArg),
110 reinterpret_cast<cblas_double_complex_type const *>(A), lda,
111 reinterpret_cast<cblas_double_complex_type const *>(B), ldb,
112 reinterpret_cast<cblas_double_complex_type const *>(&betaArg),
113 reinterpret_cast<cblas_double_complex_type *>(C), ldc
118 template <
typename MatrA,
typename MatrB,
typename MatrC>
120 matrix_expression<MatrA>
const &matA,
121 matrix_expression<MatrB>
const &matB,
122 matrix_expression<MatrC>& matC,
123 typename MatrC::value_type alpha,
130 CBLAS_TRANSPOSE transA = traits::same_orientation(matA,matC)?CblasNoTrans:CblasTrans;
131 CBLAS_TRANSPOSE transB = traits::same_orientation(matB,matC)?CblasNoTrans:CblasTrans;
132 std::size_t m = matC().size1();
133 std::size_t n = matC().size2();
134 std::size_t k = matA().size2();
135 CBLAS_ORDER stor_ord = (CBLAS_ORDER) storage_order<typename MatrC::orientation >::value;
137 gemm(stor_ord, transA, transB, (
int)m, (
int)n, (
int)k, alpha,
138 traits::storage(matA()),
139 traits::leading_dimension(matA()),
140 traits::storage(matB()),
141 traits::leading_dimension(matB()),
142 typename MatrC::value_type(1),
143 traits::storage(matC()),
144 traits::leading_dimension(matC())
149 template<
class Storage1,
class Storage2,
class Storage3,
class T1,
class T2,
class T3>
150 struct optimized_gemm_detail{
151 typedef boost::mpl::false_ type;
154 struct optimized_gemm_detail<
155 dense_tag, dense_tag, dense_tag,
156 double, double, double
158 typedef boost::mpl::true_ type;
161 struct optimized_gemm_detail<
162 dense_tag, dense_tag, dense_tag,
165 typedef boost::mpl::true_ type;
169 struct optimized_gemm_detail<
170 dense_tag, dense_tag, dense_tag,
171 std::complex<double>, std::complex<double>, std::complex<double>
173 typedef boost::mpl::true_ type;
176 struct optimized_gemm_detail<
177 dense_tag, dense_tag, dense_tag,
178 std::complex<float>, std::complex<float>, std::complex<float>
180 typedef boost::mpl::true_ type;
183 template<
class M1,
class M2,
class M3>
184 struct has_optimized_gemm
185 :
public optimized_gemm_detail<
186 typename M1::storage_category,
187 typename M2::storage_category,
188 typename M3::storage_category,
189 typename M1::value_type,
190 typename M2::value_type,
191 typename M3::value_type