每日最新頭條.有趣資訊

158行代碼!程序員複現DeepMind圖像生成神器

新智元推薦

整理編輯:張佳

【新智元導讀】最近,谷歌 DeepMInd 發表論文提出了一個用於圖像生成的遞歸神經網絡,該系統大大提高了 MNIST 上生成模型的質量。為更加深入了解 DRAW,本文作者基於 Eric Jang 用 158 行 Python 代碼實現該系統的思路,詳細闡述了 DRAW 的概念、架構和優勢等。

遞歸神經網絡是一種用於圖像生成的神經網絡結構。Draw Networks 結合了一種新的空間注意機制,該機制模擬了人眼的中心位置,採用了一個順序變化的自動編碼框架,使之對複雜圖像進行迭代構造。

該系統大大提高了 MNIST 上生成模型的質量,特別是當對街景房屋編號數據集進行訓練時,肉眼竟然無法將它生成的圖像與真實數據區別開來。

Draw 體系結構的核心是一對遞歸神經網絡:一個是壓縮用於訓練的真實圖像的編碼器,另一個是在接收到代碼後重建圖像的解碼器。這一組合系統採用隨機梯度下降的端到端訓練,損失函數的最大值變分主要取決於對數似然函數的數據。

Draw 網絡類似於其他變分自動編碼器,它包含一個編碼器網絡,該編碼器網絡決定著潛在代碼上的 distribution(潛在代碼主要捕獲有關輸入數據的顯著信息),解碼器網絡接收來自 code distribution 的樣本,並利用它們來調節其自身圖像的 distribution 。

DRAW 與其他自動解碼器的三大區別

編碼器和解碼器都是 DRAW 中的遞歸網絡,解碼器的輸出依次添加到 distribution 中以生成數據,而不是一步一步地生成 distribution 。動態更新的注意機制用於限制由編碼器負責的輸入區域和由解碼器更新的輸出區域 。簡單地說,這一網絡在每個 time-step 都能決定“讀到哪裡”和“寫到哪裡”以及“寫什麽”。

左:傳統變分自動編碼器

在生成過程中,從先前的 P(z)中提取一個樣本 z ,並通過前饋譯碼器網絡來計算給定樣本的輸入 P(x_z)的概率。

