Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up

Text to text with reinforcement learning

NotificationsYou must be signed in to change notification settings

nilboy/rlt2t

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

57 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Reinforcement learning text to text.

算法

整体思路介绍

整体架构是多模型生成多个候选答案,Rerank模型排序,选取最好的答案作为最终的输出。

生成模型和Rerank模型都选择bart结构。 生成模型 和 Rerank模型使用同一个模型,用RL训练:

强化学习: 上一个epoch保存的模型作为策略生成模型,生成a,b两个答案,通过评测指标计算score_a, score_b。Rerank模型对a,b进行比较。计算rank_loss。

rank_loss 和 自回归loss 加权作为最终loss

Rank loss 计算

def get_rank_outputs(self, orig_input_ids=None,                     orig_attention_mask=None,                     rank_a_labels=None,                     rank_b_labels=None,                     use_rank=None):    output_rank_a = self.model(orig_input_ids,                               attention_mask=orig_attention_mask,                               labels=rank_a_labels)    output_rank_b = self.model(orig_input_ids,                               attention_mask=orig_attention_mask,                               labels=rank_b_labels)    rank_a_labels_pad = rank_a_labels.masked_fill(rank_a_labels == -100, 0)    rank_b_labels_pad = rank_b_labels.masked_fill(rank_b_labels == -100, 0)    # (batch_size, seq_len)    rank_a_logits = torch.gather(output_rank_a.logits, 2, rank_a_labels_pad.unsqueeze(2)).squeeze(-1)    rank_b_logits = torch.gather(output_rank_b.logits, 2, rank_b_labels_pad.unsqueeze(2)).squeeze(-1)    diff_logits = rank_a_logits - rank_b_logits    rank_loss = -torch.log(torch.sigmoid(diff_logits))    mask_rate = self.build_mask_rate(rank_a_labels, rank_b_labels, rank_a_logits.dtype, alpha=self.delay_alpha)    rank_loss = rank_loss * mask_rate    # build select mask    select_mask = (rank_a_labels != -100) & (rank_b_labels != -100) & (mask_rate > 1e-8)    select_mask = select_mask & (use_rank.unsqueeze(-1).to(select_mask.dtype))    rank_loss = torch.masked_select(rank_loss, select_mask).sum() / (torch.masked_select(mask_rate, select_mask).sum() + 1e-8)    rank_acc = torch.sum(diff_logits > 0) / diff_logits.shape[0]    return rank_loss, rank_accdef build_mask_rate(self, rank_a_labels, rank_b_labels,                    dtype,                    alpha=0.9):    batch_size, seq_len = rank_a_labels.shape    mask_rate = torch.zeros_like(rank_a_labels, dtype=dtype)    for i in range(batch_size):        items = (rank_a_labels[i] != rank_b_labels[i]).nonzero()        start_idx = items[0].item() if len(items) > 0 else seq_len        mask_rate[i, start_idx:] = alpha ** torch.arange(seq_len - start_idx, device=rank_a_labels.device)    return mask_rate

训练流程

  • 处理训练数据
python tools/construct_data_stage_2.py
  • 转换模型
python tasks/convert_models/convert_model.py \ --input_model_name=pretrain_models/uer_bart_base \ --output_model_name=pmodels/uer_bart_base \ --model_type=bart \ --vocab_size=2000
  • DAE预训练
python tasks/pl-pretrain-t2t/run.py fit \    --config tasks/pl-pretrain-t2t/config_uer_bart_base.yaml
  • 训练生成模型
python tasks/pl-sft-t2t/run.py fit --config tasks/pl-sft-t2t/base/uer_bart_base.yaml
  • 训练rank模型用上面训练的生成模型作为初始化,加入rank_loss继续微调
# 启动策略生成服务mkdir -p output-models/uer_bart_base-rankcp -r output-models/uer_bart_base/last.ckpt.dir output-models/uer_bart_base-rank/epoch_0.ckpt.dirpython tools/rank_data_app.py output-models/uer_bart_base-rank 9550python tasks/pl-sft-t2t/run.py fit --config tasks/pl-sft-t2t/rank/uer_bart_base.yaml
  • 导出模型,转化为ctranslat2格式
ct2-transformers-converter --model output-models/uer_bart_base-rank --output_dir sub-models/uer_bart_base_rank --quantization int8

推理流程

python index.py

[8]ページ先頭

©2009-2025 Movatter.jp