trsm.hpp
Go to the documentation of this file.
1 /*!
2  *
3  *
4  * \brief -
5  *
6  * \author O. Krause
7  * \date 2011
8  *
9  *
10  * \par Copyright 1995-2015 Shark Development Team
11  *
12  * <BR><HR>
13  * This file is part of Shark.
14  * <http://image.diku.dk/shark/>
15  *
16  * Shark is free software: you can redistribute it and/or modify
17  * it under the terms of the GNU Lesser General Public License as published
18  * by the Free Software Foundation, either version 3 of the License, or
19  * (at your option) any later version.
20  *
21  * Shark is distributed in the hope that it will be useful,
22  * but WITHOUT ANY WARRANTY; without even the implied warranty of
23  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
24  * GNU Lesser General Public License for more details.
25  *
26  * You should have received a copy of the GNU Lesser General Public License
27  * along with Shark. If not, see <http://www.gnu.org/licenses/>.
28  *
29  */
30 
31 #ifndef SHARK_LINALG_BLAS_KERNELS_CBLAS_TRSM_HPP
32 #define SHARK_LINALG_BLAS_KERNELS_CBLAS_TRSM_HPP
33 
34 #include "cblas_inc.hpp"
35 
36 ///solves systems of triangular matrices
37 
38 namespace shark {namespace blas {namespace bindings {
39 inline void trsm(
40  CBLAS_ORDER order, CBLAS_UPLO uplo,CBLAS_TRANSPOSE transA,
41  CBLAS_SIDE side, CBLAS_DIAG unit,
42  int n, int nRHS,
43  float const *A, int lda, float *B, int ldb
44 ) {
45  cblas_strsm(order, side, uplo, transA, unit,n, nRHS, 1.0, A, lda, B, ldb);
46 }
47 
48 inline void trsm(
49  CBLAS_ORDER order, CBLAS_UPLO uplo,CBLAS_TRANSPOSE transA,
50  CBLAS_SIDE side, CBLAS_DIAG unit,
51  int n, int nRHS,
52  double const *A, int lda, double *B, int ldb
53 ) {
54  cblas_dtrsm(order, side, uplo, transA, unit,n, nRHS, 1.0, A, lda, B, ldb);
55 }
56 
57 inline void trsm(
58  CBLAS_ORDER order, CBLAS_UPLO uplo,CBLAS_TRANSPOSE transA,
59  CBLAS_SIDE side, CBLAS_DIAG unit,
60  int n, int nRHS,
61  std::complex<float> const *A, int lda, std::complex<float> *B, int ldb
62 ) {
63  std::complex<float> alpha(1.0,0);
64  cblas_ctrsm(order, side, uplo, transA, unit,n, nRHS,
65  reinterpret_cast<cblas_float_complex_type const *>(&alpha),
66  reinterpret_cast<cblas_float_complex_type const *>(A), lda,
67  reinterpret_cast<cblas_float_complex_type *>(B), ldb);
68 }
69 inline void trsm(
70  CBLAS_ORDER order, CBLAS_UPLO uplo,CBLAS_TRANSPOSE transA,
71  CBLAS_SIDE side, CBLAS_DIAG unit,
72  int n, int nRHS,
73  std::complex<double> const *A, int lda, std::complex<double> *B, int ldb
74 ) {
75  std::complex<double> alpha(1.0,0);
76  cblas_ztrsm(order, side, uplo, transA, unit,n, nRHS,
77  reinterpret_cast<cblas_double_complex_type const *>(&alpha),
78  reinterpret_cast<cblas_double_complex_type const *>(A), lda,
79  reinterpret_cast<cblas_double_complex_type *>(B), ldb);
80 }
81 
82 // trsm(): solves A system of linear equations A * X = B
83 // when A is A triangular matrix
84 template <bool upper, bool unit,typename TriangularA, typename MatB>
85 void trsm(
86  matrix_expression<TriangularA> const &A,
87  matrix_expression<MatB> &B,
88  boost::mpl::true_
89 ){
90  SIZE_CHECK(A().size1() == A().size2());
91  SIZE_CHECK(A().size1() == B().size1());
92 
93  //orientation is defined by the second argument
94  CBLAS_ORDER const storOrd = (CBLAS_ORDER)storage_order<typename MatB::orientation>::value;
95  //if orientations do not match, wecan interpret this as transposing A
96  bool transposeA = !traits::same_orientation(A,B);
97 
98  CBLAS_DIAG cblasUnit = unit?CblasUnit:CblasNonUnit;
99  CBLAS_UPLO cblasUplo = (upper != transposeA)?CblasUpper:CblasLower;
100 
101 
102  CBLAS_TRANSPOSE transA = transposeA?CblasTrans:CblasNoTrans;
103 
104  int m = B().size1();
105  int nrhs = B().size2();
106 
107  trsm(storOrd, cblasUplo, transA, CblasLeft,cblasUnit, m, nrhs,
108  traits::storage(A),
109  traits::leading_dimension(A),
110  traits::storage(B),
111  traits::leading_dimension(B)
112  );
113 }
114 
115 template<class Storage1, class Storage2, class T1, class T2>
116 struct optimized_trsm_detail{
117  typedef boost::mpl::false_ type;
118 };
119 template<>
120 struct optimized_trsm_detail<
121  dense_tag, dense_tag,
122  double, double
123 >{
124  typedef boost::mpl::true_ type;
125 };
126 template<>
127 struct optimized_trsm_detail<
128  dense_tag, dense_tag,
129  float, float
130 >{
131  typedef boost::mpl::true_ type;
132 };
133 
134 template<>
135 struct optimized_trsm_detail<
136  dense_tag, dense_tag,
137  std::complex<double>, std::complex<double>
138 >{
139  typedef boost::mpl::true_ type;
140 };
141 template<>
142 struct optimized_trsm_detail<
143  dense_tag, dense_tag,
144  std::complex<float>, std::complex<float>
145 >{
146  typedef boost::mpl::true_ type;
147 };
148 
149 template<class M1, class M2>
150 struct has_optimized_trsm
151 : public optimized_trsm_detail<
152  typename M1::storage_category,
153  typename M2::storage_category,
154  typename M1::value_type,
155  typename M2::value_type
156 >{};
157 
158 }}}
159 #endif