计算机系统应用教程网站

网站首页 > 技术文章 正文

脑血管病知识图谱--2 模型训练 脑血管疾病讲解

btikc 2024-09-27 01:20:34 技术文章 2 ℃ 0 评论

模型训练参考房产部分

注:keras和tensorflow版本需要对齐

参考网址:https://docs.floydhub.com/guides/environments/

模型定义

def build_lstm_crf_model(num_cates, seq_len, vocab_size, model_opts=dict()):
    opts = {
        'emb_size': 256,
        'emb_trainable': True,
        'emb_matrix': None,
        'lstm_units': 256,
        'optimizer': keras.optimizers.Adam()
    }
    opts.update(model_opts)


    input_seq = Input(shape=(seq_len,), dtype='int32')
    if opts.get('emb_matrix') is not None:
        embedding = Embedding(vocab_size, opts['emb_size'],
                              weights=[opts['emb_matrix']],
                              trainable=opts['emb_trainable'])
    else:
        embedding = Embedding(vocab_size, opts['emb_size'])
    x = embedding(input_seq)
    lstm = LSTM(opts['lstm_units'], return_sequences=True)
    x = Bidirectional(lstm)(x)
    crf = CRF(num_cates, sparse_target=True)
    output = crf(x)


    model = Model(input_seq, output)
    model.summary()
    model.compile(opts['optimizer'], loss=crf.loss_function, metrics=[crf.accuracy, 'acc'])
    return model
def build_model():
    K.clear_session()


    num_ent_classes = len(ENTITIES) + 1
    ent_emb_size = 2
    emb_size = w2v_embeddings.shape[-1]
    vocab_size = len(word2idx)


    inp_sent = Input(shape=(max_len,), dtype='int32')
    inp_ent = Input(shape=(max_len,), dtype='int32')
    inp_f_ent = Input(shape=(max_len,), dtype='float32')
    inp_t_ent = Input(shape=(max_len,), dtype='float32')
    inp_ent_dist = Input(shape=(1,), dtype='float32')
    f_ent = Lambda(lambda x: K.expand_dims(x))(inp_f_ent)
    t_ent = Lambda(lambda x: K.expand_dims(x))(inp_t_ent)


    ent_embed = Embedding(num_ent_classes, ent_emb_size)(inp_ent)
    sent_embed = Embedding(vocab_size, emb_size, weights=[w2v_embeddings], trainable=False)(inp_sent)


    x = Concatenate()([sent_embed, ent_embed])
    x = Conv1D(64, 1, padding='same', activation='relu')(x)


    f_res = layers.multiply([f_ent, x])
    t_res = layers.multiply([t_ent, x])


    conv = Conv1D(64, 3, padding='same', activation='relu')
    f_x = conv(x)
    t_x = conv(x)
    f_x = layers.add([f_x, f_res])
    t_x = layers.add([t_x, t_res])


    f_res = layers.multiply([f_ent, f_x])
    t_res = layers.multiply([t_ent, t_x])
    conv = Conv1D(64, 3, padding='same', activation='relu')
    f_x = conv(x)
    t_x = conv(x)
    f_x = layers.add([f_x, f_res])
    t_x = layers.add([t_x, t_res])


    f_res = layers.multiply([f_ent, f_x])
    t_res = layers.multiply([t_ent, t_x])
    conv = Conv1D(64, 3, padding='same', activation='relu')
    f_x = conv(x)
    t_x = conv(x)
    f_x = layers.add([f_x, f_res])
    t_x = layers.add([t_x, t_res])


    conv = Conv1D(64, 3, activation='relu')
    f_x = MaxPool1D(3)(conv(f_x))
    t_x = MaxPool1D(3)(conv(t_x))


    conv = Conv1D(64, 3, activation='relu')
    f_x = MaxPool1D(3)(conv(f_x))
    t_x = MaxPool1D(3)(conv(t_x))


    f_x = Flatten()(f_x)
    t_x = Flatten()(t_x)


    x = Concatenate()([f_x, t_x, inp_ent_dist])
    x = Dense(256, activation='relu')(x)
    x = Dense(1, activation='sigmoid')(x)


    model = Model([inp_sent, inp_ent, inp_f_ent, inp_t_ent, inp_ent_dist], x)
    return model

