使用keras【四大名著】文本分类

本文使用word2vec训练词向量,并进行文本分类,准确率>90%

数据下载:四大名著训练语料

停用词下载:中文常用停用词

训练准确率和损失变化:

import collections
import time

import matplotlib.pyplot as plt
import numpy as np
from gensim.models import word2vec
from keras.callbacks import EarlyStopping
from keras.layers import Dense, Dropout
from keras.models import Sequential
from keras.utils import np_utils


def loadStopwords():
    with open("stopwords.txt", 'r', encoding='utf-8') as fr:
        stopwords = fr.read().splitlines()
    return stopwords


start = time.time()

# 加载停用词
stopwords = set(loadStopwords())

data = []
with open('train.txt', 'r', encoding='utf8') as fr:
    data = fr.read().splitlines()

tags = []
sentences = []
all_words = []
for item in data:
    ss = item.split(' ')
    tags.append(int(ss[0]))
    valid_words = [x for x in ss[1:] if x not in stopwords]
    sentences.append(valid_words)
    all_words += valid_words

# 获取词频最高的前80%个词
print('total words(没有去重)=>', len(all_words))
# total words(没有去重)=> 896717

counter = collections.Counter(all_words)
counter_pairs = sorted(counter.items(), key=lambda x: -x[1])
words, _ = zip(*counter_pairs)
print('total words=>', len(words))
# total words=> 110967

reserved_word_count = int(len(words) * 0.8)
words = words[:reserved_word_count]
print('reserved word count=>', len(words))
# reserved word count=> 88773
stop = time.time()
print("parse time: ", str(round(stop - start, 4)), "s")
# parse time:  0.6103 s

# 训练词向量
start = stop
model = word2vec.Word2Vec(sentences, min_count=2, size=300)
stop = time.time()
print("finished train word2vec, time=>", str(stop - start), 's')
# finished train word2vec, time=> 14.128031492233276 s




# 计算句子的向量=句子中每个词的词向量之和
start = stop
features = []
hit_count = 0
missing_count = 0
for sent in sentences:
    feature = np.zeros(300)
    for word in sent:
        try:
            hit_count += 1
            feature += model.wv[word]
        except KeyError:
            missing_count += 1
            continue
    features.append(feature.tolist())

tags_count = len(set(tags))
tags = [x - 1 for x in tags]
tags = np_utils.to_categorical(tags, tags_count)
features = np.array(features)

print("len:", len(features), len(tags))
print('hit count=>', hit_count, ' missing count=>', missing_count)
# hit count=> 896717  missing count=> 63319

stop = time.time()
print('time=>', str(stop - start))

#构建模型
model = Sequential()
model.add(Dense(512, activation='relu', input_shape=(300,)))
model.add(Dropout(0.5))
model.add(Dense(tags_count, activation='softmax'))

print(model.summary())
# Model: "sequential_1"
# _________________________________________________________________
# Layer (type)                 Output Shape              Param #
# =================================================================
# dense_1 (Dense)              (None, 512)               154112
# _________________________________________________________________
# dropout_1 (Dropout)          (None, 512)               0
# _________________________________________________________________
# dense_2 (Dense)              (None, 3)                 1539
# =================================================================
# Total params: 155,651
# Trainable params: 155,651
# Non-trainable params: 0


model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

history = model.fit(features, tags, batch_size=16, epochs=100, validation_split=0.25,
                    callbacks=[EarlyStopping(monitor='val_accuracy', patience=5, min_delta=0.0005)]
                    )

print(history.params)

plt.figure(figsize=(12, 6))
plt.subplot(2, 1, 1)
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.legend(['accuracy', 'val_accuracy'])
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.subplot(2, 1, 2)
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.legend(['loss', 'val_loss'])
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.show()

print(history)





个人资料
半世晨晓
等级:5
文章:3篇
访问:1.2k
排名: 31
上一篇: 使用word2vec词向量进行【四大名著】文本分类
猜你感兴趣的圈子:
深度学习交流圈
标签: plt、words、tags、history、count、面试题
隐藏