Enhancing linear attention with residual learning
利用殘差學(xué)習(xí)增強(qiáng)線性注意力
https://arxiv.org/pdf/2509.25223v1
![]()
摘要
線性注意力為自注意力機(jī)制提供了一種線性時(shí)間復(fù)雜度的替代方案,但往往難以捕捉長距離模式。我們通過"預(yù)測-校正"的視角重新審視線性注意力,發(fā)現(xiàn)主流變體都可以被表示為歷史預(yù)測與單令牌校正的組合,這造成了表達(dá)能力瓶頸。為解決這一瓶頸,我們提出了殘差線性注意力(RLA),這是一個(gè)為線性注意力配備顯式殘差擬合機(jī)制的框架。RLA 維護(hù)一個(gè)輔助循環(huán)狀態(tài),用于學(xué)習(xí)隨時(shí)間累積殘差誤差并校正基礎(chǔ)預(yù)測。我們進(jìn)一步實(shí)例化了一個(gè) delta 規(guī)則版本——?dú)埐?Delta 網(wǎng)絡(luò)(RDN),結(jié)合了自適應(yīng)門控和殘差裁剪以增強(qiáng)校正控制和穩(wěn)定性。我們的實(shí)現(xiàn)利用了高度優(yōu)化的線性注意力核函數(shù),并保持線性的時(shí)間和內(nèi)存復(fù)雜度。在語言建模和召回密集型評估中,RLA 和 RDN 始終優(yōu)于各自的基線模型及其他現(xiàn)代線性注意力方法,在保持線性擴(kuò)展性的同時(shí)縮小了與標(biāo)準(zhǔn) Transformer 的差距。
1 引言
Transformer(Vaswani 等人,2017)架構(gòu)已成為大型語言模型的標(biāo)準(zhǔn)。然而,其自注意力機(jī)制的二次時(shí)間復(fù)雜度仍然是一個(gè)關(guān)鍵瓶頸,限制了其在長序列上的應(yīng)用(Li 等人,2024)。線性注意力最近作為標(biāo)準(zhǔn)自注意力的高效替代方案涌現(xiàn),直接解決了其過高的二次復(fù)雜度問題。通過將注意力計(jì)算重構(gòu)為循環(huán)過程,這些模型實(shí)現(xiàn)了線性時(shí)間的訓(xùn)練和推理,使其非常適合處理長序列。RetNet(Sun 等人,2023)和 Mamba(Gu & Dao,2023;Dao & Gu,2024)等架構(gòu)已展現(xiàn)出具有競爭力的性能。GLA(Yang 等人,2023)和 DeltaNet(Yang 等人,2024b)等方法通過引入數(shù)據(jù)依賴的門控和狀態(tài)更新規(guī)則來管理單一狀態(tài)矩陣內(nèi)的信息流,進(jìn)一步改進(jìn)了性能。
現(xiàn)代線性注意力方法可以被統(tǒng)一為學(xué)習(xí)從鍵到值的直接映射(Sun 等人,2024),這一過程類似于測試時(shí)訓(xùn)練。例如,delta 更新規(guī)則(Schlag 等人,2021)可以從二次損失目標(biāo)的單步在線梯度下降推導(dǎo)得出。這一視角開辟了若干改進(jìn)途徑,包括探索不同的在線學(xué)習(xí)損失函數(shù)以推導(dǎo)新的更新規(guī)則(Schlag 等人,2021;Yang 等人,2024b)、采用更復(fù)雜的映射函數(shù),或修改在線梯度更新機(jī)制(von Oswald 等人,2025;Siems 等人,2025)。例如,TTT-MLP(Sun 等人,2024)和 Titans(Behrouz 等人,2024)等近期工作利用多層感知機(jī)(MLP)作為深度記憶模塊來實(shí)現(xiàn)更強(qiáng)大的映射。然而,這種方法犧牲了模型的線性循環(huán)特性,從而使并行訓(xùn)練變得復(fù)雜。
基于這一視角,我們對注意力輸出提供了一種新的解釋。我們證明,主流線性注意力模型的輸出可以分解為來自歷史狀態(tài)的基礎(chǔ)分量和僅源自當(dāng)前令牌的校正項(xiàng)(見第 2.3 節(jié))。依賴單一令牌來執(zhí)行這種系統(tǒng)性校正造成了瓶頸,損害了模型的表達(dá)能力。為解決這些問題,我們引入了殘差線性注意力,這是一個(gè)用顯式殘差擬合機(jī)制增強(qiáng)線性注意力模型的框架。我們的方法不依賴單一令牌進(jìn)行校正,而是采用輔助狀態(tài)矩陣來顯式建模和校正基礎(chǔ)線性注意力的系統(tǒng)性預(yù)測誤差。最終輸出是基礎(chǔ)預(yù)測與該學(xué)習(xí)誤差校正的組合。我們的方法可以推廣為適用于各種線性注意力方法的統(tǒng)一框架,為構(gòu)建更強(qiáng)大的序列模型提供了一種強(qiáng)大而高效的策略。
在現(xiàn)有線性注意力方法的基礎(chǔ)上,我們提出了兩種增強(qiáng)殘差擬合的變體:殘差線性注意力(RLA)和殘差 Delta 網(wǎng)絡(luò)(RDN)。我們在一系列基準(zhǔn)測試上評估了它們,包括語言建模和召回密集型任務(wù)。我們的結(jié)果表明,這些模型優(yōu)于各自的基線模型和其他現(xiàn)代線性注意力方法,而我們的消融分析證實(shí)了框架內(nèi)每個(gè)關(guān)鍵設(shè)計(jì)選擇的重要性。
2 預(yù)備知識
2.1 作為循環(huán)模型的線性注意力
Softmax 注意力機(jī)制相對于序列長度表現(xiàn)出二次計(jì)算復(fù)雜度,在處理長序列時(shí)構(gòu)成了顯著的瓶頸。線性注意力(Katharopoulos 等人,2020)架構(gòu)通過移除 softmax 函數(shù)來解決這一問題,從而允許對計(jì)算進(jìn)行重新排序。
![]()
![]()
這種循環(huán)形式在推理過程中保持每步恒定的時(shí)間和內(nèi)存復(fù)雜度,并通過分塊并行算法實(shí)現(xiàn)高效訓(xùn)練(Yang 等人,2023)。此外,門控機(jī)制的使用催生了更多變體的發(fā)展,如 RetNet(Sun 等人,2023)、Lightning Attention(Qin 等人,2024a)和 Mamba-2(Dao & Gu,2024)。
2.2 在線學(xué)習(xí)視角
![]()
![]()
這種形式化使 Delta Net(Yang 等人,2024b;Schlag 等人,2021)等模型能夠?qū)崿F(xiàn)細(xì)粒度的記憶控制。門控 Delta Net(Yang 等人,2024a)進(jìn)一步通過在學(xué)習(xí)過程中引入權(quán)重衰減來增強(qiáng)這一方法。
2.3 分解為預(yù)測與校正
![]()
![]()
![]()
![]()
基于預(yù)測-校正的視角,我們引入了一個(gè)殘差擬合框架來增強(qiáng)線性注意力。我們的框架通過顯式擬合超出當(dāng)前令牌的上下文信息,學(xué)習(xí)一個(gè)更具表達(dá)力的校正項(xiàng)。
3 方法
本節(jié)介紹我們提出的方法,該方法通過殘差擬合過程來增強(qiáng)線性注意力。我們首先描述支撐我們方法的基礎(chǔ)殘差學(xué)習(xí)框架。接下來,我們引入自適應(yīng)校正因子以增強(qiáng)建模能力,并引入裁剪方法來穩(wěn)定殘差擬合過程。最后,我們展示我們方法的兩種最終變體。
3.1 顯式殘差擬合
![]()
![]()
利用第 2 節(jié)中線性注意力的在線學(xué)習(xí)視角,我們對輔助狀態(tài)應(yīng)用類似的更新規(guī)則。這產(chǎn)生了以下循環(huán)過程:
![]()
![]()
3.2 自適應(yīng)門控與校正因子
![]()
![]()
![]()
![]()
這種形式化使用衰減因子和校正因子來分別對來自基礎(chǔ)狀態(tài)和輔助狀態(tài)的檢索進(jìn)行動態(tài)門控。
3.3 歸一化與殘差裁剪
為確保計(jì)算穩(wěn)定性,我們引入兩種機(jī)制。首先,我們對查詢和鍵向量應(yīng)用 L2 歸一化以提高數(shù)值穩(wěn)定性。其次,我們通過裁剪殘差來解決輔助狀態(tài)中的潛在不穩(wěn)定性:
![]()
這確保了誤差校正狀態(tài)保持穩(wěn)定的學(xué)習(xí)軌跡,即使基礎(chǔ)模型產(chǎn)生瞬態(tài)的、較大的預(yù)測誤差。該裁剪方法的詳細(xì)推導(dǎo)見附錄 B。
3.4 最終形式化
殘差擬合原理是一種通用技術(shù),可以與各種線性注意力主干網(wǎng)絡(luò)集成。通過將我們的殘差機(jī)制應(yīng)用于標(biāo)準(zhǔn)加法更新規(guī)則和 delta 更新規(guī)則,我們推導(dǎo)出兩種強(qiáng)大的變體。這導(dǎo)出了我們的最終模型:
![]()
![]()
![]()
4 實(shí)驗(yàn)
4.1 實(shí)驗(yàn)設(shè)置
實(shí)現(xiàn) 為了最大化效率,我們在 Triton(Tillet 等人,2019)中實(shí)現(xiàn)了自定義注意力核函數(shù),基于 flash-linear-attention 庫(Yang & Zhang,2024)構(gòu)建。我們利用了這樣一個(gè)事實(shí):我們的狀態(tài)更新規(guī)則與線性注意力的相同,只需對其核函數(shù)進(jìn)行微小修改:我們將其增強(qiáng)為返回注意力結(jié)果和中間殘差。這種設(shè)計(jì)允許在所有殘差擬合階段重用相同的高度優(yōu)化核函數(shù),確保高吞吐量。
![]()
4.2 主要結(jié)果
核函數(shù)效率 我們將我們的核函數(shù)運(yùn)行時(shí)間與線性注意力基線和 FlashAttention(Dao 等人,2022;Dao,2023)進(jìn)行基準(zhǔn)測試,如圖 2 所示。盡管殘差擬合過程增加了計(jì)算開銷,但我們方法的運(yùn)行時(shí)間隨序列長度線性擴(kuò)展。這使其在較長序列上顯著快于二次擴(kuò)展的 FlashAttention。關(guān)于吞吐量,我們的方法與其他線性注意力機(jī)制一樣,保持幾乎恒定的高吞吐量。相反,計(jì)算受限的 FlashAttention 的吞吐量隨序列長度增加而迅速下降。
![]()
語言建模與常識推理 我們在 WikiText(Merity 等人,2016)困惑度以及一系列評估推理和常識理解的基準(zhǔn)測試上評估 RLA 和 RDN。推理任務(wù)包括 ARC-Easy、ARC-Challenge(Clark 等人,2018)、PIQA(Bisk 等人,2020)和 MMLU(Hendrycks 等人,2020),而常識理解則在 HellaSwag(Zellers 等人,2019)、Winogrande(Sakaguchi 等人,2021)、SocialIQA(Sap 等人,2019)和 LAMBADA(Paperno 等人,2016)上進(jìn)行評估。我們的主要結(jié)果總結(jié)于表 2,顯示我們提出的殘差學(xué)習(xí)變體 RLA 和 RDN 在困惑度上相對于各自的基線 sGLA 和 GDN 取得了一致的改進(jìn)。此外,我們的模型在多個(gè)基準(zhǔn)測試上優(yōu)于其他領(lǐng)先的線性注意力方法,并提供與標(biāo)準(zhǔn) Transformer 相當(dāng)?shù)男阅堋?/p>
![]()
召回密集型任務(wù) 為了評估記憶容量,我們在 Arora 等人(2024)的召回密集型任務(wù)上對我們的模型進(jìn)行基準(zhǔn)測試。此外,我們還直接使用"大海撈針"任務(wù)(NIAH)(gkamradt,2023)評估模型的檢索能力,該任務(wù)需要檢索插入在長文檔不同深度的鍵值對。這些基準(zhǔn)測試對線性注意力模型具有挑戰(zhàn)性,因?yàn)樗鼈兊挠邢逘顟B(tài)空間造成了信息瓶頸,如表 3 所示。結(jié)果表明,我們提出的 RLA 和 RDN 始終優(yōu)于其相應(yīng)的基線,在 DROP 和 FDA 基準(zhǔn)測試上取得了特別顯著的收益。此外,它們在 NIAH 任務(wù)上大幅優(yōu)于其他模型,突顯了增強(qiáng)的信息召回能力。
4.3 消融研究
在本節(jié)中,我們進(jìn)行一系列消融研究以驗(yàn)證關(guān)鍵組件的貢獻(xiàn)。我們首先量化我們學(xué)習(xí)的殘差擬合方法相對于預(yù)定義校正的優(yōu)勢。接下來,我們研究使用專用校正因子的重要性,然后分析將基礎(chǔ)預(yù)測與校正相結(jié)合的門控機(jī)制的必要性。最后,我們檢查歸一化和殘差裁剪的效果。
![]()
如表 4 所示,缺乏顯式殘差擬合的變體表現(xiàn)不如我們的完整方法。盡管該消融變體在某些基準(zhǔn)測試上保持競爭力,但它在訓(xùn)練集和評估集上的困惑度都顯著增加。這種性能下降延伸到專業(yè)領(lǐng)域,在 GSM8k(Cobbe 等人,2021)和 HumanEval(Chen 等人,2021)的困惑度測量中,其數(shù)學(xué)和代碼能力顯著退化。這證明了輔助狀態(tài)在累積過去殘差以有效細(xì)化模型輸出方面的關(guān)鍵作用。
![]()
專用校正因子 我們通過將我們的完整模型與 γ 綁定到更新因子 β 的變體進(jìn)行比較,分析使用專用校正因子 γ 的益處。在圖 3a 中,具有獨(dú)立 γ 的模型始終實(shí)現(xiàn)更低的評估損失,其中 RDN 變體顯示出更大的改進(jìn)。這一趨勢延伸到下游性能,如圖 3b 的結(jié)果所示,該圖還顯示專用校正因子在多個(gè)基準(zhǔn)測試上帶來性能提升。值得注意的是,我們的基礎(chǔ)架構(gòu)(不需要額外的 γ)仍然比基線線性注意力方法有顯著改進(jìn)。
![]()
![]()
![]()
歸一化與殘差裁剪 最后,我們研究歸一化和殘差裁剪的重要性。我們通過對 RLA 移除歸一化和裁剪來進(jìn)行消融研究。如圖 4 所示,兩個(gè)組件對穩(wěn)定訓(xùn)練都至關(guān)重要;移除它們會導(dǎo)致無界激活和性能退化。相比之下,RDN 模型對殘差裁剪很大程度上不敏感。這種魯棒性歸因于其 delta 規(guī)則更新的固有穩(wěn)定性,即使沒有殘差裁剪也能保持一致的損失曲線(圖 4b)。
![]()
5 相關(guān)工作
序列建模歷史上由循環(huán)神經(jīng)網(wǎng)絡(luò)(RNN)(Lipton 等人,2015)主導(dǎo),包括長短期記憶網(wǎng)絡(luò)(LSTM)(Hochreiter & Schmidhuber,1997)和門控循環(huán)單元(GRU)(Cho 等人,2014)等變體。雖然有效,但其固有的順序性質(zhì)阻礙了訓(xùn)練并行化。Transformer 架構(gòu)(Vaswani 等人,2017)克服了這一限制,成為序列建模的事實(shí)標(biāo)準(zhǔn)。然而,其自注意力機(jī)制具有相對于序列長度的二次計(jì)算復(fù)雜度,對長上下文應(yīng)用構(gòu)成了顯著瓶頸。
為解決這些挑戰(zhàn),近期研究重新審視了線性 RNN,將其作為高效 Transformer 替代方案的基礎(chǔ)。通過將序列處理形式化為線性循環(huán),這些模型實(shí)現(xiàn)了可并行化訓(xùn)練和線性時(shí)間推理。該領(lǐng)域的早期探索,如 S4(Gu 等人,2021)、LRU(Orvieto 等人,2023)和 RetNet(Sun 等人,2023),利用了結(jié)構(gòu)化狀態(tài)轉(zhuǎn)移矩陣。通過引入數(shù)據(jù)依賴的動態(tài)特性,后續(xù)實(shí)現(xiàn)了性能飛躍。Mamba(Gu & Dao,2023;Dao & Gu,2024)、HGRN(Qin 等人,2023;2024b)和門控線性注意力(Yang 等人,2023)等模型利用輸入依賴的門控來動態(tài)控制狀態(tài)轉(zhuǎn)移,從而增強(qiáng)其表達(dá)能力。
更先進(jìn)的方法引入了 delta 學(xué)習(xí)規(guī)則,將狀態(tài)更新從簡單的門控衰減重新框架為細(xì)粒度的記憶校正。這種方法以 DeltaNet(Yang 等人,2024b;Schlag 等人,2021)和門控 DeltaNet(Yang 等人,2024a)為代表,實(shí)現(xiàn)了更精確的動態(tài)記憶修改。該機(jī)制可以從在線學(xué)習(xí)視角理解,其中狀態(tài)更新被框架為優(yōu)化過程,如 TTT(Sun 等人,2024)所探索的。這一觀點(diǎn)啟發(fā)了進(jìn)一步的工作,旨在發(fā)現(xiàn)和改進(jìn)序列模型內(nèi)的內(nèi)在學(xué)習(xí)算法(von Oswald 等人,2023;2025)。
同期研究聚焦于增加狀態(tài)轉(zhuǎn)移的表達(dá)能力。例如,RWKV-7(Peng 等人,2025)采用對角加低秩結(jié)構(gòu),而 DeltaProduct(Siems 等人,2025)通過每令牌執(zhí)行多步更新來推廣 DeltaNet。為進(jìn)一步提升容量,近期架構(gòu)如 Titans(Behrouz 等人,2024)和 Miras(Behrouz 等人,2025)引入了非線性深度記憶,用 MLP 對狀態(tài)進(jìn)行參數(shù)化。
6 結(jié)論
在本文中,我們介紹了殘差線性注意力,這是一個(gè)通過顯式殘差擬合過程來增強(qiáng)線性注意力模型的框架。我們的方法利用輔助狀態(tài)來校正基礎(chǔ)模型的預(yù)測誤差,從而構(gòu)建更魯棒和準(zhǔn)確的上下文表示。該框架具有高度適應(yīng)性,可應(yīng)用于各種線性注意力方法。我們的實(shí)驗(yàn)證明了這種多功能性,顯示我們的方法始終優(yōu)于各自的基線。雖然這種改進(jìn)以擬合過程的額外計(jì)算為代價(jià),但平衡這一權(quán)衡為未來的研究提供了一個(gè)有前景的方向。
原文鏈接:https://arxiv.org/pdf/2509.25223v1
特別聲明:以上內(nèi)容(如有圖片或視頻亦包括在內(nèi))為自媒體平臺“網(wǎng)易號”用戶上傳并發(fā)布,本平臺僅提供信息存儲服務(wù)。
Notice: The content above (including the pictures and videos if any) is uploaded and posted by a user of NetEase Hao, which is a social media platform and only provides information storage services.