模型训练与测试

import pickle


import numpy as np
from keras.callbacks import ModelCheckpoint
from sklearn.model_selection import ShuffleSplit
from gensim.models import Word2Vec


from ner.bilstm_crf_model import build_lstm_crf_model
from ner.evaluator import Evaluator
from ner.utils import ENTITIES, Documents, Dataset, SentenceExtractor, make_predictions
from preprocess.util.data_util import get_ann_text
from resource.config import *


train = False


train_dir = 'data/train/'
test_dir = 'data/test/'


ent2idx = dict(zip(ENTITIES, range(1, len(ENTITIES) + 1)))
idx2ent = dict([(v, k) for k, v in ent2idx.items()])
num_cates = max(ent2idx.values()) + 1
sent_len = 64
vocab_size = 3000
emb_size = 100
sent_pad = 10
sent_extrator = SentenceExtractor(window_size=sent_len, pad_size=sent_pad)


model_dir = 'model/model.h5'
seq_len = sent_len + 2 * sent_pad




if train:
    get_ann_text(z1_text_dir, z1_entity_ann, train_dir, ratio=1)
    get_ann_text(z4_text_dir, z4_entity_ann, train_dir, ratio=1)


    docs = Documents(data_dir=train_dir)
    # print(docs.doc_ids)
    rs = ShuffleSplit(n_splits=1, test_size=20, random_state=2020)
    train_doc_ids, test_doc_ids = next(rs.split(docs))
    train_docs, test_docs = docs[train_doc_ids], docs[test_doc_ids]
    # print(test_docs[0].text)


    train_sents = sent_extrator(train_docs)


    train_data = Dataset(train_sents, cate2idx=ent2idx)
    train_data.build_vocab_dict(vocab_size=vocab_size)


    vocab_size = len(train_data.word2idx)


    w2v_train_sents = []
    for doc in docs:
        w2v_train_sents.append(list(doc.text))
    w2v_model = Word2Vec(w2v_train_sents, size=emb_size)


    w2v_embeddings = np.zeros((vocab_size, emb_size))
    for char, char_idx in train_data.word2idx.items():
        if char in w2v_model.wv:
            w2v_embeddings[char_idx] = w2v_model.wv[char]
    with open('data/dict.pkl', 'wb') as outp:
        pickle.dump((train_data.word2idx, w2v_embeddings), outp)


    train_X, train_y = train_data[:]
    print('train_X.shape', train_X.shape)
    print('train_y.shape', train_y.shape)
    checkpoint = ModelCheckpoint(model_dir, monitor='acc', verbose=1, save_best_only=True, mode='auto')
    model = build_lstm_crf_model(num_cates, seq_len=seq_len, vocab_size=vocab_size,
                                 model_opts={'emb_matrix': w2v_embeddings, 'emb_size': 100, 'emb_trainable': False})
    model.fit(train_X, train_y, batch_size=64, epochs=10, callbacks=[checkpoint])
