gemm.hpp
Go to the documentation of this file.
1 /*!
2  *
3  *
4  * \brief -
5  *
6  * \author O. Krause
7  * \date 2010
8  *
9  *
10  * \par Copyright 1995-2015 Shark Development Team
11  *
12  * <BR><HR>
13  * This file is part of Shark.
14  * <http://image.diku.dk/shark/>
15  *
16  * Shark is free software: you can redistribute it and/or modify
17  * it under the terms of the GNU Lesser General Public License as published
18  * by the Free Software Foundation, either version 3 of the License, or
19  * (at your option) any later version.
20  *
21  * Shark is distributed in the hope that it will be useful,
22  * but WITHOUT ANY WARRANTY; without even the implied warranty of
23  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
24  * GNU Lesser General Public License for more details.
25  *
26  * You should have received a copy of the GNU Lesser General Public License
27  * along with Shark. If not, see <http://www.gnu.org/licenses/>.
28  *
29  */
30 
31 #ifndef SHARK_LINALG_BLAS_KERNELS_DEFAULT_GEMM_HPP
32 #define SHARK_LINALG_BLAS_KERNELS_DEFAULT_GEMM_HPP
33 
34 #include "../gemv.hpp"
35 #include "../../matrix_proxy.hpp"
36 #include "../../vector.hpp"
37 #include <boost/mpl/bool.hpp>
38 
39 namespace shark { namespace blas { namespace bindings {
40 
41 //we dispatch gemm: A=B*C in the following like that:
42 //all orientations of A,B,C
43 //iterator category of B and C. We assume A to have a meaningful storage category
44 
45 // basic dispatching towards the kernels works in categories.
46 // We explain it here because the implementation needs to be inverted as reducing one case to another
47 // requires that the resulting case has already been implemented. Thus the most general cases are at the end of the file.
48 //
49 // 1. we dispatch for the orientation of the result using the relation A=B*C <=> A^T = C^T B^T
50 // thus we can assume for all compute kernels that the first argument is row_major.
51 // 2. if B is row_major as well we can cast the computation in terms of matrix-vector products
52 // computing A row by row (note that we use a specialised kernel for all-sparse)
53 // 3. If B is column_major we can dispatch as in the following:
54 // 3.1 if B is sparse, transpose B in memory. This is a bit of memory overhead but is often fast (and easy)
55 // 3.2 else cast the computation in terms of an outer product if the C is row_major
56 // 3.3 for B and C column major there are specialised kernels for every combination
57 
58 
59 //general case: result and first argument row_major (2.)
60 //=> compute as a sequence of matrix-vector products over the rows of the first argument
61 template<class M, class E1, class E2, class Orientation2,class Tag1,class Tag2>
62 void gemm_impl(
63  matrix_expression<E1> const& e1,
64  matrix_expression<E2> const& e2,
65  matrix_expression<M>& m,
66  typename M::value_type alpha,
67  row_major, row_major, Orientation2,
68  Tag1, Tag2
69 ) {
70  for (std::size_t i = 0; i != e1().size1(); ++i) {
71  matrix_row<M> mat_row(m(),i);
72  kernels::gemv(trans(e2),row(e1,i),mat_row,alpha);
73  }
74 }
75 
76 //case: sparse column_major first argument (3.1)
77 //=> transpose in memory
78 template<class M, class E1, class E2, class Orientation, class Tag>
79 void gemm_impl(
80  matrix_expression<E1> const& e1,
81  matrix_expression<E2> const& e2,
82  matrix_expression<M>& m,
83  typename M::value_type alpha,
84  row_major, column_major, Orientation o,
85  sparse_bidirectional_iterator_tag t1, Tag t2
86 ) {
87  typename transposed_matrix_temporary<E1>::type e1_trans(e1);
88  gemm_impl(e1_trans,e2,m,alpha,row_major(),row_major(),o,t1,t2);
89 }
90 
91 //case: result and second argument row_major, first argument dense column major (3.2)
92 //=> compute as a sequence of outer products.
93 // Note that this is likely to be slow if E2 is sparse and the result is also sparse. However choosing
94 // M as sparse is stupid in most cases.
95 template<class M, class E1, class E2,class Tag>
96 void gemm_impl(
97  matrix_expression<E1> const& e1,
98  matrix_expression<E2> const& e2,
99  matrix_expression<M>& m,
100  typename M::value_type alpha,
101  row_major, column_major, row_major,
102  dense_random_access_iterator_tag, Tag
103 ) {
104  for (std::size_t j = 0; j != e1().size2(); ++j) {
105  noalias(m) += alpha * outer_prod(column(e1,j),row(e2,j));
106  }
107 }
108 
109 //special case of all row-major for sparse matrices
110 template<class M, class E1, class E2>
111 void gemm_impl(
112  matrix_expression<E1> const& e1,
113  matrix_expression<E2> const& e2,
114  matrix_expression<M>& m,
115  typename M::value_type alpha,
116  row_major, row_major, row_major,
117  sparse_bidirectional_iterator_tag, sparse_bidirectional_iterator_tag
118 ) {
119  typedef typename M::value_type value_type;
120  value_type zero = value_type();
121  vector<value_type> temporary(e2().size2(), zero);
122  for (std::size_t i = 0; i != e1().size1(); ++i) {
123  kernels::gemv(trans(e2),row(e1,i),temporary,alpha);
124  for (std::size_t j = 0; j != temporary.size(); ++ j) {
125  if (temporary(j) != zero) {
126  m()(i, j) += temporary(j);//fixme: better use something like insert
127  temporary(j) = zero;
128  }
129  }
130  }
131 }
132 
133 
134 
135 
136 // case 3.3
137 //now we only need to handle the case that E1 and E2 are column major and M row_major. This
138 // is a special case for all matrix types (except sparse column_major E1)
139 
140 //dense-sparse
141 template<class M, class E1, class E2>
142 void gemm_impl(
143  matrix_expression<E1> const& e1,
144  matrix_expression<E2> const& e2,
145  matrix_expression<M>& m,
146  typename M::value_type alpha,
147  row_major, column_major, column_major,
148  dense_random_access_iterator_tag, sparse_bidirectional_iterator_tag
149 ) {
150  //compute the product row-wise
151  for (std::size_t i = 0; i != m().size1(); ++i) {
152  matrix_row<M> mat_row(m(),i);
153  kernels::gemv(trans(e2),row(e1,i),mat_row,alpha);
154  }
155 }
156 
157 //dense-dense
158 template<class M, class E1, class E2>
159 void gemm_impl(
160  matrix_expression<E1> const& e1,
161  matrix_expression<E2> const& e2,
162  matrix_expression<M>& m,
163  typename M::value_type alpha,
164  row_major r, column_major, column_major,
165  dense_random_access_iterator_tag t, dense_random_access_iterator_tag
166 ) {
167  //compute blockwise and write the transposed block.
168  std::size_t blockSize = 16;
169  typedef typename M::value_type value_type;
170  typedef typename matrix_temporary<M>::type BlockStorage;
171  BlockStorage blockStorage(blockSize,blockSize);
172 
173  typedef typename M::size_type size_type;
174  size_type size1 = m().size1();
175  size_type size2 = m().size2();
176  for (size_type i = 0; i < size1; i+= blockSize){
177  for (size_type j = 0; j < size2; j+= blockSize){
178  std::size_t blockSizei = std::min(blockSize,size1-i);
179  std::size_t blockSizej = std::min(blockSize,size2-j);
180  matrix_range<matrix<value_type> > transBlock=subrange(blockStorage,0,blockSizej,0,blockSizei);
181  transBlock.clear();
182  //reduce to all row-major case by using
183  //A_ij=B^iC_j <=> A_ij^T = (C_j)^T (B^i)^T
184  gemm_impl(
185  trans(columns(e2,j,j+blockSizej)),
186  trans(rows(e1,i,i+blockSizei)),
187  transBlock,alpha,
188  r,r,r,//all row-major
189  t,t //both targets are dense
190  );
191  //write transposed block to the matrix
192  noalias(subrange(m,i,i+blockSizei,j,j+blockSizej))+=trans(transBlock);
193  }
194  }
195 }
196 
197 //general case: column major result case (1.0)
198 //=> transformed to row_major using A=B*C <=> A^T = C^T B^T
199 template<class M, class E1, class E2, class Orientation1, class Orientation2, class Tag1, class Tag2>
200 void gemm_impl(
201  matrix_expression<E1> const& e1,
202  matrix_expression<E2> const& e2,
203  matrix_expression<M>& m,
204  typename M::value_type alpha,
205  column_major, Orientation1, Orientation2,
206  Tag1, Tag2
207 ){
208  matrix_transpose<M> transposedM(m());
209  typedef typename Orientation1::transposed_orientation transpO1;
210  typedef typename Orientation2::transposed_orientation transpO2;
211  gemm_impl(trans(e2),trans(e1),transposedM,alpha,row_major(),transpO2(),transpO1(), Tag2(),Tag1());
212 }
213 
214 //dispatcher
215 template<class M, class E1, class E2>
216 void gemm(
217  matrix_expression<E1> const& e1,
218  matrix_expression<E2> const& e2,
219  matrix_expression<M>& m,
220  typename M::value_type alpha,
221  boost::mpl::false_
222 ) {
223  SIZE_CHECK(m().size1() == e1().size1());
224  SIZE_CHECK(m().size2() == e2().size2());
225 
226  typedef typename M::orientation ResultOrientation;
227  typedef typename E1::orientation E1Orientation;
228  typedef typename E2::orientation E2Orientation;
229  typedef typename major_iterator<E1>::type::iterator_category E1Category;
230  typedef typename major_iterator<E2>::type::iterator_category E2Category;
231 
232  gemm_impl(e1, e2, m,alpha,
233  ResultOrientation(),E1Orientation(),E2Orientation(),
234  E1Category(),E2Category()
235  );
236 }
237 
238 }}}
239 
240 #endif