trsv.hpp
Go to the documentation of this file.
1 /*!
2  *
3  *
4  * \brief -
5  *
6  * \author O. Krause
7  * \date 2011
8  *
9  *
10  * \par Copyright 1995-2015 Shark Development Team
11  *
12  * <BR><HR>
13  * This file is part of Shark.
14  * <http://image.diku.dk/shark/>
15  *
16  * Shark is free software: you can redistribute it and/or modify
17  * it under the terms of the GNU Lesser General Public License as published
18  * by the Free Software Foundation, either version 3 of the License, or
19  * (at your option) any later version.
20  *
21  * Shark is distributed in the hope that it will be useful,
22  * but WITHOUT ANY WARRANTY; without even the implied warranty of
23  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
24  * GNU Lesser General Public License for more details.
25  *
26  * You should have received a copy of the GNU Lesser General Public License
27  * along with Shark. If not, see <http://www.gnu.org/licenses/>.
28  *
29  */
30 
31 #ifndef SHARK_LINALG_BLAS_KERNELS_CBLAS_TRSV_HPP
32 #define SHARK_LINALG_BLAS_KERNELS_CBLAS_TRSV_HPP
33 
34 #include "cblas_inc.hpp"
35 
36 ///solves systems of triangular matrices
37 
38 namespace shark {namespace blas{ namespace bindings {
39 inline void trsv(
40  CBLAS_ORDER order, CBLAS_UPLO uplo,
41  CBLAS_TRANSPOSE transA, CBLAS_DIAG unit,
42  int n,
43  float const *A, int lda, float *b, int strideX
44 ){
45  cblas_strsv(order, uplo, transA, unit,n, A, lda, b, strideX);
46 }
47 
48 inline void trsv(
49  CBLAS_ORDER order, CBLAS_UPLO uplo,
50  CBLAS_TRANSPOSE transA, CBLAS_DIAG unit,
51  int n,
52  double const *A, int lda, double *b, int strideX
53 ){
54  cblas_dtrsv(order, uplo, transA, unit,n, A, lda, b, strideX);
55 }
56 
57 inline void trsv(
58  CBLAS_ORDER order, CBLAS_UPLO uplo,
59  CBLAS_TRANSPOSE transA, CBLAS_DIAG unit,
60  int n,
61  std::complex<float> const *A, int lda, std::complex<float> *b, int strideX
62 ){
63  cblas_ctrsv(order, uplo, transA, unit,n,
64  reinterpret_cast<cblas_float_complex_type const *>(A), lda,
65  reinterpret_cast<cblas_float_complex_type *>(b), strideX);
66 }
67 inline void trsv(
68  CBLAS_ORDER order, CBLAS_UPLO uplo,
69  CBLAS_TRANSPOSE transA, CBLAS_DIAG unit,
70  int n,
71  std::complex<double> const *A, int lda, std::complex<double> *b, int strideX
72 ){
73  cblas_ztrsv(order, uplo, transA, unit,n,
74  reinterpret_cast<cblas_double_complex_type const *>(A), lda,
75  reinterpret_cast<cblas_double_complex_type *>(b), strideX);
76 }
77 
78 // trsv(): solves A system of linear equations A * x = b
79 // when A is A triangular matrix.
80 template <bool Upper,bool Unit,typename TriangularA, typename V>
81 void trsv(
82  matrix_expression<TriangularA> const &A,
83  vector_expression<V> &b,
84  boost::mpl::true_
85 ){
86  SIZE_CHECK(A().size1() == A().size2());
87  SIZE_CHECK(A().size1()== b().size());
88  CBLAS_DIAG cblasUnit = Unit?CblasUnit:CblasNonUnit;
89  CBLAS_ORDER const storOrd= (CBLAS_ORDER)storage_order<typename TriangularA::orientation>::value;
90  CBLAS_UPLO uplo = Upper?CblasUpper:CblasLower;
91 
92 
93  int const n = A().size1();
94 
95  trsv(storOrd, uplo, CblasNoTrans,cblasUnit, n,
96  traits::storage(A),
97  traits::leading_dimension(A),
98  traits::storage(b),
99  traits::stride(b)
100  );
101 }
102 
103 template<class Storage1, class Storage2, class T1, class T2>
104 struct optimized_trsv_detail{
105  typedef boost::mpl::false_ type;
106 };
107 template<>
108 struct optimized_trsv_detail<
109  dense_tag, dense_tag,
110  double, double
111 >{
112  typedef boost::mpl::true_ type;
113 };
114 template<>
115 struct optimized_trsv_detail<
116  dense_tag, dense_tag,
117  float, float
118 >{
119  typedef boost::mpl::true_ type;
120 };
121 
122 template<>
123 struct optimized_trsv_detail<
124  dense_tag, dense_tag,
125  std::complex<double>, std::complex<double>
126 >{
127  typedef boost::mpl::true_ type;
128 };
129 template<>
130 struct optimized_trsv_detail<
131  dense_tag, dense_tag,
132  std::complex<float>, std::complex<float>
133 >{
134  typedef boost::mpl::true_ type;
135 };
136 
137 template<class M, class V>
138 struct has_optimized_trsv
139 : public optimized_trsv_detail<
140  typename M::storage_category,
141  typename V::storage_category,
142  typename M::value_type,
143  typename V::value_type
144 >{};
145 
146 }}}
147 #endif