mctx - JAXでのモンテカルロ木検索

(Monte Carlo tree search in JAX)

Created at: 2022-03-02 01:26:43
Language: Python
License: Apache-2.0

Mctx: MCTS-in-JAX

MctxはJAXネイティブのライブラリです AlphaZero、MuZero、Gumbel MuZeroなどのモンテカルロ木探索(MCTS)アルゴリズムの実装。計算用 高速化、実装はJITコンパイルを完全にサポートします。検索アルゴリズム Mctxでは、入力のバッチに対して並列に定義され、操作されます。これ アクセラレータを最大限に活用し、アルゴリズムを機能させることができます ディープニューラルネットワークによってパラメータ化された大規模な学習環境モデルを使用します。

取り付け

PyPIからMctxの最新リリースバージョンをインストールするには、次の方法で次の方法でインストールできます。

pip install mctx

または、GitHub から最新の開発バージョンをインストールすることもできます。

pip install git+https://github.com/deepmind/mctx.git

モチベーション

学習と検索はAIの黎明期から重要なトピックでした 研究。リッチ・サットンの言葉を借りれば:

学ぶべきことの1つは、汎用の大きな力です[...] メソッド、計算の増加に伴ってスケーリングし続けるメソッドの 利用可能な計算は非常に大きくなります。と思われる2つの方法 このように任意にスケーリングするのは、検索学習です。

最近、検索アルゴリズムは学習済みモデルとうまく組み合わされています ディープニューラルネットワークによってパラメータ化され、最も強力なもののいくつかをもたらします これまでの一般的な強化学習アルゴリズム(例:MuZero)。 ただし、検索アルゴリズムをディープニューラルネットワークと組み合わせて使用する 効率的な実装が必要で、通常は高速コンパイルで記述されています 言語;これは、使いやすさとハッキング可能性を犠牲にする可能性があります。 特にC ++に精通していない研究者のために。次に、これは制限します この重要なトピックに関する採用とさらなる研究。

このライブラリを通じて、世界中の研究者が貢献できるように支援したいと考えています。 このようなエキサイティングな研究分野です。コアのJAXネイティブ実装を提供します MCTSなどの検索アルゴリズムは、 検索ベースを調査したい研究者のためのパフォーマンスと使いやすさ Python のアルゴリズム。Mctx が提供する検索メソッドは次のとおりです。 研究者がさまざまなアイデアを探求できるように、大幅に構成可能 このスペースは、次世代の検索ベースのエージェントに貢献します。

強化学習で検索

強化学習では、エージェントはスカラー報酬信号を最大化するために環境との対話を学習する必要があります。各ステップで エージェントはアクションを選択し、オブザベーションと 報酬。エージェントがアクションを選択するために使用するメカニズムを呼び出すことができます エージェントのポリシー

古典的には、ポリシーは関数近似子によって直接パラメータ化されます( REINFORCE)、または学習した推定値のセットを調べることによってポリシーが推測されます 各アクションの値(Qラーニングのように)。または、検索により、 各州でポリシーまたは値をオンザフライで構築することにより、アクションを選択します の学習済みモデルを使用して検索することにより、現在の状態にローカルな関数 環境。

考えられるすべての将来の行動方針に対する網羅的な検索は計算上です 自明ではない環境では禁止されているため、検索アルゴリズムが必要です これにより、有限の計算予算を最大限に活用できます。通常は事前確率 は、検索ツリーのどのノードを展開するかをガイドするために必要であり(構築するツリーの幅を狭めるため)、値関数は次の目的で使用されます。 エピソードに到達しないツリー内の不完全なパスの値を推定する 終了 (検索ツリーの深さを減らすため)。

クイックスタート

Mctxは、低レベルの汎用機能と高レベルの具象を提供します ポリシー: および .

search
muzero_policy
gumbel_muzero_policy

ユーザーは、いくつかの学習済みコンポーネントを提供して、 MuZeroで使用される表現、ダイナミクス、予測。 Mctx ライブラリのコンテキストでは、ルート状態の表現は次のようになります。 で指定されます。には、ポリシーネットワークからの、ルート状態の推定値、および環境モデルのルート状態を表すのに適した任意のものが含まれています。

RootFnOutput
RootFnOutput
prior_logits
value
embedding

ダイナミクス環境モデルは、. 呼び出しには と 状態.呼び出しは、aと次の状態の埋め込みを持つタプルを返す必要があります。 には、遷移の と が含まれます。 そして新しい状態のために。

recurrent_fn
recurrent_fn(params, rng_key, action, embedding)
action
embedding
(recurrent_fn_output, new_embedding)
RecurrentFnOutput
RecurrentFnOutput
reward
discount
prior_logits
value

例/visualization_demo.pyでは、次のことができます。 ポリシーの呼び出しを参照してください。

policy_output = mctx.gumbel_muzero_policy(params, rng_key, root, recurrent_fn,
                                          num_simulations=32)

には、検索によって提案されたアクションが含まれます。それ アクションを環境に渡すことができます。ポリシーを改善するために、ポリシーのトレーニングに使用できるターゲットを含む 確率。

policy_output.action
policy_output.action_weights

を使用することをお勧めします。ガンベルミューゼロはポリシーを保証します アクション値が正しく評価された場合の改善。政策改善 例/policy_improvement_demo.pyで示されています。

gumbel_muzero_policy

サンプルプロジェクト

次のプロジェクトは、Mctx の使用方法を示しています。

あなたのプロジェクトについて教えてください。

マクテックスの引用

これは公式にサポートされている Google サービスではありません。MctxはDeepMind JAXエコシステムの一部です。Mctxを引用するには、DeepMind JAXエコシステムを使用してください 引用。