在推理過程中,輸入 x 被傳遞到編碼器網絡,在潛在變量上產生一個近似的後驗 Q(z|x) 。在訓練過程中,從 Q(z|x) 中抽取 z,然後用它計算總描述長度 KL ( Q (Z|x)∣∣ P(Z)−log(P(x|z)),該長度隨隨機梯度的下降(https://en.wikipedia.org/wiki/Stochastic_gradient_descent)而減小至最小值。

右:DRAW網絡

在每一個步驟中,都會將先前 P(z)中的一個樣本 z_t 傳遞給遞歸解碼器網絡,該網絡隨後會修改 canvas matrix 的一部分。最後一個 canvas matrix cT 用於計算 P(x|z_1:t)。

在推理過程中,每個 time-step 都會讀取輸入,並將結果傳遞給編碼器 RNN,然後從上一 time-step 中的 RNN 指定讀取位置,編碼器 RNN 的輸出用於計算該 time-step 的潛在變量的近似後驗值。

損失函數

最後一個 canvas matrix cT 用於確定輸入數據的模型 D(X | cT) 的參數。如果輸入是二進製的,D 的自然選擇呈伯努利分布,means 由 σ(cT) 給出。重建損失 Lx 定義為 D 下 x 的負對數概率:

The latent loss 潛在distributions序列

的潛在損失被定義為源自

的潛在先驗 P(Z_t)的簡要 KL散度。

鑒於這一損失取決於由

繪製的潛在樣本 z_t ,因此其反過來又決定了輸入 x。如果潛在 distribution是一個這樣的 diagonal Gaussian ,P(Z_t) 便是一個均值為 0,且具有標準離差的標準 Gaussian,這種情況下方程則變為

網絡的總損失 L 是重建和潛在損失之和的期望值:

對於每個隨機梯度下降,我們使用單個 z 樣本進行優化。

L^Z 可以解釋為從之前的序列向解碼器傳輸潛在樣本序列 z_1:T 所需的 NAT 數量,並且(如果 x 是離散的)L^x 是解碼器重建給定 z_1:T 的 x 所需的 NAT 數量。因此,總損失等於解碼器和之前數據的預期壓縮量。

改善圖片

正如 EricJang 在他的文章中提到的,讓我們的神經網絡僅僅“改善圖像”而不是“一次完成圖像”會更容易些。正如人類藝術家在畫布上塗塗畫畫,並從繪畫過程中推斷出要修改什麽,以及下一步要繪製什麽。

改進圖像或逐步細化只是一次又一次地破壞我們的聯合 distribution P(C) ,導致潛在變量鏈 C1,C2,…CT−1 呈現新的變量分布 P(CT) 。

訣竅是多次從迭代細化分布 P(Ct|Ct−1)中取樣,而不是直接從 P(C) 中取樣。

在 DRAW 模型中, P(Ct|Ct−1) 是所有 t 的同一 distribution,因此我們可以將其表示為以下遞歸關係(如果不是,那麽就是Markov Chain而不是遞歸網絡了)。

DRAW模型的實際應用

假設你正在嘗試對數字 8 的圖像進行編碼。每個手寫數字的繪製方式都不同,有的樣本 8 可能看起來寬一些,有的可能長一些。如果不注意,編碼器將被迫同時捕獲所有這些小的差異。

但是……如果編碼器可以在每一幀上選擇一小段圖像並一次檢查數字 8 的每一部分呢?這會使工作更容易,對吧?

同樣的邏輯也適用於生成數字。注意力單元將決定在哪裡繪製數字 8 的下一部分-或任何其他部分-而傳遞的潛在矢量將決定解碼器生成多大的區域。

基本上,如果我們把變分的自動編碼器(VAE)中的潛在代碼看作是表示整個圖像的矢量,那麽繪圖中的潛在代碼就可以看作是表示筆畫的矢量。最後,這些向量的序列實現了原始圖像的再現。

好吧,那麽它是如何工作的呢?

在一個遞歸的 VAE 模型中,編碼器在每一個 timestep 會接收整個輸入圖像。在 Draw 中,我們需要將焦點集中在它們之間的 attention gate 上,因此編碼器隻接收到網絡認為在該 timestep 重要的圖像部分。第一個 attention gate 被稱為“Read”attention。

“Read”attention分為兩部分:

選擇圖像的重要部分和裁剪圖像

選擇圖像的重要部分

為了確定圖像的哪一部分最重要,我們需要做些觀察,並根據這些觀察做出決定。在 DRAW中,我們使用前一個 timestep 的解碼器隱藏狀態。通過使用一個簡單的完全連接的圖層,我們可以將隱藏狀態映射到三個決定方形裁剪的參數:中心 X、中心 Y 和比例。

裁剪圖像

現在,我們不再對整個圖像進行編碼,而是對其進行裁剪,隻對圖像的一小部分進行編碼。然後,這個編碼通過系統解碼成一個小補丁。

現在我們到達 attention gate 的第二部分, “write”attention,(與“read”部分的設置相同),只是“write”attention 使用當前的解碼器,而不是前一個 timestep 的解碼器。

雖然可以直觀地將注意力機制描述為一種裁剪,但實踐中使用了一種不同的方法。在上面描述的模型結構仍然精確的前提下,使用了gaussian filters矩陣,沒有利用裁剪的方式。我們在DRAW 中取了一組每個 filter 的中心間距都均勻的gaussian filters 矩陣 。

代碼一覽

我們在 Eric Jang 的代碼的基礎上,對其進行一些清理和注釋,以便於理解.

Eric 為我們提供了一些偉大的功能,可以幫助我們構建 “read” 和 “write” 注意門徑,還有過濾我們將使用的初始狀態功能,但是首先,我們需要添加新的功能,來使我們能創建一個密集層並合並圖像。並將它們保存到本地計算機中,以獲取更新的代碼。

現在讓我們把代碼放在一起以便完成。

獲得更多的PTT最新消息
按讚加入粉絲團