bert4keras进行完形填空

1、bert4keras安装:

pip install bert4keras  -i https://pypi.tuna.tsinghua.edu.cn/simple/

2、bert预训练模型下载:

https://github.com/google-research/bert/

3、代码如下:

# coding=utf-8
import os

import numpy as np
from bert4keras.models import build_transformer_model
from bert4keras.tokenizers import Tokenizer

pretrained_path = '/path/to/chinese_L-12_H-768_A-12'
config_path = os.path.join(pretrained_path, 'bert_config.json')
checkpoint_path = os.path.join(pretrained_path, 'bert_model.ckpt')
vocab_path = os.path.join(pretrained_path, 'vocab.txt')

token_dict = {}
with open(vocab_path, 'r', encoding='utf8') as f:
    for line in f:
        token = line.strip()
        token_dict[token] = len(token_dict)

tokenizer = Tokenizer(vocab_path, do_lower_case=True)
model = build_transformer_model(config_path, checkpoint_path, with_mlm=True)

# 用mlm模型预测被mask掉的部分
token_ids, segment_ids = tokenizer.encode(u'中华人民共和国简称中国')
# mask掉“人民”
token_ids[3] = token_ids[4] = token_dict['[MASK]']
probas = model.predict([np.array([token_ids]), np.array([segment_ids])])[0]
print(tokenizer.decode(probas[3:5].argmax(axis=1)))
# output: 人民

token_ids, segment_ids = tokenizer.encode(u'中国的首都是北京')
# mask掉“中国”
token_ids[1] = token_ids[2] = tokenizer._token_dict['[MASK]']
probas = model.predict([np.array([token_ids]), np.array([segment_ids])])[0]
print(tokenizer.decode(probas[1:3].argmax(axis=1)))
# output: 中国


个人资料
时海
等级:8
文章:272篇
访问:16.0w
排名: 2
上一篇: 数据可视化面试题
下一篇:tensorflow2.0创建Tensor基础操作
标签: bert、完形填空、bert4keras、面试题
隐藏