LSTM写诗

转自(有改动):神经网络学习小记录40——春节到了,用LSTM写古诗不?

数据下载地址:古诗词语料

参考代码:

import collections
import math
import os
import time

import numpy as np
from keras.callbacks import ModelCheckpoint
from keras.layers import Input, LSTM, Dropout, Dense
from keras.models import Model
from keras.models import load_model
from keras.optimizers import Adam

END_CHAR = '\n'
UNKNOWN_CHAR = ' '
unit_sentence = 6
UNITS = 256
batch_size = 64
epochs = 3
data_path = 'data1'

if not os.path.exists(data_path):
    os.mkdir(data_path)


# 加载数据,数据预处理
def load_data():
    sents = []
    with open(data_path + '/poetry.txt', 'r', encoding='utf8') as f:
        for line in f:
            line = line.strip().split(':')[1]
            if (len(line) > 5):
                ss = line.split(',')
                if (len(ss[0]) == 5):
                    sents.append(line + END_CHAR)
    words = []
    for sent in sents:
        words += [w for w in sent]

    counter = collections.Counter(words)
    count_pairs = sorted(counter.items(), key=lambda x: -x[1])

    words, _ = zip(*count_pairs)
    max_words = 3000
    words = words[:max_words] + (UNKNOWN_CHAR,)

    char2id_dict = {w: i for i, w in enumerate(words)}
    id2char_dict = {i: w for i, w in enumerate(words)}

    sents = sorted(sents, key=lambda sent: len(sent))

    unknown_char = char2id_dict.get(UNKNOWN_CHAR)

    char2id = lambda char: char2id_dict.get(char, unknown_char)

    sents_vector = np.array([list(map(char2id, sent)) for sent in sents])

    words_size = len(char2id_dict)
    print("vocab size=>", words_size)
    print("sents count=>", len(sents_vector))
    return sents_vector, char2id_dict, id2char_dict


def get_6to1(x_data, char2id_dict):
    inputs = []
    targets = []
    for i in range(len(x_data) - 6):
        x = x_data[i:i + 6]
        y = x_data[i + 6]
        if (END_CHAR in x) or y == char2id_dict[END_CHAR]:
            break
        else:
            inputs.append(x)
            targets.append(y)
    return np.array(inputs), np.array(targets)


def get_batch(batch_size, x_data, char2id_dict):
    num_vecab = len(char2id_dict)
    num_samples = len(x_data)
    num_batch = math.ceil(num_samples / batch_size)
    print("batch count=>", num_batch)
    index = 0
    for i in range(num_batch):
        one_hot_x = []
        one_hot_y = []
        for j in range(batch_size):
            if (index >= num_samples):
                break
            inputs, targets = get_6to1(x_data[index], char2id_dict)
            index += 1
            for k in range(len(inputs)):
                one_hot_x.append(inputs[k])
                one_hot_y.append(targets[k])
        batch_size_new = len(one_hot_x)
        input_data = np.zeros((
            batch_size_new, 6, num_vecab
        ))
        target_data = np.zeros((
            batch_size_new, num_vecab
        ))

        for k, (x, y) in enumerate(zip(one_hot_x, one_hot_y)):
            for t, index in enumerate(x):
                input_data[k, t, index] = 1.
            target_data[k, y] = 1
        yield input_data, target_data


# 构建模型
def build_model(num_vocab):
    inputs = Input(shape=(None, num_vocab))
    x = LSTM(UNITS, return_sequences=True)(inputs)
    x = Dropout(0.6)(x)
    x = LSTM(UNITS)(x)
    x = Dropout(0.6)(x)
    x = Dense(num_vocab, activation='softmax')(x)
    model = Model(inputs, x)
    # print(model.summary())
    return model


# 保存模型
def save_model(model, path):
    model.save(path)


# 加载模型及模型参数
def load_my_model(path, checkpoint_file_path=None):
    model = load_model(path)
    if checkpoint_file_path:
        model.load_weights(checkpoint_file_path)
    return model


# 训练模型
def train(model: Model, sents_vector, char2id_dict, train_size_ration=0.2):
    start_time = time.time()
    model.compile(optimizer=Adam(1e-3), loss='categorical_crossentropy', metrics=['accuracy'])

    checkpoint_path = data_path + '/logs'
    if not os.path.exists(checkpoint_path):
        os.mkdir(checkpoint_path)
    checkpoint = ModelCheckpoint(checkpoint_path + '/loss{loss:.3f}-val_loss{val_loss:.3f}.h5', save_weights_only=True,
                                 period=1)

    num_train = int(len(sents_vector) * train_size_ration)
    print("train size=>", num_train)

    for i in range(epochs):
        model.fit_generator(get_batch(batch_size, sents_vector[:num_train], char2id_dict),
                            steps_per_epoch=max(1, num_train // batch_size),
                            validation_data=get_batch(batch_size, sents_vector[:num_train], char2id_dict),
                            validation_steps=max(1, num_train // batch_size),
                            epochs=1,
                            initial_epoch=0,
                            callbacks=[checkpoint]
                            )

    stop_time = time.time()
    print('time=>', str(stop_time - start_time))


# 预测
def predict(model: Model, sents_vector, id2char_dict):
    vocab = len(id2char_dict)
    index = np.random.randint(0, len(sents_vector))
    sent = sents_vector[index][:unit_sentence]

    # 打印原始诗句
    ori_output = id2char(sents_vector[index], id2char_dict)
    print(ori_output)

    def _pred(text):
        temp = text[-unit_sentence:]
        x_pred = np.zeros((1, unit_sentence, vocab))
        for t, index in enumerate(temp):
            x_pred[0, t, index] = 1
        y_pred = model.predict(x_pred)[0]
        choice_id = np.random.choice(range(len(y_pred)), 1, p=y_pred)
        if id2char_dict[choice_id[0]] == ' ':
            while id2char_dict[choice_id[0]] in [',', '。', ' ']:
                choice_id = np.random.randint(0, len(id2char_dict), 1)
        return choice_id

    for i in range(24 - unit_sentence):
        pred = _pred(sent)
        sent = np.append(sent, pred)

    output = id2char(sent, id2char_dict)
    print(output)


def id2char(sent, id2char_dict):
    output = ''
    for i in range(len(sent)):
        output = output + id2char_dict[sent[i]]

    return output


####################train####################
# sents_vector,char2id_dict,id2char_dict=load_data()
# num_vocab=len(char2id_dict)
# model = build_model(num_vocab)
# save_model(model, data_path + "/model.h5")
# train(model,sents_vector,char2id_dict,0.7)

####################predict####################
sents_vector, char2id_dict, id2char_dict = load_data()
num_vocab = len(char2id_dict)
model = load_my_model(data_path + "/model.h5", data_path + "/logs/loss3.615-val_loss3.074.h5")
predict(model, sents_vector, id2char_dict)




个人资料
半世晨晓
等级:5
文章:3篇
访问:1.2k
排名: 31
下一篇:使用word2vec词向量进行【四大名著】文本分类
猜你感兴趣的圈子:
深度学习交流圈
标签: dict、char2id、sents、id2char、model、面试题
隐藏