引言

在预训练模型大行其道的时代,fastText作为2016年表示学习的经典开源方法,在现有的任务中依旧能发挥相当重要的作用。主要得益于其预测速度快,且在cpu机器上就可以很好的运行。(在标准的多核CPU上, 能够训练10亿词级别语料库的词向量在10分钟之内,能够分类有着30万多类别的50多万句子在1分钟之内。) 在节约资源的同时,可以作为大模型的前置策略。

在工业界,面临大规模数据分类时,有时对预测速度的需求也很重要,例如在大规模数据的风控场景下,每天需要过检上亿文档,来判断文档是否违规(这里可能是色情、涉政、暴力等)。如果对每条数据都过大模型将消耗大量资源,这时引入fastText做粗分类就非常必要了。

模型简介

这里只针对fastText模型做一个简单介绍,主要介绍其模型架构,及n-gram的实现。

模型架构

fastText模型
从上图可以看出,fastText模型只有三层:输入层、隐含层、输出层,输入是多个单词的embedding(这里如果ngram参数大于1的话,还会有ngram的embedding作为输入),输出是分类的类别,隐含层是对多个词向量的叠加平均。输出层是一个层次Softmax层。注意这里是没有全连接层作为中间层的,且不对hidden层进行学习,这也是fastText计算速度快的原因之一。fastText的训练目的区别于基于预训练模型的fine-tuning,是训练输入层的embedding,故而训练结束后可以同时得到每个词的词向量。

ngram特征

fastText 模型中的 n-gram 是一种文本特征提取方法,n-gram 特征是通过在原始词汇基础上生成一组子词来实现的,这些子词是由原始词汇中的连续子序列构成的。主要实现函数如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
int32_t Dictionary::getLine(
std::istream& in,
std::vector<int32_t>& words,
std::vector<int32_t>& labels) const {
std::vector<int32_t> word_hashes;
std::string token;
int32_t ntokens = 0;

reset(in);
words.clear();
labels.clear();
while (readWord(in, token)) {
uint32_t h = hash(token);
int32_t wid = getId(token, h);
entry_type type = wid < 0 ? getType(token) : getType(wid);

ntokens++;
if (type == entry_type::word) {
addSubwords(words, token, wid);
word_hashes.push_back(h);
} else if (type == entry_type::label && wid >= 0) {
labels.push_back(wid - nwords_);
}
if (token == EOS) {
break;
}
}
addWordNgrams(words, word_hashes, args_->wordNgrams);
return ntokens;
}


void Dictionary::addWordNgrams(
std::vector<int32_t>& line,
const std::vector<int32_t>& hashes,
int32_t n) const {
for (int32_t i = 0; i < hashes.size(); i++) {
uint64_t h = hashes[i];
for (int32_t j = i + 1; j < hashes.size() && j < i + n; j++) {
h = h * 116049371 + hashes[j];
pushHash(line, h % args_->bucket);
}
}
}
// 上面代码是将词与ngram进行hash的过程。把所有的n-gram都哈希到buckets个桶中,
// 哈希到同一个桶的所有n-gram共享一个embedding vector。
// 不过这种方法潜在的问题是存在哈希冲突,不同的n-gram可能会共享同一个embedding。
// 如果桶大小取的足够大,这种影响会很小。
// 训练时使用的ngram维度为:
std::shared_ptr<DenseMatrix> input = std::make_shared<DenseMatrix>(
dict_->nwords() + args_->bucket, args_->dim);
// 即原始单词+ngram之后的组合单词。

ngram示例:
我 爱 北京–>2gram–>我 爱 北京 我爱 爱北京
n-gram 特征可以捕捉到文本中的局部信息和上下文信息,因此可以提高模型的性能和泛化能力。fastText 模型中的 n-gram 特征可以选择使用单词级别或字符级别(上述代码是单词级别代码),具体选择哪种级别取决于应用场景和数据集的特征,将wordngram参数设置为大于2时,有时可以提高算子的精度。
这里可能的原因是ngram提高了模型的拟合能力。例如一个情感分类任务,在训练集中有:我现在非常开心,如果在测试集中出现的待预测句子为我现在非常开心,使用unigram时,由于使用单个词信息,fastText模型大概率会将句子预测为正向情感。而使用2-gram,输入token中有**非常开心,而预测语句中为非常不开心,**可能会导致分类正确。

训练优化

分层Softmax
fastText 使用哈夫曼树(Huffman Tree)对词表进行编码,使得出现频率高的词被编码为较短的二进制码,而出现频率低的词被编码为较长的二进制码。这样,词表中出现频率高的词就可以用较短的编码表示,从而减少计算复杂度。层次 softmax 的计算复杂度为 O(log V),可以大大降低计算时间。

