使用Pool的Tensorflow错误:无法腌制_thread.RLock对象

时间:2018-12-25 11:32:34

标签: python tensorflow multiprocessing

我试图在Tensorflow CPU上实现大规模并行微分方程求解器(30k DE),但是内存不足(大约30GB矩阵)。因此,我实现了一个基于批处理的求解器(求解时间短,并保存数据->设置新的初始->再次求解)。但是问题仍然存在。我了解到,在关闭python解释器之前,Tensorflow不会清除内存。因此,基于有关github问题的信息,我尝试使用池来实现多处理解决方案,但在池化步骤中,我始终收到“无法腌制_thread.RLock对象”的信息。有人可以帮忙!

def dAdt(X,t):
  dX = // vector of differential
  return dX

global state_vector
global state

state_vector =  [0]*n // initial state

def tensor_process():
    with tf.Session() as sess:
        print("Session started...",end="")
        tf.global_variables_initializer().run()
        state = sess.run(tensor_state)
        sess.close()


n_batch = 3
t_batch = np.array_split(t,n_batch)


for n,i in enumerate(t_batch):
    print("Batch",(n+1),"Running...",end="")
    if n>0:
        i = np.append(i[0]-0.01,i)
    print("Session started...",end="")
    init_state = tf.constant(state_vector, dtype=tf.float64)
    tensor_state = tf.contrib.odeint_fixed(dAdt, init_state, i)
    with Pool(1) as p:
        p.apply_async(tensor_process).get()
    state_vector = state[-1,:]
    np.save("state.batch"+str(n+1),state)
    state=None

2 个答案:

答案 0 :(得分:1)

Tensorflow不支持多重处理,原因有很多,例如它无法派生TensorFlow会话本身。如果您仍然想使用某种“多”东西,请尝试以下对我有用的(multiprocessing.pool.ThreadPool):

https://stackoverflow.com/a/46049195/5276428

注意:我是通过在线程上创建多个会话,然后依次调用属于每个线程的每个会话变量来实现的。如果您的问题是内存,我认为可以通过减少输入批处理大小来解决。

答案 1 :(得分:0)

不要使用N个工作池,而是尝试创建N个不同的multiprocessing.Process对象实例,并将tensor_process()函数作为目标参数,并将每个数据子集作为args参数。在for循环内启动进程,然后将其加入循环下。您可以使用共享的multiprocessing.Queue对象将结果返回到主流程。

我个人已经成功地将TensorFlow与Python的多处理模块by sub-classing Process and overriding its run() method相结合。

def run(self):
  logging.info('started inference.')
  logging.debug('TF input frame shape == {}'.format(self.tensor_shape))

  count = 0

  with tf.device('/cpu:0') if self.device_type == 'cpu' else \
      tf.device(None):
    with tf.Session(config=self.session_config) as session:
      frame_dataset = tf.data.Dataset.from_generator(
        self.generate_frames, tf.uint8, tf.TensorShape(self.tensor_shape))
      frame_dataset = frame_dataset.map(self._preprocess_frames,
                                        self._get_num_parallel_calls())
      frame_dataset = frame_dataset.batch(self.batch_size)
      frame_dataset = frame_dataset.prefetch(self.batch_size)
      next_batch = frame_dataset.make_one_shot_iterator().get_next()

      while True:
        try:
          frame_batch = session.run(next_batch)
          probs = session.run(self.output_node,
                              {self.input_node: frame_batch})
          self.prob_array[count:count + probs.shape[0]] = probs
          count += probs.shape[0]
        except tf.errors.OutOfRangeError:
          logging.info('completed inference.')
          break

  self.result_queue.put((count, self.prob_array, self.timestamp_array))
  self.result_queue.close()

我会根据您的代码编写一个示例,但我不太理解。