私は、「Deep Q-Learning」を使用してモデルを構築しようとしていますが、アクションが多数あります(2908)。標準のDQNを使用していくつかの限られた成功を収めた後:( https://www.cs.toronto.edu/~vmnih/docs/dqn.pdf )、私は、アクションスペースが大きすぎて、効果的な探索を行うことができませんでした。
次に、このペーパーを発見しました。 https://arxiv.org/pdf/1512.07679.pdf ここで、アクタークリティカルモデルとポリシーグラディエントを使用しているため、次のようになりました https: //arxiv.org/pdf/1602.01783.pdf ここでは、ポリシーグラディエントを使用して、DQN全体よりもはるかに優れた結果を得ています。
Kerasでポリシーの勾配を実装しているサイトをいくつか見つけました https://yanpanlau.github.io/2016/10/11/Torcs-Keras.html および https ://oshearesearch.com/index.php/2016/06/14/kerlym-a-deep-reinforcement-learning-toolbox-in-keras/ しかし、それらがどのように実装されているか混乱しています。前者(および私が論文を読んだとき)では、アクターネットワークに入力と出力のペアを提供する代わりに、すべての重みに勾配を提供し、ネットワークを使用して更新しますが、後者では入出力ペアを計算するだけです。
混乱しましたか?入出力ペアを提供して標準の 'fit'を使用してネットワークをトレーニングしているはずですか、それとも特別なことをする必要がありますか?後者の場合、Theanoバックエンドでどうすればよいですか? (上記の例ではTensorFlowを使用しています)。
エージェントには、基本的に、状態を各アクションの確率であるポリシーにマップする関数であるポリシーが必要です。したがって、エージェントはそのポリシーに従ってアクションを選択します。
つまり、policy = f(state)
ポリシーグラディエントには損失関数はありません。代わりに、期待される報酬のリターンを最大化しようとします。そして、log(action_prob)*アドバンテージの勾配を計算する必要があります
私はこのようなものを想定しています
2つの機能が必要です
Model.compile(...)-> model.fit(X、y)のように、典型的な分類問題のように実装するのは簡単ではないことをすでに知っています。
しかしながら、
Kerasを十分に活用するには、カスタムの損失関数と勾配の定義に慣れている必要があります。これは基本的に前者の作者がとったアプローチと同じです。
Keras関数APIとkeras.backendのドキュメントをもっと読むべきです
さらに、多くの種類のポリシーの勾配があります。
あなたが遭遇した一見矛盾する実装はどちらも有効な実装です。これらは、ポリシーグラディエントを実装する2つの同等の方法です。
バニラの実装では、ポリシーネットワークの勾配を計算します。報酬を与え、勾配の方向に重みを直接更新します。これには、Mo Kによって説明されている手順を実行する必要があります。
2番目のオプションは、おそらくkeras/tensorflowのようなautodiffフレームワークのより便利な実装です。アイデアは、教師あり学習のような入出力(状態アクション)関数を実装することですが、損失関数では、勾配はポリシー勾配と同じです。ソフトマックスポリシーの場合、これは単に「真のアクション」を予測し、(クロスエントロピー)損失に観測されたリターン/アドバンテージを掛けることを意味します。 Aleksis Pirinenはこれについていくつかの有用なメモを持っています[1]。
Kerasのオプション2の修正された損失関数は次のようになります。
import keras.backend as K
def policy_gradient_loss(Returns):
def modified_crossentropy(action,action_probs):
cost = K.categorical_crossentropy(action,action_probs,from_logits=False,axis=1 * Returns)
return K.mean(cost)
return modified_crossentropy
ここで、「action」はエピソードの真のアクション(y)、action_probsは予測確率(y *)です。これは、別のスタックオーバーフローの質問[2]に基づいています。
参照