Keras实现BiLSTM+CRF字符级序列标注

参考:Keras实现BiLSTM+CRF字符级序列标注, 有改动

keras-contrib
keras-contrib库是python深度学习库Keras的官方扩展库。它包含额外的层、激活、丢失函数、优化器等,这些在Keras本身中还不可用。所有这些附加模块都可以与核心Keras模型和模块一起使用。比如实现了CRF层。

pip install git+https://www.github.com/keras-team/keras-contrib.git
或者
git clone https://www.github.com/keras-team/keras-contrib.git
cd keras-contrib
python setup.py install
完整代码:

from collections import Counter

import numpy as np
from keras.layers import Embedding, Bidirectional, LSTM
from keras.models import Sequential
from keras.preprocessing.sequence import pad_sequences
from keras_contrib.layers import CRF


# 加载数据
def load_data(path):
    with open(path, 'r', encoding='utf8') as f:
        text = f.read().strip()
        lines = text.split('\n\n')
        data = [[word.strip().split() for word in line.strip().split('\n')] for line in lines]
        return data


train = load_data('data/train_data.data')
test = load_data('data/test_data.data')

word_counts = Counter(word[0].lower() for line in train for word in line)
vocab = [w for w, f in word_counts.items() if f >= 2]
tags = list(set(word[1] for line in train for word in line))

word2id = dict((w, i) for i, w in enumerate(vocab))

max_len = max(map(len, train))


# 数据预处理成模型输入格式
def process_data(data):
    x = [[word2id.get(word[0].lower(), 1) for word in line] for line in data]

    y = [[tags.index(word[1]) for word in line] for line in data]

    x = pad_sequences(x, max_len)

    y = pad_sequences(y, max_len, value=-1)

    y = np.expand_dims(y, 2)

    return x, y


train_x, train_y = process_data(train)

test_x, test_y = process_data(test)

# build model
model = Sequential()
model.add(Embedding(len(vocab), 200))
model.add(Bidirectional(LSTM(100, return_sequences=True)))
crf = CRF(len(tags), sparse_target=True)
model.add(crf)
model.summary()
# Model: "sequential_1"
# _________________________________________________________________
# Layer (type)                 Output Shape              Param #
# =================================================================
# embedding_1 (Embedding)      (None, None, 200)         851600
# _________________________________________________________________
# bidirectional_1 (Bidirection (None, None, 200)         240800
# _________________________________________________________________
# crf_1 (CRF)                  (None, None, 7)           1470
# =================================================================
# Total params: 1,093,870
# Trainable params: 1,093,870
# Non-trainable params: 0


# 模型训练
model.compile('adam', loss=crf.loss_function, metrics=[crf.accuracy])

# # for small data test
# sample_count = 1000
# model.fit(train_x[:sample_count], train_y[:sample_count], batch_size=16, epochs=10,
#           validation_data=[test_x[:sample_count], test_y[:sample_count]])

# train all data
model.fit(train_x, train_y, batch_size=16, epochs=10, validation_data=[test_x, test_y])

# 预测
predict_text = '中华人民共和国国务院总理周恩来在外交部长陈毅的陪同下,连续访问了埃塞俄比亚等非洲10国以及阿尔巴尼亚'
x = [word2id.get(word, 1) for word in predict_text]
sent_len = len(x)

print(sent_len)

x = pad_sequences([x], max_len)

y = model.predict(x)
y = y[-sent_len:]
y = [np.argmax(row) for row in y[0]]
y = [tags[i] for i in y]

per, loc, org = '', '', ''
for s, t in zip(predict_text, tags):
    if t in ('B-PER', 'I-PER'):
        per += ' ' + s if (t == 'B-PER') else s
    if t in ('B-ORG', 'I-ORG'):
        org += ' ' + s if (t == 'B-ORG') else s
    if t in ('B-LOC', 'I-LOC'):
        loc += ' ' + s if (t == 'B-LOC') else s

print(['person:' + per, 'location:' + loc, 'organzation:' + org])




个人资料
回味往事
等级:4
文章:2篇
访问:924
排名: 36
推荐圈子
上一篇: 使用keras对鸢尾花进行多分类
猜你感兴趣的圈子:
NLP
标签: keras、crf、train、word、len、面试题
隐藏