如何制作一个在2D numpy数组上迭代的生成器?

时间:2019-03-11 04:40:40

标签: python numpy

我有一个庞大的2D numpy数组,我想分批检索。 数组形状为= {60000,3072,我想制作一个生成器,为我提供该数组中的块,例如:1000,3072,然后是下一个1000,3072,依此类推。我该如何使生成器迭代该数组并为我传递给定大小的批处理?

3 个答案:

答案 0 :(得分:2)

考虑数组a

a = np.array([[1, 2, 3],
              [4, 5, 6],
              [7, 8, 9],
              [10, 11, 12]])

选项1
使用发电机

def get_every_n(a, n=2):
    for i in range(a.shape[0] // n):
        yield a[n*i:n*(i+1)]

for sa in get_every_n(a):
    print sa

[[1 2 3]
 [4 5 6]]
[[ 7  8  9]
 [10 11 12]]

选项2
使用reshape//

a.reshape(a.shape[0] // 2, -1, a.shape[1])

array([[[ 1,  2,  3],
        [ 4,  5,  6]],

       [[ 7,  8,  9],
        [10, 11, 12]]])

选项3
如果您要两组而不是两组

a.reshape(-1, 2, a.shape[1])

array([[[ 1,  2,  3],
        [ 4,  5,  6]],

       [[ 7,  8,  9],
        [10, 11, 12]]])

由于您明确声明需要生成器,因此可以使用选项1作为适当的参考。

答案 1 :(得分:1)

这是您拥有的数据:

import numpy as np
full_len = 5    # In your case, 60_000
cols = 3        # In your case, 3072

nd1 = np.arange(full_len*cols).reshape(full_len,cols)

您可以按照以下步骤来“生成”切片:

选项1,使用numpy.array_split():

from math import ceil

step_size = 2   # In your case, 1_000
split_list = np.array_split(nd1,ceil(full_len/step_size), axis=0)
print (split_list)

split_list现在是nd1中的切片列表。通过遍历此列表,您可以像split_list[0]split_list[1]等那样访问各个切片,并且这些切片中的每一个都是nd1的视图,并且可以与您完全一样地使用将使用任何其他numpy数组。

选项1的输出

以下是输出,显示最后一个切片比其他常规切片短:

[array([[0, 1, 2],
       [3, 4, 5]]), array([[ 6,  7,  8],
       [ 9, 10, 11]]), array([[12, 13, 14]])]

选项2,通过显式切片:

step_size = 2   # In your case, 1_000
myrange = range(0, full_len, step_size)

for r in myrange:
    my_slice_array = nd1 [r:r+step_size]
    print (my_slice_array.shape)

选项2的输出

(2, 3)
(2, 3)
(1, 3)

请注意,与切片列表不同,切片numpy数组不会复制源数组的数据。它仅在切片边界内在源numpy数组的现有数据上创建一个视图。这适用于选项1 选项2 ,因为它们都涉及切片的创建。

答案 2 :(得分:0)

如果您想要生成器方式的东西,此解决方案有效

import numpy 
bigArray = numpy.random.rand(60000, 3072) # have used this to generate dummy array

def selectArray(m,n):
  yield bigArray[m, n] # I am facing issue with giving proper slices. Please handle it yourselg. 

genObject = selectArray(1000, 3072)

,您可以使用for循环或next()遍历genObject

注意:如果您使用的是next(),请确保您正在处理StopIteration异常。

希望有帮助。