FastText天池新闻文本分类

fasttext是facebook开源的一个词向量与文本分类工具

FastText使用

# coding=utf-8
import os

import fasttext
import pandas as pd
from pandas import DataFrame
from sklearn.metrics import f1_score
from sklearn.utils import shuffle

root_dir = 'D:\\data\\aliyun\\news'

# 加载带标签的数据
labeled_data = pd.read_csv(os.path.join(root_dir, 'train_set.csv'), sep='\t')
print(labeled_data.columns.values)
# ['label' 'text']
print(labeled_data.head(1))
# 2  2967 6758 339 2021 1854...

# 将数据打乱
labeled_data: DataFrame = shuffle(labeled_data, random_state=0)

# 将label添加前缀满足fasttext格式要求
labeled_data['label'] = labeled_data['label'].apply(lambda x: '__label__' + str(x))
print(labeled_data.head(1))
# __label__0  3137 4464 5602 6644...

# 将带标签的数据分成训练数据和测试数据,训练数据占70%
train_size = int(labeled_data.shape[0] * 0.7)
train_data = labeled_data[:train_size]
test_data = labeled_data[train_size:]

print(train_data.shape, test_data.shape)
# (140000, 2) (60000, 2)

labeled_data.to_csv('train.csv', index=False, header=False, sep='\t')

model = fasttext.train_supervised('train.csv', lr=1.0, wordNgrams=2,
                                  verbose=2, minCount=1, epoch=25, loss="hs")

# 保存模型
model.save_model('model.bin')

# 加载模型
model = fasttext.load_model('model.bin')

pred = model.predict(test_data['text'].values.tolist())

pred = [x[0] for x in pred[0]]

print(f1_score(test_data['label'], pred, average='macro'))
# 0.9979669090846849





个人资料
时海
等级:8
文章:272篇
访问:16.0w
排名: 2
推荐圈子
上一篇: 科技周报20200907
下一篇:FastText 构建四大名著词向量
猜你感兴趣的圈子:
AI比赛交流圈
标签: labeled、train、fasttext、data、model、面试题
隐藏