每日最新頭條.有趣資訊

Jax 生態再添新庫:DeepMind 開源 Haiku、RLax

機器之心報導

參與:一鳴

Jax 是一個優秀的代碼庫,在進行科學計算的同時能夠自動微分,還有 GPU、TPU 的性能加速加持。但是 Jax 的生態還不夠完善,使用者相比 TF、PyTorch 少得多。近日,DeepMind 開源了兩個基於 Jax 的新庫,給這個生態注入了新的活力。

Jax 是谷歌開源的一個科學計算庫,能對 Python 程序與 NumPy 運算執行自動微分,而且能夠在 GPU 和 TPU 上運行,具有很高的性能。基於 Jax 已有很多優秀的開源項目,如 Trax 等。近日,DeepMind 開源了兩個基於 Jax 的新機器學習庫,分別是 Haiku 和 RLax,它們都有著各自的特色,對於豐富深度學習社區框架、提升研究者和開發者的使用體驗有著不小的意義。

Haiku:https://github.com/deepmind/haiku

RLax:https://github.com/deepmind/rlax

Haiku:在 Jax 上進行面向對象開發

首先值得注意的是 Haiku,這是一個面向 Jax 的深度學習代碼庫,它是由 Sonnet 作者——一個谷歌的神經網絡庫團隊開發的。

為什麽要使用 Haiku?這是因為其支持的是 Jax,Jax 在靈活性和性能上具有相當的優勢。但是另一方面,Jax 本身是函數式的,和面向對象的用戶習慣有差別。因此,通過 Haiku,用戶可以在 Jax 上進行面向對象開發了。

此外,Haiku 的 API 和編程模型都是基於 Sonnet,因此使用過 Sonnet 的用戶可以快速上手。項目作者也表示,Sonnet 之於 TensorFlow 的提升就如同 Haiku 之於 Jax。

目前,Haiku 已公開了 Alpha 版本,已完全開源。項目作者歡迎使用者提出建議。

Haiku 怎麽和 Jax 互動

Haiku 主要分為兩個模塊:hk.Modules和hk.transform。下文將會分別介紹。

hk.Modules是 Python 對象,保存著到參數、其他模塊和方法的參照(references)。

hk.transform 則負責將面向對象的模塊轉換為純粹的函數式代碼,然後讓 jax 中的 jax.jit, jax.grad, jax.pmap 等進行處理,從而實現和 Jax 組件的兼容。

Haiku 的功能

Haiku 能夠做到很多機器學習需要完成的任務,相關功能和代碼如下:

自定義你的模塊

在 Haiku 中,類似於 TF2.0 和 PyTorch,你可以自定義模塊,作為 hk.Module 的子類。例如,自定義一個線性層:

可以看出,Haiku 的代碼和 TensorFlow 等非常相似,但是你可以看到包括 numpy 等的方法還可以定義在模塊中。Haiku 的優勢就在於,它不是一個封閉的框架,而是代碼庫,因此可以在定義模塊的過程中調用其他的庫和方法。

當定義好線性層後,我們想要試試自動微分的方法了:

這裡可以看到,定義好模塊和前向傳播的函數後,使用 hk.transform(forward_fn) 可以將這種面向對象的方法轉換成 Jax 底層的函數式代碼進行處理,因此你不需要擔心底層的計算問題。另外,這裡的代碼相比 TensorFlow 還要簡潔。

非訓練狀態

有時候,我們想要在訓練的過程中保持某些內部參數的狀態,在 Haiku 上這也是非常容易實現的。

如上所示,只需要兩行代碼進行設置。

和 jax.pmap 聯合進行分布式訓練

由於所有的代碼都會被轉換成 Jax 的函數,因此它們和 jax.pmap. 是完全兼容的。這說明,我們可以利用 jax.pmap 來進行分布式計算。

如下為進行數據分割的分布式加速代碼,首先,我們先定義模型和訓練步驟:

然後設定將參數拷貝到所有的設備上:

定義數據分批的方法,以及參數更新的方法:

最後開始分布式計算即可:

RLax:Jax 上也有強化學習庫了

除了令人印象深刻的 Haiku 外,DeepMind 還開源了 RLax——這是一個基於 Jax 的強化學習庫。

相比 Haiku,RLax 專門針對強化學習。項目作者認為,儘管強化學習中的算子和函數並不是完全的算法,但是,如果需要構建完全基於函數式的智能體,就需要特定的數學算子。

因此,函數式的 Jax 就成為了一個不錯的選擇。在 Jax 上進行一定的開發後,就可以有專用的強化學習庫了。RLax 目前的資料還較少,但項目已提供了一個示例代碼:使用 RLax 進行 Q-learning 模型的搭建和訓練。

代碼如下,首先,使用 Haiku 構建基本的強化學習模型:

設定訓練的方法:

以下和 Jax 結合,定義策略、獎勵等:

可以看到,RLax 基於 jax.jit 的方法,在性能方面有不錯的提升。更有趣的是,構建模型的過程中使用了前文提到的 Haiku,可見基於 Jax 生態的代碼庫之間都是可以兼容的。

從 DeepMind 近日開源的兩個代碼庫可以看到,雖然現在深度學習框架依然在穩步發展,但是針對高性能的科學計算也漸漸變得更為重要了。而 Jax 這樣的優秀開源項目,無疑也需要更多的生態支持。這次開源的 Haiku 和 RLax,無疑能夠鞏固 Jax 的地位,使其優秀的特性進一步得到發揮。

本文為機器之心報導,轉載請聯繫本公眾號獲得授權。

------------------------------------------------

獲得更多的PTT最新消息
按讚加入粉絲團