torch.index_select
index_select 只能处理两维矩阵,指定行或者列的索引,按行或者按列取出
# indices 只能是一维
x = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8], [2, 2, 3, 4]])
print(x)
indices = torch.tensor([0, 2])
print(torch.index_select(x, 0, indices))
# 按行取
print(torch.index_select(x, 1, indices))
# 按列取
tensor([[1, 2, 3, 4],
[5, 6, 7, 8],
[2, 2, 3, 4]])
# 按行取
tensor([[1, 2, 3, 4],
[2, 2, 3, 4]])
# 按列取
tensor([[1, 3],
[5, 7],
[2, 3]])
torch.gather
gather可处理多维张量, 按具体的索引取出对应位置的数(不是整行或者整列取)
x = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8], [2, 2, 3, 4]])
print(x)
# 3*4
indices = torch.tensor([[1], [2], [3]])
# 3*1
# 意思是取[0,1],[2,2],[3,3]索引位置的数
# 每行取一个,indices指定列标
print(torch.gather(x, 1, indices))
# 按照第二维度取数,indices第一个维度要和x一样
indices = torch.tensor([[1, 2, 1, 0]])
# 1*4
# 意思是取[1,0],[2,1],[1,2],[0,3]索引位置的数
# 每列取一个,indices指定行标
print(torch.gather(x, 0, indices))
# 按照第一维度取数,indices第二个维度要和x一样
tensor([[1, 2, 3, 4],
[5, 6, 7, 8],
[2, 2, 3, 4]])
# 每行取一个,指定列标
tensor([[2],
[7],
[4]])
# 每列取一个,指定行标
tensor([[5, 2, 7, 4]])