defget_label(self): "label" labels = ["财经", "彩票", "房产", "股票", "家居", "教育", "科技", "社会", "时尚", "时政", "体育", "星座", "游戏", "娱乐"] self.label_id = {i:idx for idx, i in enumerate(labels)} self.id_label = {id:label for label, id in self.label_id.items()} defload_to_list(self, file_in): "load" with open(file_in) as fin: return [i.strip() for i in fin] defprocess_2_fastText(self, file_in): """formate to fastText __label__0 text __label__1 text """ with open(f"{file_in[:-4]}_fast.txt", "w", encoding='utf-8') as fout: for line in self.load_to_list(file_in): text, label = line.strip().split("\t") text = " ".join(jieba.lcut(text)) id = self.label_id[label] fout.write(f"__label__{id}{text}\n")
deftrain(self): "train" # self.process_2_fastText(self.train_path) model = fasttext.train_supervised(f"{self.train_path[:-4]}_fast.txt", wordNgrams=1, minCount=2) model.save_model("model_fastText.bin") return model # model = fasttext.load_model("model_filename.bin")
if mode == "dev": label_pre = [] content_label = self.load_to_list(self.dev_path) content = [i.split("\t")[0] for i in content_label] label_true = [self.label_id[i.split("\t")[1]] for i in content_label] for text in content: text = " ".join(jieba.lcut(text)) temp = __predict(text) label_pre.append(temp) print("accuracy_score", accuracy_score(label_true, label_pre)) else: content = self.load_to_list(self.test_path) with open("submit.txt", 'w', encoding='utf-8') as fout: for text in content: text = " ".join(jieba.lcut(text)) label = self.id_label[__predict(text)] fout.write(f"{label}\n") defmain(): """main""" FastText_ = FastText() model = FastText_.train() # model = fasttext.load_model("model_fastText.bin") FastText_.prediction("dev", model) FastText_.prediction("test", model) if __name__ == "__main__": main()