实战应用

数据说明

使用百度飞桨学习赛数据: 中文新闻文本标题分类,完成预测后提交到比赛网址可以查看在测试集上的得分,直观验证模型的效果。

简介

THUCNews是根据新浪新闻RSS订阅频道2005~2011年间的历史数据筛选过滤生成,包含74万篇新闻文档(2.19 GB),均为UTF-8纯文本格式。在原始新浪新闻分类体系的基础上,重新整合划分出14个候选分类类别:财经、彩票、房产、股票、家居、教育、科技、社会、时尚、时政、体育、星座、游戏、娱乐

环境搭建

1
2
3
4
5
$ git clone https://github.com/facebookresearch/fastText.git
$ cd fastText
$ sudo pip install .
$ # or :
$ sudo python setup.py install

使用pip install有时候会安装不成功,推荐使用官网的安装方式。

完整代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
#! python3
# -*- encoding: utf-8 -*-
###############################################################
# @File : train_fastText_main.py
# @Time : 2022/12/01 17:32:28
# @Author : heng
# @Version : 1.0
# @Contact : hengsblog@163.com
###############################################################
"""
@comment:使用fastText进行多分类
"""
from sklearn.metrics import f1_score, accuracy_score, precision_score, recall_score
import random
import re
import fasttext
import jieba

class FastText(object):
"simple fastText model"
def __init__(self):
"init"
self.train_path = f"data/train.txt"
self.dev_path = f"data/dev.txt"
self.test_path = f"data/test.txt"
self.get_label()

def preprocess(sentence):
"""前期预处理"""
sentence = Sent.pre_sentence(sentence, remove_stopwords=False, replace_punc=True, replace_num=False, cut_words=False)
return sentence

def get_label(self):
"label"
labels = ["财经", "彩票", "房产", "股票", "家居", "教育", "科技", "社会", "时尚", "时政", "体育", "星座", "游戏", "娱乐"]
self.label_id = {i:idx for idx, i in enumerate(labels)}
self.id_label = {id:label for label, id in self.label_id.items()}

def load_to_list(self, file_in):
"load"
with open(file_in) as fin:
return [i.strip() for i in fin]

def process_2_fastText(self, file_in):
"""formate to fastText
__label__0 text
__label__1 text
"""
with open(f"{file_in[:-4]}_fast.txt", "w", encoding='utf-8') as fout:
for line in self.load_to_list(file_in):
text, label = line.strip().split("\t")
text = " ".join(jieba.lcut(text))
id = self.label_id[label]
fout.write(f"__label__{id} {text}\n")

def train(self):
"train"
# self.process_2_fastText(self.train_path)
model = fasttext.train_supervised(f"{self.train_path[:-4]}_fast.txt", wordNgrams=1, minCount=2)
model.save_model("model_fastText.bin")
return model
# model = fasttext.load_model("model_filename.bin")

def prediction(self, mode, model):
"""预测服务"""
def __predict(text):
label = model.predict(text)
return int(label[0][0].split("__")[-1])

if mode == "dev":
label_pre = []
content_label = self.load_to_list(self.dev_path)
content = [i.split("\t")[0] for i in content_label]
label_true = [self.label_id[i.split("\t")[1]] for i in content_label]
for text in content:
text = " ".join(jieba.lcut(text))
temp = __predict(text)
label_pre.append(temp)
print("accuracy_score", accuracy_score(label_true, label_pre))
else:
content = self.load_to_list(self.test_path)
with open("submit.txt", 'w', encoding='utf-8') as fout:
for text in content:
text = " ".join(jieba.lcut(text))
label = self.id_label[__predict(text)]
fout.write(f"{label}\n")

def main():
"""main"""
FastText_ = FastText()
model = FastText_.train()
# model = fasttext.load_model("model_fastText.bin")
FastText_.prediction("dev", model)
FastText_.prediction("test", model)

if __name__ == "__main__":
main()

参数调整:

对训练集调整参数后,在验证集上的效果如下表所示:
使用基础预处理后的结果:82.71
去掉停用词结果:81.27

提交分数

提交系统后得分为82.06

需要注意

  1. fastText输入时,未分词的结果很差。未分词调参结果如下:
    | wordNgrams | minCount | 准确率 |
    | — | — | — |
    | 1 | 2 | 0.203 |
    | 2 | 2 | 0.482 |
    | 2 | 3 | 0.470 |
    | 3 | 2 | 0.481 |
    | 4 | 1 | 0.519 |
    | 5 | 1 | 0.524 |

即:使用5gram都无法达到分词的效果,线上测试集只有0.2左右的准确率。

  1. 当验证集的准确率超过99%后,需要构造新的验证集。

参考