以多线程方式加载几个npz文件

时间:2016-02-10 23:33:41

标签: python multithreading numpy

我有几个.npz个文件。所有.npz文件都使用相同的结构:每个结构只包含两个变量,总是使用相同的变量名。截至目前,我只是循环遍历所有.npz文件,检索两个变量值并将它们附加到某个全局变量中:

# Let's assume there are 100 npz files
x_train = []
y_train = []
for npz_file_number in range(100):
    data = dict(np.load('{0:04d}.npz'.format(npz_file_number)))
    x_train.append(data['x'])
    y_train.append(data['y'])

需要一段时间,瓶颈就是CPU。将xy变量附加到x_trainy_train变量的顺序无关紧要。

有没有办法在多线程中加载多个.npz文件?

1 个答案:

答案 0 :(得分:2)

我对@Brent Washburne的评论感到惊讶,并决定自己尝试一下。我认为一般问题是双重的:

首先,读取数据通常是IO绑定的,因此编写多线程代码通常不会产生高性能提升。其次,由于语言本身的设计,在python中进行共享内存并行化本身就很困难。与本机c相比,有更多的开销。

但是,让我们看看我们能做些什么。

# some imports
import numpy as np
import glob
from multiprocessing import Pool
import os

# creating some temporary data
tmp_dir = os.path.join('tmp', 'nptest')
if not os.path.exists(tmp_dir):
    os.makedirs(tmp_dir)
    for i in range(100):
        x = np.random.rand(10000, 50)
        file_path = os.path.join(tmp_dir, '%05d.npz' % i)
        np.savez_compressed(file_path, x=x)

def read_x(path):
    with np.load(path) as data:
        return data["x"]

def serial_read(files):
    x_list = list(map(read_x, files))
    return x_list

def parallel_read(files):
    with Pool() as pool:
        x_list = pool.map(read_x, files)
    return x_list

好的,准备好的东西。让我们来看看时间。

files = glob.glob(os.path.join(tmp_dir, '*.npz'))

%timeit x_serial = serial_read(files)
# 1 loops, best of 3: 7.04 s per loop

%timeit x_parallel = parallel_read(files)
# 1 loops, best of 3: 3.56 s per loop

np.allclose(x_serial, x_parallel)
# True

它实际上看起来像是一个不错的加速。我正在使用两个真实和两个超线程核心。

要一次运行并计时,您可以执行以下脚本:

from __future__ import print_function
from __future__ import division

# some imports
import numpy as np
import glob
import sys
import multiprocessing
import os
import timeit

# creating some temporary data
tmp_dir = os.path.join('tmp', 'nptest')
if not os.path.exists(tmp_dir):
    os.makedirs(tmp_dir)
    for i in range(100):
        x = np.random.rand(10000, 50)
        file_path = os.path.join(tmp_dir, '%05d.npz' % i)
        np.savez_compressed(file_path, x=x)

def read_x(path):
    data = dict(np.load(path))
    return data['x']

def serial_read(files):
    x_list = list(map(read_x, files))
    return x_list

def parallel_read(files):
    pool = multiprocessing.Pool(processes=4)
    x_list = pool.map(read_x, files)
    return x_list


files = glob.glob(os.path.join(tmp_dir, '*.npz'))
#files = files[0:5] # to test on a subset of the npz files

# Timing:
timeit_runs = 5

timer = timeit.Timer(lambda: serial_read(files))
print('serial_read: {0:.4f} seconds averaged over {1} runs'
      .format(timer.timeit(number=timeit_runs) / timeit_runs,
      timeit_runs))
# 1 loops, best of 3: 7.04 s per loop

timer = timeit.Timer(lambda: parallel_read(files))
print('parallel_read: {0:.4f} seconds averaged over {1} runs'
      .format(timer.timeit(number=timeit_runs) / timeit_runs,
      timeit_runs))
# 1 loops, best of 3: 3.56 s per loop

# Examples of use:
x = serial_read(files)
print('len(x): {0}'.format(len(x))) # len(x): 100
print('len(x[0]): {0}'.format(len(x[0]))) # len(x[0]): 10000
print('len(x[0][0]): {0}'.format(len(x[0][0]))) # len(x[0]): 10000
print('x[0][0]: {0}'.format(x[0][0])) # len(x[0]): 10000
print('x[0].nbytes: {0} MB'.format(x[0].nbytes / 1e6)) # 4.0 MB