在numpy数组上进行迭代以保持索引

时间:2018-11-24 21:47:34

标签: python numpy iteration

请考虑以下代码段:

n_samples, n_rows, n_cols, n_boxes_per_cell, _ = Y_pred.shape
    for example in range(n_samples):
        for x_grid in range(n_rows):
            for y_grid in range(n_cols):
                for bnd_box in range(n_boxes_per_cell):
                    bnd_box_label = Y_pred[example, x_grid, y_grid, bnd_box]
                    do_some_stuff(bnd_box_label, x_grid, y_grid)

如何通过最多一次显式迭代获得功能等效的代码?请注意,我需要索引x_gridy_grid

2 个答案:

答案 0 :(得分:3)

您可以使用np.ndindex

for example, x_grid, y_grid, bnd_box in np.ndindex(Y_pred.shape[:4]):
    etc.

答案 1 :(得分:1)

我不确定这是否是您要寻找的,但是您始终可以从多个可迭代对象中构建生成器:

all_combinations = ((a, b, c, d) for a in range(n_samples)
                                 for b in range(n_rows) 
                                 for c in range(n_cols) 
                                 for d in range(n_boxes_per_cell))

for examples, x_grid, y_grid, bnd_box in all_combinations:
    do stuff

这与使用itertools.product(* iterables)相同,并且对任何可迭代的对象均有效,而不仅仅是对索引/整数的迭代。