jax - Python + NumPyプログラムの構成可能な変換:差別化、ベクトル化、JITからGPU/TPUなど

(Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more)

Created at: 2018-10-26 05:25:02
Language: Python
License: NOASSERTION
ロゴ

JAX:AutogradおよびXLA

継続的インテグレーション PyPIバージョン

クイックスタート | 変換 | インストールガイド | ニューラルネットライブラリ | ログの変更 | リファレンスドキュメント

JAXとは何ですか?

JAXはAutogradXLAであり、高性能の機械学習研究のために統合されています。

Autogradの更新バージョンにより、JAXはネイティブのPython関数とNumPy関数を自動的に区別できます。ループ、分岐、再帰、クロージャを介して区別でき、導関数の導関数の導関数を取ることができます。フォワードモード微分だけでなく、リバースモード微分(バックプロパゲーション)もサポートして

grad
おり、この2つは任意の順序で任意に構成できます。

新機能は、JAXがXLAを使用 してGPUおよびTPUでNumPyプログラムをコンパイルおよび実行することです。コンパイルはデフォルトで内部で行われ、ライブラリ呼び出しはジャストインタイムでコンパイルおよび実行されます。ただし、JAXでは、1関数APIを使用して、独自のPython関数をXLA最適化カーネルにジャストインタイムでコンパイルすることもできます

jit
。コンパイルと自動微分は任意に構成できるため、Pythonを離れることなく、高度なアルゴリズムを表現して最大のパフォーマンスを得ることができます。を使用して、複数のGPUまたはTPUコアを一度にプログラム
pmap
し、全体を区別することもできます。

もう少し掘り下げてみると、JAXは実際には 構成可能な関数変換のための拡張可能なシステムであることがわかります。

grad
とは両方とも
jit
そのような変換のインスタンスです。その他は
vmap
、自動ベクトル化および
pmap
複数のアクセラレータの単一プログラム複数データ(SPMD)並列プログラミング用であり、今後さらに増える予定です。

これは研究プロジェクトであり、公式のGoogle製品ではありません。バグや 鋭いエッジが予想されます。試してみて、バグを報告し、あなたの考えを私たちに知らせてください!

import jax.numpy as jnp
from jax import grad, jit, vmap

def predict(params, inputs):
  for W, b in params:
    outputs = jnp.dot(inputs, W) + b
    inputs = jnp.tanh(outputs)  # inputs to the next layer
  return outputs                # no activation on last layer

def loss(params, inputs, targets):
  preds = predict(params, inputs)
  return jnp.sum((preds - targets)**2)

grad_loss = jit(grad(loss))  # compiled gradient evaluation function
perex_grads = jit(vmap(grad_loss, in_axes=(None, 0, 0)))  # fast per-example grads

コンテンツ

クイックスタート:クラウドでのコラボ

Google Cloud GPUに接続された、ブラウザでノートブックを使用してすぐにジャンプします。スターターノートブックは次のとおりです。

JAXはCloudTPUで実行されるようになりました。プレビューを試すには、CloudTPUColabsを参照してください。

JAXをさらに深く掘り下げるには、次のようにします。

ニューラルネットワークの構築一次確率的最適化などのミニライブラリ

jax.example_libraries
、またはを確認することもできます。
stax
optimizers

変換

JAXは、その中核として、数値関数を変換するための拡張可能なシステムです。主な関心のある4つの変換は、、、、、

grad
およびです 。
jit
vmap
pmap

自動微分
grad

JAXにはAutogradとほぼ同じAPIがあります。最も一般的な関数は

grad
、リバースモードグラデーション用です。

from jax import grad
import jax.numpy as jnp

def tanh(x):  # Define a function
  y = jnp.exp(-2.0 * x)
  return (1.0 - y) / (1.0 + y)

grad_tanh = grad(tanh)  # Obtain its gradient function
print(grad_tanh(1.0))   # Evaluate it at x = 1.0
# prints 0.4199743

で任意の順序に区別できます

