参考:Keras实现BiLSTM+CRF字符级序列标注, 有改动
keras-contrib
keras-contrib库是python深度学习库Keras的官方扩展库。它包含额外的层、激活、丢失函数、优化器等,这些在Keras本身中还不可用。所有这些附加模块都可以与核心Keras模型和模块一起使用。比如实现了CRF层。
pip install git+https://www.github.com/keras-team/keras-contrib.git或者
git clone https://www.github.com/keras-team/keras-contrib.git cd keras-contrib python setup.py install完整代码:
from collections import Counter import numpy as np from keras.layers import Embedding, Bidirectional, LSTM from keras.models import Sequential from keras.preprocessing.sequence import pad_sequences from keras_contrib.layers import CRF # 加载数据 def load_data(path): with open(path, 'r', encoding='utf8') as f: text = f.read().strip() lines = text.split('\n\n') data = [[word.strip().split() for word in line.strip().split('\n')] for line in lines] return data train = load_data('data/train_data.data') test = load_data('data/test_data.data') word_counts = Counter(word[0].lower() for line in train for word in line) vocab = [w for w, f in word_counts.items() if f >= 2] tags = list(set(word[1] for line in train for word in line)) word2id = dict((w, i) for i, w in enumerate(vocab)) max_len = max(map(len, train)) # 数据预处理成模型输入格式 def process_data(data): x = [[word2id.get(word[0].lower(), 1) for word in line] for line in data] y = [[tags.index(word[1]) for word in line] for line in data] x = pad_sequences(x, max_len) y = pad_sequences(y, max_len, value=-1) y = np.expand_dims(y, 2) return x, y train_x, train_y = process_data(train) test_x, test_y = process_data(test) # build model model = Sequential() model.add(Embedding(len(vocab), 200)) model.add(Bidirectional(LSTM(100, return_sequences=True))) crf = CRF(len(tags), sparse_target=True) model.add(crf) model.summary() # Model: "sequential_1" # _________________________________________________________________ # Layer (type) Output Shape Param # # ================================================================= # embedding_1 (Embedding) (None, None, 200) 851600 # _________________________________________________________________ # bidirectional_1 (Bidirection (None, None, 200) 240800 # _________________________________________________________________ # crf_1 (CRF) (None, None, 7) 1470 # ================================================================= # Total params: 1,093,870 # Trainable params: 1,093,870 # Non-trainable params: 0 # 模型训练 model.compile('adam', loss=crf.loss_function, metrics=[crf.accuracy]) # # for small data test # sample_count = 1000 # model.fit(train_x[:sample_count], train_y[:sample_count], batch_size=16, epochs=10, # validation_data=[test_x[:sample_count], test_y[:sample_count]]) # train all data model.fit(train_x, train_y, batch_size=16, epochs=10, validation_data=[test_x, test_y]) # 预测 predict_text = '中华人民共和国国务院总理周恩来在外交部长陈毅的陪同下,连续访问了埃塞俄比亚等非洲10国以及阿尔巴尼亚' x = [word2id.get(word, 1) for word in predict_text] sent_len = len(x) print(sent_len) x = pad_sequences([x], max_len) y = model.predict(x) y = y[-sent_len:] y = [np.argmax(row) for row in y[0]] y = [tags[i] for i in y] per, loc, org = '', '', '' for s, t in zip(predict_text, tags): if t in ('B-PER', 'I-PER'): per += ' ' + s if (t == 'B-PER') else s if t in ('B-ORG', 'I-ORG'): org += ' ' + s if (t == 'B-ORG') else s if t in ('B-LOC', 'I-LOC'): loc += ' ' + s if (t == 'B-LOC') else s print(['person:' + per, 'location:' + loc, 'organzation:' + org])