计算机系统应用教程网站

网站首页 > 技术文章 正文

提示学习系列:P-tuning v2微调BERT实现文本多分类

btikc 2024-09-11 02:02:00 技术文章 15 ℃ 0 评论

关键词:提示学习,P-tuning v2,BERT

前言

P-tuning v2是清华团队在P-tuning基础上提出的一种提示微调大模型方法,它旨在解决提示学习在小尺寸模型上效果不佳,以及无法对下游NLU任务通用的问题,本文对该方法进行简要介绍和实践。


内容摘要

  • P-tuning v2理论方法简介
  • P-tuning v2微调BERT实践
  • P-tuning v2、PET、Fine-Tuning效果对比

P-tuning v2理论方法简介

相比于现有的Prompt tuning方式,P-tuning v2的调整主要体现在:

  • 1.为了增强对下游任务的通用性,使用类似Fine-tuning的[CLS]为作为任务的预测表征
  • 2.引入Deep Prompt Tuning,在Transformer的每一层Block中的输入层,对输入添加一定长度的前缀Prompt Embedding,让模型自适应学习Prompt的表征

模型的结构图如下

以330M的BERT预训练模型为例,Transformer的Encoder模块一共12层,token的维度表征为768,设置提示长度为20,则要学习的连续提示Embedding表征为12 * [20, 768],相比于P-tuning v1可学习的参数数量明显增多,同时这些参数嵌入在模型网络的每一层,相比于P-tuning v1仅在输入层添加参数,中间网络没有任何参数添加的形式,P-tuning v2中参数对模型结果的影响更加直接。


P-tuning v2微调BERT实践

论文团队只提供了P-tuning v2在BERT结构上的方案和源码,在源码中作者并没有改造Bert的代码结构来给每一层创建随机Embedding再做自注意力,而是采用了类似交叉注意力的方式,对每一层的Key和Value额外拼接了一定长度可学习的Prompt Embedding,让输入Query和拼接Prompt后的Key、Value做交叉注意力。

采用HuggingFace的模型类源码能够很容易的实现,代码如下

batch_size = input_ids.shape[0]
past_key_values = self.get_prompt(batch_size=batch_size)
prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.bert.device)
attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)

outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            past_key_values=past_key_values,
        )

其中past_key_values是额外给Key,Value添加的Prompt Embedding,attention_mask也同步增加前缀。

在Bert内部,会把past_key_values拼接在每一层经过Key、Value线性变换后的向量的前面,代码如下,原始输入分别经过Key、Value线性映射后直接在头部拼接可学习的参数Embedding,来达到P-tuning v2的效果。

        elif past_key_value is not None:
            key_layer = self.transpose_for_scores(self.key(hidden_states))
            value_layer = self.transpose_for_scores(self.value(hidden_states))
            key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
            value_layer = torch.cat([past_key_value[1], value_layer], dim=2)

除此之外P-tuning v2和Fine-tuning的实现无明显区别,取[CLS]的池化输出计算损失。本文采用和前文提示学习系列:P-Tuning微调BERT/GPT-2实现文本多分类同样的数据集,在新闻数据上通过P-tuning v2提示微调来实现文本多分类,模型网络代码如下

