人工智能如何克服遺忘困境?
活到老,學到老,人類可以在不斷變化的環境中連續自適應地學習——在新的環境中不斷吸收新知識,並根據不同的環境靈活調整自己的行為。模仿碳基生命的這一特性,針對連續學習(continual learning,CL)的機器學習算法的研究應運而生,並成為大家日益關注的焦點。
那麼,什麼是連續學習?相較於傳統單任務的機器學習方法,連續學習旨在學習一系列任務,即在連續的信息流中,從不斷改變的概率分布中學習和記住多個任務,並隨着時間的推移,不斷學習新知識,同時保留之前學到的知識。
然而,這個領域的技術發展並非一帆風順,面臨着許多難題。《莊子·秋水》中曾描述過一個這樣的故事:戰國時期,燕國有一少年聽聞趙國都城邯鄲人走路姿勢異常優美,心嚮往之。遺憾的是,他在跟隨邯鄲人學步數月後,卻把之前走路姿勢忘記了,最後甚至都不會走路了,無奈只好爬回了燕國。有趣的是,這則寓言故事深蘊着當前連續學習模型的困境之一——災難性遺忘(catastrophic forgetting),模型在學習新任務之後,由於參數更新對模型的干擾,會忘記如何解決舊任務。而對於機器學習技術而言,另一普遍關注的概念便是泛化誤差(generalization error),這是衡量機器學習模型泛化能力的標準,用以評估訓練好的模型對未知數據預測的準確性。泛化誤差越小,說明模型的泛化能力越好。
儘管目前很多實驗研究致力於解決連續學習中的災難性遺忘問題,但是對連續學習的理論研究還十分有限。哪些因素與災難性遺忘和泛化誤差相關?它們如何明確地影響模型的連續學習能力?對此我們所知甚少。
近期,來自美國俄亥俄州立大學Ness Shroff教授團隊的研究工作「Theory on Forgetting and Generalization of Continual Learning」或有望為這一問題提供詳細的解答。他們從理論上解釋了過度參數化(over parameterization)、任務相似性(task similarity)和任務排序(task ordering)對遺忘和泛化誤差的影響,發現更多的模型參數、更低的噪聲水平、更大的相鄰任務間差異,有助於降低遺忘。同時,通過深度神經網絡(DNN),他們在真實數據集上驗證了該理論的可行性。
圖註:論文封面,該論文於2023年2月刊登在ArXiv上
*連續學習線性模型的構建
在經典的機器學習理論中,參數越多,模型越複雜,往往會帶來不期望見到的過擬合。但以DNN為代表的深度學習模型則不然,其參數越多,模型訓練效果越好。為了理解這一現象,作者更加關注在過參數化的情況下(p>n),連續學習模型的表現。文章首次定義了基於過參數化線性模型的連續學習模型,考量其在災難性遺忘和泛化誤差問題上的閉合解(定理1.1)。
定理1.1 當p≥n+2時,則:
T={1,…,T}代表任務序列;||wi∗ - wj∗||2表徵任務i和j之間的相似性;p為模型實際參數的數量;n為模型需要的參數數量;r為過參數化的比例,r=1-n/p;σ為噪聲水平;ci,j =(1-r)(rT-i-rj-i+rT-j),其中1≤i≤j≤T;更多參數介紹詳看原始文獻和附錄部分。
(9)式和(10)式分別為災難性遺忘FT和泛化誤差GT的數學表示。它們不僅描述了連續學習在線性模型中是如何工作的,還為其在一些真實的數據集和DNN中的應用提供指導。
*連續學習中的鼎足三分
在上述數學模型的基礎上,作者還研究了在連續學習過程中,過參數化、任務之間的相似程度和任務的訓練順序三個因素對災難性遺忘和泛化誤差的影響。
1)過參數化
·更多的模型訓練參數將有助於降低遺忘
如定理1.1所示,當表示參數數量的p趨近於0時,E[FT]也將趨近於零。
·噪聲水平和(或)任務間相似度低的情況下,過參數化更好
為了比較過參數化和欠參數化時模型的性能,作者構建了與定理1.1類似的,在欠參數情況下的理論模型定理1.2。
定理1.2 當n≥p+2時,則:
如定理1.2所示,欠參數化的情況下,當噪聲水平σ較大時,以及當訓練的任務間區分度較大時,E[FT]和E[GT]都變大。相反,過參數化的情況下,當噪聲水平σ較大時,以及當訓練的任務間不太相似時,E[FT]和E[GT]都變小。這表明當噪聲水平高和(或)訓練任務相似性較低時,過參數化的情況可能比欠參數化的情況訓練效果更好,即存在良性過擬合。
2)連續訓練任務的相似性
· 泛化誤差隨着任務相似性的增加而降低,而遺忘則可能不會隨之降低
如定理1.1所示,由於公式(10)中G2項的係數始終為正,所以當任務之間越相似,區分度越少時,泛化誤差會相應降低。但是由於公式(9)中,F2項的係數並不總是為正,所以可能出現任務之間的相似性增加模型的遺忘性能也增加的情況。
3)任務訓練順序
· 在早期階段將差異大的任務相鄰訓練,將有助於降低遺忘
為了找到連續學習中,任務的最優訓練順序。作者考慮了兩種特殊情況。情況一,任務集由一個特殊的任務,和剩餘其它完全一模一樣的任務組成。情況二,任務集由數目相同的不同任務組成。通過對兩種情況的比較分析得出:
首先,特殊的任務在訓練時,應優先在前半段執行;
其次,相鄰任務之間應差異較大;這些措施都將有助於降低連續學習模型的遺忘。但是,最小化的遺忘和最小化的泛化誤差的最佳任務訓練排序有時並不相同。
DNN對連續學習模型的驗證
最後,為了驗證上述推論的可靠性,作者使用DNN在真實數據集上進行實驗。後續的實驗結果明確地證實了,任務相似性對連續學習模型災難性遺忘的非單調性影響。而關於任務排序影響的實驗結果也與前面線性模型中的發現一致,即應在模型訓練早期設置區分度較大的任務學習,並安排區分度較大任務相鄰訓練。
表1:使用TRGP和TRGP+兩種任務策略在不同數據集中訓練得到的準確性和反向遷移(用負值表示遺忘;值越大/正,表示知識反向遷移效果越好)結果
正向遷移:在學習新任務的過程中,利用以前的任務中學習到的經驗來幫助新任務的知識學習。
反向遷移:在學習新任務的過程中,學習到的新知識,鞏固了以前任務的知識學習。
PMNIST數據集:MNIST數據集是機器學習模型訓練所使用的經典數據集,包含0-9這10個數字的手寫樣本,其中每個樣本的輸入是一個圖像,標籤是圖像所代表的數字。PMNIST是基於MNIST數據集的變種,由10種不同的MNIST樣本置換順序的連續學習任務組成,可進行連續學習問題的評估。Split CIFAR-100數據集:CIFAR-100數據集也是機器學習模型訓練所使用的經典數據集,包含100種分類任務,如蜜蜂、蝴蝶等。每類有600張彩色圖像,其中500張作為訓練集,100張作為測試集。同樣,為了在該數據集上進行連續學習問題的評估,作者將CIFAR-100數據集等分為10組,每一組由10個完全不同的分類任務組成,重構了Split CIFAR-100連續學習數據集。
更有趣的是,作者發現,相較於賦以不同時間點學習的舊任務相同的權重(TRGP)的策略,賦以最近學習的舊任務更多的權重(TRGP+),可以更好地促進連續學習模型的知識正向遷移和反向遷移(表 1)。這些發現有望為後續連續學習策略的設計提供理論參考。-(文:追問NextQuestion/鈦媒體)