完成NUTS采样后,PyMC3挂起

时间:2018-02-09 15:50:07

标签: python python-3.x theano python-multiprocessing pymc3

我有一个似乎运行顺畅的模型。我用大都市样本初始化我的坚果采样器,并且大都市样本完成没有问题。然后它进入坚果采样器,完成所有迭代,但随后挂起。

我试过了:

  • 重新启动内核
  • 清除我的theano缓存
  • 重新启动计算机并重新启动它在
  • 中运行的docker
  • 多链的Metropolis抽样确实完成然后终止。
  • 稍微更改种子或模型规范会使模型有时终止

但由于它没有错误,我不确定如何排除故障。当我中断这个过程时,它总是卡在同一个地方。输出粘贴在下面。 任何帮助诊断问题将非常感激。

我的样本代码在这里:

with my_model:
    start_trace = pm.sample(7000,step=pm.Metropolis())
start_sds = {}
nms = start_trace.varnames
for i in nms:
    start_sds[i]=start_trace[i].std()
with my_model:
    step = pm.NUTS(scaling=my_model.dict_to_array(start_sds)**2,
                   is_cov=True)
    signal_trace = pm.sample(500,step=step,start=start_trace[-1],njobs=3)

完成抽样。第一个进度条是大都市样本,第二个是NUTS:

100%|██████████| 7000/7000 [00:09<00:00, 718.51it/s]
100%|██████████| 500/500 [01:37<00:00,  1.47s/it]

查看顶部,有四个进程,每个进程使用大约相同的内存量,但只有一个使用cpu。通常,当它正确终止时,它会在其他3个进程停止使用cpu时结束。这些其他过程停止的事实表明它已经完成了采样,并且该问题与终止多处理有关。

直到我打断,(我已经把它留了一夜),我得到以下错误:

Process ForkPoolWorker-1:
KeyboardInterrupt
Process ForkPoolWorker-3:
Traceback (most recent call last):
Traceback (most recent call last):
  File "/usr/lib/python3.4/multiprocessing/process.py", line 254, in _bootstrap
    self.run()
  File "/usr/lib/python3.4/multiprocessing/process.py", line 254, in _bootstrap
    self.run()
  File "/usr/lib/python3.4/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/usr/lib/python3.4/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/usr/lib/python3.4/multiprocessing/pool.py", line 108, in worker
    task = get()
  File "/opt/ds/lib/python3.4/site-packages/joblib/pool.py", line 362, in get
    return recv()
  File "/usr/lib/python3.4/multiprocessing/pool.py", line 108, in worker
    task = get()
  File "/opt/ds/lib/python3.4/site-packages/joblib/pool.py", line 360, in get
    racquire()
  File "/usr/lib/python3.4/multiprocessing/connection.py", line 250, in recv
    buf = self._recv_bytes()
KeyboardInterrupt
  File "/usr/lib/python3.4/multiprocessing/connection.py", line 416, in _recv_bytes
    buf = self._recv(4)
  File "/usr/lib/python3.4/multiprocessing/connection.py", line 383, in _recv
    chunk = read(handle, remaining)
Process ForkPoolWorker-2:
  File "/usr/lib/python3.4/multiprocessing/process.py", line 254, in _bootstrap
    self.run()
Traceback (most recent call last):
  File "/usr/lib/python3.4/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/usr/lib/python3.4/multiprocessing/pool.py", line 108, in worker
    task = get()
  File "/opt/ds/lib/python3.4/site-packages/joblib/pool.py", line 360, in get
    racquire()
KeyboardInterrupt
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
/opt/ds/lib/python3.4/site-packages/joblib/parallel.py in retrieve(self)
    698                 if getattr(self._backend, 'supports_timeout', False):
--> 699                     self._output.extend(job.get(timeout=self.timeout))
    700                 else:

/usr/lib/python3.4/multiprocessing/pool.py in get(self, timeout)
    592     def get(self, timeout=None):
--> 593         self.wait(timeout)
    594         if not self.ready():

/usr/lib/python3.4/multiprocessing/pool.py in wait(self, timeout)
    589     def wait(self, timeout=None):
--> 590         self._event.wait(timeout)
    591 

/usr/lib/python3.4/threading.py in wait(self, timeout)
    552             if not signaled:
--> 553                 signaled = self._cond.wait(timeout)
    554             return signaled

/usr/lib/python3.4/threading.py in wait(self, timeout)
    289             if timeout is None:
--> 290                 waiter.acquire()
    291                 gotit = True

KeyboardInterrupt: 

During handling of the above exception, another exception occurred:

KeyboardInterrupt                         Traceback (most recent call last)
<ipython-input-13-cad95f434ae5> in <module>()
     87         step = pm.NUTS(scaling=my_model.dict_to_array(start_sds)**2,
     88                        is_cov=True)
