每日最新頭條.有趣資訊

MIT\北大\CMU合作 找到深度神經網絡全局最優解

新智元報導

來源:Arxiv

編譯:大明

【新智元導讀】深度學習的網絡訓練損失問題一直是學術界關注的熱點。過去,利用梯度下降法找到的一般都是局部最優解。近日,CMU、MIT和北京大學的研究人員分別對深度全連接前饋神經網絡、ResNet和卷積ResNet進行了分析,並表明利用梯度下降可以找到全局最小值,在多項式時間內實現零訓練損失。

在目標函數非凸的情況下,梯度下降在訓練深度神經網絡中也能夠找到全局最小值。本文證明,對於具有殘差連接的超參數化的深度神經網絡(ResNet),採用梯度下降可以在多項式時間內實現零訓練損失。

本文的分析基於由神經網絡架構建立的Gram矩陣的特定結構。該結構顯示在整個訓練過程中,Gram矩陣是穩定的,並且這種穩定性意味著梯度下降算法的全局最優性。使用ResNet可以獲得相對於全連接的前饋網絡架構的優勢。

對於前饋神經網絡,邊界要求每層網絡中的神經元數量隨網絡深度的增加呈指數級增長。對於ResNet,只要求每層的神經元數量隨著網絡深度的實現多項式縮放。我們進一步將此類分析擴展到深度殘余卷積神經網絡上,並獲得了類似的收斂結果。

找到梯度下降全局最優解,實現訓練零損失

深度學習中的一個難題是隨機初始化的一階方法,即使目標函數是非凸的,梯度下降也會實現零訓練損失。一般認為過參數化是這種現象的主要原因,因為只有當神經網絡具有足夠大的容量時,該神經網絡才有可能適合所有訓練數據。在實踐中,許多神經網絡架構呈現高度的過參數化。

訓練深度神經網絡的第二個神秘現象是“越深層的網絡越難訓練”。為了解決這個問題,採用提出了深度殘差網絡(ResNet)架構,該架構使得隨機初始化的一階方法能夠訓練具有更多層數的數量級的神經網絡。

從理論上講,線性網絡中的殘余鏈路可以防止大的零鄰域中的梯度消失,但對於具有非線性激活的神經網絡,使用殘差連接的優勢還不是很清楚。

本文揭開了這兩個現象的神秘面紗。我們考慮設定n個數據點,神經網絡有H層,寬度為m。然後考慮最小二乘損失,假設激活函數是Lipschitz和平滑的。這個假設適用於許多激活函數,包括soft-plus。

論文鏈接:

https://arxiv.org/pdf/1811.03804.pdf

首先考慮全連接前饋神經網絡,在神經元數量m=Ω(poly(n)2O(H))的情況下,隨機初始化的梯度下降會以線性速度收斂至零訓練損失。

接下來考慮ResNet架構。只要神經元數量m =Ω(poly(n,H)),那麽隨機初始化的梯度下降會以線性速率收斂到零訓練損失。與第一個結果相比,ResNet對網絡層數的依賴性呈指數級上升。這證明了使用殘差連接的優勢。

最後,用相同的技術來分析卷積ResNet。結果表明,如果m = poly(n,p,H),其中p是patch數量,則隨機初始化的梯度下降也可以實現零訓練損失。

本文的研究證據建立在先前關於兩層神經網絡梯度下降的研究理念之上。首先,作者分析了預測的動力學情況,其收斂性由神經網絡結構引出的Gram矩陣的最小特徵值確定,為了降低其最小特徵值的下限,從初始化階段限制每個權重矩陣的距離就可以了。

其次,作者使用Li和Liang[2018]的觀察結果,如果神經網絡是過參數化的,那麽每個權重矩陣都接近其初始化狀態。本文在分析深度神經網絡時,需要構建更多深度神經網絡的架構屬性和新技術。

本文附錄中給出了詳細的數學證明過程

接下來,論文分別給出了全連接前饋神經網絡、ResNet和卷積ResNet的分析過程,並在長達20余頁的附錄部分(本文含附錄共計45頁)給出了詳細的數學證明過程,對自己的數學功底有自信的讀者可以自行參看論文。這裡僅就ResNet分析過程中,Gram矩陣的構建和研究假設做簡要說明。

Gram矩陣的構建

以上是網絡寬度m趨於無限時的漸進Gram矩陣。我們特做出如下假設,該假設條件決定了收斂速度和過參數化數量。

注意,這裡的λ和全連接前饋神經網絡中的不同,因為這裡的λ隻由K(0)決定,一般來說,除非兩個數據點是平行的,否則λ總是正數。

研究結論和局限:目前還不是隨機梯度下降

在本文中,我們表明深度過度參數化網絡上的梯度下降可以獲得零訓練損失。其中關鍵是證明了Gram矩陣在過參數化條件下會越來越穩定,因此梯度下降的每一步都以幾何速率減少損失。

最後列出未來的一些潛在研究方向:

1.本文主要關注訓練損失,但沒有解決測試損失的問題。如何找到梯度下降的低測試損失的解決方案將是一個重要問題。尤其是現有的成果隻表明梯度下降在與kernel方法和隨機特徵方法相同的情況下才起作用。

2.網絡層的寬度m是ResNet架構的所有參數的多項式,但仍然非常大。而在現實網絡中,數量較大的是參數的數量,而不是網絡層的寬度,數據點數量n是個很大的常量。如何改進帳析過程,使其涵蓋常用的網絡,是一個重要的、有待解決的問題。

3、目前的分析只是梯度下降,不是隨機梯度下降。我們認為這一分析可以擴展到隨機梯度下降,同時仍然保持線性收斂速度。

論文鏈接:

https://arxiv.org/pdf/1811.03804.pdf

獲得更多的PTT最新消息
按讚加入粉絲團