每日最新頭條.有趣資訊

谷歌大腦:只要網絡足夠寬,激活函數皆可拋

選自arXiv

作者:Jaehoon Lee 等

機器之心編輯部

深度神經網絡以其強大的非線性能力為傲,借助它可以擬合圖像和語音等複雜數據。但最近谷歌大腦的研究者表明只要網絡足夠寬,線性化的網絡可以產生和原版網絡相近的預測結果和準確率,即使沒有激活函數也一樣。這有點反直覺,你現在告訴我 Wide ResNet 那樣的強大模型,在 SGD 中用不用激活函數都一樣?

基於深度神經網絡的機器學習模型在很多任務上達到了前所未有的性能。這些模型一般被視為複雜的系統,很難進行理論分析。此外,由於主導最優化過程的通常是高維非凸損失曲面,因此要描述這些模型在訓練中的梯度動態變化非常具有挑戰性。

就像在物理學中常見的那樣,探索此類系統的理想極限有助於解決這些困難問題。對於神經網絡來說,其中一個理想極限就是無限寬度(infinite width),即全連接層中的隱藏單元數,或者卷積層中的通道數無窮大。在這種限制之下,網絡初始化時的輸出來自於

高斯過程

(GP);此外,在使用平方損失進行精確貝葉斯訓練後,網絡輸出仍然由 GP 控制。除了理論上比較簡單之外,無限寬度的極限也具有實際意義,因為研究者發現更寬的網絡可以更好地泛化。

谷歌大腦的這項研究探索了寬神經網絡在梯度下降時的學習動態,他們發現這一動態過程的權重空間描述可以變得非常簡單:隨著寬度變大,神經網絡在初始化時可以被其參數的一階泰勒展開式(Taylor expansion)有效地代替。這樣我們就可以得到一種線性模型,它的梯度下降過程變得易於分析。雖然線性化只在無限寬度限制下是精確的,但即使在有限寬度的情況下,研究者仍然發現原始網絡的預測與線性化版本的預測非常一致。這種一致性在不同架構、優化方法和損失函數之間持續存在。

對於平方損失,精確的學習動態過程允許存在一個閉式解,它允許我們用 GP 來表征預測分布的演變。這個結果可以看成是「sample-then-optimize」後驗采樣向深度神經網絡訓練的延伸。實驗模擬證實,對於具有不同隨機初始化的有限寬度模型集合,實驗結果可以準確地建模了其預測的變化。

論文的主要貢獻:

作者表明,這項研究工作最重要的貢獻是展示了參數空間中的動態更新過程等價於模型的訓練動態過程(dynamics),且該模型是網絡所有參數(權重項與偏置項)的仿射變換。無論選擇哪種損失函數,該結果都成立。尤其是在使用平方損失時,動態過程允許使用閉式解作為訓練時間的函數。所以像 Wide ResNet 那樣的強大非線性模型,只要足夠寬,它可以直接通過線性的仿射變換直接模擬,激活函數什麽的都沒啥必要了~

這些理論可能看起來太簡單了,不適用於實踐中的神經網絡。儘管如此,作者仍然通過實驗研究了該理論在有限寬度中的適用性,並發現有限寬度線性網絡能表征各種條件下的學習動態過程和後驗函數分布,包括表征實踐中常用的 Wide ResNet。

論文:Wide Neural Networks of Any Depth Evolve as Linear Models Under Gradient Descent

論文鏈接:https://arxiv.org/pdf/1902.06720.pdf

摘要:深度學習研究的長期目標是準確描述訓練和泛化過程。但是,神經網絡極其複雜的損失函數表面使動態過程的理論分析撲朔迷離。谷歌大腦的這項研究展示了,寬神經網絡的學習動態過程難度得到了極大簡化;而對於寬度有限的神經網絡,它們受到線性模型的支配,該線性模型通過初始參數附近的一階泰勒展開式進行定義。此外,具備平方損失的寬神經網絡基於梯度的訓練反映了寬貝葉斯神經網絡和高斯過程之間的對應,這種寬神經網絡生成的測試集預測來自具備特定組合核(compositional kernel)的高斯過程。儘管這些理論結果僅適用於無限寬度的神經網絡,但研究者找到了一些實驗證據,證明即使是寬度有限的現實網絡,其原始網絡的預測結果和線性版本的預測結果也符合該理論。這一理論在不同架構、優化方法和損失函數上具備穩健性。

