Tensor运算
在自己写模型时,需要熟练掌握Tensor的相关运算,尤其需要注意一些广播的细节,在运算之后要对维度有准确的把握。
基础运算
形状相同的张量运算
对于任意具有相同形状的张量,常⻅的标准算术运算符(+、-、*、/、**)都可以被升级为按元素运算。
我们可以在同⼀形状的任意两个张量上调⽤按元素操作。
1 2 3 4 5 6 7 8 9 10 11 12 13
| a = torch.Tensor([1, 2, 3]) b = torch.Tensor([2, 2, 3]) print(a * b) print(a + b) print(a - b) print(a / b) print(a ** b)
|
广播机制
如果张量维度不同,则会默认使用广播机制进行按元素的操作。
⾸先,通过适当复制元素来扩展⼀个或两个数组,以便在转换之后,两个张量具有相同的形状。其次,对⽣成的数组执⾏按元素操作。这里和numpy的广播机制相同。
numpy中数组或者高维数组之间的点乘(以及元素与元素的其他运算)不要求拥有相同的形状,通过广播机制可以完成正确运算,但是必须遵循一些规则:
1、两个输入数组能不能运算要看每个维度的长度,如果每个维度长度相同或者某个维度长度为1,则可以进行运算;
2、如果输入数组的维数不同,在维数小的数组的形状前面补充1
3、如果某个输入数组在某一维度的长度为1,则将该维度的第一个数与其他输入数组在该维度上依次进行运算
示例如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26
| tensor_1 = torch.arange(4).reshape((2, 2)) tensor_2 = torch.arange(2).reshape((1, 2)) tensor_3 = tensor_1 + tensor_2 print(tensor_3.shape)
tensor_1 = torch.arange(2).reshape((2, 1)) tensor_2 = torch.arange(4).reshape((1, 4)) tensor_3 = tensor_1 + tensor_2 print(tensor_3.shape)
tensor_1 = torch.arange(8).reshape((2, 4)) tensor_2 = torch.arange(16).reshape((2, 2, 4)) tensor_3 = tensor_1 + tensor_2 print(tensor_3.shape)
tensor_1 = torch.Tensor([5]) print(tensor_1.shape) tensor_2 = torch.arange(4).reshape((1, 4)) tensor_3 = tensor_1 + tensor_2 print(tensor_3.shape) print(tensor_3)
tensor_1 = torch.arange(6).reshape((2, 3)) tensor_2 = torch.arange(4).reshape((1, 4)) tensor_3 = tensor_1 + tensor_2 print(tensor_3.shape)
|
计算函数
相加torch.sum()
torch.sum(input, dim, keepdim=False, *, dtype=None) → Tensor
dim (int) – 缩减的维度
从下列示例可以发现:dim参数不为空时,缩减该参数的维度。
当传入的张量为2维时,dim=0则缩减0维,即按行相加
keepdim=True时,维度不会改变,如二维按行相加后不会缩减为1维。如最后一个示例所示
1 2 3 4 5 6 7 8 9 10 11 12 13 14
| a = torch.arange(24).reshape((2, 3, 4)) torch.sum(a) b = torch.sum(a, dim=1) a = torch.arange(6).reshape((2, 3)) print(a)
b = torch.sum(a, dim=1) b = torch.sum(a, dim=0) print(b.shape)
b = torch.sum(a, dim=0, keepdim=True) print(b) print(b.shape)
|
相乘torch.prod()
torch.prod(input, dim, keepdim=False, *, dtype=None) → Tensor
参数与用法和相加完全一致
1 2 3 4 5 6 7 8 9 10 11 12 13 14
| a = torch.arange(1,7).reshape((2, 3)) print(a)
b = torch.prod(a) print(b) print(b.shape) print(float(b))
b = torch.prod(a, dim=1) print(b)
b = torch.prod(a, dim=1, keepdim=True) print(b.shape)
|
引用
- Torch官方文档
- PyTorch精品教程与源码算法精讲