TensorDeviceThreadPool.h
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #if defined(EIGEN_USE_THREADS) && !defined(EIGEN_CXX11_TENSOR_TENSOR_DEVICE_THREAD_POOL_H)
11 #define EIGEN_CXX11_TENSOR_TENSOR_DEVICE_THREAD_POOL_H
12 
13 namespace Eigen {
14 
15 // This defines an interface that ThreadPoolDevice can take to use
16 // custom thread pools underneath.
17 class ThreadPoolInterface {
18  public:
19  virtual void Schedule(std::function<void()> fn) = 0;
20 
21  virtual ~ThreadPoolInterface() {}
22 };
23 
24 // The implementation of the ThreadPool type ensures that the Schedule method
25 // runs the functions it is provided in FIFO order when the scheduling is done
26 // by a single thread.
27 class ThreadPool : public ThreadPoolInterface {
28  public:
29  // Construct a pool that contains "num_threads" threads.
30  explicit ThreadPool(int num_threads) {
31  for (int i = 0; i < num_threads; i++) {
32  threads_.push_back(new std::thread([this]() { WorkerLoop(); }));
33  }
34  }
35 
36  // Wait until all scheduled work has finished and then destroy the
37  // set of threads.
38  ~ThreadPool()
39  {
40  {
41  // Wait for all work to get done.
42  std::unique_lock<std::mutex> l(mu_);
43  empty_.wait(l, [this]() { return pending_.empty(); });
44  exiting_ = true;
45 
46  // Wakeup all waiters.
47  for (auto w : waiters_) {
48  w->ready = true;
49  w->work = nullptr;
50  w->cv.notify_one();
51  }
52  }
53 
54  // Wait for threads to finish.
55  for (auto t : threads_) {
56  t->join();
57  delete t;
58  }
59  }
60 
61  // Schedule fn() for execution in the pool of threads. The functions are
62  // executed in the order in which they are scheduled.
63  void Schedule(std::function<void()> fn) {
64  std::unique_lock<std::mutex> l(mu_);
65  if (waiters_.empty()) {
66  pending_.push_back(fn);
67  } else {
68  Waiter* w = waiters_.back();
69  waiters_.pop_back();
70  w->ready = true;
71  w->work = fn;
72  w->cv.notify_one();
73  }
74  }
75 
76  protected:
77  void WorkerLoop() {
78  std::unique_lock<std::mutex> l(mu_);
79  Waiter w;
80  while (!exiting_) {
81  std::function<void()> fn;
82  if (pending_.empty()) {
83  // Wait for work to be assigned to me
84  w.ready = false;
85  waiters_.push_back(&w);
86  w.cv.wait(l, [&w]() { return w.ready; });
87  fn = w.work;
88  w.work = nullptr;
89  } else {
90  // Pick up pending work
91  fn = pending_.front();
92  pending_.pop_front();
93  if (pending_.empty()) {
94  empty_.notify_all();
95  }
96  }
97  if (fn) {
98  mu_.unlock();
99  fn();
100  mu_.lock();
101  }
102  }
103  }
104 
105  private:
106  struct Waiter {
107  std::condition_variable cv;
108  std::function<void()> work;
109  bool ready;
110  };
111 
112  std::mutex mu_;
113  std::vector<std::thread*> threads_; // All threads
114  std::vector<Waiter*> waiters_; // Stack of waiting threads.
115  std::deque<std::function<void()>> pending_; // Queue of pending work
116  std::condition_variable empty_; // Signaled on pending_.empty()
117  bool exiting_ = false;
118 };
119 
120 
121 // Notification is an object that allows a user to to wait for another
122 // thread to signal a notification that an event has occurred.
123 //
124 // Multiple threads can wait on the same Notification object.
125 // but only one caller must call Notify() on the object.
126 class Notification {
127  public:
128  Notification() : notified_(false) {}
129  ~Notification() {}
130 
131  void Notify() {
132  std::unique_lock<std::mutex> l(mu_);
133  eigen_assert(!notified_);
134  notified_ = true;
135  cv_.notify_all();
136  }
137 
138  void WaitForNotification() {
139  std::unique_lock<std::mutex> l(mu_);
140  cv_.wait(l, [this]() { return notified_; } );
141  }
142 
143  private:
144  std::mutex mu_;
145  std::condition_variable cv_;
146  bool notified_;
147 };
148 
149 // Runs an arbitrary function and then calls Notify() on the passed in
150 // Notification.
151 template <typename Function, typename... Args> struct FunctionWrapper
152 {
153  static void run(Notification* n, Function f, Args... args) {
154  f(args...);
155  n->Notify();
156  }
157 };
158 
159 static EIGEN_STRONG_INLINE void wait_until_ready(Notification* n) {
160  if (n) {
161  n->WaitForNotification();
162  }
163 }
164 
165 
166 // Build a thread pool device on top the an existing pool of threads.
167 struct ThreadPoolDevice {
168  // The ownership of the thread pool remains with the caller.
169  ThreadPoolDevice(ThreadPoolInterface* pool, size_t num_cores) : pool_(pool), num_threads_(num_cores) { }
170 
171  EIGEN_STRONG_INLINE void* allocate(size_t num_bytes) const {
172  return internal::aligned_malloc(num_bytes);
173  }
174 
175  EIGEN_STRONG_INLINE void deallocate(void* buffer) const {
176  internal::aligned_free(buffer);
177  }
178 
179  EIGEN_STRONG_INLINE void memcpy(void* dst, const void* src, size_t n) const {
180  ::memcpy(dst, src, n);
181  }
182  EIGEN_STRONG_INLINE void memcpyHostToDevice(void* dst, const void* src, size_t n) const {
183  memcpy(dst, src, n);
184  }
185  EIGEN_STRONG_INLINE void memcpyDeviceToHost(void* dst, const void* src, size_t n) const {
186  memcpy(dst, src, n);
187  }
188 
189  EIGEN_STRONG_INLINE void memset(void* buffer, int c, size_t n) const {
190  ::memset(buffer, c, n);
191  }
192 
193  EIGEN_STRONG_INLINE size_t numThreads() const {
194  return num_threads_;
195  }
196 
197  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE int majorDeviceVersion() const {
198  // Should return an enum that encodes the ISA supported by the CPU
199  return 1;
200  }
201 
202  template <class Function, class... Args>
203  EIGEN_STRONG_INLINE Notification* enqueue(Function&& f, Args&&... args) const {
204  Notification* n = new Notification();
205  std::function<void()> func =
206  std::bind(&FunctionWrapper<Function, Args...>::run, n, f, args...);
207  pool_->Schedule(func);
208  return n;
209  }
210  template <class Function, class... Args>
211  EIGEN_STRONG_INLINE void enqueueNoNotification(Function&& f, Args&&... args) const {
212  std::function<void()> func = std::bind(f, args...);
213  pool_->Schedule(func);
214  }
215 
216  private:
217  ThreadPoolInterface* pool_;
218  size_t num_threads_;
219 };
220 
221 
222 } // end namespace Eigen
223 
224 #endif // EIGEN_CXX11_TENSOR_TENSOR_DEVICE_THREAD_POOL_H
Namespace containing all symbols from the Eigen library.
Definition: CXX11Meta.h:13