TTT(Test-Time-Training layers)架构,在预测过程中训练

论文地址:https://arxiv.org/abs/2407.04620

论文PDF:240704620v1.pdf

在测试时学习:具有表达性隐藏状态的 RNN

自注意力在长上下文中表现良好,但具有二次复杂度。现有的 RNN 层具有线性复杂度,但它们在长上下文中的性能受到隐藏状态表达能力的限制。我们提出了一类新的序列建模层,具有线性复杂性和富有表现力的隐藏状态。关键思想是将隐藏状态本身作为机器学习模型,将更新规则作为自监督学习的步骤。由于隐藏状态甚至可以通过测试序列上的训练来更新,因此我们的层称为测试时训练(TTT)层。我们考虑两个实例:TTT-Linear 和 TTT-MLP,其隐藏状态分别是线性模型和两层 MLP。我们以 125M 到 1.3B 参数的规模评估我们的实例化,与强大的 Transformer 和 Mamba(现代 RNN)进行比较。 TTT-Linear 和 TTT-MLP 均匹配或超过基线。与 Transformer 类似,它们可以通过调节更多 token 来不断减少困惑,而 Mamba 在 16k 上下文之后就不能了。通过初步的系统优化,TTT-Linear 在 8k 环境下已经比 Transformer 更快,并且在挂钟时间上与 Mamba 相当。 TTT-MLP 在内存 I/O 方面仍然面临挑战,但在长期背景下显示出更大的潜力,为未来的研究指明了一个有希望的方向。


全文总结

本文主要介绍了一种新的序列建模层——具有表达性隐藏状态的Test--Time Training(TTT)层,包括其原理、实现方式、性能优势以及相关的训练方法和面临的问题。

重要亮点

  • TTT层的提出:鉴于自注意力机制在长上下文处理中的复杂性和现有RNN层在长上下文表现的局限性,提出了具有线性复杂度和表达性隐藏状态的TTT层,其关键在于使隐藏状态成为机器学习模型,更新规则为自监督学习的一步。
  • TTT层的实例:介绍了TTT-Linear和TTT-MLP两个实例,前者的隐藏状态是线性模型,后者是两层MLP,它们在一定规模的参数评估中匹配或超过了基线模型。
  • TTT层的效率提升:通过采用小批量的TTT和双重形式来提高硬件效率,使TTT-Linear在8k上下文时比Transformer更快,与Mamba相当。
  • TTT层的训练方式:TTT层的前向传播有对应的反向传播,训练网络时分为外循环和内循环,外循环优化网络其他部分的参数,内循环训练TTT层内的参数。
  • TTT层的自监督任务:自监督任务对TTT至关重要,决定了隐藏状态学习的特征类型,通过将输入处理为损坏形式来优化学习问题。
  • 评估与开放问题:对TTT-Linear和TTT-MLP进行了综合评估,提出了一些开放问题,鼓励社区共同探索解决方案。
  • TTT层的实现:给出了TTT层的PyTorch风格的朴素实现示例,说明了如何将其融入更大的网络中进行训练。