operation.hpp
Go to the documentation of this file.
1 /*!
2  * \brief Some special matrix-products
3  *
4  * \author O. Krause
5  * \date 2013
6  *
7  *
8  * \par Copyright 1995-2015 Shark Development Team
9  *
10  * <BR><HR>
11  * This file is part of Shark.
12  * <http://image.diku.dk/shark/>
13  *
14  * Shark is free software: you can redistribute it and/or modify
15  * it under the terms of the GNU Lesser General Public License as published
16  * by the Free Software Foundation, either version 3 of the License, or
17  * (at your option) any later version.
18  *
19  * Shark is distributed in the hope that it will be useful,
20  * but WITHOUT ANY WARRANTY; without even the implied warranty of
21  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
22  * GNU Lesser General Public License for more details.
23  *
24  * You should have received a copy of the GNU Lesser General Public License
25  * along with Shark. If not, see <http://www.gnu.org/licenses/>.
26  *
27  */
28 #ifndef SHARK_LINALG_BLAS_OPERATION_HPP
29 #define SHARK_LINALG_BLAS_OPERATION_HPP
30 
31 #include "kernels/gemv.hpp"
32 #include "kernels/gemm.hpp"
33 #include "kernels/tpmv.hpp"
34 #include "kernels/trmv.hpp"
35 #include "kernels/trmm.hpp"
36 
37 namespace shark {
38 namespace blas {
39 
40 namespace detail{
41 
42 ///\brief Computes y=alpha*Ax or y += alpha*Ax
43 template<class ResultV, class M, class V>
44 void axpy_prod_impl(
45  matrix_expression<M> const& matrix,
46  vector_expression<V> const& vector,
47  vector_expression<ResultV>& result,
48  bool init,
49  typename ResultV::value_type alpha,
50  linear_structure
51 ) {
52 
53  if (init)
54  result().clear();
55 
56  kernels::gemv(matrix, vector, result,alpha);
57 }
58 ///\brief Computes y=alpha*Ax or y += alpha*Ax
59 template<class ResultV, class M, class V>
60 void axpy_prod_impl(
61  matrix_expression<M> const& matrix,
62  vector_expression<V> const& vector,
63  vector_expression<ResultV>& result,
64  bool init,
65  typename ResultV::value_type alpha,
66  packed_structure
67 ) {
68  if(init){
69  noalias(result) = vector;
70  kernels::tpmv(matrix, result);
71  result() *= alpha;
72  }else{
73  typename vector_temporary<V>::type temp(result);
74  noalias(result) = vector;
75  kernels::tpmv(matrix, result);
76  result() *= alpha;
77  noalias(result) += temp;
78  }
79 }
80 
81 }
82 
83 
84 ///\brief Computes y=alpha*Ax or y += alpha*Ax
85 template<class ResultV, class M, class V>
86 void axpy_prod(
87  matrix_expression<M> const& matrix,
88  vector_expression<V> const& vector,
90  bool init = true,
91  typename ResultV::value_type alpha = 1.0
92 ) {
93  SIZE_CHECK(matrix().size1()==result().size());
94  SIZE_CHECK(matrix().size2()==vector().size());
95 
96 
97  detail::axpy_prod_impl(matrix, vector, result,init, alpha,typename M::orientation());
98 }
99 
100 ////\brief Computes C=alpha*Ax or C += alpha*Ax
101 ///
102 ///This the dispatcher for temporary result proxies
103 template<class ResultV, class M, class V>
105  matrix_expression<M> const& matrix,
106  vector_expression<V> const& vector,
107  temporary_proxy<ResultV> result,
108  bool init = true,
109  typename ResultV::value_type alpha = 1.0
110 ) {
111  SIZE_CHECK(matrix().size1()==result.size());
112  SIZE_CHECK(matrix().size2()==vector().size());
113  axpy_prod(matrix,vector,static_cast<ResultV&>(result),init,alpha);
114 }
115 
116 ///\brief Computes y=alpha*xA or y += alpha*xA
117 template<class ResultV, class V, class M>
119  vector_expression<V> const& vector,
120  matrix_expression<M> const& matrix,
122  bool init = true,
123  typename ResultV::value_type alpha = 1.0
124 ) {
125  SIZE_CHECK(matrix().size2()==result().size());
126  SIZE_CHECK(matrix().size1()==vector().size());
127  axpy_prod(trans(matrix), vector, result,init,alpha);
128 }
129 
130 ////\brief Computes C=alpha*xA or C += alpha*xA
131 ///
132 ///This the dispatcher for temporary result proxies
133 template<class ResultV, class M, class V>
135  vector_expression<V> const& vector,
136  matrix_expression<M> const& matrix,
137  temporary_proxy<ResultV> result,
138  bool init = true,
139  typename ResultV::value_type alpha = 1.0
140 ) {
141  SIZE_CHECK(matrix().size2()==result.size());
142  SIZE_CHECK(matrix().size1()==vector().size());
143  axpy_prod(trans(matrix), vector, static_cast<ResultV&>(result),init,alpha);
144 }
145 
146 /// \brief Implements the matrix products m+=alpha * e1*e2 or m = alpha*e1*e2.
147 template<class M, class E1, class E2>
149  matrix_expression<E1> const& e1,
150  matrix_expression<E2> const& e2,
152  bool init = true,
153  typename M::value_type alpha = 1.0
154 ) {
155  SIZE_CHECK(m().size1() == e1().size1());
156  SIZE_CHECK(m().size2() == e2().size2());
157  SIZE_CHECK(e1().size2() == e2().size1());
158 
159  if (init)
160  m().clear();
161 
162  kernels::gemm(e1,e2,m,alpha);
163 }
164 
165 template<class M, class E1, class E2>
167  matrix_expression<E1> const& e1,
168  matrix_expression<E2> const& e2,
170  bool init = true,
171  typename M::value_type alpha = 1.0
172 ) {
173  axpy_prod(e1,e2,static_cast<M&>(m),init,alpha);
174 }
175 
176 /// \brief computes C= alpha*AA^T or C+=alpha* AA^T
177 template<class M, class E>
179  matrix_expression<E> const& A,
181  bool init = true,
182  typename M::value_type alpha = 1.0
183 ) {
184  SIZE_CHECK(m().size1() == A().size1());
185  SIZE_CHECK(m().size2() == m().size1());
186 
187  axpy_prod(A, trans(A), m,init, alpha);
188 }
189 
190 /// \brief computes C= alpha*AA^T or C+=alpha* AA^T
191 template<class M, class E>
193  matrix_expression<E> const& A,
195  bool init = 1.0,
196  typename M::value_type alpha = 1.0
197 ) {
198  symm_prod(A, static_cast<M&>(m),init, alpha);
199 }
200 
201 /// \brief Computes x=Ax for a triangular matrix A
202 ///
203 /// The first template argument governs the type
204 /// of triangular matrix: lower, upper, unit_lower and unit_upper.
205 ///
206 ///Example: triangular_prod<lower>(A,x);
207 template<class TriangularType, class MatrixA, class V>
211 ) {
212  kernels::trmv<TriangularType::is_upper, TriangularType::is_unit>(A, x);
213 }
214 
215 /// \brief Computes B=AB for a triangular matrix A and a dense matrix B in place
216 ///
217 /// The first template argument governs the type
218 /// of triangular matrix: lower, upper, unit_lower and unit_upper.
219 ///
220 ///Example: triangular_prod<lower>(A,B);
221 template<class TriangularType, class MatrixA, class MatB>
225 ) {
226  kernels::trmm<TriangularType::is_upper, TriangularType::is_unit>(A, B);
227 }
228 
229 /// \brief triangular prod for temporary left-hand side arguments
230 ///
231 /// Dispatches to the other versions of triangular_prod, see their documentation
232 template<class TriangularType, class MatrixA, class E>
236 ) {
237  triangular_prod<TriangularType>(A, static_cast<E&>(e));
238 }
239 
240 
241 }
242 }
243 
244 #endif