受到人类做决策的思维过程的启发,即通过将一个问题逐个分解为多个子问题,并按照链式的方式串行思考,最终得到思考结果,这个过程被成为”思维链(chain-of-thoughts)“。
研究表明,中间推理过程(intermediate reasoning (“rationales”))可以显著提高语言模型在数学或常识回答等复杂推理任务中的表现。
在中间推理过程中,经过良好预训练的LLM可以使用中间步骤的“暂存器(scratchpads)”可以在算术问题上获得完美的分布性能,以及强大的训练任务数据分布外泛化能力。
而相比之下,专门训练用来直接回答答案的one-shot model则无法很好地应对训练任务数据分布外泛化能力。
这些研究工作表明,在给出最终答案之前先给出明确的推理理由(rationale generation),对LLM在各种任务中都很有帮助,包括数学推理、常识推理、代码评估、社会偏见推理和自然语言推理等任务。
目前进行rationale generation有两种主要方法:
然而,目前进行rationale generation的两种主要方法都有严重的缺点。
在本文中,我们采用了不同的方法:即利用LLM自身包含的推理能力,通过迭代,引导LLM产生高质量的rationales的。
具体的流程大致如下:
以上迭代过程是一个协同进化过程,rationale generation提升了微调数据集的质量,而微调数据集通过增强sft-model的能力,进一步也提升了rationale generation的效果。
综合以上过程,我们开发了自学推理器(Self-Taught Reasoner,STaR,下图 1)方法,
Figure 1: An overview of STaR and a STaR-generated rationale on CommonsenseQA. We indicate the fine-tuning outer loop with a dashed line. The questions and ground truth answers are expected to be present in the dataset, while the rationales are generated using STaR.
这是一种可扩展的引导方法,允许模型学习产生自己的推理过程,同时也能够解决不断出现的新领域问题。
参考链接:
https://www.promptingguide.ai/techniques/cot https://learnprompting.org/docs/intermediate/chain_of_thought
首先,我们有一个预训练 LLM,M。以及一个关于问题 x 的初始数据集D,并包含正确的最终答案 y:
迭代优化从一个小prompt数据集(包含中间推理过程r)P开始:
其中 P ≪ D(例如P = 10),这里表示prompt示例集远小于初始数据集数量,完整的Rationale需要在后续的迭代中逐步补全。
接下来,与标准的few-shot prompting一样,我们将prompt示例集连接到 D 中的每个示例,即:
将拼接后的数据集 xi 输入LLM,基于概率预测原理,LLM生成对应的以及与之对应的
接下来是专家修正过程(Rationalize),我们假设产生正确答案的Rationale,相比产生那些产生错误答案的Rationale,质量更高。因此,我们过滤出能够产生正确答案的Rationale。
接下来,我们基于过滤后的数据集(xi,yi,ri),微调LLM,得到一个新的sft-model。
最后,我们基于新微调的sft-model,继续从prompt开始重复整个流程。
我们不断上述重复这个过程,直到性能达到稳定水平。
从强化学习的角度看,STaR 可以看作是 RL 风格的策略梯度目标(RL-style policy gradient objective)算法的近似。
M 可以被视为离散潜变量模型(discrete latent variable model):
换句话说,M 在预测 y 之前首先对潜在推理原因 r 进行采样。
奖励函数函数来自专家的ground truth反馈,整个数据集的总奖励期望为:
其中梯度是通过策略梯度的标准对数导数技巧获得的。
注意,指标函数会丢弃所有无法得出正确结果的rationales采样的梯度的答案yi。因此, STaR 会采用贪婪模式,通过对采样进行解码,不断缩小当前值和估计值之间的损失,以此完成 J 的近似优化。
这种近似优化方法,使得 STAR 成为一个简单且广泛的通用LLM训练方法。
对于导致失败的rationales,算法无法获得任何训练信号。
为了解决这个问题,我们提出了“合理化(rationalizationb)”的技术。 具体来说,我们通过输入一个hint(合理推理提示词),引导LLM生成显而易见地推理过程以及正确答案。但是,当向我们的数据集添加合理化生成的rationales时,我们不会在其数据集中包含hint(合理推理提示词),就好像模型在没有提示的情况下就得出了基本原理。
过滤后,我们将先前生成的数据集与合理化生成的数据集进行整合,并进行微调训练。
Figure 2: A few-shot prompt hint we use for rationalization (and not for rationale generation), using the rationale from [6], with its hint included in green, followed by the rationale and the answer generated by the model.
参考链接:
https://github.com/ezelikman/STaR