grad

print(grad(grad(grad(tanh)))(1.0))
# prints 0.62162673

より高度なautodiffの場合は

jax.vjp
、逆モードのJacobian製品と
jax.jvp
、順モードのJacobianベクトル製品に使用できます。この2つは、相互に、および他のJAX変換を使用して任意に構成できます。完全なヘッセ行列を効率的に計算する関数を作成するためにそれらを構成する1つの方法は次のとおりです。

from jax import jit, jacfwd, jacrev

def hessian(fun):
  return jit(jacfwd(jacrev(fun)))

Autogradと同様に、Pythonコントロール構造で差別化を自由に使用できます。

def abs_val(x):
  if x > 0:
    return x
  else:
    return -x

abs_val_grad = grad(abs_val)
print(abs_val_grad(1.0))   # prints 1.0
print(abs_val_grad(-1.0))  # prints -1.0 (abs_val is re-evaluated)

詳細については、自動微分に関するリファレンスドキュメントJAXAutodiffクックブックを参照してください 。

とのコンパイル
jit

XLAを使用して、デコレータまたは高階関数として

jit
使用される関数をエンドツーエンドでコンパイルできます 。
@jit

import jax.numpy as jnp
from jax import jit

def slow_f(x):
  # Element-wise ops see a large benefit from fusion
  return x * x + x * 2.0

x = jnp.ones((5000, 5000))
fast_f = jit(slow_f)
%timeit -n10 -r3 fast_f(x)  # ~ 4.5 ms / loop on Titan X
%timeit -n10 -r3 slow_f(x)  # ~ 14.5 ms / loop (also on GPU via JAX)

jit
と、
grad
およびその他のJAX変換を好きなように組み合わせることができます。

を使用

jit
すると、関数が使用できるPython制御フローの種類に制約が課せられます。詳細については、GotchasNotebookを参照してください 。

による自動ベクトル化
vmap

vmap
ベクトル化マップです。配列軸に沿って関数をマッピングするというおなじみのセマンティクスがありますが、ループを外側に保持する代わりに、パフォーマンスを向上させるためにループを関数のプリミティブ操作にプッシュダウンします。

を使用

vmap
すると、コードでバッチディメンションを持ち歩く必要がなくなります。たとえば、次の単純なバッチ処理されていないニューラルネットワーク予測関数について考えてみます

def predict(params, input_vec):
  assert input_vec.ndim == 1
  activations = input_vec
  for W, b in params:
    outputs = jnp.dot(W, activations) + b  # `activations` on the right-hand side!
    activations = jnp.tanh(outputs)        # inputs to the next layer
  return outputs                           # no activation on last layer

jnp.dot(activations, W)
代わりに、の左側のバッチディメンションを許可するように記述することがよくあります
activations
が、この特定の予測関数は、単一の入力ベクトルにのみ適用されるように記述されています。この関数を一度に入力のバッチに適用したい場合は、意味的には次のように書くことができます。

from functools import partial
predictions = jnp.stack(list(map(partial(predict, params), input_batch)))

ただし、一度に1つの例をネットワーク経由でプッシュするのは時間がかかります。計算をベクトル化することをお勧めします。これにより、すべてのレイヤーで、行列-ベクトルの乗算ではなく、行列-行列の乗算を実行します。

vmap
関数は私たちのためにその変換を行います。つまり、私たちが書く場合

from jax import vmap
predictions = vmap(partial(predict, params))(input_batch)
# or, alternatively
predictions = vmap(predict, in_axes=(None, 0))(params, input_batch)

次に、

vmap
関数は関数内の外側のループをプッシュし、マシンは、手動でバッチ処理を行ったかのように、行列と行列の乗算を実行することになります。

なしで単純なニューラルネットワークを手動でバッチ処理するのは簡単です