class Model(nn.Module):
    def __init__(self, num_labels, pre_seq_len=40, hidden_size=PRE_TRAIN_CONFIG.hidden_size, hidden_dropout_prob=0.1):
        super(Model, self).__init__()
        self.num_labels = num_labels
        self.pre_seq_len = pre_seq_len
        self.n_layer = PRE_TRAIN_CONFIG.num_hidden_layers
        self.n_head = PRE_TRAIN_CONFIG.num_attention_heads
        self.n_embd = PRE_TRAIN_CONFIG.hidden_size // PRE_TRAIN_CONFIG.num_attention_heads
        self.bert = PRE_TRAIN
        self.dropout = torch.nn.Dropout(hidden_dropout_prob)
        self.classifier = torch.nn.Linear(hidden_size, num_labels)
        for param in self.bert.parameters():
            param.requires_grad = False
        self.prefix_tokens = torch.arange(self.pre_seq_len).long()
        self.prefix_encoder = PrefixEncoder(self.pre_seq_len, PRE_TRAIN_CONFIG.num_hidden_layers,
                                            PRE_TRAIN_CONFIG.hidden_size)
        requires_grad_param = 0
        total_param = 0
        for name, param in self.named_parameters():
            total_param += param.numel()
            if param.requires_grad:
                requires_grad_param += param.numel()

        print('total param: {}, trainable param: {}, trainable/total: {}'.format(total_param, requires_grad_param,
                                                                                 requires_grad_param / total_param))

    def get_prompt(self, batch_size):
        # TODO 统一构造embedding并且改造为对应的维度
        prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.bert.device)
        past_key_values = self.prefix_encoder(prefix_tokens)
        past_key_values = past_key_values.view(
            batch_size,  # 128
            self.pre_seq_len,  # 40
            self.n_layer * 2,  # 24
            self.n_head,  #
            self.n_embd
        )
        past_key_values = self.dropout(past_key_values)
        # TODO 根据n_layer * 2分为2个
        past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2)
        return past_key_values

    def forward(
            self,
            input_ids=None,
            attention_mask=None,
            token_type_ids=None,
            position_ids=None,
            head_mask=None,
            inputs_embeds=None,
            labels=None,
            output_attentions=None,
            output_hidden_states=None,
            return_dict=None
    ):
        batch_size = input_ids.shape[0]
        past_key_values = self.get_prompt(batch_size=batch_size)
        prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.bert.device)
        attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)

        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            past_key_values=past_key_values,
        )

        pooled_output = outputs[1]

        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)

        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

        return SequenceClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

其中PrefixEncoder为创建的随机初始化的Prompt Embedding,实现如下

class PrefixEncoder(nn.Module):
    def __init__(self, pre_seq_len, num_hidden_layers, hidden_size, prefix_projection=False):
        super().__init__()
        self.prefix_projection = prefix_projection  # false
        if self.prefix_projection:
            # Use a two-layer MLP to encode the prefix
            self.embedding = torch.nn.Embedding(pre_seq_len, hidden_size)
            self.trans = torch.nn.Sequential(
                torch.nn.Linear(hidden_size, prefix_hidden_size),
                torch.nn.Tanh(),
                torch.nn.Linear(prefix_hidden_size, num_hidden_layers * 2 * hidden_size)
            )
        else:
            self.embedding = torch.nn.Embedding(pre_seq_len, num_hidden_layers * 2 * hidden_size)

    def forward(self, prefix: torch.Tensor):
        if self.prefix_projection:
            prefix_tokens = self.embedding(prefix)
            past_key_values = self.trans(prefix_tokens)
        else:
            past_key_values = self.embedding(prefix)
        return past_key_values

根据Bert的层数、隐藏层维度,Prompt长度来确定需要训练的参数量,初始化之后为了和Bert注意力源码中的Key和Value拼接,需要额外做注意力头维度分割和转置。


P-tuning v2、PET、Fine-Tuning效果对比

笔者在不同样本数量下对Bert采用P-tuning v2,PET和Fine-Tuning微调,其中P-tuning v2冻结大模型,仅微调Prompt Embedding,PET和Fine-Tuning采用全参微调,以20000条样本为例,F1和模型训练参量对比如下

其中P-tuning v2的预测精度略高于Fine-Tuning,明显高于PET,同时训练参数量为74万,而其他两种全参微调参数量达到1亿,P-tuning v2仅需要约0.1%的参数微调量就能达到全参微调的效果。笔者在不同样本量的多次训练测试下,P-tuning v2的F1值接近Fine-Tuning,仍普遍低于Fine-Tuning,但是明显优秀于PET和P-tuning v1。

本文暂时没有评论,来添加一个吧(●'◡'●)

欢迎 发表评论:

最近发表
标签列表