网站首页 > 技术文章 正文
模型训练参考房产部分
注:keras和tensorflow版本需要对齐
参考网址:https://docs.floydhub.com/guides/environments/
模型定义
def build_lstm_crf_model(num_cates, seq_len, vocab_size, model_opts=dict()):
opts = {
'emb_size': 256,
'emb_trainable': True,
'emb_matrix': None,
'lstm_units': 256,
'optimizer': keras.optimizers.Adam()
}
opts.update(model_opts)
input_seq = Input(shape=(seq_len,), dtype='int32')
if opts.get('emb_matrix') is not None:
embedding = Embedding(vocab_size, opts['emb_size'],
weights=[opts['emb_matrix']],
trainable=opts['emb_trainable'])
else:
embedding = Embedding(vocab_size, opts['emb_size'])
x = embedding(input_seq)
lstm = LSTM(opts['lstm_units'], return_sequences=True)
x = Bidirectional(lstm)(x)
crf = CRF(num_cates, sparse_target=True)
output = crf(x)
model = Model(input_seq, output)
model.summary()
model.compile(opts['optimizer'], loss=crf.loss_function, metrics=[crf.accuracy, 'acc'])
return model
def build_model():
K.clear_session()
num_ent_classes = len(ENTITIES) + 1
ent_emb_size = 2
emb_size = w2v_embeddings.shape[-1]
vocab_size = len(word2idx)
inp_sent = Input(shape=(max_len,), dtype='int32')
inp_ent = Input(shape=(max_len,), dtype='int32')
inp_f_ent = Input(shape=(max_len,), dtype='float32')
inp_t_ent = Input(shape=(max_len,), dtype='float32')
inp_ent_dist = Input(shape=(1,), dtype='float32')
f_ent = Lambda(lambda x: K.expand_dims(x))(inp_f_ent)
t_ent = Lambda(lambda x: K.expand_dims(x))(inp_t_ent)
ent_embed = Embedding(num_ent_classes, ent_emb_size)(inp_ent)
sent_embed = Embedding(vocab_size, emb_size, weights=[w2v_embeddings], trainable=False)(inp_sent)
x = Concatenate()([sent_embed, ent_embed])
x = Conv1D(64, 1, padding='same', activation='relu')(x)
f_res = layers.multiply([f_ent, x])
t_res = layers.multiply([t_ent, x])
conv = Conv1D(64, 3, padding='same', activation='relu')
f_x = conv(x)
t_x = conv(x)
f_x = layers.add([f_x, f_res])
t_x = layers.add([t_x, t_res])
f_res = layers.multiply([f_ent, f_x])
t_res = layers.multiply([t_ent, t_x])
conv = Conv1D(64, 3, padding='same', activation='relu')
f_x = conv(x)
t_x = conv(x)
f_x = layers.add([f_x, f_res])
t_x = layers.add([t_x, t_res])
f_res = layers.multiply([f_ent, f_x])
t_res = layers.multiply([t_ent, t_x])
conv = Conv1D(64, 3, padding='same', activation='relu')
f_x = conv(x)
t_x = conv(x)
f_x = layers.add([f_x, f_res])
t_x = layers.add([t_x, t_res])
conv = Conv1D(64, 3, activation='relu')
f_x = MaxPool1D(3)(conv(f_x))
t_x = MaxPool1D(3)(conv(t_x))
conv = Conv1D(64, 3, activation='relu')
f_x = MaxPool1D(3)(conv(f_x))
t_x = MaxPool1D(3)(conv(t_x))
f_x = Flatten()(f_x)
t_x = Flatten()(t_x)
x = Concatenate()([f_x, t_x, inp_ent_dist])
x = Dense(256, activation='relu')(x)
x = Dense(1, activation='sigmoid')(x)
model = Model([inp_sent, inp_ent, inp_f_ent, inp_t_ent, inp_ent_dist], x)
return model
模型训练与测试
import pickle
import numpy as np
from keras.callbacks import ModelCheckpoint
from sklearn.model_selection import ShuffleSplit
from gensim.models import Word2Vec
from ner.bilstm_crf_model import build_lstm_crf_model
from ner.evaluator import Evaluator
from ner.utils import ENTITIES, Documents, Dataset, SentenceExtractor, make_predictions
from preprocess.util.data_util import get_ann_text
from resource.config import *
train = False
train_dir = 'data/train/'
test_dir = 'data/test/'
ent2idx = dict(zip(ENTITIES, range(1, len(ENTITIES) + 1)))
idx2ent = dict([(v, k) for k, v in ent2idx.items()])
num_cates = max(ent2idx.values()) + 1
sent_len = 64
vocab_size = 3000
emb_size = 100
sent_pad = 10
sent_extrator = SentenceExtractor(window_size=sent_len, pad_size=sent_pad)
model_dir = 'model/model.h5'
seq_len = sent_len + 2 * sent_pad
if train:
get_ann_text(z1_text_dir, z1_entity_ann, train_dir, ratio=1)
get_ann_text(z4_text_dir, z4_entity_ann, train_dir, ratio=1)
docs = Documents(data_dir=train_dir)
# print(docs.doc_ids)
rs = ShuffleSplit(n_splits=1, test_size=20, random_state=2020)
train_doc_ids, test_doc_ids = next(rs.split(docs))
train_docs, test_docs = docs[train_doc_ids], docs[test_doc_ids]
# print(test_docs[0].text)
train_sents = sent_extrator(train_docs)
train_data = Dataset(train_sents, cate2idx=ent2idx)
train_data.build_vocab_dict(vocab_size=vocab_size)
vocab_size = len(train_data.word2idx)
w2v_train_sents = []
for doc in docs:
w2v_train_sents.append(list(doc.text))
w2v_model = Word2Vec(w2v_train_sents, size=emb_size)
w2v_embeddings = np.zeros((vocab_size, emb_size))
for char, char_idx in train_data.word2idx.items():
if char in w2v_model.wv:
w2v_embeddings[char_idx] = w2v_model.wv[char]
with open('data/dict.pkl', 'wb') as outp:
pickle.dump((train_data.word2idx, w2v_embeddings), outp)
train_X, train_y = train_data[:]
print('train_X.shape', train_X.shape)
print('train_y.shape', train_y.shape)
checkpoint = ModelCheckpoint(model_dir, monitor='acc', verbose=1, save_best_only=True, mode='auto')
model = build_lstm_crf_model(num_cates, seq_len=seq_len, vocab_size=vocab_size,
model_opts={'emb_matrix': w2v_embeddings, 'emb_size': 100, 'emb_trainable': False})
model.fit(train_X, train_y, batch_size=64, epochs=10, callbacks=[checkpoint])
else:
# test
# get_ann_text(z1_text_dir, z1_entity_ann, train_dir, ratio=1)
# get_ann_text(z4_text_dir, z4_entity_ann, train_dir, ratio=1)
test_docs = Documents(data_dir=test_dir)
test_sents = sent_extrator(test_docs)
with open('data/dict.pkl', 'rb') as inp:
(word2idx, w2v_embeddings) = pickle.load(inp)
test_data = Dataset(test_sents, word2idx=word2idx, cate2idx=ent2idx)
test_X, _ = test_data[:]
# print(w2v_embeddings.shape[0])
# print(len(word2idx))
model = build_lstm_crf_model(num_cates, seq_len=seq_len, vocab_size=len(word2idx),
model_opts={'emb_matrix': w2v_embeddings, 'emb_size': 100, 'emb_trainable': False})
model.load_weights(model_dir)
preds = model.predict(test_X, batch_size=64, verbose=True)
pred_docs = make_predictions(preds, test_data, sent_pad, test_docs, idx2ent)
for k, v in pred_docs.items():
fname = '{}.ann'.format(v.doc_id)
test_file = os.path.join(ent_result, fname)
with open(test_file, 'w', encoding='utf-8') as f:
for i in v.ents:
f.write('{}\t{} {} {}\t{}\n'.format(i.ent_id, i.category, i.start_pos, i.end_pos, i.text))
f_score, precision, recall = Evaluator.f1_score(test_docs, pred_docs)
print('f_score: ', f_score)
print('precision: ', precision)
print('recall: ', recall)
# sample_doc_id = list(pred_docs.keys())[0]
# print(test_docs[sample_doc_id].text)
# print(pred_docs[sample_doc_id].text)
def generate_submission(preds, entity_pairs, threshold):
doc_rels = defaultdict(set)
for p, ent_pair in zip(preds, entity_pairs):
if p >= threshold:
doc_id = ent_pair.doc_id
f_ent_id = ent_pair.from_ent.ent_id
t_ent_id = ent_pair.to_ent.ent_id
category = ent_pair.from_ent.category + '-' + ent_pair.to_ent.category
doc_rels[doc_id].add((f_ent_id, t_ent_id, category))
submits = dict()
tot_num_rels = 0
for doc_id, rels in doc_rels.items():
output_str = ''
for i, rel in enumerate(rels):
tot_num_rels += 1
line = 'R{}\t{} Arg1:{} Arg2:{}\n'.format(i + 1, rel[2], rel[0], rel[1])
output_str += line
submits[doc_id] = output_str
print('Total number of relations: {}. In average {} relations per doc.'.format(tot_num_rels,
tot_num_rels / len(submits)))
return submits
def output_submission(dir_name, submits, test_dir):
for doc_id, rels_str in submits.items():
fname = '{}.ann'.format(doc_id)
test_file = os.path.join(test_dir, fname)
with open(test_file, 'r', encoding='utf-8') as f:
content = f.read()
content += rels_str
with open(os.path.join(dir_name, fname), 'w', encoding='utf8') as f:
f.writelines(content)
if __name__ == '__main__':
train = False
filepath = 'model/model.h5'
sent_extractor = SentenceExtractor(sent_split_char='。', window_size=2, rel_types=RELATIONS,
filter_no_rel_candidates_sents=True)
max_len = 150
all_rel_types = set([tuple(re.split('[-]', rel)) for rel in RELATIONS])
ent2idx = dict(zip(ENTITIES, range(1, len(ENTITIES) + 1)))
ent_pair_extractor = EntityPairsExtractor(all_rel_types, max_len=max_len)
if train:
get_ann_text(z4_text_dir, ent_rel, train_data_dir, ratio=0.1)
get_ann_text(z1_text_dir, ent_rel, test_data_dir, ratio=0.1)
train_docs = Documents(train_data_dir)
test_docs = Documents(test_data_dir)
# print(len(train_docs))
# print(train_docs.doc_ids)
doc_ent_pair_ids = set()
for doc in train_docs:
for rel in doc.rels:
doc_ent_pair_id = (doc.doc_id, rel.ent1.ent_id, rel.ent2.ent_id)
doc_ent_pair_ids.add(doc_ent_pair_id)
train_sents = sent_extractor(train_docs)
test_sents = sent_extractor(test_docs)
# print(train_sents[0].text)
train_entity_pairs = ent_pair_extractor(train_sents)
test_entity_pairs = ent_pair_extractor(test_sents)
word2idx = {'<pad>': 0, '<unk>': 1}
word2idx, idx2word, w2v_embeddings = train_word_embeddings(
entity_pairs=chain(train_entity_pairs, test_entity_pairs),
word2idx=word2idx,
size=100,
iter=10
)
with open('data/dict.pkl', 'wb') as outp:
pickle.dump((word2idx, w2v_embeddings), outp)
model = build_model()
train_data = Dataset(train_entity_pairs, doc_ent_pair_ids, word2idx=word2idx, max_len=max_len, cate2idx=ent2idx)
tr_sent, tr_ent, tr_f_ent, tr_t_ent, tr_ent_dist, tr_y = train_data[:]
model.compile('adam', loss='binary_crossentropy', metrics=['acc'])
checkpoint = ModelCheckpoint(filepath, monitor='acc', verbose=1, save_best_only=True, mode='auto')
model.fit(x=[tr_sent, tr_ent, tr_f_ent, tr_t_ent, tr_ent_dist],
y=tr_y, batch_size=32, epochs=6, callbacks=[checkpoint])
else:
def split_data(output_dir, split):
# split = 20
# output_dir = 'data/output/'
ent_list = os.listdir(ent_rel)
filename_list = [ent.split('.')[0] for ent in ent_list]
filename_list = [filename_list[i:i + split] for i in range(0, len(filename_list), split)]
# print(filename_list)
# print(len(filename_list))
z1_text_list = os.listdir(z1_text_dir)
z4_text_list = os.listdir(z4_text_dir)
z1_list = [text.split('.')[0] for text in z1_text_list]
z4_list = [text.split('.')[0] for text in z4_text_list]
folder_list = []
for i in range(len(filename_list)):
new_folder_path = output_dir + '%s' % i
folder_list.append(new_folder_path)
if not os.path.exists(new_folder_path):
os.mkdir(new_folder_path)
if not os.listdir(new_folder_path):
continue
for filename in filename_list[i]:
old_ann_path = os.path.join(ent_rel, filename + '.ann')
shutil.copy(old_ann_path, new_folder_path + filename + '.ann')
old_text_path = ''
if filename in z1_list:
old_text_path = os.path.join(z1_text_dir, filename + '.txt')
if filename in z4_list:
old_text_path = os.path.join(z4_text_dir, filename + '.txt')
shutil.copy(old_text_path, new_folder_path + filename + '.txt')
return folder_list
folder_list = split_data('data/output/', 20)
print(folder_list)
for folder in folder_list:
print(folder)
test_docs = Documents(folder)
sents = sent_extractor(test_docs)
entity_pairs = ent_pair_extractor(sents)
with open('data/dict.pkl', 'rb') as inp:
(word2idx, w2v_embeddings) = pickle.load(inp)
model = build_model()
model.load_weights(filepath)
test_data = Dataset(entity_pairs, word2idx=word2idx, max_len=max_len, cate2idx=ent2idx)
te_sent, te_ent, te_f_ent, te_t_ent, te_ent_dist, te_y = test_data[:]
preds = model.predict(x=[te_sent, te_ent, te_f_ent, te_t_ent, te_ent_dist], verbose=1)
# print(preds)
submits = generate_submission(preds, entity_pairs, 0.5)
# print(submits)
output_submission(rel_result, submits, folder)
- 上一篇: GPU 显存优化指南:深度解析与实战技巧
- 下一篇: 动作识别与关系推理 动作识别与关系推理的区别
猜你喜欢
- 2024-09-27 后端思维篇:如何抽一个观察者模板
- 2024-09-27 使用Hourglass网络来理解人体姿态
- 2024-09-27 SM框架整合篇 ssm框架crud
- 2024-09-27 动物分类器 动物分类网
- 2024-09-27 图表显示日志离线信息 离线日志正在运行
- 2024-09-27 比用Pytorch框架快200倍!0.76秒后,笔记本上的CNN就搞定了MNIST
- 2024-09-27 卷积神经网络背后的数学 卷积神经网络教学视频
- 2024-09-27 多层级遇到多兴趣:快手、武汉大学用于序列推荐的多粒度神经模型
- 2024-09-27 使用分割来寻找疑似结节(13) 分割检测
- 2024-09-27 买药秒送 JADE动态线程池实践及原理浅析
你 发表评论:
欢迎- 最近发表
-
- 在 Spring Boot 项目中使用 activiti
- 开箱即用-activiti流程引擎(active 流程引擎)
- 在springBoot项目中整合使用activiti
- activiti中的网关是干什么的?(activiti包含网关)
- SpringBoot集成工作流Activiti(完整源码和配套文档)
- Activiti工作流介绍及使用(activiti工作流会签)
- SpringBoot集成工作流Activiti(实际项目演示)
- activiti工作流引擎(activiti工作流引擎怎么用)
- 工作流Activiti初体验及在数据库中生成的表
- Activiti工作流浅析(activiti6.0工作流引擎深度解析)
- 标签列表
-
- oraclesql优化 (66)
- 类的加载机制 (75)
- feignclient (62)
- 一致性hash算法 (71)
- dockfile (66)
- 锁机制 (57)
- javaresponse (60)
- 查看hive版本 (59)
- phpworkerman (57)
- spark算子 (58)
- vue双向绑定的原理 (68)
- springbootget请求 (58)
- docker网络三种模式 (67)
- spring控制反转 (71)
- data:image/jpeg (69)
- base64 (69)
- java分页 (64)
- kibanadocker (60)
- qabstracttablemodel (62)
- java生成pdf文件 (69)
- deletelater (62)
- com.aspose.words (58)
- android.mk (62)
- qopengl (73)
- epoch_millis (61)
本文暂时没有评论,来添加一个吧(●'◡'●)