vmap
が、それ以外の場合、手動のベクトル化は非現実的または不可能な場合があります。例ごとの勾配を効率的に計算するという問題を取り上げます。つまり、パラメーターの固定セットに対して、バッチ内の各例で個別に評価された損失関数の勾配を計算する必要があります。を使用
vmap
すると、簡単です。

per_example_gradients = vmap(partial(grad(loss), params))(inputs, targets)

もちろん、、、およびその他のJAX変換を使用

vmap
して任意に構成できます。、、、およびの高速ジャコビアン行列とヘッセ行列の計算には、順方向モードと逆方向モードの両方の自動微分 を使用します。
jit
grad
vmap
jax.jacfwd
jax.jacrev
jax.hessian

SPMDプログラミング
pmap

複数のGPUなど、複数のアクセラレータの並列プログラミングには、を使用します

pmap
pmap
高速並列集合通信操作を含む、単一プログラム複数データ(SPMD)プログラムを作成します。適用する
pmap
とは、作成した関数がXLAによってコンパイルされ(と同様に
jit
)、デバイス間で並行して複製および実行されることを意味します。

8GPUマシンでの例を次に示します。

from jax import random, pmap
import jax.numpy as jnp

# Create 8 random 5000 x 6000 matrices, one per GPU
keys = random.split(random.PRNGKey(0), 8)
mats = pmap(lambda key: random.normal(key, (5000, 6000)))(keys)

# Run a local matmul on each device in parallel (no data transfer)
result = pmap(lambda x: jnp.dot(x, x.T))(mats)  # result.shape is (8, 5000, 5000)

# Compute the mean on each device in parallel and print the result
print(pmap(jnp.mean)(result))
# prints [1.1566595 1.1805978 ... 1.2321935 1.2015157]

純粋なマップを表現することに加えて 、デバイス間で高速の集合通信操作を使用できます。

from functools import partial
from jax import lax

@partial(pmap, axis_name='i')
def normalize(x):
  return x / lax.psum(x, 'i')

print(normalize(jnp.arange(4.)))
# prints [0.         0.16666667 0.33333334 0.5       ]

より洗練された通信パターンのために関数をネスト

pmap
することもできます。

すべてが構成されているため、並列計算によって自由に区別できます。

from jax import grad

@pmap
def f(x):
  y = jnp.sin(x)
  @pmap
  def g(z):
    return jnp.cos(z) * jnp.tan(y.sum()) * jnp.tanh(x).sum()
  return grad(lambda w: jnp.sum(g(w)))(x)

print(f(x))
# [[ 0.        , -0.7170853 ],
#  [-3.1085174 , -0.4824318 ],
#  [10.366636  , 13.135289  ],
#  [ 0.22163185, -0.52112055]]

print(grad(lambda x: jnp.sum(f(x)))(x))
# [[ -3.2369726,  -1.6356447],
#  [  4.7572474,  11.606951 ],
#  [-98.524414 ,  42.76499  ],
#  [ -1.6007166,  -1.2568436]]

関数を逆モードで微分する

pmap
場合(たとえば、を使用
grad
)、計算の逆方向パスは順方向パスと同じように並列化されます。

詳細については、SPMDクックブックSPMDMNIST分類子のゼロからの例 を参照してください。

現在の落とし穴

