mesh-transformer-jax - JAXとHaikuの並列変圧器のモデル化

(Model parallel transformers in JAX and Haiku)

Created at: 2021-03-14 07:31:13
Language: Python
License: Apache-2.0

目次

  1. メッシュトランス JAX
    1. 最新情報
  2. 事前トレーニング済みモデル
    1. GPT-J-6B
      1. リンクス
      2. 確認
      3. ライセンス
      4. モデルの詳細
      5. ゼロショット評価
  3. アーキテクチャと使用法
    1. 微調整
    2. JAX の依存関係
  4. 藤堂

メッシュトランス JAX

トランスフォーマーのモデル並列処理にJAXの/演算子を使用した俳句ライブラリ。

xmap
pjit

並列処理スキームは元のメガトロン-LMに似ており、効率的です 高速2DメッシュネットワークによるTPUで。ZeRoスタイルを実装した実験モデルバージョンもあります シャーディング

このライブラリは、TPUv3で最大約40Bのパラメータを拡張できるように設計されています。 並列処理戦略を使用する必要があります。それについては、GPT-NeoXDeepSpeedなどの他の実装を参照してください。

研究の将来の方向性の1つは、このコードベースをswarm-jaxと統合して、パイプラインの並列処理でさらなるスケーラビリティを実現することです。

最新情報

12-07-21: 微調整のガイドを追加

事前トレーニング済みモデル

GPT-J-6B

The Pile でトレーニングされた 60 億個のパラメーター、自己回帰テキスト生成モデル。

リンクス

スリムな重み(bf16の重みのみ、推論用、9GB)

フルウェイト(オプティマイザパラメータを含む、61GB)

コラボデモ

ウェブデモ

アランのブログ投稿

確認

このプロジェクトは、EleutherAIの支援を受けてTPUリサーチクラウドが寛大に提供してくれたコンピューティングなしでは不可能でした。

Cloud TPU VM アルファ版への早期アクセスを提供してくれた Google の Cloud TPU チームに感謝します。 (公開中!)

何らかの方法で助けてくれたすべての人に感謝します(アルファベット順に記載されています):

ライセンス

GPT-J-6Bの重量は、Apacheライセンスのバージョン2.0の下でライセンスされています。

モデルの詳細

ハイパーパラメータ 価値
n_parameters 6,053,381,344
n_layers 28*
d_model 4,096
d_ff 16,384
n_heads 16
d_head 256
n_ctx 2,048
n_vocab 50,257 (GPT-2/3と同じトークナイザー)
位置エンコーディング ロータリー位置エンコーディング (RoPE)
RoPE 寸法 64

*
各層は、1つのフィードフォワードブロックと1つの自己注意ブロックで構成されます

モデルは 28 個のレイヤーで構成され、モデル寸法は 4096、フィードフォワード寸法は 16384 です。モデル ディメンションは 16 個のヘッドに分割され、各ヘッドのディメンションは 256 です。ロータリーポジションエンコーディング(RoPE)が64に適用されました 各ヘッドの寸法。モデルは、トークン化ボキャブラリ 50257 でトレーニングされ、同じ BPE のセットを使用します。 GPT-2/GPT-3.

ゼロショット評価

モデルは、パフォーマンスによって大まかにソートされ、使用できない場合はFLOPによってソートされます。

モデル 重み トレーニング FLOP ランバダPPL ↓ ランバダアック↑ ウィノグランデ ↑ ヘラスワグ↑ ピカ ↑ データセットのサイズ (GB)
チャンス 0 ~たくさん ~0% 50% 25% 25% 0
GPT-3-アダ‡ ----- 9.95 51.6% 52.9% 43.4% 70.5% -----
GPT-2-1.5B ----- 10.63 51.21% 59.4% 50.9% 70.8% 40
GPTNeo-1.3B‡ 3.0e21 7.50 57.2% 55.0% 48.9% 71.1% 825
メガトロン-2.5B* 2.4e21 ----- 61.7% ----- ----- ----- 174
GPTNeo-2.7B‡ 6.8e21 5.63 62.2% 56.5% 55.8% 73.0% 825
GPT-3-1.3B*‡ 2.4e21 5.44 63.6% 58.7% 54.7% 75.1% ~800
GPT-3-バベッジ‡ ----- 5.58 62.4% 59.0% 54.5% 75.5% -----
メガトロン-8.3B* 7.8e21 ----- 66.5% ----- ----- ----- 174
GPT-3-2.7B*‡ 4.8e21 4.60 67.1% 62.3% 62.8% 75.6% ~800
メガトロン-11B† 1.0e22 ----- ----- ----- ----- ----- 161
GPT-J-6B 1.5e22 3.99 69.7% 65.3% 66.1% 76.5% 825
GPT-3-6.7B*‡ 1.2e22 4.00 70.3% 64.5% 67.4% 78.0% ~800
GPT-3-キュリー‡ ----- 4.00 69.3% 65.6% 68.5% 77.9% -----
GPT-3-13B*‡ 2.3e22 3.56 72.5% 67.9% 70.9% 78.5% ~800
GPT-3-175B*‡ 3.1e23 3.00 76.2% 70.2% 78.9% 81.0% ~800
GPT-3-ダヴィンチ‡ ----- 3.0 75% 72% 78% 80% -----
ホリネズミ 230B* 6.31E+23 ----- 74.50% 70.10% 79.20% 81.80% 1344
MT-NLG 530B*‡ ----- ----- 76.6% 73.0% 80.2% 82.0% -----

