TensorContraction.h
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_H
11 #define EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_H
12 
13 namespace Eigen {
14 
22 namespace internal {
23 
24 template<typename Dimensions, typename LhsXprType, typename RhsXprType>
25 struct traits<TensorContractionOp<Dimensions, LhsXprType, RhsXprType> >
26 {
27  // Type promotion to handle the case where the types of the lhs and the rhs are different.
28  typedef typename internal::promote_storage_type<typename LhsXprType::Scalar,
29  typename RhsXprType::Scalar>::ret Scalar;
30  typedef typename promote_storage_type<typename traits<LhsXprType>::StorageKind,
31  typename traits<RhsXprType>::StorageKind>::ret StorageKind;
32  typedef typename promote_index_type<typename traits<LhsXprType>::Index,
33  typename traits<RhsXprType>::Index>::type Index;
34  typedef typename LhsXprType::Nested LhsNested;
35  typedef typename RhsXprType::Nested RhsNested;
36  typedef typename remove_reference<LhsNested>::type _LhsNested;
37  typedef typename remove_reference<RhsNested>::type _RhsNested;
38 
39  // From NumDims below.
40  static const int NumDimensions = traits<RhsXprType>::NumDimensions + traits<RhsXprType>::NumDimensions - 2 * array_size<Dimensions>::value;
41  static const int Layout = traits<LhsXprType>::Layout;
42 
43  enum {
44  Flags = 0
45  };
46 };
47 
48 template<typename Dimensions, typename LhsXprType, typename RhsXprType>
49 struct eval<TensorContractionOp<Dimensions, LhsXprType, RhsXprType>, Eigen::Dense>
50 {
51  typedef const TensorContractionOp<Dimensions, LhsXprType, RhsXprType>& type;
52 };
53 
54 template<typename Dimensions, typename LhsXprType, typename RhsXprType>
55 struct nested<TensorContractionOp<Dimensions, LhsXprType, RhsXprType>, 1, typename eval<TensorContractionOp<Dimensions, LhsXprType, RhsXprType> >::type>
56 {
57  typedef TensorContractionOp<Dimensions, LhsXprType, RhsXprType> type;
58 };
59 
60 template<typename Indices_, typename LeftArgType_, typename RightArgType_, typename Device_>
61 struct traits<TensorEvaluator<const TensorContractionOp<Indices_, LeftArgType_, RightArgType_>, Device_> > {
62  typedef Indices_ Indices;
63  typedef LeftArgType_ LeftArgType;
64  typedef RightArgType_ RightArgType;
65  typedef Device_ Device;
66 
67  // From NumDims below.
68  static const int NumDimensions = traits<LeftArgType_>::NumDimensions + traits<RightArgType_>::NumDimensions - 2 * array_size<Indices_>::value;
69 };
70 
71 } // end namespace internal
72 
73 template<typename Indices, typename LhsXprType, typename RhsXprType>
74 class TensorContractionOp : public TensorBase<TensorContractionOp<Indices, LhsXprType, RhsXprType>, ReadOnlyAccessors>
75 {
76  public:
77  typedef typename Eigen::internal::traits<TensorContractionOp>::Scalar Scalar;
78  typedef typename internal::promote_storage_type<typename LhsXprType::CoeffReturnType,
79  typename RhsXprType::CoeffReturnType>::ret CoeffReturnType;
80  typedef typename Eigen::internal::nested<TensorContractionOp>::type Nested;
81  typedef typename Eigen::internal::traits<TensorContractionOp>::StorageKind StorageKind;
82  typedef typename Eigen::internal::traits<TensorContractionOp>::Index Index;
83 
84  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionOp(
85  const LhsXprType& lhs, const RhsXprType& rhs, const Indices& dims)
86  : m_lhs_xpr(lhs), m_rhs_xpr(rhs), m_indices(dims) {}
87 
88  EIGEN_DEVICE_FUNC
89  const Indices& indices() const { return m_indices; }
90 
92  EIGEN_DEVICE_FUNC
93  const typename internal::remove_all<typename LhsXprType::Nested>::type&
94  lhsExpression() const { return m_lhs_xpr; }
95 
96  EIGEN_DEVICE_FUNC
97  const typename internal::remove_all<typename RhsXprType::Nested>::type&
98  rhsExpression() const { return m_rhs_xpr; }
99 
100  protected:
101  typename LhsXprType::Nested m_lhs_xpr;
102  typename RhsXprType::Nested m_rhs_xpr;
103  const Indices m_indices;
104 };
105 
106 
107 template<typename Derived>
108 struct TensorContractionEvaluatorBase
109 {
110  typedef typename internal::traits<Derived>::Indices Indices;
111  typedef typename internal::traits<Derived>::LeftArgType LeftArgType;
112  typedef typename internal::traits<Derived>::RightArgType RightArgType;
113  typedef typename internal::traits<Derived>::Device Device;
114 
115  typedef TensorContractionOp<Indices, LeftArgType, RightArgType> XprType;
116  typedef typename internal::remove_const<typename XprType::Scalar>::type Scalar;
117  typedef typename XprType::Index Index;
118  typedef typename XprType::CoeffReturnType CoeffReturnType;
119  typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
120 
121  enum {
122  IsAligned = true,
123  PacketAccess = (internal::unpacket_traits<PacketReturnType>::size > 1),
124  Layout = TensorEvaluator<LeftArgType, Device>::Layout,
125  CoordAccess = false, // to be implemented
126  RawAccess = true
127  };
128 
129  // Most of the code is assuming that both input tensors are ColMajor. If the
130  // inputs are RowMajor, we will "cheat" by swapping the LHS and RHS:
131  // If we want to compute A * B = C, where A is LHS and B is RHS, the code
132  // will pretend B is LHS and A is RHS.
133  typedef typename internal::conditional<
134  static_cast<int>(Layout) == static_cast<int>(ColMajor), LeftArgType, RightArgType>::type EvalLeftArgType;
135  typedef typename internal::conditional<
136  static_cast<int>(Layout) == static_cast<int>(ColMajor), RightArgType, LeftArgType>::type EvalRightArgType;
137 
138  static const int LDims =
139  internal::array_size<typename TensorEvaluator<EvalLeftArgType, Device>::Dimensions>::value;
140  static const int RDims =
141  internal::array_size<typename TensorEvaluator<EvalRightArgType, Device>::Dimensions>::value;
142  static const int ContractDims = internal::array_size<Indices>::value;
143  static const int NumDims = LDims + RDims - 2 * ContractDims;
144 
145  typedef array<Index, ContractDims> contract_t;
146  typedef array<Index, LDims - ContractDims> left_nocontract_t;
147  typedef array<Index, RDims - ContractDims> right_nocontract_t;
148 
149  typedef DSizes<Index, NumDims> Dimensions;
150 
151  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
152  TensorContractionEvaluatorBase(const XprType& op, const Device& device)
153  : m_leftImpl(choose(Cond<static_cast<int>(Layout) == static_cast<int>(ColMajor)>(),
154  op.lhsExpression(), op.rhsExpression()), device),
155  m_rightImpl(choose(Cond<static_cast<int>(Layout) == static_cast<int>(ColMajor)>(),
156  op.rhsExpression(), op.lhsExpression()), device),
157  m_device(device),
158  m_result(NULL) {
159  EIGEN_STATIC_ASSERT((static_cast<int>(TensorEvaluator<LeftArgType, Device>::Layout) ==
160  static_cast<int>(TensorEvaluator<RightArgType, Device>::Layout)),
161  YOU_MADE_A_PROGRAMMING_MISTAKE);
162 
163 
164  DSizes<Index, LDims> eval_left_dims;
165  DSizes<Index, RDims> eval_right_dims;
166  array<IndexPair<Index>, ContractDims> eval_op_indices;
167  if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
168  // For ColMajor, we keep using the existing dimensions
169  for (int i = 0; i < LDims; i++) {
170  eval_left_dims[i] = m_leftImpl.dimensions()[i];
171  }
172  for (int i = 0; i < RDims; i++) {
173  eval_right_dims[i] = m_rightImpl.dimensions()[i];
174  }
175  // We keep the pairs of contracting indices.
176  for (int i = 0; i < ContractDims; i++) {
177  eval_op_indices[i].first = op.indices()[i].first;
178  eval_op_indices[i].second = op.indices()[i].second;
179  }
180  } else {
181  // For RowMajor, we need to reverse the existing dimensions
182  for (int i = 0; i < LDims; i++) {
183  eval_left_dims[i] = m_leftImpl.dimensions()[LDims - i - 1];
184  }
185  for (int i = 0; i < RDims; i++) {
186  eval_right_dims[i] = m_rightImpl.dimensions()[RDims - i - 1];
187  }
188  // We need to flip all the pairs of contracting indices as well as
189  // reversing the dimensions.
190  for (int i = 0; i < ContractDims; i++) {
191  eval_op_indices[i].first = LDims - 1 - op.indices()[ContractDims - 1 - i].second;
192  eval_op_indices[i].second = RDims - 1 - op.indices()[ContractDims - 1 - i].first;
193  }
194  }
195 
196  // Check for duplicate axes and make sure the first index in eval_op_indices
197  // is increasing. Using O(n^2) sorting is OK since ContractDims is small
198  for (int i = 0; i < ContractDims; i++) {
199  for (int j = i + 1; j < ContractDims; j++) {
200  eigen_assert(eval_op_indices[j].first != eval_op_indices[i].first &&
201  eval_op_indices[j].second != eval_op_indices[i].second &&
202  "contraction axes should be unique");
203  if (eval_op_indices[j].first < eval_op_indices[i].first) {
204  numext::swap(eval_op_indices[j], eval_op_indices[i]);
205  }
206  }
207  }
208 
209  array<Index, LDims> lhs_strides;
210  lhs_strides[0] = 1;
211  for (int i = 0; i < LDims-1; ++i) {
212  lhs_strides[i+1] = lhs_strides[i] * eval_left_dims[i];
213  }
214 
215  array<Index, RDims> rhs_strides;
216  rhs_strides[0] = 1;
217  for (int i = 0; i < RDims-1; ++i) {
218  rhs_strides[i+1] = rhs_strides[i] * eval_right_dims[i];
219  }
220 
221  if (m_i_strides.size() > 0) m_i_strides[0] = 1;
222  if (m_j_strides.size() > 0) m_j_strides[0] = 1;
223  if (m_k_strides.size() > 0) m_k_strides[0] = 1;
224 
225  m_i_size = 1;
226  m_j_size = 1;
227  m_k_size = 1;
228 
229  // To compute the dimension, we simply concatenate the non-contracting
230  // dimensions of the left and then the right tensor. Additionally, we also
231  // compute the strides corresponding to the left non-contracting
232  // dimensions and right non-contracting dimensions.
233  m_lhs_inner_dim_contiguous = true;
234  int dim_idx = 0;
235  unsigned int nocontract_idx = 0;
236 
237  for (int i = 0; i < LDims; i++) {
238  // find if we are contracting on index i of left tensor
239  bool contracting = false;
240  for (int j = 0; j < ContractDims; j++) {
241  if (eval_op_indices[j].first == i) {
242  contracting = true;
243  break;
244  }
245  }
246  if (!contracting) {
247  // add dimension size to output dimensions
248  m_dimensions[dim_idx] = eval_left_dims[i];
249  m_left_nocontract_strides[nocontract_idx] = lhs_strides[i];
250  if (dim_idx != i) {
251  m_lhs_inner_dim_contiguous = false;
252  }
253  if (nocontract_idx+1 < internal::array_size<left_nocontract_t>::value) {
254  m_i_strides[nocontract_idx+1] =
255  m_i_strides[nocontract_idx] * eval_left_dims[i];
256  } else {
257  m_i_size = m_i_strides[nocontract_idx] * eval_left_dims[i];
258  }
259  dim_idx++;
260  nocontract_idx++;
261  }
262  }
263 
264  nocontract_idx = 0;
265  for (int i = 0; i < RDims; i++) {
266  bool contracting = false;
267  // find if we are contracting on index i of right tensor
268  for (int j = 0; j < ContractDims; j++) {
269  if (eval_op_indices[j].second == i) {
270  contracting = true;
271  break;
272  }
273  }
274  if (!contracting) {
275  m_dimensions[dim_idx] = eval_right_dims[i];
276  if (nocontract_idx+1 < internal::array_size<right_nocontract_t>::value) {
277  m_j_strides[nocontract_idx+1] =
278  m_j_strides[nocontract_idx] * eval_right_dims[i];
279  } else {
280  m_j_size = m_j_strides[nocontract_idx] * eval_right_dims[i];
281  }
282  m_right_nocontract_strides[nocontract_idx] = rhs_strides[i];
283  dim_idx++;
284  nocontract_idx++;
285  }
286  }
287 
288  // Now compute the strides corresponding to the contracting dimensions. We
289  // assumed above that non-contracting axes are represented in the same order
290  // in the matrix as they are in the tensor. This is not the case for
291  // contracting axes. As the contracting axes must be of the same size in
292  // each tensor, we'll only look at the first tensor here.
293  m_rhs_inner_dim_contiguous = true;
294  m_rhs_inner_dim_reordered = false;
295  for (int i = 0; i < ContractDims; i++) {
296  Index left = eval_op_indices[i].first;
297  Index right = eval_op_indices[i].second;
298 
299  Index size = eval_left_dims[left];
300  eigen_assert(size == eval_right_dims[right] &&
301  "Contraction axes must be same size");
302 
303  if (i+1 < static_cast<int>(internal::array_size<contract_t>::value)) {
304  m_k_strides[i+1] = m_k_strides[i] * size;
305  } else {
306  m_k_size = m_k_strides[i] * size;
307  }
308  m_left_contracting_strides[i] = lhs_strides[left];
309  m_right_contracting_strides[i] = rhs_strides[right];
310 
311  if (i > 0 && right < eval_op_indices[i-1].second) {
312  m_rhs_inner_dim_reordered = true;
313  }
314  if (right != i) {
315  m_rhs_inner_dim_contiguous = false;
316  }
317  }
318 
319  // If the layout is RowMajor, we need to reverse the m_dimensions
320  if (static_cast<int>(Layout) == static_cast<int>(RowMajor)) {
321  for (int i = 0, j = NumDims - 1; i < j; i++, j--) {
322  numext::swap(m_dimensions[i], m_dimensions[j]);
323  }
324  }
325  }
326 
327  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; }
328 
329  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar* data) {
330  m_leftImpl.evalSubExprsIfNeeded(NULL);
331  m_rightImpl.evalSubExprsIfNeeded(NULL);
332  if (data) {
333  evalTo(data);
334  return false;
335  } else {
336  m_result = static_cast<Scalar *>(m_device.allocate(dimensions().TotalSize() * sizeof(Scalar)));
337  evalTo(m_result);
338  return true;
339  }
340  }
341 
342  EIGEN_DEVICE_FUNC void evalTo(Scalar* buffer) const {
343  if (this->m_lhs_inner_dim_contiguous) {
344  if (this->m_rhs_inner_dim_contiguous) {
345  if (this->m_rhs_inner_dim_reordered) {
346  static_cast<const Derived*>(this)->template evalProduct<true, true, true, Unaligned>(buffer);
347  }
348  else {
349  static_cast<const Derived*>(this)->template evalProduct<true, true, false, Unaligned>(buffer);
350  }
351  }
352  else {
353  if (this->m_rhs_inner_dim_reordered) {
354  static_cast<const Derived*>(this)->template evalProduct<true, false, true, Unaligned>(buffer);
355  }
356  else {
357  static_cast<const Derived*>(this)->template evalProduct<true, false, false, Unaligned>(buffer);
358  }
359  }
360  }
361  else {
362  if (this->m_rhs_inner_dim_contiguous) {
363  if (this->m_rhs_inner_dim_reordered) {
364  static_cast<const Derived*>(this)->template evalProduct<false, true, true, Unaligned>(buffer);
365  }
366  else {
367  static_cast<const Derived*>(this)->template evalProduct<false, true, false, Unaligned>(buffer);
368  }
369  }
370  else {
371  if (this->m_rhs_inner_dim_reordered) {
372  static_cast<const Derived*>(this)->template evalProduct<false, false, true, Unaligned>(buffer);
373  }
374  else {
375  static_cast<const Derived*>(this)->template evalProduct<false, false, false, Unaligned>(buffer);
376  }
377  }
378  }
379  }
380 
381  template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment>
382  EIGEN_DEVICE_FUNC void evalGemv(Scalar* buffer) const {
383  const Index rows = m_i_size;
384  const Index cols = m_k_size;
385 
386  typedef typename internal::remove_const<typename EvalLeftArgType::Scalar>::type LhsScalar;
387  typedef typename internal::remove_const<typename EvalRightArgType::Scalar>::type RhsScalar;
388  typedef TensorEvaluator<EvalLeftArgType, Device> LeftEvaluator;
389  typedef TensorEvaluator<EvalRightArgType, Device> RightEvaluator;
390  const Index lhs_packet_size = internal::unpacket_traits<typename LeftEvaluator::PacketReturnType>::size;
391  const Index rhs_packet_size = internal::unpacket_traits<typename RightEvaluator::PacketReturnType>::size;
392  const int lhs_alignment = LeftEvaluator::IsAligned ? Aligned : Unaligned;
393  const int rhs_alignment = RightEvaluator::IsAligned ? Aligned : Unaligned;
394  typedef internal::TensorContractionInputMapper<LhsScalar, Index, internal::Lhs,
395  LeftEvaluator, left_nocontract_t,
396  contract_t, lhs_packet_size,
397  lhs_inner_dim_contiguous,
398  false, lhs_alignment> LhsMapper;
399 
400  typedef internal::TensorContractionInputMapper<RhsScalar, Index, internal::Rhs,
401  RightEvaluator, right_nocontract_t,
402  contract_t, rhs_packet_size,
403  rhs_inner_dim_contiguous,
404  rhs_inner_dim_reordered, rhs_alignment> RhsMapper;
405 
406  LhsMapper lhs(m_leftImpl, m_left_nocontract_strides, m_i_strides,
407  m_left_contracting_strides, m_k_strides);
408  RhsMapper rhs(m_rightImpl, m_right_nocontract_strides, m_j_strides,
409  m_right_contracting_strides, m_k_strides);
410 
411  const Scalar alpha(1);
412  const Index resIncr(1);
413 
414  // zero out the result buffer (which must be of size at least rows * sizeof(Scalar)
415  m_device.memset(buffer, 0, rows * sizeof(Scalar));
416 
417  internal::general_matrix_vector_product<Index,LhsScalar,LhsMapper,ColMajor,false,RhsScalar,RhsMapper,false>::run(
418  rows, cols, lhs, rhs,
419  buffer, resIncr, alpha);
420  }
421 
422  template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment>
423  EIGEN_DEVICE_FUNC void evalGemm(Scalar* buffer) const {
424  // columns in left side, rows in right side
425  const Index k = this->m_k_size;
426 
427  // rows in left side
428  const Index m = this->m_i_size;
429 
430  // columns in right side
431  const Index n = this->m_j_size;
432 
433  // zero out the result buffer (which must be of size at least m * n * sizeof(Scalar)
434  this->m_device.memset(buffer, 0, m * n * sizeof(Scalar));
435 
436  // define mr, nr, and all of my data mapper types
437  typedef typename internal::remove_const<typename EvalLeftArgType::Scalar>::type LhsScalar;
438  typedef typename internal::remove_const<typename EvalRightArgType::Scalar>::type RhsScalar;
439  typedef typename internal::gebp_traits<LhsScalar, RhsScalar> Traits;
440 
441  const Index nr = Traits::nr;
442  const Index mr = Traits::mr;
443 
444  typedef TensorEvaluator<EvalLeftArgType, Device> LeftEvaluator;
445  typedef TensorEvaluator<EvalRightArgType, Device> RightEvaluator;
446 
447  const Index lhs_packet_size = internal::unpacket_traits<typename LeftEvaluator::PacketReturnType>::size;
448  const Index rhs_packet_size = internal::unpacket_traits<typename RightEvaluator::PacketReturnType>::size;
449 
450  typedef internal::TensorContractionInputMapper<LhsScalar, Index, internal::Lhs,
451  LeftEvaluator, left_nocontract_t,
452  contract_t, lhs_packet_size,
453  lhs_inner_dim_contiguous,
454  false, Unaligned> LhsMapper;
455 
456  typedef internal::TensorContractionInputMapper<RhsScalar, Index, internal::Rhs,
457  RightEvaluator, right_nocontract_t,
458  contract_t, rhs_packet_size,
459  rhs_inner_dim_contiguous,
460  rhs_inner_dim_reordered, Unaligned> RhsMapper;
461 
462  typedef internal::blas_data_mapper<Scalar, Index, ColMajor> OutputMapper;
463 
464  // Declare GEBP packing and kernel structs
465  internal::gemm_pack_lhs<LhsScalar, Index, typename LhsMapper::SubMapper, mr, Traits::LhsProgress, ColMajor> pack_lhs;
466  internal::gemm_pack_rhs<RhsScalar, Index, typename RhsMapper::SubMapper, nr, ColMajor> pack_rhs;
467 
468  internal::gebp_kernel<LhsScalar, RhsScalar, Index, OutputMapper, mr, nr, false, false> gebp;
469 
470  // initialize data mappers
471  LhsMapper lhs(this->m_leftImpl, this->m_left_nocontract_strides, this->m_i_strides,
472  this->m_left_contracting_strides, this->m_k_strides);
473 
474  RhsMapper rhs(this->m_rightImpl, this->m_right_nocontract_strides, this->m_j_strides,
475  this->m_right_contracting_strides, this->m_k_strides);
476 
477  OutputMapper output(buffer, m);
478 
479  // Sizes of the blocks to load in cache. See the Goto paper for details.
480  internal::TensorContractionBlocking<LhsMapper, RhsMapper, Index, internal::ShardByCol> blocking(k, m, n, 1);
481  const Index kc = blocking.kc();
482  const Index mc = numext::mini(m, blocking.mc());
483  const Index nc = numext::mini(n, blocking.nc());
484  const Index sizeA = mc * kc;
485  const Index sizeB = kc * nc;
486 
487  LhsScalar* blockA = static_cast<LhsScalar *>(this->m_device.allocate(sizeA * sizeof(LhsScalar)));
488  RhsScalar* blockB = static_cast<RhsScalar *>(this->m_device.allocate(sizeB * sizeof(RhsScalar)));
489 
490  for(Index i2=0; i2<m; i2+=mc)
491  {
492  const Index actual_mc = numext::mini(i2+mc,m)-i2;
493  for (Index k2 = 0; k2 < k; k2 += kc) {
494  // make sure we don't overshoot right edge of left matrix, then pack vertical panel
495  const Index actual_kc = numext::mini(k2 + kc, k) - k2;
496  pack_lhs(blockA, lhs.getSubMapper(i2, k2), actual_kc, actual_mc, 0, 0);
497 
498  // series of horizontal blocks
499  for (Index j2 = 0; j2 < n; j2 += nc) {
500  // make sure we don't overshoot right edge of right matrix, then pack block
501  const Index actual_nc = numext::mini(j2 + nc, n) - j2;
502  pack_rhs(blockB, rhs.getSubMapper(k2, j2), actual_kc, actual_nc, 0, 0);
503 
504  // call gebp (matrix kernel)
505  // The parameters here are copied from Eigen's GEMM implementation
506  gebp(output.getSubMapper(i2, j2), blockA, blockB, actual_mc, actual_kc, actual_nc, Scalar(1), -1, -1, 0, 0);
507  }
508  }
509  }
510 
511  this->m_device.deallocate(blockA);
512  this->m_device.deallocate(blockB);
513  }
514 
515  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() {
516  m_leftImpl.cleanup();
517  m_rightImpl.cleanup();
518 
519  if (m_result != NULL) {
520  m_device.deallocate(m_result);
521  m_result = NULL;
522  }
523  }
524 
525  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const {
526  return m_result[index];
527  }
528 
529  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost costPerCoeff(bool) const {
530  return TensorOpCost(sizeof(CoeffReturnType), 0, 0);
531  }
532 
533  template<int LoadMode>
534  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packet(Index index) const {
535  return internal::ploadt<PacketReturnType, LoadMode>(m_result + index);
536  }
537 
538  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar* data() const { return m_result; }
539 
540  protected:
541  // Prevent assignment
542  TensorContractionEvaluatorBase& operator = (const TensorContractionEvaluatorBase&);
543  Dimensions m_dimensions;
544 
545  contract_t m_k_strides;
546  contract_t m_left_contracting_strides;
547  contract_t m_right_contracting_strides;
548 
549  bool m_lhs_inner_dim_contiguous;
550  bool m_rhs_inner_dim_contiguous;
551  bool m_rhs_inner_dim_reordered;
552 
553  left_nocontract_t m_i_strides;
554  right_nocontract_t m_j_strides;
555  left_nocontract_t m_left_nocontract_strides;
556  right_nocontract_t m_right_nocontract_strides;
557 
558  Index m_i_size;
559  Index m_j_size;
560  Index m_k_size;
561 
562  TensorEvaluator<EvalLeftArgType, Device> m_leftImpl;
563  TensorEvaluator<EvalRightArgType, Device> m_rightImpl;
564  const Device& m_device;
565  Scalar* m_result;
566 };
567 
568 
569 // evaluator for default device
570 template<typename Indices, typename LeftArgType, typename RightArgType, typename Device>
571 struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, Device> :
572  public TensorContractionEvaluatorBase<
573  TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, Device> > {
574  typedef TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, Device> Self;
575  typedef TensorContractionEvaluatorBase<Self> Base;
576 
577  typedef TensorContractionOp<Indices, LeftArgType, RightArgType> XprType;
578  typedef typename internal::remove_const<typename XprType::Scalar>::type Scalar;
579  typedef typename XprType::Index Index;
580  typedef typename XprType::CoeffReturnType CoeffReturnType;
581  typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
582 
583  enum {
584  Layout = TensorEvaluator<LeftArgType, Device>::Layout
585  };
586 
587  // Most of the code is assuming that both input tensors are ColMajor. If the
588  // inputs are RowMajor, we will "cheat" by swapping the LHS and RHS:
589  // If we want to compute A * B = C, where A is LHS and B is RHS, the code
590  // will pretend B is LHS and A is RHS.
591  typedef typename internal::conditional<
592  static_cast<int>(Layout) == static_cast<int>(ColMajor), LeftArgType, RightArgType>::type EvalLeftArgType;
593  typedef typename internal::conditional<
594  static_cast<int>(Layout) == static_cast<int>(ColMajor), RightArgType, LeftArgType>::type EvalRightArgType;
595 
596  static const int LDims =
597  internal::array_size<typename TensorEvaluator<EvalLeftArgType, Device>::Dimensions>::value;
598  static const int RDims =
599  internal::array_size<typename TensorEvaluator<EvalRightArgType, Device>::Dimensions>::value;
600  static const int ContractDims = internal::array_size<Indices>::value;
601 
602  typedef array<Index, ContractDims> contract_t;
603  typedef array<Index, LDims - ContractDims> left_nocontract_t;
604  typedef array<Index, RDims - ContractDims> right_nocontract_t;
605 
606  static const int NumDims = LDims + RDims - 2 * ContractDims;
607 
608  // Could we use NumDimensions here?
609  typedef DSizes<Index, NumDims> Dimensions;
610 
611  EIGEN_DEVICE_FUNC TensorEvaluator(const XprType& op, const Device& device) :
612  Base(op, device) { }
613 
614  template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment>
615  EIGEN_DEVICE_FUNC void evalProduct(Scalar* buffer) const {
616  if (this->m_j_size == 1) {
617  this->template evalGemv<lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, rhs_inner_dim_reordered, Alignment>(buffer);
618  return;
619  }
620 
621  this->template evalGemm<lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, rhs_inner_dim_reordered, Alignment>(buffer);
622  }
623 };
624 
625 } // end namespace Eigen
626 
627 #endif // EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_H
Namespace containing all symbols from the Eigen library.
Definition: AdolcForward:45