else:
    # test
    # get_ann_text(z1_text_dir, z1_entity_ann, train_dir, ratio=1)
    # get_ann_text(z4_text_dir, z4_entity_ann, train_dir, ratio=1)


    test_docs = Documents(data_dir=test_dir)
    test_sents = sent_extrator(test_docs)
    with open('data/dict.pkl', 'rb') as inp:
        (word2idx, w2v_embeddings) = pickle.load(inp)


    test_data = Dataset(test_sents, word2idx=word2idx, cate2idx=ent2idx)
    test_X, _ = test_data[:]
    # print(w2v_embeddings.shape[0])
    # print(len(word2idx))
    model = build_lstm_crf_model(num_cates, seq_len=seq_len, vocab_size=len(word2idx),
                                 model_opts={'emb_matrix': w2v_embeddings, 'emb_size': 100, 'emb_trainable': False})
    model.load_weights(model_dir)


    preds = model.predict(test_X, batch_size=64, verbose=True)
    pred_docs = make_predictions(preds, test_data, sent_pad, test_docs, idx2ent)


    for k, v in pred_docs.items():
        fname = '{}.ann'.format(v.doc_id)
        test_file = os.path.join(ent_result, fname)
        with open(test_file, 'w', encoding='utf-8') as f:
            for i in v.ents:
                f.write('{}\t{} {} {}\t{}\n'.format(i.ent_id, i.category, i.start_pos, i.end_pos, i.text))


    f_score, precision, recall = Evaluator.f1_score(test_docs, pred_docs)
    print('f_score: ', f_score)
    print('precision: ', precision)
    print('recall: ', recall)


    # sample_doc_id = list(pred_docs.keys())[0]
    # print(test_docs[sample_doc_id].text)


    # print(pred_docs[sample_doc_id].text)
def generate_submission(preds, entity_pairs, threshold):
    doc_rels = defaultdict(set)
    for p, ent_pair in zip(preds, entity_pairs):
        if p >= threshold:
            doc_id = ent_pair.doc_id
            f_ent_id = ent_pair.from_ent.ent_id
            t_ent_id = ent_pair.to_ent.ent_id
            category = ent_pair.from_ent.category + '-' + ent_pair.to_ent.category
            doc_rels[doc_id].add((f_ent_id, t_ent_id, category))
    submits = dict()
    tot_num_rels = 0
    for doc_id, rels in doc_rels.items():
        output_str = ''
        for i, rel in enumerate(rels):
            tot_num_rels += 1
            line = 'R{}\t{} Arg1:{} Arg2:{}\n'.format(i + 1, rel[2], rel[0], rel[1])
            output_str += line
        submits[doc_id] = output_str
    print('Total number of relations: {}. In average {} relations per doc.'.format(tot_num_rels,
                                                                                   tot_num_rels / len(submits)))
    return submits




def output_submission(dir_name, submits, test_dir):


    for doc_id, rels_str in submits.items():
        fname = '{}.ann'.format(doc_id)
        test_file = os.path.join(test_dir, fname)
        with open(test_file, 'r', encoding='utf-8') as f:
            content = f.read()
        content += rels_str
        with open(os.path.join(dir_name, fname), 'w', encoding='utf8') as f:
            f.writelines(content)




