import re from collections import Counter defget_ori_vocab(): """构建初始词表,把非中文符号及中文标点全部作为分隔符""" vocab_all = [] with open(corpus_path) as fin: for line in fin: line = re.sub("^\u4e00-\u9fa5", " ", line) line = re.sub("[《》“”,。?!()\-=+:]", " ", line) vocab_all.extend([" ".join(list(word)) + "</w>"for word in line.split()]) word_freqs = Counter(vocab_all) splits = {word: [c for c in word.split()] for word in word_freqs.keys()} return word_freqs, splits
defcompute_pair_freqs(): """计算单个字的频率""" pair_freqs = defaultdict(int) for word, freq in word_freqs.items(): split = splits[word] if len(split) == 1: continue for i in range(len(split) - 1): pair = (split[i], split[i + 1]) pair_freqs[pair] += freq for i, key in enumerate(pair_freqs.keys()): print(f"{key}: {pair_freqs[key]}") if i >= 5: break return pair_freqs
for pair, freq in pair_freqs.items(): if max_freq isNoneor max_freq < freq: best_pair = pair max_freq = freq print(f"best pair is {best_pair} and max_freq is {max_freq}") return best_pair pair_freqs = compute_pair_freqs() best_pair = find_max(pair_freqs)
将频率最高的字对合并成新的subword,并更新词表
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
defmerge_pair(best_pair, word_freqs, splits): """合并""" a, b = best_pair[0], best_pair[1] for word in word_freqs: split = splits[word] if len(split) == 1: continue
i = 0 while i < len(split) - 1: if split[i] == a and split[i + 1] == b: split = split[:i] + [a + b] + split[i + 2 :] else: i += 1 splits[word] = split return splits splits = merge_pair(best_pair, word_freqs, splits)
deftokenize(text): """解码""" pre_tokenized_text = text.split() splits = [[l for l in word] for word in pre_tokenized_text] for pair, merge in merges.items(): for idx, split in enumerate(splits): i = 0 while i < len(split) - 1: if split[i] == pair[0] and split[i + 1] == pair[1]: split = split[:i] + [merge] + split[i + 2 :] else: i += 1 splits[idx] = split