Torch7:使用ByteTensor蒙版进行切片张量

时间:2016-03-29 16:18:32

标签: torch

我有两个张量:

  1. 标签是1D Tensor(5000)
  2. 数据集是4D Tensor(5000,1,32,32)
  3. 我想有效地切割与值1的标签相对应的标签和数据集。我成功切割标签而不是数据集。

    切片标签:

    positive_mask = labels:eq(1)
    sliced_labels = labels[positive_mask]
    

    我尝试执行以下操作来切片数据集并失败:

    sliced_dataset = dataset[positive_mask]
    sliced_dataset = dataset[{positive_mask, {}, {}, {}}]
    sliced_dataset = dataset:narrow(1,positive_mask)
    sliced_dataset = dataset:select(1,positive_mask)
    

    是否有优雅方法在Torch7中执行此操作?

1 个答案:

答案 0 :(得分:1)

sliced_dataset = dataset:index(1, positive_mask:nonzero():squeeze())