例と説明を含む現在の落とし穴のより徹底的な調査については、GotchasNotebookを読むことを強くお勧めします。いくつかの傑出したもの:

  1. JAX変換は、副作用がなく、参照透過性を尊重する純粋関数でのみ機能します(つまり、を使用したオブジェクトIDテストは保持されません)。不純なPython関数でJAX変換を使用すると、 またはのようなエラーが表示される場合があります。
    is
    Exception: Can't lift Traced...
    Exception: Different traces at same level
  2. のような配列のインプレースミューテーション更新は
    x[i] += y
    サポートされていませんが、機能的な代替手段があります。の下で
    jit
    、これらの機能的な代替手段は、バッファをインプレースで自動的に再利用します。
  3. 乱数は異なりますが、正当な理由があります。
  4. 畳み込み演算子を探している場合は、
    jax.lax
    パッケージに含まれています。
  5. JAXは、デフォルトで単精度(32ビットなど
    float32
    )の値を適用します。 倍精度 (64ビットなど)を有効にするには、起動時に変数を設定する(または環境変数
    float64
    を設定する)必要があります。TPUでは、JAXは、やなどの「matmulのような」操作で内部一時変数を除くすべてにデフォルトで32ビット値を使用します。これらの操作には、真の32ビットをシミュレートするために使用できるパラメーターがありますが、実行時間が遅くなる可能性があります。
    jax_enable_x64
    JAX_ENABLE_X64=True
    jax.numpy.dot
    lax.conv
    precision
  6. PythonスカラーとNumPyタイプの組み合わせを含むNumPyのdtypeプロモーションセマンティクスの一部は保持されません。つまり、
    np.add(1, np.array([2], np.float32)).dtype
    では
    float64
    なくです
    float32
  7. のようないくつかの変換は、Python制御フローの使用方法を制約し
    jit
    ます。何か問題が発生すると、常に大きなエラーが発生します。 パラメータ、 の ような 構造化された制御フロープリミティブを使用 するか、より小さなサブ関数で使用する必要がある場合があります。
    jit
    static_argnums
    lax.scan
    jit

インストール

jaxlib
JAXは純粋なPythonで記述されていますが、パッケージとしてインストールする必要があるXLAに依存します。次の手順を使用して、を使用してバイナリパッケージをインストールする
pip
か、ソースからJAXをビルドします。

jaxlib
Linux(Ubuntu 16.04以降)およびmacOS(10.12以降)プラットフォームでのインストールまたはビルドをサポートしています。Windowsユーザーは、 Windows SubsystemforLinuxを介してCPUおよびGPUでJAXを使用できます。初期のネイティブWindowsサポートがいくつかありますが、まだやや未成熟であるため、バイナリリリースはなく 、ソースからビルドする必要があります。

pipのインストール:CPU

ラップトップでローカル開発を行うのに役立つ可能性のあるCPUのみのバージョンのJAXをインストールするには、次のコマンドを実行できます。

pip install --upgrade pip
pip install --upgrade "jax[cpu]"

Linuxでは、多くの場合、最初にホイール

pip
をサポートするバージョンに更新する 必要があります。
manylinux2014

pipのインストール:GPU(CUDA)

CPUとNVidiaGPUの両方をサポートするJAXをインストールする場合、 CUDACuDNNがまだインストールされていない場合は、最初にインストールする必要があります。他の一般的なディープラーニングシステムとは異なり、JAXは

pip
パッケージの一部としてCUDAまたはCuDNNをバンドルしていません。

JAXは、Linux専用のビルド済みのCUDA互換ホイール、CUDA 11.1以降、およびCuDNN8.0.5以降を提供します。オペレーティングシステム、CUDA、およびCuDNNの他の組み合わせも可能ですが、ソースから構築する必要があります。

  • CUDA11.1以降が必要です
    • ソースからビルドする場合は古いバージョンのCUDAを使用できる場合がありますが、11.1より古いすべてのCUDAバージョンのCUDAには既知のバグがあるため、古いバージョンのCUDA用にビルド済みのバイナリは出荷されません。
  • ビルド済みホイールでサポートされているcuDNNバージョンは次のとおりです。
    • cuDNN8.2以降。cuDNNのインストールが十分に新しい場合は、追加機能をサポートしているため、cuDNN8.2ホイールを使用することをお勧めします。
    • cuDNN8.0.5以降。
  • 少なくとも CUDAツールキットの対応するドライバーバージョンと同じくらい新しいNVidiaドライバーバージョンを使用する必要があります。たとえば、CUDA 11.4 update 4がインストールされている場合、Linuxの場合はNVidiaドライバー470.82.01以降を使用する必要があります。JAXはJITコンパイルコードに依存しているため、これは厳格な要件です。古いドライバーは障害につながる可能性があります。

