トランスフォーマーのモデル並列処理にJAXの/演算子を使用した俳句ライブラリ。
xmap
pjit
並列処理スキームは元のメガトロン-LMに似ており、効率的です 高速2DメッシュネットワークによるTPUで。ZeRoスタイルを実装した実験モデルバージョンもあります シャーディング。
このライブラリは、TPUv3で最大約40Bのパラメータを拡張できるように設計されています。 並列処理戦略を使用する必要があります。それについては、GPT-NeoXやDeepSpeedなどの他の実装を参照してください。
研究の将来の方向性の1つは、このコードベースをswarm-jaxと統合して、パイプラインの並列処理でさらなるスケーラビリティを実現することです。
12-07-21: 微調整のガイドを追加
The Pile でトレーニングされた 60 億個のパラメーター、自己回帰テキスト生成モデル。
このプロジェクトは、EleutherAIの支援を受けてTPUリサーチクラウドが寛大に提供してくれたコンピューティングなしでは不可能でした。
Cloud TPU VM アルファ版への早期アクセスを提供してくれた Google の Cloud TPU チームに感謝します。 (公開中!)
何らかの方法で助けてくれたすべての人に感謝します(アルファベット順に記載されています):
GPT-J-6Bの重量は、Apacheライセンスのバージョン2.0の下でライセンスされています。
ハイパーパラメータ | 価値 |
---|---|
n_parameters | 6,053,381,344 |
n_layers | 28* |
d_model | 4,096 |
d_ff | 16,384 |
n_heads | 16 |
d_head | 256 |
n_ctx | 2,048 |
n_vocab | 50,257 (GPT-2/3と同じトークナイザー) |
位置エンコーディング | ロータリー位置エンコーディング (RoPE) |
RoPE 寸法 | 64 |
*各層は、1つのフィードフォワードブロックと1つの自己注意ブロックで構成されます
モデルは 28 個のレイヤーで構成され、モデル寸法は 4096、フィードフォワード寸法は 16384 です。モデル ディメンションは 16 個のヘッドに分割され、各ヘッドのディメンションは 256 です。ロータリーポジションエンコーディング(RoPE)が64に適用されました 各ヘッドの寸法。モデルは、トークン化ボキャブラリ 50257 でトレーニングされ、同じ BPE のセットを使用します。 GPT-2/GPT-3.
モデルは、パフォーマンスによって大まかにソートされ、使用できない場合はFLOPによってソートされます。
モデル | 重み | トレーニング FLOP | ランバダPPL ↓ | ランバダアック↑ | ウィノグランデ ↑ | ヘラスワグ↑ | ピカ ↑ | データセットのサイズ (GB) |
---|---|---|---|---|---|---|---|---|
チャンス | 0 | ~たくさん | ~0% | 50% | 25% | 25% | 0 | |
GPT-3-アダ‡ | ✘ | ----- | 9.95 | 51.6% | 52.9% | 43.4% | 70.5% | ----- |
GPT-2-1.5B | ----- | 10.63 | 51.21% | 59.4% | 50.9% | 70.8% | 40 | |
GPTNeo-1.3B‡ | 3.0e21 | 7.50 | 57.2% | 55.0% | 48.9% | 71.1% | 825 | |
メガトロン-2.5B* | ✘ | 2.4e21 | ----- | 61.7% | ----- | ----- | ----- | 174 |
GPTNeo-2.7B‡ | 6.8e21 | 5.63 | 62.2% | 56.5% | 55.8% | 73.0% | 825 | |
GPT-3-1.3B*‡ | ✘ | 2.4e21 | 5.44 | 63.6% | 58.7% | 54.7% | 75.1% | ~800 |
GPT-3-バベッジ‡ | ✘ | ----- | 5.58 | 62.4% | 59.0% | 54.5% | 75.5% | ----- |
メガトロン-8.3B* | ✘ | 7.8e21 | ----- | 66.5% | ----- | ----- | ----- | 174 |
GPT-3-2.7B*‡ | ✘ | 4.8e21 | 4.60 | 67.1% | 62.3% | 62.8% | 75.6% | ~800 |
メガトロン-11B† | 1.0e22 | ----- | ----- | ----- | ----- | ----- | 161 | |
GPT-J-6B‡ | 1.5e22 | 3.99 | 69.7% | 65.3% | 66.1% | 76.5% | 825 | |
GPT-3-6.7B*‡ | ✘ | 1.2e22 | 4.00 | 70.3% | 64.5% | 67.4% | 78.0% | ~800 |
GPT-3-キュリー‡ | ✘ | ----- | 4.00 | 69.3% | 65.6% | 68.5% | 77.9% | ----- |
GPT-3-13B*‡ | ✘ | 2.3e22 | 3.56 | 72.5% | 67.9% | 70.9% | 78.5% | ~800 |
GPT-3-175B*‡ | ✘ | 3.1e23 | 3.00 | 76.2% | 70.2% | 78.9% | 81.0% | ~800 |
GPT-3-ダヴィンチ‡ | ✘ | ----- | 3.0 | 75% | 72% | 78% | 80% | ----- |
ホリネズミ 230B* | ✘ | 6.31E+23 | ----- | 74.50% | 70.10% | 79.20% | 81.80% | 1344 |
MT-NLG 530B*‡ | ✘ | ----- | ----- | 76.6% | 73.0% | 80.2% | 82.0% | ----- |
*それぞれの著者によって報告された評価番号を表し、他のすべての番号は リリースされた状態でLM評価ハーネスを実行する 重み付けまたは API アクセスを使用します。微妙な実装の違いと異なるゼロショットタスクフレーミングにより、これらは 直接比較できない場合があります。詳細については、このブログ投稿を参照してください 細部。
†メガトロン-11Bモデルは比較可能なメトリックを提供しておらず、リリースされた重みを使用するいくつかの実装は提供しません 生成品質と評価を再現。(1 2 3を参照) したがって、評価は試みられなかった。
‡これらのモデルは、テストセットの汚染の可能性を含むデータでトレーニングされています。OpenAI GPT-3 モデル 特定のテストセットのトレーニングデータの重複排除に失敗しましたが、GPT-Neoモデルとこのモデルは どのテストセットに対しても重複排除されていないThe Pileでトレーニングされています。
このリポジトリ内のほとんどのスクリプトは、TPU-VM アーキテクチャでは仮想マシンである TPU で実行するように設計されています。 任意のコードを実行できます。ほとんどのスクリプトは、TPUをスピンアップし、SSHをスピンアップして依存関係を設定するように設計されています ローカルディレクトリからコードをコピーし、Rayワーカーを起動します RPC 呼び出しを受け入れることができます。
TPUVMは、実行中のモデルのトレーニングステップと評価、チェックポイントの保存と読み込みを処理しますが、ドライバーはpython プログラムは、データの読み込みと一般的なオーケストレーション(チェックポイントを保存するタイミングなど)を処理します。
つまり、ほとんどのスクリプト(など)は、 RPC 待機時間とデータ転送コストを最小限に抑えるために、TPU と同じリージョン。その他のスクリプト (通常、 や などの引数を取らないもの) TPUVMで直接実行されることを期待してください。device_* スクリプトは v3-8 でのみ機能し、より大きなポッドでは機能しません。
train.py
eval_harness.py
--tpu
device_sample.py
device_serve.py
device_train.py
さらに、提供されたチェックポイントを変換する方法の例()があります(8 シャード (GPT-J-6B) の場合) を、GPU で実行している場合など、より小さな数に減らします。
resharding_example.py
モデルを微調整するには、TPU VM で実行します。TPU v3-8を使用すると、~5000 のレートで微調整できます。 トークン/秒、これは小規模から中規模のデータセットには十分なはずです。
device_train.py
徹底的な微調整手順については、ステップバイステップガイドをお読みください。
このライブラリには、JAXバージョンに固有の要件があることに注意してください。具体的には、v1 モデル ( GPT-J 6B)が必要です。これは、 に依存します。これが行われない場合、あなたは得るでしょう 不可解なXMAPエラー
jax==0.2.12
jaxlib==0.1.68
ただし、v2モデルコード(一般に公開されている重みなし)を使用するには、最新のJAXバージョンを使用できます。
このリポジトリを引用するには:
@misc{mesh-transformer-jax, author = {Wang, Ben}, title = {{Mesh-Transformer-JAX: Model-Parallel Implementation of Transformer Language Model with JAX}}, howpublished = {\url{https://github.com/kingoflolz/mesh-transformer-jax}}, year = 2021, month = May }
GPT-J-6Bの重量を引用するには:
@misc{gpt-j, author = {Wang, Ben and Komatsuzaki, Aran}, title = {{GPT-J-6B: A 6 Billion Parameter Autoregressive Language Model}}, howpublished = {\url{https://github.com/kingoflolz/mesh-transformer-jax}}, year = 2021, month = May }
このリポジトリまたは事前トレーニング済みの重みのいずれかを使用して何かクールなことを行う場合は、ぜひお聞かせください。 githubの問題を開くか、電子メール(プロファイル内)で連絡してください。