概念
想象一下,你在教一个孩子造句,最直接的方法是你说一个词,让他跟读一个词,而不是让他从头猜到尾。
这种老师引导的思想就是Teacher Forcing。
Teacher Forcing 是训练自回归序列模型(比如RNN、LSTM、GRU、seq2seq、transformer 的自回归解码器等)时常用的一种训练策略
核心思想:
- 在训练时,模型每次不使用上一个state的输出作为下一个state的输入,而是直接使用训练数据的标准答案(ground truth)的对应上一项作为下一个state的输入。换句话说,训练时给模型看老师提供的正确前文,而不是模型自己预测的前文。这里的真实答案我们一般称作{% hint 'ground truth' '地面实况:' '通过直接观察获得的信息,而不是通过推断获得的信息。' %}
用机器翻译任务简单举例:
- 任务:将英文“I love you”翻译成中文“我爱你”。
- 没有Teacher Forcing(自由运行):
- 输入“I love you”,模型首先预测第一个词。假设它预测错了,输出成了“我恨”。
- 下一步,它将“我恨”作为输入,去预测下一个词。这已经走上了歧途,最终可能输出“我恨你这个世界”。
- 模型从这个完全错误的序列中学习,梯度会非常不稳定,训练效率极低。
- 使用Teacher Forcing:
- 输入“I love you”和起始符
<sos>,模型预测第一个词。它的输出是“我恨”,但计算损失时是与真实值“我”比较。 - 关键一步:在下一个时间步,我们忽略模型输出的“我恨”,而是直接将真实值“我” 作为输入,让模型预测下一个词。
- 模型现在基于“我”来预测,更有可能输出“爱”。
- 再下一步,我们又将真实值“爱”作为输入,让模型预测,它更可能输出“你”。
- 这样,模型每一步都是在“正确上下文”的引导下进行学习。
- 输入“I love you”和起始符
由来
直到什么是Teacher Forcing后,来思考一个问题,为什么要用Teacher Forcing?
- 加快收敛:使用真实前文能减少训练时错误级联,使梯度更稳定,学习更快 → 训练迭代过程早期的RNN预测能力非常弱,几乎不能给出好的生成成绩,如果某一个unit产生了垃圾结果,必然会影响后面一片unit的学习,teacher forcing最早的动机就是用来解决RNN的这个问题。
- 有效防止误差累积:在自由运行模式下,早期的一个小错误会作为后续步骤的输入,导致错误累计像滚雪球一样越来越大,这会使得模型难以从严重的错误中恢复学习。Teacher Forcing 彻底切断了错误传播的链条,确保每个时间步的输入都是干净的。
这是我认为最关键的两个原因,当然,网上还有其他说法,比如让训练过程高效并行化等等。
核心问题
Exposure bias
暴露偏差问题:在训练时,模型习惯于在完美的“真实数据”环境下进行预测;但在推理(测试)时,模型必须使用自己上一步生成的(可能不完美的)输出作为当前步的输入。这种训练和推理之间的环境差异被称为“曝光偏差”。这可能导致模型在推理时非常脆弱,一旦产生一个错误,后续输出就容易崩溃。
解决方案:
1. 计划采样(Scheduled Sampling)
这是一种经典的解决方案。在训练过程中,我们并不总是100%使用真实标签,而是以一个概率 p 使用真实标签,以概率 1-p 使用模型自己上一步的预测结果作为输入。这个概率 p 可以随着训练的进行逐渐减小(例如,从1.0线性衰减到0.5),让模型逐步从“有辅导”过渡到“自主推理”。
2.集束搜索(Beam Search)
在预测单词这种离散值的输出时,一种常用方法是对词表中每一个单词的预测概率执行搜索,生成多个候选的输出序列。 这个方法常用于机器翻译(MT)等问题,以优化翻译的输出序列。 beam search是完成此任务应用最广的方法,通过这种启发式搜索(heuristic search),可减小模型学习阶段performance与测试阶段performance的差异。 ———————————————— 版权声明:本文为CSDN博主「Alanaker」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。 原文链接:https://blog.csdn.net/qq_30219017/article/details/89090690
3.课程学习(Curriculum Learning)
核心思想:先易后难
- 课程学习的灵感来源于人类的教育方式。我们不会教小学生微积分,而是从加减乘除开始。同样地,课程学习在训练模型时,先让模型学习“简单”的样本,再逐步过渡到“复杂”的样本。
如何定义“难”与“易”?
这是课程学习的关键。难度衡量标准因任务而异,常见的有:
- 序列长度:对于文本,短句子通常比长句子简单。
- 词汇复杂度:句子中罕见词的比例越低,句子越简单。
- ......
对于课程学习的思想,需要将训练集由易到难排列,同时根据epoch调整数据难度,对于目前任务场景来说,训练集少说也需要百万量级,对齐进行难易排序,难度和成本都非常高,所以不易实现。
代码实现
以一个机器翻译任务来简单实现teacher forcing,其实也比较简单,重点是思想
# 定义teacher_forcing
teacher_forcing_ratio = 0.5
def train(x, y, encoder, decoder, encoder_optimizer, decoder_optimizer, loss):
# 对数据进行编码 [1, 6] -> [1, 6, 256]
encoder_output, encoder_hidden = encoder(x, encoder.init_hidden())
# print(f'x: {x.shape}') # [1, 6]
# print(f'encoder_output: {encoder_output.shape}') # [1, 6, 256]
# print(f'encoder_hidden: {encoder_hidden.shape}') # [1, 1, 256]
# 准备解码器参数
# 第一个参数: 中间语义张量C
encoder_output_c = torch.zeros(MAX_LENGTH, encoder.hidden_size, device=device)
# 将真是编码结果赋值
for idx in range(encoder_output.shape[1]):
encoder_output_c[idx] = encoder_output[0, idx]
# 第二个参数
decoder_hidden = encoder_hidden
# 第三个参数: 开始字符
input_y = torch.tensor([[SOS_token]], device=device)
# 定义损失值
my_loss = 0.0
# 定义真实翻译句子长度
y_len = y.shape[1]
# 是否使用 teacher_forcing
use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False
if use_teacher_forcing:
# 逐字符解码
for idx in range(y_len):
# 送入解码器进行预测
output_y, decoder_hidden, atten_weight = decoder(input_y, decoder_hidden, encoder_output_c)
# 当前时间步真实值 -> 取出对应位置单词
target_y = y[0][idx].view(1)
# 计算损失
my_loss += loss(output_y, target_y)
# 更新input_y -> 使用真实值
input_y = y[0][idx].view(1, -1)
else:
# 逐字符解码
for idx in range(y_len):
# 送入解码器进行预测
output_y, decoder_hidden, atten_weight = decoder(input_y, decoder_hidden, encoder_output_c)
# 当前时间步真实值 -> 取出对应位置单词
target_y = y[0][idx].view(1)
# 计算损失
my_loss += loss(output_y, target_y)
# 得出当前时间步预测结果
topv, topi = torch.topk(output_y, 1)
input_y = topi.detach()
# 梯度清零 反向传播 参数更新
encoder_optimizer.zero_grad()
decoder_optimizer.zero_grad()
my_loss.backward()
encoder_optimizer.step()
decoder_optimizer.step()
# 返回损失
return my_loss.item() / y_len # 当前样本平均损失