diff --git a/Pyrado/pyrado/algorithms/meta/arpl.py b/Pyrado/pyrado/algorithms/meta/arpl.py index dcae9949453..57df3d9e630 100644 --- a/Pyrado/pyrado/algorithms/meta/arpl.py +++ b/Pyrado/pyrado/algorithms/meta/arpl.py @@ -37,10 +37,8 @@ ) from pyrado.environment_wrappers.state_augmentation import StateAugmentationWrapper from pyrado.environments.sim_base import SimEnv -from pyrado.exploration.stochastic_action import StochasticActionExplStrat from pyrado.logger.step import StepLogger from pyrado.policies.base import Policy -from pyrado.sampling.parallel_rollout_sampler import ParallelRolloutSampler from pyrado.sampling.sequences import * @@ -61,11 +59,7 @@ def __init__( env: Union[SimEnv, StateAugmentationWrapper], subrtn: Algorithm, policy: Policy, - expl_strat: StochasticActionExplStrat, max_iter: int, - num_rollouts: int = None, - steps_num: int = None, - apply_dynamics_noise: bool = False, logger: StepLogger = None, ): """ @@ -75,11 +69,7 @@ def __init__( :param env: the environment in which the agent should be trained :param subrtn: algorithm which performs the policy / value-function optimization :param policy: policy to be updated - :param expl_strat: the exploration strategy :param max_iter: the maximum number of iterations - :param num_rollouts: the number of rollouts to be performed for each update step - :param steps_num: the number of steps to be performed for each update step - :param apply_dynamics_noise: whether adversarially generated dynamics noise should be applied :param logger: logger for every step of the algorithm, if `None` the default logger will be created """ assert isinstance(subrtn, Algorithm) @@ -107,7 +97,7 @@ def wrap_env( halfspan: float = 0.25, proc_eps: float = 0.01, proc_phi: float = 0.05, - torch_observation = None, + torch_observation=None, obs_eps: float = 0.01, obs_phi: float = 0.05, ): @@ -127,7 +117,9 @@ def wrap_env( """ # Initialize adversarial wrappers in the correct order if dynamics: - assert isinstance(env, StateAugmentationWrapper), pyrado.TypeErr(env, given_name='env', expected_type=StateAugmentationWrapper) + assert isinstance(env, StateAugmentationWrapper), pyrado.TypeErr( + env, given_name="env", expected_type=StateAugmentationWrapper + ) env = AdversarialDynamicsWrapper(env, policy, dyn_eps, dyn_phi, halfspan) if process: env = AdversarialStateWrapper(env, policy, proc_eps, proc_phi, torch_observation=torch_observation) diff --git a/Pyrado/pyrado/environment_wrappers/adversarial.py b/Pyrado/pyrado/environment_wrappers/adversarial.py index 4e270d82850..dcf717ddd61 100644 --- a/Pyrado/pyrado/environment_wrappers/adversarial.py +++ b/Pyrado/pyrado/environment_wrappers/adversarial.py @@ -31,17 +31,17 @@ from typing import Callable, Optional import numpy as np -from numpy.core import numeric -import pyrado -from pyrado.algorithms.base import Algorithm -from pyrado.environments.base import Env -from pyrado.policies.base import Policy import torch as to from init_args_serializer import Serializable +from numpy.core import numeric +import pyrado +from pyrado.algorithms.base import Algorithm from pyrado.environment_wrappers.base import EnvWrapper from pyrado.environment_wrappers.state_augmentation import StateAugmentationWrapper from pyrado.environment_wrappers.utils import inner_env, typed_env +from pyrado.environments.base import Env +from pyrado.policies.base import Policy class AdversarialWrapper(EnvWrapper, ABC): @@ -107,7 +107,9 @@ def get_arpl_grad(self, state): class AdversarialStateWrapper(AdversarialWrapper, Serializable): """ " Wrapper to apply adversarial perturbations to the state (used in ARPL)""" - def __init__(self, wrapped_env: Env, policy: Policy, eps: numeric, phi, torch_observation:Optional[Callable]=None): + def __init__( + self, wrapped_env: Env, policy: Policy, eps: numeric, phi, torch_observation: Optional[Callable] = None + ): """ Constructor @@ -119,10 +121,9 @@ def __init__(self, wrapped_env: Env, policy: Policy, eps: numeric, phi, torch_ob Serializable._init(self, locals()) AdversarialWrapper.__init__(self, wrapped_env, policy, eps, phi) if not torch_observation: - raise pyrado.TypeErr(msg='The observation must be passed as torch') + raise pyrado.TypeErr(msg="The observation must be passed as torch") self.torch_obs = torch_observation - def step(self, act: np.ndarray) -> tuple: obs, reward, done, info = self.wrapped_env.step(act) saw = typed_env(self.wrapped_env, StateAugmentationWrapper) diff --git a/Pyrado/scripts/training/qq-su_arpl.py b/Pyrado/scripts/training/qq-su_arpl.py index 6a256f3ebc5..aef4e964bd2 100755 --- a/Pyrado/scripts/training/qq-su_arpl.py +++ b/Pyrado/scripts/training/qq-su_arpl.py @@ -17,14 +17,10 @@ from pyrado.utils.argparser import get_argparser from pyrado.utils.data_types import EnvSpec + def torch_observation(state: to.tensor) -> to.tensor: - return to.stack([ - to.sin(state[0]), - to.cos(state[0]), - to.sin(state[1]), - to.cos(state[1]), - state[2], - state[3]]) + return to.stack([to.sin(state[0]), to.cos(state[0]), to.sin(state[1]), to.cos(state[1]), state[2], state[3]]) + if __name__ == "__main__": # Parse command line arguments @@ -47,8 +43,6 @@ def torch_observation(state: to.tensor) -> to.tensor: policy_hparam = dict(hidden_sizes=[32, 32], hidden_nonlin=to.tanh) # FNN policy = FNNPolicy(spec=env.spec, **policy_hparam) - - env = ARPL.wrap_env( env, policy, @@ -62,7 +56,7 @@ def torch_observation(state: to.tensor) -> to.tensor: obs_eps=0.05, proc_phi=0.1, proc_eps=0.03, - torch_observation=torch_observation + torch_observation=torch_observation, ) # Critic @@ -94,10 +88,9 @@ def torch_observation(state: to.tensor) -> to.tensor: ) algo_hparam = dict( max_iter=500, - steps_num=23 * env.max_steps, ) subrtn = PPO(ex_dir, env, policy, critic, **subrtn_hparam) - algo = ARPL(ex_dir, env, subrtn, policy, subrtn.expl_strat, **algo_hparam) + algo = ARPL(ex_dir, env, subrtn, policy, **algo_hparam) # Save the hyper-parameters save_dicts_to_yaml( diff --git a/Pyrado/tests/algorithms/test_algorithms.py b/Pyrado/tests/algorithms/test_algorithms.py index a38d19c8866..ee35f22601c 100644 --- a/Pyrado/tests/algorithms/test_algorithms.py +++ b/Pyrado/tests/algorithms/test_algorithms.py @@ -502,17 +502,12 @@ def test_arpl_wrappers(env): env = StateAugmentationWrapper(env, domain_param=None) assert len(inner_env(env).domain_param) == env.obs_space.flat_dim - env.offset env.reset() - env.step(0.0)[0][env.offset:] + env.step(0.0)[0][env.offset :] def _qqsu_torch_observation(state: to.tensor) -> to.tensor: - return to.stack([ - to.sin(state[0]), - to.cos(state[0]), - to.sin(state[1]), - to.cos(state[1]), - state[2], - state[3]]) + return to.stack([to.sin(state[0]), to.cos(state[0]), to.sin(state[1]), to.cos(state[1]), state[2], state[3]]) + @pytest.mark.parametrize("env", ["default_qqsu"], ids=["qqsu"], indirect=True) def test_arpl(ex_dir, env): @@ -535,7 +530,7 @@ def test_arpl(ex_dir, env): obs_eps=0.05, proc_phi=0.1, proc_eps=0.03, - torch_observation=_qqsu_torch_observation + torch_observation=_qqsu_torch_observation, ) vfcn_hparam = dict(hidden_sizes=[32, 32], hidden_nonlin=to.tanh) # FNN @@ -564,13 +559,13 @@ def test_arpl(ex_dir, env): ) algo_hparam = dict( max_iter=2, - steps_num=3 * env.max_steps, ) subrtn = PPO(ex_dir, env, policy, critic, **subrtn_hparam) - algo = ARPL(ex_dir, env, subrtn, policy, subrtn.expl_strat, **algo_hparam) + algo = ARPL(ex_dir, env, subrtn, policy, **algo_hparam) algo.train(snapshot_mode="best") + @pytest.mark.parametrize("env", ["default_qqsu", "default_bob"], ids=["qqsu", "bob"], indirect=True) def test_arpl_observation(env): policy_hparam = dict(hidden_sizes=[32, 32], hidden_nonlin=to.tanh) # FNN diff --git a/Pyrado/tests/algorithms/test_meta.py b/Pyrado/tests/algorithms/test_meta.py index b2754e42a7b..febad66a24a 100644 --- a/Pyrado/tests/algorithms/test_meta.py +++ b/Pyrado/tests/algorithms/test_meta.py @@ -409,18 +409,9 @@ def test_arpl(ex_dir, env: SimEnv): ) arpl_hparam = dict( max_iter=2, - steps_num=23 * env.max_steps, - halfspan=0.05, - dyn_eps=0.07, - dyn_phi=0.25, - obs_phi=0.1, - obs_eps=0.05, - proc_phi=0.1, - proc_eps=0.03, - torch_observation=True, ) ppo = PPO(ex_dir, env, policy, critic, **algo_hparam) - algo = ARPL(ex_dir, env, ppo, policy, ppo.expl_strat, **arpl_hparam) + algo = ARPL(ex_dir, env, ppo, policy, **arpl_hparam) algo.train(snapshot_mode="best") assert algo.curr_iter == algo.max_iter