引言

在NLP任务中(后续拓展为多模态任务),顺序信息至关重要,例如: 我借给你300块与你借给我300块具有完全不同的含义。
对于Transformer模型来说,由于Attention模块的无序性(无法区分不同位置的Token),必须加入额外的信息来记录顺序,这里引入了位置编码。位置编码在后续基于Transformer架构的文章中有很多不同的实现方式,尤其是在大语言模型大行其道的现在,在面对长token的输入时,挑选合适的位置编码也会提升训练的效果。本文整理主流模型的位置编码实现方式,并用torch实现以加深理解。
位置编码从实现方式上大致可以分为2类:

  • 绝对位置编码: 将位置信息融入到输入中
  • 相对位置编码: 微调Attention结构,使其可以分辨不同位置的Token

绝对位置编码

在输入的第k个向量xkx_k中加入位置向量变为xk+pkx_{k}+p_{k},其中pkp_{k}只依赖于位置编号k。实现方式类似下图:
position.png

正弦曲线(sinusoidal)位置编码

这种位置编码是Transformer原始论文中实现的位置编码,实现公式如下:
论文中原始公式
其中,PE表示位置编码矩阵,pos表示token的位置,2i表示embedding中的偶数位置,2i+1表示奇数位置(这个最好结合下面的代码进行理解,在下面的代码实现中,exp_value实现了指数部分2i/dmdoel2i/d_{mdoel},变量out实现了三角函数内的值 100002i/dmodel10000^{2i/d_{model}}),d-model表示每个token向量化后的维度。
原论文中还提及作者实验了可学习的位置编码,并且这两种位置编码具有接近的结果,之所以选择正弦曲线位置编码是因为它可以允许模型序列外推到比训练期间遇到的序列更长的位置。(it may allow the model to extrapolate to sequence lengths longer than the ones encountered during training)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
class SinPositionEncoding(nn.Module):
def __init__(self, max_sequence_length, d_model, base=10000):
super().__init__()
self.max_sequence_length = max_sequence_length
self.d_model = d_model
self.base = base

