pytorch tensor的索引与切片
切片方式与numpy是类似。
* a[:2, :1, :, :],
* 可以用-1索引。
* ::2,表示所有数据,间隔为2,即 start:end:step。
* a.index_select(1,torch.tensor([2]))
# 1表示维度,后面是索引(必须是tensor格式,想连续选取可以用tensor.arange())
* 三个点(…):
表示取最大维度的数据,不用输入很多的(:,:,)
比如下面的数据三个点…可以代替中间的维度,并且两边数据是相等的:
* torch.masked_select:
* torch.take(a, torch.tensor([0, 3, 5])
先将数据打平展开为一维,再选取展开后对应索引[0, 3, 5]的数据