在火炬中分开一个张量

时间:2017-03-14 12:44:22

标签: torch

给定大小为n x 2A x B x C的输入张量,如何将其拆分为两个张量,每个张量为n x A x B x C?基本上,n是批量大小。

2 个答案:

答案 0 :(得分:1)

您可以使用torch.split

torch.split(input_tensor, split_size_or_sections=A, dim=1)

答案 1 :(得分:0)

我认为你可以这样做:

tensor_a = torch.Tensor(n, 2A, B,C)
-- Initialize tensor_a with the data

tensor_b = torch.Tensor(n, A, B, C)
tensor_b = tensor_a[{{},1,{},{}}]
tensor_c = torch.Tensor(n, A, B, C)
tensor_c = tensor_a[{{},2,{},{}}]