第40个方法
torch.split(tensor, split_size_or_sections, dim=0)
此方法和前面的torch.chunk()一样,都是将tensor进行切割的方法,但是这两有什么区别呢?其实我们split方法功能更多样一些。
看一下此方法的参数:
tensor
:要被切割的tensor。split_size_or_sections
:当此参数为整数时,意思是将tensor按照每块大小为split_size_or_sections来切割,当此参数为列表时,将此tensor切成和列表中元素大小一样的大小的块。dim
:指定要切分的维度。
这里主要就是split_size_or_sections参数的使用,当split_size_or_sections为整数时,此方法和torch.chunk()方法一样(此方法这里讲的比较详细,可以点击查看,而本方法这里简短讲述),都是将tensor切割为每块大小都为split_size_or_sections,最后块可能会小一些。如下图:
而当split_size_or_sections为列表时,方法根据列表中的元素的大小,将tensor分为len(split_size_or_sections)个块,并且每个块的大小等于split_size_or_sections中元素的大小。如下所示:
当然,此方法可以用在别的维度上,例如在1维上:
其实对于维度为1也很简单,就是将tensor按照列来切割。对于高维也是如此,就是在指定维度上将里面的元素分割开即可,切割后生成的tensor和原tensor维度相等。
- 注意,此方法中,列表中元素的和,应该等于dim维度上的元素个数,例如这里是2+2=4。如果不等,会报错。
此方法生成的结果是视图,和原tensor共享内存。
如果想要修改返回的值,请使用clone函数,例如下图所示:
此时就不会产生修改时对原tensor也进行了修改的错误。
此方法与fTensor.split()
一样,例如上图中的b = torch.split(a, (2, 2), dim=1)
可以改成b = a.split((2, 2), dim=1)
效果一样。