クイックスタート | 変換 | インストールガイド | ニューラルネットライブラリ | ログの変更 | リファレンスドキュメント
JAXはAutogradとXLAであり、高性能の機械学習研究のために統合されています。
Autogradの更新バージョンにより、JAXはネイティブのPython関数とNumPy関数を自動的に区別できます。ループ、分岐、再帰、クロージャを介して区別でき、導関数の導関数の導関数を取ることができます。フォワードモード微分だけでなく、リバースモード微分(バックプロパゲーション)もサポートしてgrad
おり、この2つは任意の順序で任意に構成できます。
新機能は、JAXがXLAを使用
してGPUおよびTPUでNumPyプログラムをコンパイルおよび実行することです。コンパイルはデフォルトで内部で行われ、ライブラリ呼び出しはジャストインタイムでコンパイルおよび実行されます。ただし、JAXでは、1関数APIを使用して、独自のPython関数をXLA最適化カーネルにジャストインタイムでコンパイルすることもできます
jit
。コンパイルと自動微分は任意に構成できるため、Pythonを離れることなく、高度なアルゴリズムを表現して最大のパフォーマンスを得ることができます。を使用して、複数のGPUまたはTPUコアを一度にプログラムpmap
し、全体を区別することもできます。
もう少し掘り下げてみると、JAXは実際には
構成可能な関数変換のための拡張可能なシステムであることがわかります。grad
とは両方とも
jit
そのような変換のインスタンスです。その他は
vmap
、自動ベクトル化および
pmap
複数のアクセラレータの単一プログラム複数データ(SPMD)並列プログラミング用であり、今後さらに増える予定です。
これは研究プロジェクトであり、公式のGoogle製品ではありません。バグや 鋭いエッジが予想されます。試してみて、バグを報告し、あなたの考えを私たちに知らせてください!
import jax.numpy as jnp
from jax import grad, jit, vmap
def predict(params, inputs):
for W, b in params:
outputs = jnp.dot(inputs, W) + b
inputs = jnp.tanh(outputs) # inputs to the next layer
return outputs # no activation on last layer
def loss(params, inputs, targets):
preds = predict(params, inputs)
return jnp.sum((preds - targets)**2)
grad_loss = jit(grad(loss)) # compiled gradient evaluation function
perex_grads = jit(vmap(grad_loss, in_axes=(None, 0, 0))) # fast per-example grads
Google Cloud GPUに接続された、ブラウザでノートブックを使用してすぐにジャンプします。スターターノートブックは次のとおりです。
grad微分、
jitコンパイル、および
vmapベクトル化のためのNumPy
JAXはCloudTPUで実行されるようになりました。プレビューを試すには、CloudTPUColabsを参照してください。
JAXをさらに深く掘り下げるには、次のようにします。
ニューラルネットワークの構築
や一次確率的最適化などのミニライブラリ
jax.example_libraries
、または例を確認することもできます。stax
optimizers
JAXは、その中核として、数値関数を変換するための拡張可能なシステムです。主な関心のある4つの変換は、、、、、
gradおよびです 。
jit
vmap
pmap
grad
JAXにはAutogradとほぼ同じAPIがあります。最も一般的な関数は
grad
、リバースモードグラデーション用です。
from jax import grad
import jax.numpy as jnp
def tanh(x): # Define a function
y = jnp.exp(-2.0 * x)
return (1.0 - y) / (1.0 + y)
grad_tanh = grad(tanh) # Obtain its gradient function
print(grad_tanh(1.0)) # Evaluate it at x = 1.0
# prints 0.4199743
で任意の順序に区別できます
grad。
print(grad(grad(grad(tanh)))(1.0))
# prints 0.62162673
より高度なautodiffの場合は
jax.vjp
、逆モードのJacobian製品と
jax.jvp
、順モードのJacobianベクトル製品に使用できます。この2つは、相互に、および他のJAX変換を使用して任意に構成できます。完全なヘッセ行列を効率的に計算する関数を作成するためにそれらを構成する1つの方法は次のとおりです。
from jax import jit, jacfwd, jacrev
def hessian(fun):
return jit(jacfwd(jacrev(fun)))
Autogradと同様に、Pythonコントロール構造で差別化を自由に使用できます。
def abs_val(x):
if x > 0:
return x
else:
return -x
abs_val_grad = grad(abs_val)
print(abs_val_grad(1.0)) # prints 1.0
print(abs_val_grad(-1.0)) # prints -1.0 (abs_val is re-evaluated)
詳細については、自動微分に関するリファレンスドキュメント とJAXAutodiffクックブックを参照してください 。
jit
XLAを使用して、デコレータまたは高階関数としてjit
使用される関数をエンドツーエンドでコンパイルできます
。
@jit
import jax.numpy as jnp
from jax import jit
def slow_f(x):
# Element-wise ops see a large benefit from fusion
return x * x + x * 2.0
x = jnp.ones((5000, 5000))
fast_f = jit(slow_f)
%timeit -n10 -r3 fast_f(x) # ~ 4.5 ms / loop on Titan X
%timeit -n10 -r3 slow_f(x) # ~ 14.5 ms / loop (also on GPU via JAX)
jitと、
gradおよびその他のJAX変換を好きなように組み合わせることができます。
を使用
jitすると、関数が使用できるPython制御フローの種類に制約が課せられます。詳細については、GotchasNotebookを参照してください 。
vmap
vmap
ベクトル化マップです。配列軸に沿って関数をマッピングするというおなじみのセマンティクスがありますが、ループを外側に保持する代わりに、パフォーマンスを向上させるためにループを関数のプリミティブ操作にプッシュダウンします。
を使用
vmapすると、コードでバッチディメンションを持ち歩く必要がなくなります。たとえば、次の単純なバッチ処理されていないニューラルネットワーク予測関数について考えてみます。
def predict(params, input_vec):
assert input_vec.ndim == 1
activations = input_vec
for W, b in params:
outputs = jnp.dot(W, activations) + b # `activations` on the right-hand side!
activations = jnp.tanh(outputs) # inputs to the next layer
return outputs # no activation on last layer
jnp.dot(activations, W)代わりに、の左側のバッチディメンションを許可するように記述することがよくあります
activationsが、この特定の予測関数は、単一の入力ベクトルにのみ適用されるように記述されています。この関数を一度に入力のバッチに適用したい場合は、意味的には次のように書くことができます。
from functools import partial
predictions = jnp.stack(list(map(partial(predict, params), input_batch)))
ただし、一度に1つの例をネットワーク経由でプッシュするのは時間がかかります。計算をベクトル化することをお勧めします。これにより、すべてのレイヤーで、行列-ベクトルの乗算ではなく、行列-行列の乗算を実行します。
vmap関数は私たちのためにその変換を行います。つまり、私たちが書く場合
from jax import vmap
predictions = vmap(partial(predict, params))(input_batch)
# or, alternatively
predictions = vmap(predict, in_axes=(None, 0))(params, input_batch)
次に、
vmap関数は関数内の外側のループをプッシュし、マシンは、手動でバッチ処理を行ったかのように、行列と行列の乗算を実行することになります。
なしで単純なニューラルネットワークを手動でバッチ処理するのは簡単です
vmapが、それ以外の場合、手動のベクトル化は非現実的または不可能な場合があります。例ごとの勾配を効率的に計算するという問題を取り上げます。つまり、パラメーターの固定セットに対して、バッチ内の各例で個別に評価された損失関数の勾配を計算する必要があります。を使用
vmapすると、簡単です。
per_example_gradients = vmap(partial(grad(loss), params))(inputs, targets)
もちろん、、、およびその他のJAX変換を使用
vmapして任意に構成できます。、、、およびの高速ジャコビアン行列とヘッセ行列の計算には、順方向モードと逆方向モードの両方の自動微分 を使用します。
jit
grad
vmap
jax.jacfwd
jax.jacrev
jax.hessian
pmap
複数のGPUなど、複数のアクセラレータの並列プログラミングには、を使用します
pmap
。
pmap高速並列集合通信操作を含む、単一プログラム複数データ(SPMD)プログラムを作成します。適用する
pmapとは、作成した関数がXLAによってコンパイルされ(と同様に
jit)、デバイス間で並行して複製および実行されることを意味します。
8GPUマシンでの例を次に示します。
from jax import random, pmap
import jax.numpy as jnp
# Create 8 random 5000 x 6000 matrices, one per GPU
keys = random.split(random.PRNGKey(0), 8)
mats = pmap(lambda key: random.normal(key, (5000, 6000)))(keys)
# Run a local matmul on each device in parallel (no data transfer)
result = pmap(lambda x: jnp.dot(x, x.T))(mats) # result.shape is (8, 5000, 5000)
# Compute the mean on each device in parallel and print the result
print(pmap(jnp.mean)(result))
# prints [1.1566595 1.1805978 ... 1.2321935 1.2015157]
純粋なマップを表現することに加えて 、デバイス間で高速の集合通信操作を使用できます。
from functools import partial
from jax import lax
@partial(pmap, axis_name='i')
def normalize(x):
return x / lax.psum(x, 'i')
print(normalize(jnp.arange(4.)))
# prints [0. 0.16666667 0.33333334 0.5 ]
より洗練された通信パターンのために関数をネストpmap
することもできます。
すべてが構成されているため、並列計算によって自由に区別できます。
from jax import grad
@pmap
def f(x):
y = jnp.sin(x)
@pmap
def g(z):
return jnp.cos(z) * jnp.tan(y.sum()) * jnp.tanh(x).sum()
return grad(lambda w: jnp.sum(g(w)))(x)
print(f(x))
# [[ 0. , -0.7170853 ],
# [-3.1085174 , -0.4824318 ],
# [10.366636 , 13.135289 ],
# [ 0.22163185, -0.52112055]]
print(grad(lambda x: jnp.sum(f(x)))(x))
# [[ -3.2369726, -1.6356447],
# [ 4.7572474, 11.606951 ],
# [-98.524414 , 42.76499 ],
# [ -1.6007166, -1.2568436]]
関数を逆モードで微分する
pmap場合(たとえば、を使用
grad)、計算の逆方向パスは順方向パスと同じように並列化されます。
詳細については、SPMDクックブック とSPMDMNIST分類子のゼロからの例 を参照してください。
例と説明を含む現在の落とし穴のより徹底的な調査については、GotchasNotebookを読むことを強くお勧めします。いくつかの傑出したもの:
is
Exception: Can't lift Traced...
Exception: Different traces at same level
x[i] += yサポートされていませんが、機能的な代替手段があります。の下で
jit、これらの機能的な代替手段は、バッファをインプレースで自動的に再利用します。
jax.laxパッケージに含まれています。
float32)の値を適用します。 倍精度 (64ビットなど)を有効にするには、起動時に変数を設定する(または環境変数
float64を設定する)必要があります。TPUでは、JAXは、やなどの「matmulのような」操作で内部一時変数を除くすべてにデフォルトで32ビット値を使用します。これらの操作には、真の32ビットをシミュレートするために使用できるパラメーターがありますが、実行時間が遅くなる可能性があります。
jax_enable_x64
JAX_ENABLE_X64=True
jax.numpy.dot
lax.conv
precision
np.add(1, np.array([2], np.float32)).dtypeでは
float64なくです
float32。
jitます。何か問題が発生すると、常に大きなエラーが発生します。の パラメータ、 の ような 構造化された制御フロープリミティブを使用 するか、より小さなサブ関数で使用する必要がある場合があります。
jit
static_argnums
lax.scan
jit
jaxlibJAXは純粋なPythonで記述されていますが、パッケージとしてインストールする必要があるXLAに依存します。次の手順を使用して、を使用してバイナリパッケージをインストールする
pipか、ソースからJAXをビルドします。
jaxlibLinux(Ubuntu 16.04以降)およびmacOS(10.12以降)プラットフォームでのインストールまたはビルドをサポートしています。Windowsユーザーは、 Windows SubsystemforLinuxを介してCPUおよびGPUでJAXを使用できます。初期のネイティブWindowsサポートがいくつかありますが、まだやや未成熟であるため、バイナリリリースはなく 、ソースからビルドする必要があります。
ラップトップでローカル開発を行うのに役立つ可能性のあるCPUのみのバージョンのJAXをインストールするには、次のコマンドを実行できます。
pip install --upgrade pip
pip install --upgrade "jax[cpu]"
Linuxでは、多くの場合、最初にホイール
pipをサポートするバージョンに更新する 必要があります。
manylinux2014
CPUとNVidiaGPUの両方をサポートするJAXをインストールする場合、 CUDAと CuDNNがまだインストールされていない場合は、最初にインストールする必要があります。他の一般的なディープラーニングシステムとは異なり、JAXは
pipパッケージの一部としてCUDAまたはCuDNNをバンドルしていません。
JAXは、Linux専用のビルド済みのCUDA互換ホイール、CUDA 11.1以降、およびCuDNN8.0.5以降を提供します。オペレーティングシステム、CUDA、およびCuDNNの他の組み合わせも可能ですが、ソースから構築する必要があります。
次に、実行します
pip install --upgrade pip
# Installs the wheel compatible with CUDA 11 and cuDNN 8.2 or newer.
# Note: wheels only available on linux.
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
jaxlibのバージョンは、使用する既存のCUDAインストールのバージョンに対応している必要があります。jaxlibの特定のCUDAおよびCuDNNバージョンを明示的に指定できます。
pip install --upgrade pip
# Installs the wheel compatible with Cuda >= 11.4 and cudnn >= 8.2
pip install "jax[cuda11_cudnn82]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
# Installs the wheel compatible with Cuda >= 11.1 and cudnn >= 8.0.5
pip install "jax[cuda11_cudnn805]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
次のコマンドでCUDAバージョンを見つけることができます:
nvcc --version
一部のGPU機能は、CUDAのインストールがであると想定しています。XX
/usr/local/cuda-X.XはCUDAのバージョン番号(例
cuda-11.1)に置き換える必要があります。CUDAがシステムの他の場所にインストールされている場合は、シンボリックリンクを作成できます。
sudo ln -s /path/to/cuda /usr/local/cuda-X.X
ビルド済みのホイールでエラーや問題が発生した場合は、課題追跡システムでお知らせください。
JAXは、 GoogleCloudTPU用のビルド済みホイールも提供します 。JAXを適切なバージョンの
jaxlibおよびと一緒にインストールする
libtpuには、クラウドTPUVMで次を実行できます。
pip install --upgrade pip
pip install "jax[tpu]>=0.2.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
Colab TPUランタイムにはJAXがプリインストールされていますが、JAXをインポートする前に、次のコードを実行してTPUを初期化する必要があります。
import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()
ColabTPUランタイムはCloudTPUVMよりも古いTPUアーキテクチャを使用するため、
jax[tpu]Colabへのインストールは避ける必要があります。何らかの理由でColabTPUランタイムでjaxおよびjaxlibライブラリを更新する場合は、上記のCPUの指示に従ってください(つまり、インストールし
jax[cpu]ます)。
ソースからのJAXの構築を参照してください。
複数のGoogle研究グループが、JAXでニューラルネットワークをトレーニングするためのライブラリを開発および共有しています。例とハウツーガイドを備えたニューラルネットワークトレーニング用のフル機能のライブラリが必要な場合は、 Flaxを試してください。
さらに、DeepMindは、ニューラルネットワークモジュール用のHaiku、勾配処理と最適化用のOptax、RLアルゴリズム用の RLax、信頼性の高いコードとテスト用のchex など、JAX周辺のライブラリのエコシステムをオープンソース化しました。(ここでDeepMindのNeurIPS 2020 JAXエコシステムの講演を ご覧ください)
このリポジトリを引用するには:
@software{jax2018github, author = {James Bradbury and Roy Frostig and Peter Hawkins and Matthew James Johnson and Chris Leary and Dougal Maclaurin and George Necula and Adam Paszke and Jake Vander{P}las and Skye Wanderman-{M}ilne and Qiao Zhang}, title = {{JAX}: composable transformations of {P}ython+{N}um{P}y programs}, url = {http://github.com/google/jax}, version = {0.3.13}, year = {2018}, }
上記のbibtexエントリでは、名前はアルファベット順で、バージョン番号はjax / version.pyからのものであり、年はプロジェクトのオープンソースリリースに対応しています。
自動微分とXLAへのコンパイルのみをサポートするJAXの初期バージョンは、SysML2018に掲載された論文で説明されています。現在、JAXのアイデアと機能をより包括的で最新の論文でカバーするよう取り組んでいます。
JAX APIの詳細については、 リファレンスドキュメントを参照してください。
JAX開発者として開始するには、 開発者向けドキュメントを参照してください。