*
それぞれの著者によって報告された評価番号を表し、他のすべての番号は リリースされた状態でLM評価ハーネスを実行する 重み付けまたは API アクセスを使用します。微妙な実装の違いと異なるゼロショットタスクフレーミングにより、これらは 直接比較できない場合があります。詳細については、このブログ投稿を参照してください 細部。

メガトロン-11Bモデルは比較可能なメトリックを提供しておらず、リリースされた重みを使用するいくつかの実装は提供しません 生成品質と評価を再現。(1 2 3を参照) したがって、評価は試みられなかった。

これらのモデルは、テストセットの汚染の可能性を含むデータでトレーニングされています。OpenAI GPT-3 モデル 特定のテストセットのトレーニングデータの重複排除に失敗しましたが、GPT-Neoモデルとこのモデルは どのテストセットに対しても重複排除されていないThe Pileでトレーニングされています。

アーキテクチャと使用法

このリポジトリ内のほとんどのスクリプトは、TPU-VM アーキテクチャでは仮想マシンである TPU で実行するように設計されています。 任意のコードを実行できます。ほとんどのスクリプトは、TPUをスピンアップし、SSHをスピンアップして依存関係を設定するように設計されています ローカルディレクトリからコードをコピーし、Rayワーカーを起動します RPC 呼び出しを受け入れることができます。

TPUVMは、実行中のモデルのトレーニングステップと評価、チェックポイントの保存と読み込みを処理しますが、ドライバーはpython プログラムは、データの読み込みと一般的なオーケストレーション(チェックポイントを保存するタイミングなど)を処理します。

つまり、ほとんどのスクリプト(など)は、 RPC 待機時間とデータ転送コストを最小限に抑えるために、TPU と同じリージョン。その他のスクリプト (通常、 や などの引数を取らないもの) TPUVMで直接実行されることを期待してください。device_* スクリプトは v3-8 でのみ機能し、より大きなポッドでは機能しません。

train.py
eval_harness.py
--tpu
device_sample.py
device_serve.py
device_train.py

さらに、提供されたチェックポイントを変換する方法の例()があります(8 シャード (GPT-J-6B) の場合) を、GPU で実行している場合など、より小さな数に減らします。

resharding_example.py

微調整

モデルを微調整するには、TPU VM で実行します。TPU v3-8を使用すると、~5000 のレートで微調整できます。 トークン/秒、これは小規模から中規模のデータセットには十分なはずです。

device_train.py

徹底的な微調整手順については、ステップバイステップガイドをお読みください。

JAX の依存関係

このライブラリには、JAXバージョンに固有の要件があることに注意してください。具体的には、v1 モデル ( GPT-J 6B)が必要です。これは、 に依存します。これが行われない場合、あなたは得るでしょう 不可解なXMAPエラー

jax==0.2.12
jaxlib==0.1.68

ただし、v2モデルコード(一般に公開されている重みなし)を使用するには、最新のJAXバージョンを使用できます。

引用

このリポジトリを引用するには:

@misc{mesh-transformer-jax,
  author = {Wang, Ben},
  title = {{Mesh-Transformer-JAX: Model-Parallel Implementation of Transformer Language Model with JAX}},
  howpublished = {\url{https://github.com/kingoflolz/mesh-transformer-jax}},
  year = 2021,
  month = May
}

GPT-J-6Bの重量を引用するには:

@misc{gpt-j,
  author = {Wang, Ben and Komatsuzaki, Aran},
  title = {{GPT-J-6B: A 6 Billion Parameter Autoregressive Language Model}},
  howpublished = {\url{https://github.com/kingoflolz/mesh-transformer-jax}},
  year = 2021,
  month = May
}

このリポジトリまたは事前トレーニング済みの重みのいずれかを使用して何かクールなことを行う場合は、ぜひお聞かせください。 githubの問題を開くか、電子メール(プロファイル内)で連絡してください。

藤堂

  • [x]頭と破片を解きほぐす
  • [x] TPUのテスト/ベンチマーク
  • [x]グラデーションチェックポイントを実装する
  • [x]初期化を修正
  • [x]混合精度
  • [x] プリエンプティブルTPUを扱う
  • [x] 生成のテストと検証
  • [x]メモリ効率のために複製する代わりにシャードのアクティブ化を行う(v2の場合)
  • [x] ZeROスタイルのシャーディングをサポートします(v2で)