32 #ifndef SHARK_LINALG_BLAS_KERNELS_CBLAS_TRMM_HPP 33 #define SHARK_LINALG_BLAS_KERNELS_CBLAS_TRMM_HPP 36 #include "../../matrix_proxy.hpp" 37 #include "../../vector_expression.hpp" 38 #include <boost/mpl/bool.hpp> 40 namespace shark {
namespace blas {
namespace bindings {
43 CBLAS_ORDER
const order,
44 CBLAS_SIDE
const side,
45 CBLAS_UPLO
const uplo,
46 CBLAS_TRANSPOSE
const transA,
47 CBLAS_DIAG
const unit,
50 float const *A,
int const lda,
51 float* B,
int const incB
53 cblas_strmm(order, side, uplo, transA, unit, M, N,
61 CBLAS_ORDER
const order,
62 CBLAS_SIDE
const side,
63 CBLAS_UPLO
const uplo,
64 CBLAS_TRANSPOSE
const transA,
65 CBLAS_DIAG
const unit,
68 double const *A,
int const lda,
69 double* B,
int const incB
71 cblas_dtrmm(order, side, uplo, transA, unit, M, N,
80 CBLAS_ORDER
const order,
81 CBLAS_SIDE
const side,
82 CBLAS_UPLO
const uplo,
83 CBLAS_TRANSPOSE
const transA,
84 CBLAS_DIAG
const unit,
87 std::complex<float>
const *A,
int const lda,
88 std::complex<float>* B,
int const incB
90 std::complex<float> alpha = 1.0;
91 cblas_ctrmm(order, side, uplo, transA, unit, M, N,
92 reinterpret_cast<cblas_float_complex_type const *>(&alpha),
93 reinterpret_cast<cblas_float_complex_type const *>(A), lda,
94 reinterpret_cast<cblas_float_complex_type *>(B), incB
99 CBLAS_ORDER
const order,
100 CBLAS_SIDE
const side,
101 CBLAS_UPLO
const uplo,
102 CBLAS_TRANSPOSE
const transA,
103 CBLAS_DIAG
const unit,
106 std::complex<double>
const *A,
int const lda,
107 std::complex<double>* B,
int const incB
109 std::complex<double> alpha = 1.0;
110 cblas_ztrmm(order, side, uplo, transA, unit, M, N,
111 reinterpret_cast<cblas_double_complex_type const *>(&alpha),
112 reinterpret_cast<cblas_double_complex_type const *>(A), lda,
113 reinterpret_cast<cblas_double_complex_type *>(B), incB
117 template <
bool upper,
bool unit,
typename MatA,
typename MatB>
119 matrix_expression<MatA>
const& A,
120 matrix_expression<MatB>& B,
125 std::size_t n = A().size1();
126 std::size_t m = B().size2();
127 CBLAS_DIAG cblasUnit = unit?CblasUnit:CblasNonUnit;
128 CBLAS_UPLO cblasUplo = upper?CblasUpper:CblasLower;
129 CBLAS_ORDER stor_ord= (CBLAS_ORDER)storage_order<typename MatA::orientation>::value;
130 CBLAS_TRANSPOSE
trans=CblasNoTrans;
134 CBLAS_ORDER stor_ordB= (CBLAS_ORDER)storage_order<typename MatB::orientation>::value;
135 if(stor_ord != stor_ordB){
137 cblasUplo= upper?CblasLower:CblasUpper;
140 trmm(stor_ordB, CblasLeft, cblasUplo, trans, cblasUnit,
143 traits::leading_dimension(A),
145 traits::leading_dimension(B)
149 template<
class Storage1,
class Storage2,
class T1,
class T2>
150 struct optimized_trmm_detail{
151 typedef boost::mpl::false_ type;
154 struct optimized_trmm_detail<
155 dense_tag, dense_tag,
158 typedef boost::mpl::true_ type;
161 struct optimized_trmm_detail<
162 dense_tag, dense_tag,
165 typedef boost::mpl::true_ type;
169 struct optimized_trmm_detail<
170 dense_tag, dense_tag,
171 std::complex<double>, std::complex<double>
173 typedef boost::mpl::true_ type;
176 struct optimized_trmm_detail<
177 dense_tag, dense_tag,
178 std::complex<float>, std::complex<float>
180 typedef boost::mpl::true_ type;
183 template<
class M1,
class M2>
184 struct has_optimized_trmm
185 :
public optimized_trmm_detail<
186 typename M1::storage_category,
187 typename M2::storage_category,
188 typename M1::value_type,
189 typename M2::value_type