31 #ifndef SHARK_LINALG_BLAS_KERNELS_DEFAULT_GEMM_HPP 32 #define SHARK_LINALG_BLAS_KERNELS_DEFAULT_GEMM_HPP 34 #include "../gemv.hpp" 35 #include "../../matrix_proxy.hpp" 36 #include "../../vector.hpp" 37 #include <boost/mpl/bool.hpp> 39 namespace shark {
namespace blas {
namespace bindings {
61 template<
class M,
class E1,
class E2,
class Orientation2,
class Tag1,
class Tag2>
63 matrix_expression<E1>
const& e1,
64 matrix_expression<E2>
const& e2,
65 matrix_expression<M>& m,
66 typename M::value_type alpha,
67 row_major, row_major, Orientation2,
70 for (std::size_t i = 0; i != e1().size1(); ++i) {
71 matrix_row<M> mat_row(m(),i);
78 template<
class M,
class E1,
class E2,
class Orientation,
class Tag>
80 matrix_expression<E1>
const& e1,
81 matrix_expression<E2>
const& e2,
82 matrix_expression<M>& m,
83 typename M::value_type alpha,
84 row_major, column_major, Orientation o,
85 sparse_bidirectional_iterator_tag t1, Tag t2
87 typename transposed_matrix_temporary<E1>::type e1_trans(e1);
88 gemm_impl(e1_trans,e2,m,alpha,row_major(),row_major(),o,t1,t2);
95 template<
class M,
class E1,
class E2,
class Tag>
97 matrix_expression<E1>
const& e1,
98 matrix_expression<E2>
const& e2,
99 matrix_expression<M>& m,
100 typename M::value_type alpha,
101 row_major, column_major, row_major,
102 dense_random_access_iterator_tag, Tag
104 for (std::size_t j = 0; j != e1().size2(); ++j) {
110 template<
class M,
class E1,
class E2>
112 matrix_expression<E1>
const& e1,
113 matrix_expression<E2>
const& e2,
114 matrix_expression<M>& m,
115 typename M::value_type alpha,
116 row_major, row_major, row_major,
117 sparse_bidirectional_iterator_tag, sparse_bidirectional_iterator_tag
119 typedef typename M::value_type value_type;
120 value_type zero = value_type();
121 vector<value_type> temporary(e2().size2(), zero);
122 for (std::size_t i = 0; i != e1().size1(); ++i) {
124 for (std::size_t j = 0; j != temporary.size(); ++ j) {
125 if (temporary(j) != zero) {
126 m()(i, j) += temporary(j);
141 template<
class M,
class E1,
class E2>
143 matrix_expression<E1>
const& e1,
144 matrix_expression<E2>
const& e2,
145 matrix_expression<M>& m,
146 typename M::value_type alpha,
147 row_major, column_major, column_major,
148 dense_random_access_iterator_tag, sparse_bidirectional_iterator_tag
151 for (std::size_t i = 0; i != m().size1(); ++i) {
152 matrix_row<M> mat_row(m(),i);
158 template<
class M,
class E1,
class E2>
160 matrix_expression<E1>
const& e1,
161 matrix_expression<E2>
const& e2,
162 matrix_expression<M>& m,
163 typename M::value_type alpha,
164 row_major r, column_major, column_major,
165 dense_random_access_iterator_tag t, dense_random_access_iterator_tag
168 std::size_t blockSize = 16;
169 typedef typename M::value_type value_type;
170 typedef typename matrix_temporary<M>::type BlockStorage;
171 BlockStorage blockStorage(blockSize,blockSize);
173 typedef typename M::size_type size_type;
174 size_type size1 = m().size1();
175 size_type size2 = m().size2();
176 for (size_type i = 0; i < size1; i+= blockSize){
177 for (size_type j = 0; j < size2; j+= blockSize){
178 std::size_t blockSizei =
std::min(blockSize,size1-i);
179 std::size_t blockSizej =
std::min(blockSize,size2-j);
180 matrix_range<matrix<value_type> > transBlock=
subrange(blockStorage,0,blockSizej,0,blockSizei);
199 template<
class M,
class E1,
class E2,
class Orientation1,
class Orientation2,
class Tag1,
class Tag2>
201 matrix_expression<E1>
const& e1,
202 matrix_expression<E2>
const& e2,
203 matrix_expression<M>& m,
204 typename M::value_type alpha,
205 column_major, Orientation1, Orientation2,
208 matrix_transpose<M> transposedM(m());
209 typedef typename Orientation1::transposed_orientation transpO1;
210 typedef typename Orientation2::transposed_orientation transpO2;
211 gemm_impl(
trans(e2),
trans(e1),transposedM,alpha,row_major(),transpO2(),transpO1(), Tag2(),Tag1());
215 template<
class M,
class E1,
class E2>
217 matrix_expression<E1>
const& e1,
218 matrix_expression<E2>
const& e2,
219 matrix_expression<M>& m,
220 typename M::value_type alpha,
226 typedef typename M::orientation ResultOrientation;
227 typedef typename E1::orientation E1Orientation;
228 typedef typename E2::orientation E2Orientation;
229 typedef typename major_iterator<E1>::type::iterator_category E1Category;
230 typedef typename major_iterator<E2>::type::iterator_category E2Category;
232 gemm_impl(e1, e2, m,alpha,
233 ResultOrientation(),E1Orientation(),E2Orientation(),
234 E1Category(),E2Category()