关键词:提示学习,P-tuning v2,BERT
前言
P-tuning v2是清华团队在P-tuning基础上提出的一种提示微调大模型方法,它旨在解决提示学习在小尺寸模型上效果不佳,以及无法对下游NLU任务通用的问题,本文对该方法进行简要介绍和实践。
内容摘要
- P-tuning v2理论方法简介
- P-tuning v2微调BERT实践
- P-tuning v2、PET、Fine-Tuning效果对比
P-tuning v2理论方法简介
相比于现有的Prompt tuning方式,P-tuning v2的调整主要体现在:
- 1.为了增强对下游任务的通用性,使用类似Fine-tuning的[CLS]为作为任务的预测表征
- 2.引入Deep Prompt Tuning,在Transformer的每一层Block中的输入层,对输入添加一定长度的前缀Prompt Embedding,让模型自适应学习Prompt的表征
模型的结构图如下
以330M的BERT预训练模型为例,Transformer的Encoder模块一共12层,token的维度表征为768,设置提示长度为20,则要学习的连续提示Embedding表征为12 * [20, 768],相比于P-tuning v1可学习的参数数量明显增多,同时这些参数嵌入在模型网络的每一层,相比于P-tuning v1仅在输入层添加参数,中间网络没有任何参数添加的形式,P-tuning v2中参数对模型结果的影响更加直接。
P-tuning v2微调BERT实践
论文团队只提供了P-tuning v2在BERT结构上的方案和源码,在源码中作者并没有改造Bert的代码结构来给每一层创建随机Embedding再做自注意力,而是采用了类似交叉注意力的方式,对每一层的Key和Value额外拼接了一定长度可学习的Prompt Embedding,让输入Query和拼接Prompt后的Key、Value做交叉注意力。
采用HuggingFace的模型类源码能够很容易的实现,代码如下
batch_size = input_ids.shape[0]
past_key_values = self.get_prompt(batch_size=batch_size)
prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.bert.device)
attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
past_key_values=past_key_values,
)
其中past_key_values是额外给Key,Value添加的Prompt Embedding,attention_mask也同步增加前缀。
在Bert内部,会把past_key_values拼接在每一层经过Key、Value线性变换后的向量的前面,代码如下,原始输入分别经过Key、Value线性映射后直接在头部拼接可学习的参数Embedding,来达到P-tuning v2的效果。
elif past_key_value is not None:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
除此之外P-tuning v2和Fine-tuning的实现无明显区别,取[CLS]的池化输出计算损失。本文采用和前文提示学习系列:P-Tuning微调BERT/GPT-2实现文本多分类同样的数据集,在新闻数据上通过P-tuning v2提示微调来实现文本多分类,模型网络代码如下
class Model(nn.Module):
def __init__(self, num_labels, pre_seq_len=40, hidden_size=PRE_TRAIN_CONFIG.hidden_size, hidden_dropout_prob=0.1):
super(Model, self).__init__()
self.num_labels = num_labels
self.pre_seq_len = pre_seq_len
self.n_layer = PRE_TRAIN_CONFIG.num_hidden_layers
self.n_head = PRE_TRAIN_CONFIG.num_attention_heads
self.n_embd = PRE_TRAIN_CONFIG.hidden_size // PRE_TRAIN_CONFIG.num_attention_heads
self.bert = PRE_TRAIN
self.dropout = torch.nn.Dropout(hidden_dropout_prob)
self.classifier = torch.nn.Linear(hidden_size, num_labels)
for param in self.bert.parameters():
param.requires_grad = False
self.prefix_tokens = torch.arange(self.pre_seq_len).long()
self.prefix_encoder = PrefixEncoder(self.pre_seq_len, PRE_TRAIN_CONFIG.num_hidden_layers,
PRE_TRAIN_CONFIG.hidden_size)
requires_grad_param = 0
total_param = 0
for name, param in self.named_parameters():
total_param += param.numel()
if param.requires_grad:
requires_grad_param += param.numel()
print('total param: {}, trainable param: {}, trainable/total: {}'.format(total_param, requires_grad_param,
requires_grad_param / total_param))
def get_prompt(self, batch_size):
# TODO 统一构造embedding并且改造为对应的维度
prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.bert.device)
past_key_values = self.prefix_encoder(prefix_tokens)
past_key_values = past_key_values.view(
batch_size, # 128
self.pre_seq_len, # 40
self.n_layer * 2, # 24
self.n_head, #
self.n_embd
)
past_key_values = self.dropout(past_key_values)
# TODO 根据n_layer * 2分为2个
past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2)
return past_key_values
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None
):
batch_size = input_ids.shape[0]
past_key_values = self.get_prompt(batch_size=batch_size)
prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.bert.device)
attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
past_key_values=past_key_values,
)
pooled_output = outputs[1]
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
其中PrefixEncoder为创建的随机初始化的Prompt Embedding,实现如下
class PrefixEncoder(nn.Module):
def __init__(self, pre_seq_len, num_hidden_layers, hidden_size, prefix_projection=False):
super().__init__()
self.prefix_projection = prefix_projection # false
if self.prefix_projection:
# Use a two-layer MLP to encode the prefix
self.embedding = torch.nn.Embedding(pre_seq_len, hidden_size)
self.trans = torch.nn.Sequential(
torch.nn.Linear(hidden_size, prefix_hidden_size),
torch.nn.Tanh(),
torch.nn.Linear(prefix_hidden_size, num_hidden_layers * 2 * hidden_size)
)
else:
self.embedding = torch.nn.Embedding(pre_seq_len, num_hidden_layers * 2 * hidden_size)
def forward(self, prefix: torch.Tensor):
if self.prefix_projection:
prefix_tokens = self.embedding(prefix)
past_key_values = self.trans(prefix_tokens)
else:
past_key_values = self.embedding(prefix)
return past_key_values
根据Bert的层数、隐藏层维度,Prompt长度来确定需要训练的参数量,初始化之后为了和Bert注意力源码中的Key和Value拼接,需要额外做注意力头维度分割和转置。
P-tuning v2、PET、Fine-Tuning效果对比
笔者在不同样本数量下对Bert采用P-tuning v2,PET和Fine-Tuning微调,其中P-tuning v2冻结大模型,仅微调Prompt Embedding,PET和Fine-Tuning采用全参微调,以20000条样本为例,F1和模型训练参量对比如下
其中P-tuning v2的预测精度略高于Fine-Tuning,明显高于PET,同时训练参数量为74万,而其他两种全参微调参数量达到1亿,P-tuning v2仅需要约0.1%的参数微调量就能达到全参微调的效果。笔者在不同样本量的多次训练测试下,P-tuning v2的F1值接近Fine-Tuning,仍普遍低于Fine-Tuning,但是明显优秀于PET和P-tuning v1。
本文暂时没有评论,来添加一个吧(●'◡'●)