转自(有改动):神经网络学习小记录40——春节到了,用LSTM写古诗不?
数据下载地址:古诗词语料
参考代码:
import collections import math import os import time import numpy as np from keras.callbacks import ModelCheckpoint from keras.layers import Input, LSTM, Dropout, Dense from keras.models import Model from keras.models import load_model from keras.optimizers import Adam END_CHAR = '\n' UNKNOWN_CHAR = ' ' unit_sentence = 6 UNITS = 256 batch_size = 64 epochs = 3 data_path = 'data1' if not os.path.exists(data_path): os.mkdir(data_path) # 加载数据,数据预处理 def load_data(): sents = [] with open(data_path + '/poetry.txt', 'r', encoding='utf8') as f: for line in f: line = line.strip().split(':')[1] if (len(line) > 5): ss = line.split(',') if (len(ss[0]) == 5): sents.append(line + END_CHAR) words = [] for sent in sents: words += [w for w in sent] counter = collections.Counter(words) count_pairs = sorted(counter.items(), key=lambda x: -x[1]) words, _ = zip(*count_pairs) max_words = 3000 words = words[:max_words] + (UNKNOWN_CHAR,) char2id_dict = {w: i for i, w in enumerate(words)} id2char_dict = {i: w for i, w in enumerate(words)} sents = sorted(sents, key=lambda sent: len(sent)) unknown_char = char2id_dict.get(UNKNOWN_CHAR) char2id = lambda char: char2id_dict.get(char, unknown_char) sents_vector = np.array([list(map(char2id, sent)) for sent in sents]) words_size = len(char2id_dict) print("vocab size=>", words_size) print("sents count=>", len(sents_vector)) return sents_vector, char2id_dict, id2char_dict def get_6to1(x_data, char2id_dict): inputs = [] targets = [] for i in range(len(x_data) - 6): x = x_data[i:i + 6] y = x_data[i + 6] if (END_CHAR in x) or y == char2id_dict[END_CHAR]: break else: inputs.append(x) targets.append(y) return np.array(inputs), np.array(targets) def get_batch(batch_size, x_data, char2id_dict): num_vecab = len(char2id_dict) num_samples = len(x_data) num_batch = math.ceil(num_samples / batch_size) print("batch count=>", num_batch) index = 0 for i in range(num_batch): one_hot_x = [] one_hot_y = [] for j in range(batch_size): if (index >= num_samples): break inputs, targets = get_6to1(x_data[index], char2id_dict) index += 1 for k in range(len(inputs)): one_hot_x.append(inputs[k]) one_hot_y.append(targets[k]) batch_size_new = len(one_hot_x) input_data = np.zeros(( batch_size_new, 6, num_vecab )) target_data = np.zeros(( batch_size_new, num_vecab )) for k, (x, y) in enumerate(zip(one_hot_x, one_hot_y)): for t, index in enumerate(x): input_data[k, t, index] = 1. target_data[k, y] = 1 yield input_data, target_data # 构建模型 def build_model(num_vocab): inputs = Input(shape=(None, num_vocab)) x = LSTM(UNITS, return_sequences=True)(inputs) x = Dropout(0.6)(x) x = LSTM(UNITS)(x) x = Dropout(0.6)(x) x = Dense(num_vocab, activation='softmax')(x) model = Model(inputs, x) # print(model.summary()) return model # 保存模型 def save_model(model, path): model.save(path) # 加载模型及模型参数 def load_my_model(path, checkpoint_file_path=None): model = load_model(path) if checkpoint_file_path: model.load_weights(checkpoint_file_path) return model # 训练模型 def train(model: Model, sents_vector, char2id_dict, train_size_ration=0.2): start_time = time.time() model.compile(optimizer=Adam(1e-3), loss='categorical_crossentropy', metrics=['accuracy']) checkpoint_path = data_path + '/logs' if not os.path.exists(checkpoint_path): os.mkdir(checkpoint_path) checkpoint = ModelCheckpoint(checkpoint_path + '/loss{loss:.3f}-val_loss{val_loss:.3f}.h5', save_weights_only=True, period=1) num_train = int(len(sents_vector) * train_size_ration) print("train size=>", num_train) for i in range(epochs): model.fit_generator(get_batch(batch_size, sents_vector[:num_train], char2id_dict), steps_per_epoch=max(1, num_train // batch_size), validation_data=get_batch(batch_size, sents_vector[:num_train], char2id_dict), validation_steps=max(1, num_train // batch_size), epochs=1, initial_epoch=0, callbacks=[checkpoint] ) stop_time = time.time() print('time=>', str(stop_time - start_time)) # 预测 def predict(model: Model, sents_vector, id2char_dict): vocab = len(id2char_dict) index = np.random.randint(0, len(sents_vector)) sent = sents_vector[index][:unit_sentence] # 打印原始诗句 ori_output = id2char(sents_vector[index], id2char_dict) print(ori_output) def _pred(text): temp = text[-unit_sentence:] x_pred = np.zeros((1, unit_sentence, vocab)) for t, index in enumerate(temp): x_pred[0, t, index] = 1 y_pred = model.predict(x_pred)[0] choice_id = np.random.choice(range(len(y_pred)), 1, p=y_pred) if id2char_dict[choice_id[0]] == ' ': while id2char_dict[choice_id[0]] in [',', '。', ' ']: choice_id = np.random.randint(0, len(id2char_dict), 1) return choice_id for i in range(24 - unit_sentence): pred = _pred(sent) sent = np.append(sent, pred) output = id2char(sent, id2char_dict) print(output) def id2char(sent, id2char_dict): output = '' for i in range(len(sent)): output = output + id2char_dict[sent[i]] return output ####################train#################### # sents_vector,char2id_dict,id2char_dict=load_data() # num_vocab=len(char2id_dict) # model = build_model(num_vocab) # save_model(model, data_path + "/model.h5") # train(model,sents_vector,char2id_dict,0.7) ####################predict#################### sents_vector, char2id_dict, id2char_dict = load_data() num_vocab = len(char2id_dict) model = load_my_model(data_path + "/model.h5", data_path + "/logs/loss3.615-val_loss3.074.h5") predict(model, sents_vector, id2char_dict)