trmv.hpp
Go to the documentation of this file.
1 //===========================================================================
2 /*!
3  *
4  *
5  * \brief -
6  *
7  * \author O. Krause
8  * \date 2010
9  *
10  *
11  * \par Copyright 1995-2015 Shark Development Team
12  *
13  * <BR><HR>
14  * This file is part of Shark.
15  * <http://image.diku.dk/shark/>
16  *
17  * Shark is free software: you can redistribute it and/or modify
18  * it under the terms of the GNU Lesser General Public License as published
19  * by the Free Software Foundation, either version 3 of the License, or
20  * (at your option) any later version.
21  *
22  * Shark is distributed in the hope that it will be useful,
23  * but WITHOUT ANY WARRANTY; without even the implied warranty of
24  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
25  * GNU Lesser General Public License for more details.
26  *
27  * You should have received a copy of the GNU Lesser General Public License
28  * along with Shark. If not, see <http://www.gnu.org/licenses/>.
29  *
30  */
31 //===========================================================================
32 #ifndef SHARK_LINALG_BLAS_KERNELS_CBLAS_TRMV_HPP
33 #define SHARK_LINALG_BLAS_KERNELS_CBLAS_TRMV_HPP
34 
35 #include "cblas_inc.hpp"
36 #include "../../matrix_proxy.hpp"
37 #include "../../vector_expression.hpp"
38 #include <boost/mpl/bool.hpp>
39 
40 namespace shark {namespace blas {namespace bindings {
41 
42 inline void trmv(
43  CBLAS_ORDER const Order,
44  CBLAS_UPLO const uplo,
45  CBLAS_TRANSPOSE const transA,
46  CBLAS_DIAG const unit,
47  int const N,
48  float const *A, int const lda,
49  float* X, int const incX
50 ) {
51  cblas_strmv(Order, uplo, transA, unit, N,
52  A, lda,
53  X, incX
54  );
55 }
56 
57 inline void trmv(
58  CBLAS_ORDER const Order,
59  CBLAS_UPLO const uplo,
60  CBLAS_TRANSPOSE const transA,
61  CBLAS_DIAG const unit,
62  int const N,
63  double const *A, int const lda,
64  double* X, int const incX
65 ) {
66  cblas_dtrmv(Order, uplo, transA, unit, N,
67  A, lda,
68  X, incX
69  );
70 }
71 
72 
73 inline void trmv(
74  CBLAS_ORDER const Order,
75  CBLAS_UPLO const uplo,
76  CBLAS_TRANSPOSE const transA,
77  CBLAS_DIAG const unit,
78  int const N,
79  std::complex<float> const *A, int const lda,
80  std::complex<float>* X, int const incX
81 ) {
82  cblas_ctrmv(Order, uplo, transA, unit, N,
83  reinterpret_cast<cblas_float_complex_type const *>(A), lda,
84  reinterpret_cast<cblas_float_complex_type *>(X), incX
85  );
86 }
87 
88 inline void trmv(
89  CBLAS_ORDER const Order,
90  CBLAS_UPLO const uplo,
91  CBLAS_TRANSPOSE const transA,
92  CBLAS_DIAG const unit,
93  int const N,
94  std::complex<double> const *A, int const lda,
95  std::complex<double>* X, int const incX
96 ) {
97  cblas_ztrmv(Order, uplo, transA, unit, N,
98  reinterpret_cast<cblas_double_complex_type const *>(A), lda,
99  reinterpret_cast<cblas_double_complex_type *>(X), incX
100  );
101 }
102 
103 template <bool upper, bool unit, typename MatrA, typename VectorX>
104 void trmv(
105  matrix_expression<MatrA> const& A,
106  vector_expression<VectorX> &x,
107  boost::mpl::true_
108 ){
109  SIZE_CHECK(x().size() == A().size2());
110  SIZE_CHECK(A().size2() == A().size1());
111  std::size_t n = A().size1();
112  CBLAS_DIAG cblasUnit = unit?CblasUnit:CblasNonUnit;
113  CBLAS_UPLO cblasUplo = upper?CblasUpper:CblasLower;
114  CBLAS_ORDER stor_ord= (CBLAS_ORDER)storage_order<typename MatrA::orientation>::value;
115 
116  trmv(stor_ord, cblasUplo, CblasNoTrans, cblasUnit, (int)n,
117  traits::storage(A),
118  traits::leading_dimension(A),
119  traits::storage(x),
120  traits::stride(x)
121  );
122 }
123 
124 template<class Storage1, class Storage2, class T1, class T2>
125 struct optimized_trmv_detail{
126  typedef boost::mpl::false_ type;
127 };
128 template<>
129 struct optimized_trmv_detail<
130  dense_tag, dense_tag,
131  double, double
132 >{
133  typedef boost::mpl::true_ type;
134 };
135 template<>
136 struct optimized_trmv_detail<
137  dense_tag, dense_tag,
138  float, float
139 >{
140  typedef boost::mpl::true_ type;
141 };
142 
143 template<>
144 struct optimized_trmv_detail<
145  dense_tag, dense_tag,
146  std::complex<double>, std::complex<double>
147 >{
148  typedef boost::mpl::true_ type;
149 };
150 template<>
151 struct optimized_trmv_detail<
152  dense_tag, dense_tag,
153  std::complex<float>, std::complex<float>
154 >{
155  typedef boost::mpl::true_ type;
156 };
157 
158 template<class M1, class M2>
159 struct has_optimized_trmv
160 : public optimized_trmv_detail<
161  typename M1::storage_category,
162  typename M2::storage_category,
163  typename M1::value_type,
164  typename M2::value_type
165 >{};
166 
167 }}}
168 #endif