このリポジトリは、次のペーパーからFlashAttentionの公式実装を提供します。
FlashAttention:IO認識による高速でメモリ効率の高い正確な注意
Tri Dao、Daniel Y. Fu、Stefano Ermon、Atri Rudra、ChristopherRéPaper
:https ://arxiv.org/abs/2205.14135
コンパイルするには(CUDA 11、NVCC、およびAmpere GPUが必要):
python setup.py install
インターフェース:
src/flash_attention.py
PyTorchの標準的な注意に対してベンチマークを実行するには:
PYTHONPATH=$PWD python benchmarks/benchmark_flash_attention.py
FlashAttentionは現在以下をサポートしています。
暫定ロードマップ:
シーケンスの長さに応じて、さまざまなGPUでPyTorchの標準的な注意に対してFlashAttentionを使用することで期待されるスピードアップ(フォワードパスとバックワードパスの組み合わせ)とメモリの節約を示します(スピードアップはメモリ帯域幅に依存します-低速のGPUメモリでよりスピードアップが見られます)。
これらのパラメーターを使用してFlashAttentionの高速化を表示します(BERTベースと同様)。
グラフは、128〜4096のシーケンス長を示しています(A100で標準のアテンションがメモリを使い果たした場合)が、FlashAttentionはシーケンス長64Kまでスケールアップできます。
通常、128〜4Kのシーケンス長で2〜4倍のスピードアップが見られ、カーネルを融合するため、ドロップアウトとマスキングを使用するとさらにスピードアップが見られます。512や1Kなどの言語モデルで一般的なシーケンス長では、ドロップアウトとマスキングを使用すると最大4倍のスピードアップが見られます。
このグラフにメモリの節約を示します(ドロップアウトまたはマスキングを使用するかどうかに関係なく、メモリフットプリントは同じであることに注意してください)。メモリの節約はシーケンスの長さに比例します。標準のアテンションのメモリはシーケンスの長さが2次であるのに対し、FlashAttentionのメモリはシーケンスの長さが線形であるためです。シーケンス長2Kで10倍、4Kで20倍のメモリ節約が見られます。その結果、FlashAttentionははるかに長いシーケンス長にスケーリングできます。
RTX 3090には、12個のアテンションヘッドを備えたバッチサイズ12を使用します。メモリの節約はA100と同じなので、ここではスピードアップのみを示します。
GDDR6Xのメモリ帯域幅はA100HBMよりも低いため(〜900 GB / s対〜1.5 TB / s)、GTX 3090ではわずかに高速化(2.5〜4.5x)が見られます。
私たちの実装では、開始点としてApexの FMHAコードを使用しています。
FMHAの実装について詳細に説明してくれたKoYoung -Junと、CUDAに関する質問に対する思慮深い回答に感謝します。
このコードベースを使用する場合、またはその他の方法で私たちの作業が価値があると思われる場合は、以下を引用してください。
@article{dao2022flashattention, title={FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness}, author={Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher}, journal={arXiv preprint arXiv:2205.14135}, year={2022} }