33 #ifndef EIGEN_TRIANGULAR_MATRIX_VECTOR_MKL_H
34 #define EIGEN_TRIANGULAR_MATRIX_VECTOR_MKL_H
46 template<
typename Index,
int Mode,
typename LhsScalar,
bool ConjLhs,
typename RhsScalar,
bool ConjRhs,
int StorageOrder>
47 struct triangular_matrix_vector_product_trmv :
48 triangular_matrix_vector_product<Index,Mode,LhsScalar,ConjLhs,RhsScalar,ConjRhs,StorageOrder,BuiltIn> {};
50 #define EIGEN_MKL_TRMV_SPECIALIZE(Scalar) \
51 template<typename Index, int Mode, bool ConjLhs, bool ConjRhs> \
52 struct triangular_matrix_vector_product<Index,Mode,Scalar,ConjLhs,Scalar,ConjRhs,ColMajor,Specialized> { \
53 static void run(Index _rows, Index _cols, const Scalar* _lhs, Index lhsStride, \
54 const Scalar* _rhs, Index rhsIncr, Scalar* _res, Index resIncr, Scalar alpha) { \
55 triangular_matrix_vector_product_trmv<Index,Mode,Scalar,ConjLhs,Scalar,ConjRhs,ColMajor>::run( \
56 _rows, _cols, _lhs, lhsStride, _rhs, rhsIncr, _res, resIncr, alpha); \
59 template<typename Index, int Mode, bool ConjLhs, bool ConjRhs> \
60 struct triangular_matrix_vector_product<Index,Mode,Scalar,ConjLhs,Scalar,ConjRhs,RowMajor,Specialized> { \
61 static void run(Index _rows, Index _cols, const Scalar* _lhs, Index lhsStride, \
62 const Scalar* _rhs, Index rhsIncr, Scalar* _res, Index resIncr, Scalar alpha) { \
63 triangular_matrix_vector_product_trmv<Index,Mode,Scalar,ConjLhs,Scalar,ConjRhs,RowMajor>::run( \
64 _rows, _cols, _lhs, lhsStride, _rhs, rhsIncr, _res, resIncr, alpha); \
68 EIGEN_MKL_TRMV_SPECIALIZE(
double)
69 EIGEN_MKL_TRMV_SPECIALIZE(
float)
70 EIGEN_MKL_TRMV_SPECIALIZE(dcomplex)
71 EIGEN_MKL_TRMV_SPECIALIZE(scomplex)
74 #define EIGEN_MKL_TRMV_CM(EIGTYPE, MKLTYPE, EIGPREFIX, MKLPREFIX) \
75 template<typename Index, int Mode, bool ConjLhs, bool ConjRhs> \
76 struct triangular_matrix_vector_product_trmv<Index,Mode,EIGTYPE,ConjLhs,EIGTYPE,ConjRhs,ColMajor> { \
78 IsLower = (Mode&Lower) == Lower, \
79 SetDiag = (Mode&(ZeroDiag|UnitDiag)) ? 0 : 1, \
80 IsUnitDiag = (Mode&UnitDiag) ? 1 : 0, \
81 IsZeroDiag = (Mode&ZeroDiag) ? 1 : 0, \
82 LowUp = IsLower ? Lower : Upper \
84 static void run(Index _rows, Index _cols, const EIGTYPE* _lhs, Index lhsStride, \
85 const EIGTYPE* _rhs, Index rhsIncr, EIGTYPE* _res, Index resIncr, EIGTYPE alpha) \
87 if (ConjLhs || IsZeroDiag) { \
88 triangular_matrix_vector_product<Index,Mode,EIGTYPE,ConjLhs,EIGTYPE,ConjRhs,ColMajor,BuiltIn>::run( \
89 _rows, _cols, _lhs, lhsStride, _rhs, rhsIncr, _res, resIncr, alpha); \
92 Index size = (std::min)(_rows,_cols); \
93 Index rows = IsLower ? _rows : size; \
94 Index cols = IsLower ? size : _cols; \
96 typedef VectorX##EIGPREFIX VectorRhs; \
100 Map<const VectorRhs, 0, InnerStride<> > rhs(_rhs,cols,InnerStride<>(rhsIncr)); \
102 if (ConjRhs) x_tmp = rhs.conjugate(); else x_tmp = rhs; \
107 char trans, uplo, diag; \
108 MKL_INT m, n, lda, incx, incy; \
110 MKLTYPE alpha_, beta_; \
111 assign_scalar_eig2mkl<MKLTYPE, EIGTYPE>(alpha_, alpha); \
112 assign_scalar_eig2mkl<MKLTYPE, EIGTYPE>(beta_, EIGTYPE(1)); \
122 uplo = IsLower ? 'L' : 'U'; \
123 diag = IsUnitDiag ? 'U' : 'N'; \
126 MKLPREFIX##trmv(&uplo, &trans, &diag, &n, (const MKLTYPE*)_lhs, &lda, (MKLTYPE*)x, &incx); \
129 MKLPREFIX##axpy(&n, &alpha_,(const MKLTYPE*)x, &incx, (MKLTYPE*)_res, &incy); \
131 if (size<(std::max)(rows,cols)) { \
132 if (ConjRhs) x_tmp = rhs.conjugate(); else x_tmp = rhs; \
135 y = _res + size*resIncr; \
143 a = _lhs + size*lda; \
147 MKLPREFIX##gemv(&trans, &m, &n, &alpha_, (const MKLTYPE*)a, &lda, (const MKLTYPE*)x, &incx, &beta_, (MKLTYPE*)y, &incy); \
152 EIGEN_MKL_TRMV_CM(
double,
double, d, d)
153 EIGEN_MKL_TRMV_CM(dcomplex, MKL_Complex16, cd, z)
154 EIGEN_MKL_TRMV_CM(
float,
float, f, s)
155 EIGEN_MKL_TRMV_CM(scomplex, MKL_Complex8, cf, c)
158 #define EIGEN_MKL_TRMV_RM(EIGTYPE, MKLTYPE, EIGPREFIX, MKLPREFIX) \
159 template<typename Index, int Mode, bool ConjLhs, bool ConjRhs> \
160 struct triangular_matrix_vector_product_trmv<Index,Mode,EIGTYPE,ConjLhs,EIGTYPE,ConjRhs,RowMajor> { \
162 IsLower = (Mode&Lower) == Lower, \
163 SetDiag = (Mode&(ZeroDiag|UnitDiag)) ? 0 : 1, \
164 IsUnitDiag = (Mode&UnitDiag) ? 1 : 0, \
165 IsZeroDiag = (Mode&ZeroDiag) ? 1 : 0, \
166 LowUp = IsLower ? Lower : Upper \
168 static void run(Index _rows, Index _cols, const EIGTYPE* _lhs, Index lhsStride, \
169 const EIGTYPE* _rhs, Index rhsIncr, EIGTYPE* _res, Index resIncr, EIGTYPE alpha) \
172 triangular_matrix_vector_product<Index,Mode,EIGTYPE,ConjLhs,EIGTYPE,ConjRhs,RowMajor,BuiltIn>::run( \
173 _rows, _cols, _lhs, lhsStride, _rhs, rhsIncr, _res, resIncr, alpha); \
176 Index size = (std::min)(_rows,_cols); \
177 Index rows = IsLower ? _rows : size; \
178 Index cols = IsLower ? size : _cols; \
180 typedef VectorX##EIGPREFIX VectorRhs; \
184 Map<const VectorRhs, 0, InnerStride<> > rhs(_rhs,cols,InnerStride<>(rhsIncr)); \
186 if (ConjRhs) x_tmp = rhs.conjugate(); else x_tmp = rhs; \
191 char trans, uplo, diag; \
192 MKL_INT m, n, lda, incx, incy; \
194 MKLTYPE alpha_, beta_; \
195 assign_scalar_eig2mkl<MKLTYPE, EIGTYPE>(alpha_, alpha); \
196 assign_scalar_eig2mkl<MKLTYPE, EIGTYPE>(beta_, EIGTYPE(1)); \
205 trans = ConjLhs ? 'C' : 'T'; \
206 uplo = IsLower ? 'U' : 'L'; \
207 diag = IsUnitDiag ? 'U' : 'N'; \
210 MKLPREFIX##trmv(&uplo, &trans, &diag, &n, (const MKLTYPE*)_lhs, &lda, (MKLTYPE*)x, &incx); \
213 MKLPREFIX##axpy(&n, &alpha_,(const MKLTYPE*)x, &incx, (MKLTYPE*)_res, &incy); \
215 if (size<(std::max)(rows,cols)) { \
216 if (ConjRhs) x_tmp = rhs.conjugate(); else x_tmp = rhs; \
219 y = _res + size*resIncr; \
220 a = _lhs + size*lda; \
231 MKLPREFIX##gemv(&trans, &n, &m, &alpha_, (const MKLTYPE*)a, &lda, (const MKLTYPE*)x, &incx, &beta_, (MKLTYPE*)y, &incy); \
236 EIGEN_MKL_TRMV_RM(
double,
double, d, d)
237 EIGEN_MKL_TRMV_RM(dcomplex, MKL_Complex16, cd, z)
238 EIGEN_MKL_TRMV_RM(
float,
float, f, s)
239 EIGEN_MKL_TRMV_RM(scomplex, MKL_Complex8, cf, c)
245 #endif // EIGEN_TRIANGULAR_MATRIX_VECTOR_MKL_H
Definition: Eigen_Colamd.h:54