trmm.hpp
Go to the documentation of this file.
1 //===========================================================================
2 /*!
3  *
4  *
5  * \brief -
6  *
7  * \author O. Krause
8  * \date 2010
9  *
10  *
11  * \par Copyright 1995-2015 Shark Development Team
12  *
13  * <BR><HR>
14  * This file is part of Shark.
15  * <http://image.diku.dk/shark/>
16  *
17  * Shark is free software: you can redistribute it and/or modify
18  * it under the terms of the GNU Lesser General Public License as published
19  * by the Free Software Foundation, either version 3 of the License, or
20  * (at your option) any later version.
21  *
22  * Shark is distributed in the hope that it will be useful,
23  * but WITHOUT ANY WARRANTY; without even the implied warranty of
24  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
25  * GNU Lesser General Public License for more details.
26  *
27  * You should have received a copy of the GNU Lesser General Public License
28  * along with Shark. If not, see <http://www.gnu.org/licenses/>.
29  *
30  */
31 //===========================================================================
32 #ifndef SHARK_LINALG_BLAS_KERNELS_CBLAS_TRMM_HPP
33 #define SHARK_LINALG_BLAS_KERNELS_CBLAS_TRMM_HPP
34 
35 #include "cblas_inc.hpp"
36 #include "../../matrix_proxy.hpp"
37 #include "../../vector_expression.hpp"
38 #include <boost/mpl/bool.hpp>
39 
40 namespace shark {namespace blas {namespace bindings {
41 
42 inline void trmm(
43  CBLAS_ORDER const order,
44  CBLAS_SIDE const side,
45  CBLAS_UPLO const uplo,
46  CBLAS_TRANSPOSE const transA,
47  CBLAS_DIAG const unit,
48  int const M,
49  int const N,
50  float const *A, int const lda,
51  float* B, int const incB
52 ) {
53  cblas_strmm(order, side, uplo, transA, unit, M, N,
54  1.0,
55  A, lda,
56  B, incB
57  );
58 }
59 
60 inline void trmm(
61  CBLAS_ORDER const order,
62  CBLAS_SIDE const side,
63  CBLAS_UPLO const uplo,
64  CBLAS_TRANSPOSE const transA,
65  CBLAS_DIAG const unit,
66  int const M,
67  int const N,
68  double const *A, int const lda,
69  double* B, int const incB
70 ) {
71  cblas_dtrmm(order, side, uplo, transA, unit, M, N,
72  1.0,
73  A, lda,
74  B, incB
75  );
76 }
77 
78 
79 inline void trmm(
80  CBLAS_ORDER const order,
81  CBLAS_SIDE const side,
82  CBLAS_UPLO const uplo,
83  CBLAS_TRANSPOSE const transA,
84  CBLAS_DIAG const unit,
85  int const M,
86  int const N,
87  std::complex<float> const *A, int const lda,
88  std::complex<float>* B, int const incB
89 ) {
90  std::complex<float> alpha = 1.0;
91  cblas_ctrmm(order, side, uplo, transA, unit, M, N,
92  reinterpret_cast<cblas_float_complex_type const *>(&alpha),
93  reinterpret_cast<cblas_float_complex_type const *>(A), lda,
94  reinterpret_cast<cblas_float_complex_type *>(B), incB
95  );
96 }
97 
98 inline void trmm(
99  CBLAS_ORDER const order,
100  CBLAS_SIDE const side,
101  CBLAS_UPLO const uplo,
102  CBLAS_TRANSPOSE const transA,
103  CBLAS_DIAG const unit,
104  int const M,
105  int const N,
106  std::complex<double> const *A, int const lda,
107  std::complex<double>* B, int const incB
108 ) {
109  std::complex<double> alpha = 1.0;
110  cblas_ztrmm(order, side, uplo, transA, unit, M, N,
111  reinterpret_cast<cblas_double_complex_type const *>(&alpha),
112  reinterpret_cast<cblas_double_complex_type const *>(A), lda,
113  reinterpret_cast<cblas_double_complex_type *>(B), incB
114  );
115 }
116 
117 template <bool upper, bool unit, typename MatA, typename MatB>
118 void trmm(
119  matrix_expression<MatA> const& A,
120  matrix_expression<MatB>& B,
121  boost::mpl::true_
122 ){
123  SIZE_CHECK(A().size1() == A().size2());
124  SIZE_CHECK(A().size2() == B().size1());
125  std::size_t n = A().size1();
126  std::size_t m = B().size2();
127  CBLAS_DIAG cblasUnit = unit?CblasUnit:CblasNonUnit;
128  CBLAS_UPLO cblasUplo = upper?CblasUpper:CblasLower;
129  CBLAS_ORDER stor_ord= (CBLAS_ORDER)storage_order<typename MatA::orientation>::value;
130  CBLAS_TRANSPOSE trans=CblasNoTrans;
131 
132  //special case: MatA and MatB do not have same storage order. in this case compute as
133  //AB->B^TA^T where transpose of B is done implicitely by exchanging storage order
134  CBLAS_ORDER stor_ordB= (CBLAS_ORDER)storage_order<typename MatB::orientation>::value;
135  if(stor_ord != stor_ordB){
136  trans = CblasTrans;
137  cblasUplo= upper?CblasLower:CblasUpper;
138  }
139 
140  trmm(stor_ordB, CblasLeft, cblasUplo, trans, cblasUnit,
141  (int)n, int(m),
142  traits::storage(A),
143  traits::leading_dimension(A),
144  traits::storage(B),
145  traits::leading_dimension(B)
146  );
147 }
148 
149 template<class Storage1, class Storage2, class T1, class T2>
150 struct optimized_trmm_detail{
151  typedef boost::mpl::false_ type;
152 };
153 template<>
154 struct optimized_trmm_detail<
155  dense_tag, dense_tag,
156  double, double
157 >{
158  typedef boost::mpl::true_ type;
159 };
160 template<>
161 struct optimized_trmm_detail<
162  dense_tag, dense_tag,
163  float, float
164 >{
165  typedef boost::mpl::true_ type;
166 };
167 
168 template<>
169 struct optimized_trmm_detail<
170  dense_tag, dense_tag,
171  std::complex<double>, std::complex<double>
172 >{
173  typedef boost::mpl::true_ type;
174 };
175 template<>
176 struct optimized_trmm_detail<
177  dense_tag, dense_tag,
178  std::complex<float>, std::complex<float>
179 >{
180  typedef boost::mpl::true_ type;
181 };
182 
183 template<class M1, class M2>
184 struct has_optimized_trmm
185 : public optimized_trmm_detail<
186  typename M1::storage_category,
187  typename M2::storage_category,
188  typename M1::value_type,
189  typename M2::value_type
190 >{};
191 
192 }}}
193 #endif