次に、実行します

pip install --upgrade pip
# Installs the wheel compatible with CUDA 11 and cuDNN 8.2 or newer.
# Note: wheels only available on linux.
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

jaxlibのバージョンは、使用する既存のCUDAインストールのバージョンに対応している必要があります。jaxlibの特定のCUDAおよびCuDNNバージョンを明示的に指定できます。

pip install --upgrade pip

# Installs the wheel compatible with Cuda >= 11.4 and cudnn >= 8.2
pip install "jax[cuda11_cudnn82]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

# Installs the wheel compatible with Cuda >= 11.1 and cudnn >= 8.0.5
pip install "jax[cuda11_cudnn805]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

次のコマンドでCUDAバージョンを見つけることができます:

nvcc --version

一部のGPU機能は、CUDAのインストールがであると想定しています。XX

/usr/local/cuda-X.X
はCUDAのバージョン番号(例
cuda-11.1
)に置き換える必要があります。CUDAがシステムの他の場所にインストールされている場合は、シンボリックリンクを作成できます。

sudo ln -s /path/to/cuda /usr/local/cuda-X.X

ビルド済みのホイールでエラーや問題が発生した場合は、課題追跡システムでお知らせください。

pipのインストール:Google Cloud TPU

JAXは、 GoogleCloudTPU用のビルド済みホイールも提供します 。JAXを適切なバージョンの

jaxlib
およびと一緒にインストールする
libtpu
には、クラウドTPUVMで次を実行できます。

pip install --upgrade pip
pip install "jax[tpu]>=0.2.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

pipのインストール:Colab TPU

Colab TPUランタイムにはJAXがプリインストールされていますが、JAXをインポートする前に、次のコードを実行してTPUを初期化する必要があります。

import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()

ColabTPUランタイムはCloudTPUVMよりも古いTPUアーキテクチャを使用するため、

jax[tpu]
Colabへのインストールは避ける必要があります。何らかの理由でColabTPUランタイムでjaxおよびjaxlibライブラリを更新する場合は、上記のCPUの指示に従ってください(つまり、インストールし
jax[cpu]
ます)。

ソースからJAXを構築する

ソースからのJAXの構築を参照してください。

ニューラルネットワークライブラリ

複数のGoogle研究グループが、JAXでニューラルネットワークをトレーニングするためのライブラリを開発および共有しています。例とハウツーガイドを備えたニューラルネットワークトレーニング用のフル機能のライブラリが必要な場合は、 Flaxを試してください。

さらに、DeepMindは、ニューラルネットワークモジュール用のHaiku、勾配処理と最適化用のOptax、RLアルゴリズム用の RLax、信頼性の高いコードとテスト用のchex など、JAX周辺のライブラリのエコシステムをオープンソース化しました。(ここでDeepMindのNeurIPS 2020 JAXエコシステムの講演を ご覧ください)

JAXを引用する

このリポジトリを引用するには:

@software{jax2018github,
  author = {James Bradbury and Roy Frostig and Peter Hawkins and Matthew James Johnson and Chris Leary and Dougal Maclaurin and George Necula and Adam Paszke and Jake Vander{P}las and Skye Wanderman-{M}ilne and Qiao Zhang},
  title = {{JAX}: composable transformations of {P}ython+{N}um{P}y programs},
  url = {http://github.com/google/jax},
  version = {0.3.13},
  year = {2018},
}

上記のbibtexエントリでは、名前はアルファベット順で、バージョン番号はjax / version.pyからのものであり、年はプロジェクトのオープンソースリリースに対応しています。

自動微分とXLAへのコンパイルのみをサポートするJAXの初期バージョンは、SysML2018に掲載された論文で説明されています。現在、JAXのアイデアと機能をより包括的で最新の論文でカバーするよう取り組んでいます。

リファレンスドキュメント

JAX APIの詳細については、 リファレンスドキュメントを参照してください。

JAX開発者として開始するには、 開発者向けドキュメントを参照してください。