if __name__ == '__main__':


    train = False
    filepath = 'model/model.h5'
    sent_extractor = SentenceExtractor(sent_split_char='。', window_size=2, rel_types=RELATIONS,
                                       filter_no_rel_candidates_sents=True)
    max_len = 150
    all_rel_types = set([tuple(re.split('[-]', rel)) for rel in RELATIONS])
    ent2idx = dict(zip(ENTITIES, range(1, len(ENTITIES) + 1)))


    ent_pair_extractor = EntityPairsExtractor(all_rel_types, max_len=max_len)


    if train:
        get_ann_text(z4_text_dir, ent_rel, train_data_dir, ratio=0.1)
        get_ann_text(z1_text_dir, ent_rel, test_data_dir, ratio=0.1)


        train_docs = Documents(train_data_dir)
        test_docs = Documents(test_data_dir)


        # print(len(train_docs))
        # print(train_docs.doc_ids)


        doc_ent_pair_ids = set()
        for doc in train_docs:
            for rel in doc.rels:
                doc_ent_pair_id = (doc.doc_id, rel.ent1.ent_id, rel.ent2.ent_id)
                doc_ent_pair_ids.add(doc_ent_pair_id)


        train_sents = sent_extractor(train_docs)
        test_sents = sent_extractor(test_docs)


        # print(train_sents[0].text)


        train_entity_pairs = ent_pair_extractor(train_sents)
        test_entity_pairs = ent_pair_extractor(test_sents)
        word2idx = {'<pad>': 0, '<unk>': 1}
        word2idx, idx2word, w2v_embeddings = train_word_embeddings(
            entity_pairs=chain(train_entity_pairs, test_entity_pairs),
            word2idx=word2idx,
            size=100,
            iter=10
        )
        with open('data/dict.pkl', 'wb') as outp:
            pickle.dump((word2idx, w2v_embeddings), outp)


        model = build_model()


        train_data = Dataset(train_entity_pairs, doc_ent_pair_ids, word2idx=word2idx, max_len=max_len, cate2idx=ent2idx)
        tr_sent, tr_ent, tr_f_ent, tr_t_ent, tr_ent_dist, tr_y = train_data[:]
        model.compile('adam', loss='binary_crossentropy', metrics=['acc'])


        checkpoint = ModelCheckpoint(filepath, monitor='acc', verbose=1, save_best_only=True, mode='auto')
        model.fit(x=[tr_sent, tr_ent, tr_f_ent, tr_t_ent, tr_ent_dist],
                  y=tr_y, batch_size=32, epochs=6, callbacks=[checkpoint])


    else:
        def split_data(output_dir, split):
            # split = 20
            # output_dir = 'data/output/'
            ent_list = os.listdir(ent_rel)
            filename_list = [ent.split('.')[0] for ent in ent_list]
            filename_list = [filename_list[i:i + split] for i in range(0, len(filename_list), split)]
            # print(filename_list)
            # print(len(filename_list))


            z1_text_list = os.listdir(z1_text_dir)
            z4_text_list = os.listdir(z4_text_dir)
            z1_list = [text.split('.')[0] for text in z1_text_list]
            z4_list = [text.split('.')[0] for text in z4_text_list]


            folder_list = []
            for i in range(len(filename_list)):
                new_folder_path = output_dir + '%s' % i
                folder_list.append(new_folder_path)
                if not os.path.exists(new_folder_path):
                    os.mkdir(new_folder_path)
                if not os.listdir(new_folder_path):
                    continue


                for filename in filename_list[i]:
                    old_ann_path = os.path.join(ent_rel, filename + '.ann')
                    shutil.copy(old_ann_path, new_folder_path + filename + '.ann')


                    old_text_path = ''
                    if filename in z1_list:
                        old_text_path = os.path.join(z1_text_dir, filename + '.txt')
                    if filename in z4_list:
                        old_text_path = os.path.join(z4_text_dir, filename + '.txt')


                    shutil.copy(old_text_path, new_folder_path + filename + '.txt')
            return folder_list


        folder_list = split_data('data/output/', 20)
        print(folder_list)
        for folder in folder_list:
            print(folder)
            test_docs = Documents(folder)


            sents = sent_extractor(test_docs)
            entity_pairs = ent_pair_extractor(sents)


            with open('data/dict.pkl', 'rb') as inp:
                (word2idx, w2v_embeddings) = pickle.load(inp)


            model = build_model()
            model.load_weights(filepath)


            test_data = Dataset(entity_pairs, word2idx=word2idx, max_len=max_len, cate2idx=ent2idx)


            te_sent, te_ent, te_f_ent, te_t_ent, te_ent_dist, te_y = test_data[:]
            preds = model.predict(x=[te_sent, te_ent, te_f_ent, te_t_ent, te_ent_dist], verbose=1)
            # print(preds)
            submits = generate_submission(preds, entity_pairs, 0.5)
            # print(submits)


            output_submission(rel_result, submits, folder)



Tags:

本文暂时没有评论,来添加一个吧(●'◡'●)

欢迎 发表评论:

最近发表
标签列表