flash-attention - このリポジトリは、次のペーパーからFlashAttentionの公式実装を提供します。

(NULL)

Created at: 2022-05-20 05:22:06
Language: C++
License: Apache-2.0

FlashAttention

このリポジトリは、次のペーパーからFlashAttentionの公式実装を提供します。

FlashAttention:IO認識による高速でメモリ効率の高い正確な注意
Tri Dao、Daniel Y. Fu、Stefano Ermon、Atri Rudra、ChristopherRéPaper
https ://arxiv.org/abs/2205.14135 FlashAttention

アルファリリース(0.1)。

コンパイルするには(CUDA 11、NVCC、およびAmpere GPUが必要):

python setup.py install

インターフェース:

src/flash_attention.py

PyTorchの標準的な注意に対してベンチマークを実行するには:

PYTHONPATH=$PWD python benchmarks/benchmark_flash_attention.py

FlashAttentionは現在以下をサポートしています。

  1. アンペアGPU(A100、RTX 3090など)。
  2. fp16。
  3. ヘッド寸法16、32、64。

暫定ロードマップ:

  1. [2022年6月]パッケージをpip-installableにします。
  2. [2022年6月]SM86GPUをサポート(例:RTX 3080、3090)[終わり]。
  3. [2022年6月]Cutlassを使用するようにリファクタリングします。
  4. [2022年6月]SM75GPU(T4など)をサポートします。
  5. [2022年6月]bf16をサポートします。
  6. [2022年7月]サポートヘッドの寸法128。
  7. [2022年7月]SM70GPU(V100)をサポートします。
  8. [2022年8月]ヒューズロータリー埋め込み。
  9. [2022年8月]アテンションリニアバイアス(例:ALiBi)をサポートします。

スピードアップとメモリ節約

シーケンスの長さに応じて、さまざまなGPUでPyTorchの標準的な注意に対してFlashAttentionを使用することで期待されるスピードアップ(フォワードパスとバックワードパスの組み合わせ)とメモリの節約を示します(スピードアップはメモリ帯域幅に依存します-低速のGPUメモリでよりスピードアップが見られます)。

A100

これらのパラメーターを使用してFlashAttentionの高速化を表示します(BERTベースと同様)。

  • バッチサイズ8
  • ヘッド寸法64
  • 12の注意ヘッド

グラフは、128〜4096のシーケンス長を示しています(A100で標準のアテンションがメモリを使い果たした場合)が、FlashAttentionはシーケンス長64Kまでスケールアップできます。

スピードアップ

FlashAttentionの高速化

通常、128〜4Kのシーケンス長で2〜4倍のスピードアップが見られ、カーネルを融合するため、ドロップアウトとマスキングを使用するとさらにスピードアップが見られます。512や1Kなどの言語モデルで一般的なシーケンス長では、ドロップアウトとマスキングを使用すると最大4倍のスピードアップが見られます。

メモリー

FlashAttentionメモリ

このグラフにメモリの節約を示します(ドロップアウトまたはマスキングを使用するかどうかに関係なく、メモリフットプリントは同じであることに注意してください)。メモリの節約はシーケンスの長さに比例します。標準のアテンションのメモリはシーケンスの長さが2次であるのに対し、FlashAttentionのメモリはシーケンスの長さが線形であるためです。シーケンス長2Kで10倍、4Kで20倍のメモリ節約が見られます。その結果、FlashAttentionははるかに長いシーケンス長にスケーリングできます。

RTX 3090

RTX 3090には、12個のアテンションヘッドを備えたバッチサイズ12を使用します。メモリの節約はA100と同じなので、ここではスピードアップのみを示します。

FlashAttentionスピードアップGTX3090

GDDR6Xのメモリ帯域幅はA100HBMよりも低いため(〜900 GB / s対〜1.5 TB / s)、GTX 3090ではわずかに高速化(2.5〜4.5x)が見られます。

謝辞

私たちの実装では、開始点としてApexの FMHAコードを使用しています。

FMHAの実装について詳細に説明してくれたKoYoung -Junと、CUDAに関する質問に対する思慮深い回答に感謝します。

引用

このコードベースを使用する場合、またはその他の方法で私たちの作業が価値があると思われる場合は、以下を引用してください。

@article{dao2022flashattention,
  title={FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness},
  author={Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
  journal={arXiv preprint arXiv:2205.14135},
  year={2022}
}