---> 89         signal_trace = pm.sample(500,step=step,start=start_trace[-1],njobs=3)
     90 
     91     pr = forestplot(signal_trace[-500:],

/opt/ds/lib/python3.4/site-packages/pymc3/sampling.py in sample(draws, step, init, n_init, start, trace, chain, njobs, tune, progressbar, model, random_seed)
    173         sample_func = _sample
    174 
--> 175     return sample_func(**sample_args)
    176 
    177 

/opt/ds/lib/python3.4/site-packages/pymc3/sampling.py in _mp_sample(**kwargs)
    322                                                      random_seed=rseed[i],
    323                                                      start=start_vals[i],
--> 324                                                      **kwargs) for i in range(njobs))
    325     return merge_traces(traces)
    326 

/opt/ds/lib/python3.4/site-packages/joblib/parallel.py in __call__(self, iterable)
    787                 # consumption.
    788                 self._iterating = False
--> 789             self.retrieve()
    790             # Make sure that we get a last message telling us we are done
    791             elapsed_time = time.time() - self._start_time

/opt/ds/lib/python3.4/site-packages/joblib/parallel.py in retrieve(self)
    719                     # scheduling.
    720                     ensure_ready = self._managed_backend
--> 721                     backend.abort_everything(ensure_ready=ensure_ready)
    722 
    723                 if not isinstance(exception, TransportableException):

/opt/ds/lib/python3.4/site-packages/joblib/_parallel_backends.py in abort_everything(self, ensure_ready)
    143     def abort_everything(self, ensure_ready=True):
    144         """Shutdown the pool and restart a new one with the same parameters"""
--> 145         self.terminate()
    146         if ensure_ready:
    147             self.configure(n_jobs=self.parallel.n_jobs, parallel=self.parallel,

/opt/ds/lib/python3.4/site-packages/joblib/_parallel_backends.py in terminate(self)
    321     def terminate(self):
    322         """Shutdown the process or thread pool"""
--> 323         super(MultiprocessingBackend, self).terminate()
    324         if self.JOBLIB_SPAWNED_PROCESS in os.environ:
    325             del os.environ[self.JOBLIB_SPAWNED_PROCESS]

/opt/ds/lib/python3.4/site-packages/joblib/_parallel_backends.py in terminate(self)
    134         if self._pool is not None:
    135             self._pool.close()
--> 136             self._pool.terminate()  # terminate does a join()
    137             self._pool = None
    138 

/opt/ds/lib/python3.4/site-packages/joblib/pool.py in terminate(self)
    604         for i in range(n_retries):
    605             try:
--> 606                 super(MemmapingPool, self).terminate()
    607                 break
    608             except OSError as e:

/usr/lib/python3.4/multiprocessing/pool.py in terminate(self)
    494         self._state = TERMINATE
    495         self._worker_handler._state = TERMINATE
--> 496         self._terminate()
    497 
    498     def join(self):

/usr/lib/python3.4/multiprocessing/util.py in __call__(self, wr, _finalizer_registry, sub_debug, getpid)
    183                 sub_debug('finalizer calling %s with args %s and kwargs %s',
    184                           self._callback, self._args, self._kwargs)
--> 185                 res = self._callback(*self._args, **self._kwargs)
    186             self._weakref = self._callback = self._args = \
    187                             self._kwargs = self._key = None

/usr/lib/python3.4/multiprocessing/pool.py in _terminate_pool(cls, taskqueue, inqueue, outqueue, pool, worker_handler, task_handler, result_handler, cache)
    524 
    525         util.debug('helping task handler/workers to finish')
--> 526         cls._help_stuff_finish(inqueue, task_handler, len(pool))
    527 
    528         assert result_handler.is_alive() or len(cache) == 0

/usr/lib/python3.4/multiprocessing/pool.py in _help_stuff_finish(inqueue, task_handler, size)
    509         # task_handler may be blocked trying to put items on inqueue
    510         util.debug('removing tasks from inqueue until task handler finished')
--> 511         inqueue._rlock.acquire()
    512         while task_handler.is_alive() and inqueue._reader.poll():
    513             inqueue._reader.recv()

KeyboardInterrupt: 

这是最常失败的模型:

with pm.Model() as twitter_signal:
        #location fixed effects
        mu_c = pm.Flat('mu_c')
        sig_c = pm.HalfCauchy('sig_c',beta=2.5)
        c_raw = pm.Normal('c_raw',mu=0,sd=1,shape=n_location)
        c = pm.Deterministic('c',mu_c + sig_c*c_raw)

        #time fixed effects
        mu_t = pm.Flat('mu_t')
        sig_t = pm.HalfCauchy('sig_t',beta=2.5)
        t_raw = pm.Normal('t_raw',mu=0,sd=1,shape=n_time)
        t = pm.Deterministic('t',mu_t + sig_t*t_raw)

        #signal effect 
        b_sig = pm.Normal('b_sig',0,sd=100**2,shape=1)    

        #control
        b_control = pm.Normal('b_control',mu=0,sd=100**2,shape=1)

    # define linear model and link function

        #y_hat
        theta =c[df.location.values] + \
            t[df.dates.values] + \
            (b_sig[df.c.values]* df.sig.values) + \
            (b_death[df.c.values]*df.control.values)

        disp = pm.HalfCauchy('disp',beta=2.5)    

        ## Define likelihood
        y = pm.NegativeBinomial('y', mu=np.exp(theta),
                        alpha=disp, 
                        observed=df.loc[:,yvar])

0 个答案:

没有答案