1、bert4keras安装:
pip install bert4keras -i https://pypi.tuna.tsinghua.edu.cn/simple/
2、bert预训练模型下载:
https://github.com/google-research/bert/
3、代码如下:
# coding=utf-8 import os import numpy as np from bert4keras.models import build_transformer_model from bert4keras.tokenizers import Tokenizer pretrained_path = '/path/to/chinese_L-12_H-768_A-12' config_path = os.path.join(pretrained_path, 'bert_config.json') checkpoint_path = os.path.join(pretrained_path, 'bert_model.ckpt') vocab_path = os.path.join(pretrained_path, 'vocab.txt') token_dict = {} with open(vocab_path, 'r', encoding='utf8') as f: for line in f: token = line.strip() token_dict[token] = len(token_dict) tokenizer = Tokenizer(vocab_path, do_lower_case=True) model = build_transformer_model(config_path, checkpoint_path, with_mlm=True) # 用mlm模型预测被mask掉的部分 token_ids, segment_ids = tokenizer.encode(u'中华人民共和国简称中国') # mask掉“人民” token_ids[3] = token_ids[4] = token_dict['[MASK]'] probas = model.predict([np.array([token_ids]), np.array([segment_ids])])[0] print(tokenizer.decode(probas[3:5].argmax(axis=1))) # output: 人民 token_ids, segment_ids = tokenizer.encode(u'中国的首都是北京') # mask掉“中国” token_ids[1] = token_ids[2] = tokenizer._token_dict['[MASK]'] probas = model.predict([np.array([token_ids]), np.array([segment_ids])])[0] print(tokenizer.decode(probas[1:3].argmax(axis=1))) # output: 中国