pytorch每日一学32(torch.gather())沿指定维度收集值

本文详细介绍了PyTorch中的torch.gather函数,通过实例展示如何在不同维度上根据索引index获取tensor元素,并适用于一维和二维情况。理解这个方法有助于在深度学习中高效处理数据索引操作。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

第32个方法

torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor

此方法的作用是沿着指定的维度收集值,首先看参数:

  • input(Tenor):源tensor。
  • dim(int):要收集的维度,不可超过input的维度大小。
  • index(LongTensor):要收集的元素的索引,形状应该和input一样。
  • sparse_grad(bool):如果为True,那么input应该是一个稀疏tensor。
  • out(Tensor):输出结果的tensor。

首先来看一个简单的例子:

import torch

a = torch.rand(6)
index = torch.tensor([0])
b = torch.gather(a, 0, index)
print(a)
print(b)

结果为:
在这里插入图片描述

  • 可以看到,找到了index所在的元素,这是比较简单的用法,接下来看二维的情况。
import torch

a = torch.rand(3, 3)
index = torch.tensor([[0, 1, 2], [1, 2, 0], [2, 0, 1]])
b = torch.gather(a, 0, index)
print(a)
print(b)

在这里插入图片描述

  • 先给大家看一个官网的公式:
    在这里插入图片描述
  • 当然官网给的是三维的公式,我们这里是二维的,所以我们忽略掉k就行了。想要生成b的话,例如生成b[0]的[0.8029, 0.6605, 0.6150],我们指定的维度是零,那么每次使用index中的元素替换a中的第一维就好了,例如对于b中的b[0][0],index[0][0]是0,所以b[0][0]=a[0][0],而b[0][1]=a[1][0]因为index[0][1]=1,依次类推b[0][2]=a[2][0],依次下去便可以生成b。
  • 而对于dim=1或者其他的情况,也就是使用index中的元素替换相应维度上的数字即可。

例如:在这里插入图片描述
b[1][0]=a[1][1]=0.1893,因为index[1][0]=1,而b[1][1]=a[1][2]=0.6783,因为index[1][1]=2,同理b[1][2]=a[1][0]=0.1685,因为index[1][2]=0。依次类推

其实对于更高维的思想是一样的,b[i][j][k]…就是用index[i][j][k]…来替换a中对应维度的下标,然后得到的元素。
在这里插入图片描述

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值