pytorch每日一学40(torch.split())将张量拆分为块(tensor的分割)

torch.split()方法是PyTorch中用于切割张量的工具,它与torch.chunk()类似但更灵活。当split_size_or_sections为整数时,按固定大小切割;为列表时,按列表元素大小切割张量。该方法可在不同维度上操作,返回结果为视图并与原始张量共享内存。注意,列表元素之和应等于切割维度的大小。在修改返回的张量时,需先使用clone()避免影响原始数据。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

第40个方法

torch.split(tensor, split_size_or_sections, dim=0)

此方法和前面的torch.chunk()一样,都是将tensor进行切割的方法,但是这两有什么区别呢?其实我们split方法功能更多样一些。

看一下此方法的参数:

  • tensor:要被切割的tensor。
  • split_size_or_sections:当此参数为整数时,意思是将tensor按照每块大小为split_size_or_sections来切割,当此参数为列表时,将此tensor切成和列表中元素大小一样的大小的块。
  • dim:指定要切分的维度。

这里主要就是split_size_or_sections参数的使用,当split_size_or_sections为整数时,此方法和torch.chunk()方法一样(此方法这里讲的比较详细,可以点击查看,而本方法这里简短讲述),都是将tensor切割为每块大小都为split_size_or_sections,最后块可能会小一些。如下图:
在这里插入图片描述
而当split_size_or_sections为列表时,方法根据列表中的元素的大小,将tensor分为len(split_size_or_sections)个块,并且每个块的大小等于split_size_or_sections中元素的大小。如下所示:
在这里插入图片描述
当然,此方法可以用在别的维度上,例如在1维上:
在这里插入图片描述
其实对于维度为1也很简单,就是将tensor按照列来切割。对于高维也是如此,就是在指定维度上将里面的元素分割开即可,切割后生成的tensor和原tensor维度相等。

  • 注意,此方法中,列表中元素的和,应该等于dim维度上的元素个数,例如这里是2+2=4。如果不等,会报错。

此方法生成的结果是视图,和原tensor共享内存。
如果想要修改返回的值,请使用clone函数,例如下图所示:
在这里插入图片描述
此时就不会产生修改时对原tensor也进行了修改的错误。

此方法与fTensor.split()一样,例如上图中的b = torch.split(a, (2, 2), dim=1)可以改成b = a.split((2, 2), dim=1)效果一样。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值