def forward(self):
pe = torch.zeros(self.max_sequence_length, self.d_model, dtype=torch.float) # size(max_sequence_length, d_model)
exp_1 = torch.arange(self.d_model // 2, dtype=torch.float) # 初始化一半维度,sin位置编码的维度被分为了两部分
exp_value = exp_1 / (self.d_model / 2)

alpha = 1 / (self.base ** exp_value) # size(dmodel/2)
out = torch.arange(self.max_sequence_length, dtype=torch.float)[:, None] @ alpha[None, :] # size(max_sequence_length, d_model/2)
embedding_sin = torch.sin(out)
embedding_cos = torch.cos(out)

pe[:, 0::2] = embedding_sin # 奇数位置设置为sin
pe[:, 1::2] = embedding_cos # 偶数位置设置为cos
return pe

SinPositionEncoding(d_model=4, max_sequence_length=10, base=10000).forward()

正弦位置编码不需要进行学习,是初始化时直接根据如上公式赋值的常量, 因此有一定的外推性。
又由于位置α+β\alpha+\beta的向量可以表示成位置α和位置β的向量组合,表明正弦编码可以表达相对位置信息。

可学习位置编码

这种位置编码是Bert、GPT、ViT等架构的实现方式,直接将位置编码当作可训练参数,让它随着训练过程更新。实现方式简单,交给模型进行自学习,大力出奇迹。

1
2
3
4
5
6
7
8
9
10
class TrainablePositionEncoding(nn.Module):
def __init__(self, max_sequence_length, d_model):
super().__init__()
self.max_sequence_length = max_sequence_length
self.d_model = d_model

def forward(self):
pe = nn.Embedding(self.max_sequence_length, self.d_model)
nn.init.constant(pe.weight, 0.)
return pe

相对位置编码

相对位置并没有完整建模每个输入的位置信息,而是根据Attention中K,V矩阵的偏移量产生不同的Embedding,计算Attention时考虑当前位置与被Attention位置的相对距离。相对位置编码几乎都是在Softmax之前的Attention矩阵上进行操作的

经典相对位置编码

相对位置编码起源于Google的论文《Self-Attention with Relative Position Representations》,华为开源的NEZHA模型也用到了这种位置编码。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
class RelativePosition(nn.Module):
def __init__(self, num_units, max_relative_position):
super().__init__()
self.num_units = num_units
self.max_relative_position = max_relative_position
self.embeddings_table = nn.Parameter(torch.Tensor(max_relative_position * 2 + 1, num_units))
nn.init.xavier_uniform_(self.embeddings_table)

def forward(self, length_q, length_k):
range_vec_q = torch.arange(length_q)
range_vec_k = torch.arange(length_k)
distance_mat = range_vec_k[None, :] - range_vec_q[:, None]
distance_mat_clipped = torch.clamp(distance_mat, -self.max_relative_position, self.max_relative_position)
final_mat = distance_mat_clipped + self.max_relative_position
final_mat = torch.LongTensor(final_mat).cuda()
embeddings = self.embeddings_table[final_mat].cuda()

return embeddings

class RelativeMultiHeadAttention(nn.Module):
def __init__(self, d_model, n_heads, dropout=0.1, batch_size=6):
"Take in model size and number of heads."
super(RelativeMultiHeadAttention, self).__init__()
self.d_model = d_model
self.n_heads = n_heads
self.batch_size = batch_size

assert d_model % n_heads == 0
self.head_dim = d_model // n_heads

self.linears = _get_clones(nn.Linear(d_model, d_model), 4)
self.dropout = nn.Dropout(p=dropout)
self.relative_position_k = RelativePosition(self.head_dim, max_relative_position=16)
self.relative_position_v = RelativePosition(self.head_dim, max_relative_position=16)

self.scale = torch.sqrt(torch.FloatTensor([self.head_dim])).cuda()

def forward(self, query, key, value):
# embedding
# query, key, value = [batch_size, len, hid_dim]
query, key, value = [l(x).view(self.batch_size, -1, self.d_model) for l, x in
zip(self.linears, (query, key, value))]

len_k = query.shape[1]
len_q = query.shape[1]
len_v = value.shape[1]

# Self-Attention
# r_q1, r_k1 = [batch_size, len, n_heads, head_dim]
r_q1 = query.view(self.batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
r_k1 = key.view(self.batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
attn1 = torch.matmul(r_q1, r_k1.permute(0, 1, 3, 2))

r_q2 = query.permute(1, 0, 2).contiguous().view(len_q, self.batch_size * self.n_heads, self.head_dim)
r_k2 = self.relative_position_k(len_q, len_k)
attn2 = torch.matmul(r_q2, r_k2.transpose(1, 2)).transpose(0, 1)
attn2 = attn2.contiguous().view(self.batch_size, self.n_heads, len_q, len_k)
attn = (attn1 + attn2) / self.scale

attn = self.dropout(torch.softmax(attn, dim=-1))
# attn = [batch_size, n_heads, len, len]
r_v1 = value.view(self.batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
weight1 = torch.matmul(attn, r_v1)
r_v2 = self.relative_position_v(len_q, len_v)
weight2 = attn.permute(2, 0, 1, 3).contiguous().view(len_q, self.batch_size * self.n_heads, len_k)
weight2 = torch.matmul(weight2, r_v2)
weight2 = weight2.transpose(0, 1).contiguous().view(self.batch_size, self.n_heads, len_q, self.head_dim)

x = weight1 + weight2
# x = [batch size, n heads, query len, head dim]

x = x.permute(0, 2, 1, 3).contiguous()
# x = [batch size, query len, n heads, head dim]

x = x.view(self.batch_size * len_q, self.d_model)
# x = [batch size * query len, hid dim]

return self.linears[-1](x)

旋转位置编码

在Llama及Llama2,QWen等模型中,使用了这种位置编码的方式。在论文RoFormer: Enhanced Transformer with Rotary Position Embedding中有详细的解释。
Rope是将绝对位置编码与相对位置编码进行结合,通过绝对位置编码的方式实现相对位置编码。
image.png
这部分内容可以参考文末的参考文章,写的非常详细。
Rope有不同的实现方式,这里是Llama源码中的实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
# 生成旋转矩阵
def precompute_freqs_cis(dim: int, seq_len: int, theta: float = 10000.0):
# 计算词向量元素两两分组之后,每组元素对应的旋转角度\theta_i
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
# 生成 token 序列索引 t = [0, 1,..., seq_len-1]
t = torch.arange(seq_len, device=freqs.device)
# freqs.shape = [seq_len, dim // 2]
freqs = torch.outer(t, freqs).float() # 计算m * \theta

# 计算结果是个复数向量
# 假设 freqs = [x, y]
# 则 freqs_cis = [cos(x) + sin(x)i, cos(y) + sin(y)i]
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
return freqs_cis

# 旋转位置编码计算
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
# xq.shape = [batch_size, seq_len, dim]
# xq_.shape = [batch_size, seq_len, dim // 2, 2]
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 2)
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 2)

# 转为复数域
xq_ = torch.view_as_complex(xq_)
xk_ = torch.view_as_complex(xk_)

# 应用旋转操作,然后将结果转回实数域
# xq_out.shape = [batch_size, seq_len, dim]
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(2)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(2)
return xq_out.type_as(xq), xk_out.type_as(xk)

class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()

self.wq = Linear(...)
self.wk = Linear(...)
self.wv = Linear(...)

self.freqs_cis = precompute_freqs_cis(dim, max_seq_len * 2)

def forward(self, x: torch.Tensor):
bsz, seqlen, _ = x.shape
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)

xq = xq.view(batch_size, seq_len, dim)
xk = xk.view(batch_size, seq_len, dim)
xv = xv.view(batch_size, seq_len, dim)

# attention 操作之前,应用旋转位置编码
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)

# scores.shape = (bs, seqlen, seqlen)
scores = torch.matmul(xq, xk.transpose(1, 2)) / math.sqrt(dim)
scores = F.softmax(scores.float(), dim=-1)
output = torch.matmul(scores, xv) # (batch_size, seq_len, dim)

参考

  1. 科学空间-位置编码
  2. github
  3. T5模型中的位置编码
  4. 一步一步,推导旋转位置编码 (Rotary Position Embedding, RoPE)
  5. 一文看懂 LLaMA 中的旋转式位置编码(Rotary Position Embedding)
  6. 十分钟读懂旋转编码(RoPE)