OpenAI Gymフレームワークに従って、カスタム環境を作成しました。 step
、reset
、action
、およびreward
関数を含みます。このカスタム環境でOpenAIベースラインを実行することを目指しています。ただし、これに先立って、OpenAIジムに環境を登録する必要があります。 OpenAIジムにカスタム環境を登録する方法を教えてください。また、OpenAIベースラインコードを変更してこれを組み込む必要がありますか?
baselinesリポジトリを変更する必要はありません。
これは最小限の例です。必要なすべての関数(step
、reset
、...)を備えたmyenv.py
があるとします。クラス環境の名前はMyEnv
であり、classic_control
フォルダーに追加する必要があります。必ず
myenv.py
ファイルをgym/gym/envs/classic_control
に配置__init__.py
に追加(同じフォルダにあります)
from gym.envs.classic_control.myenv import MyEnv
追加して、環境をgym/gym/envs/__init__.py
に登録します
gym.envs.register(
id='MyEnv-v0',
entry_point='gym.envs.classic_control:MyEnv',
max_episode_steps=1000,
)
登録時に、reward_threshold
およびkwargs
を追加することもできます(クラスが引数を取る場合)。gym/gym/envs/__init__.py
ではなく、実行するスクリプト(TRPO、PPOなど)に直接環境を登録することもできます。
[〜#〜]編集[〜#〜]
これは [〜#〜] lqr [〜#〜] 環境を作成するための最小限の例です。
以下のコードをlqr_env.py
に保存して、gymのclassic_controlフォルダーに配置します。
import gym
from gym import spaces
from gym.utils import seeding
import numpy as np
class LqrEnv(gym.Env):
def __init__(self, size, init_state, state_bound):
self.init_state = init_state
self.size = size
self.action_space = spaces.Box(low=-state_bound, high=state_bound, shape=(size,))
self.observation_space = spaces.Box(low=-state_bound, high=state_bound, shape=(size,))
self._seed()
def _seed(self, seed=None):
self.np_random, seed = seeding.np_random(seed)
return [seed]
def _step(self,u):
costs = np.sum(u**2) + np.sum(self.state**2)
self.state = np.clip(self.state + u, self.observation_space.low, self.observation_space.high)
return self._get_obs(), -costs, False, {}
def _reset(self):
high = self.init_state*np.ones((self.size,))
self.state = self.np_random.uniform(low=-high, high=high)
self.last_u = None
return self._get_obs()
def _get_obs(self):
return self.state
from gym.envs.classic_control.lqr_env import LqrEnv
を__init__.py
に追加します(classic_controlでも)。
スクリプトで、環境を作成するときに、
gym.envs.register(
id='Lqr-v0',
entry_point='gym.envs.classic_control:LqrEnv',
max_episode_steps=150,
kwargs={'size' : 1, 'init_state' : 10., 'state_bound' : np.inf},
)
env = gym.make('Lqr-v0')