Tensor运算

在自己写模型时,需要熟练掌握Tensor的相关运算,尤其需要注意一些广播的细节,在运算之后要对维度有准确的把握。

基础运算

形状相同的张量运算

对于任意具有相同形状的张量,常⻅的标准算术运算符(+、-、*、/、**)都可以被升级为按元素运算。

我们可以在同⼀形状的任意两个张量上调⽤按元素操作。

1
2
3
4
5
6
7
8
9
10
11
12
13
a = torch.Tensor([1, 2, 3])
b = torch.Tensor([2, 2, 3])
print(a * b)
print(a + b)
print(a - b)
print(a / b)
print(a ** b)

# tensor([2., 4., 9.])
# tensor([3., 4., 6.])
# tensor([-1., 0., 0.])
# tensor([0.5000, 1.0000, 1.0000])
# tensor([ 1., 4., 27.])

广播机制

如果张量维度不同,则会默认使用广播机制进行按元素的操作。

⾸先,通过适当复制元素来扩展⼀个或两个数组,以便在转换之后,两个张量具有相同的形状。其次,对⽣成的数组执⾏按元素操作。这里和numpy的广播机制相同。

numpy中数组或者高维数组之间的点乘(以及元素与元素的其他运算)不要求拥有相同的形状,通过广播机制可以完成正确运算,但是必须遵循一些规则:

1、两个输入数组能不能运算要看每个维度的长度,如果每个维度长度相同或者某个维度长度为1,则可以进行运算;

2、如果输入数组的维数不同,在维数小的数组的形状前面补充1

3、如果某个输入数组在某一维度的长度为1,则将该维度的第一个数与其他输入数组在该维度上依次进行运算

示例如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
tensor_1 = torch.arange(4).reshape((2, 2))
tensor_2 = torch.arange(2).reshape((1, 2))
tensor_3 = tensor_1 + tensor_2
print(tensor_3.shape) # 可以进行计算,结果size=(2, 2)

tensor_1 = torch.arange(2).reshape((2, 1))
tensor_2 = torch.arange(4).reshape((1, 4))
tensor_3 = tensor_1 + tensor_2
print(tensor_3.shape) # 可以进行计算,结果size=(2, 4)

tensor_1 = torch.arange(8).reshape((2, 4))
tensor_2 = torch.arange(16).reshape((2, 2, 4))
tensor_3 = tensor_1 + tensor_2
print(tensor_3.shape) # 可进行计算,结果size=(2,2,4)

tensor_1 = torch.Tensor([5])
print(tensor_1.shape)
tensor_2 = torch.arange(4).reshape((1, 4))
tensor_3 = tensor_1 + tensor_2
print(tensor_3.shape) # 可进行计算,结果size=(1,4)
print(tensor_3) # tensor([[5., 6., 7., 8.]])

tensor_1 = torch.arange(6).reshape((2, 3))
tensor_2 = torch.arange(4).reshape((1, 4))
tensor_3 = tensor_1 + tensor_2
print(tensor_3.shape) # 不可进行计算

计算函数

相加torch.sum()

torch.sum(input, dim, keepdim=False, *, dtype=None) → Tensor

dim (int) – 缩减的维度

从下列示例可以发现:dim参数不为空时,缩减该参数的维度。

当传入的张量为2维时,dim=0则缩减0维,即按行相加

keepdim=True时,维度不会改变,如二维按行相加后不会缩减为1维。如最后一个示例所示

1
2
3
4
5
6
7
8
9
10
11
12
13
14
a = torch.arange(24).reshape((2, 3, 4))
torch.sum(a) # tensor(15.)
b = torch.sum(a, dim=1) # tensor([ 3, 12])
a = torch.arange(6).reshape((2, 3))
print(a)
#tensor([[0, 1, 2],
# [3, 4, 5]])
b = torch.sum(a, dim=1) # tensor([ 3, 12])
b = torch.sum(a, dim=0) # tensor([3, 5, 7])
print(b.shape) # torch.Size([3])

b = torch.sum(a, dim=0, keepdim=True)
print(b) # tensor([[3, 5, 7]])
print(b.shape) # torch.Size([1, 3])

相乘torch.prod()

torch.prod(input, dim, keepdim=False, *, dtype=None) → Tensor

参数与用法和相加完全一致

1
2
3
4
5
6
7
8
9
10
11
12
13
14
a = torch.arange(1,7).reshape((2, 3))
print(a)
# tensor([[1, 2, 3],
# [4, 5, 6]])
b = torch.prod(a)
print(b) # tensor(720)
print(b.shape) # torch.Size([])
print(float(b)) # 将b转为标量 720.0

b = torch.prod(a, dim=1)
print(b) # tensor([ 6, 120])

b = torch.prod(a, dim=1, keepdim=True) # tensor([ 3, 12])
print(b.shape) # torch.Size([2, 1])

引用

  1. Torch官方文档
  2. PyTorch精品教程与源码算法精讲