理論結果

線性化網絡

在實驗部分,本論文展示了線性化網絡(linearized network)能獲得和原始深度非線性網絡相同的輸出結果和準確率等。這一部分簡單介紹了什麽是線性化的網絡,更多理論分析可以查看原論文的第二章節。對於線性化網絡的訓練動態過程,首先我們需要將神經網絡的輸出替換為一階泰勒展開式:

其中 ω_t ≡ θ_t ? θ_0 表示模型參數從初始值到終值的變化。表達式 (6) 左邊的 f_t 為兩項之和:第一項為網絡的初始化輸出,根據泰勒公式,它在訓練過程中是不改變的;第二項則會捕捉初始值在訓練過程中的變化。如果我們使用線性函數,那麽梯度流的動態過程可以表示為:

因為 f_0 對θ的梯度 ?f_0 在整個訓練中都為常數,這些動態過程會顯得比較簡單。在使用 MSE 損失函數時,常微分方程有閉式解:

因此,儘管沒有訓練該網絡,我們同樣能獲得線性化神經網絡沿時間的演化過程。我們只需要計算正切核函數 Θ_0 hat 和初始狀態的輸出 f_0,並根據方程 11、12 和 9 計算模型輸出和權重的動態變化過程。重要的是,這樣計算出來的值竟然和對應非線性深度網絡迭代學習出來的值非常相似。

實驗

本研究進行了實驗,以證明寬神經網絡的訓練動態能夠被線性模型很好地捕捉。實驗包括使用全批量和小批量梯度下降的全連接、卷積和 wide ResNet 架構(梯度下降的學習率非常小),以使連續時間逼近(continuous time approximation)能夠發揮作用。實驗考慮在 CIFAR10 數據集上進行二分類(馬和飛機)、在 MNIST 和 CIFAR-10 數據集上進行十個類別的分類。在使用 MSE 損失時,研究者將二分類任務作為回歸任務來看待,一個類別的回歸值是+1,另一個類別的回歸值是-1。

原始網絡與線性網絡之間的訓練動態過程對比

圖 5、6、7 對比了線性網絡和實際網絡的訓練動態過程。所有示例中二者都達到了很好的一致。

圖 4 展示了線性模型可以很好地描述在 CIFAR-10 數據集上使用交叉熵損失執行分類任務時的學習動態。圖 6 使用交叉熵損失測試 MNIST 分類任務,且使用動量法優化器進行訓練。圖 5 和圖 7 對比了對線性網絡和原始網絡直接進行訓練時二者的訓練動態過程。

圖 4:在模型上執行全批量梯度下降與線性版本上的分析動態過程(analytic dynamics)類似,不管是網絡輸出,還是單個權重。

圖 5:使用具備帶有動量的最優化器進行全批量梯度下降時,卷積網絡和其線性版本的表現類似。

圖 6:神經網絡及其線性版本在 MNIST 數據集上通過具備動量的 SGD 和交叉熵損失進行訓練時,表現類似。

圖 7 對比了使用 MSE 損失和具備動量的 SGD 訓練的 Wide ResNet 的線性動態過程和真實動態過程。研究者稍微修改了圖 7 中的殘差模塊結構,使每一層的通道數保持固定(該示例中通道數為 1024),其他與原始實現一致。

圖 7:Wide ResNet 及其線性化版本表現類似,二者都是通過帶有動量的 SGD 和 MSE 損失在 CIFAR-10 數據集上訓練的。

圖 8 為一系列模型繪製了平台均方根誤差(plateau RMSE),它是寬度和數據集大小的函數。總體來看,誤差隨寬度的增加而降低。全連接網絡的誤差降幅約為 1/N,卷積和 WRN 架構的誤差降幅更加模糊。

圖 8:誤差取決於深度和數據集大小。

本文為機器之心編譯,轉載請聯繫本公眾號獲得授權。

------------------------------------------------

獲得更多的PTT最新消息
按讚加入粉絲團