diff --git a/handyrl/train.py b/handyrl/train.py index 971a0890..49ad7fe1 100755 --- a/handyrl/train.py +++ b/handyrl/train.py @@ -58,8 +58,13 @@ def replace_none(a, b): players = [random.choice(players)] # template for padding - obs_zeros = map_r(moments[0]['observation'][moments[0]['turn'][0]], lambda o: np.zeros_like(o)) - amask_zeros = np.zeros_like(moments[0]['action_mask'][moments[0]['turn'][0]]) + def find_nonzero(ms, key): + for m in ms: + for val in m[key].values(): + if val is not None: + return val + obs_zeros = map_r(find_nonzero(moments, 'observation'), lambda o: np.zeros_like(o)) + amask_zeros = np.zeros_like(find_nonzero(moments, 'action_mask')) # data that is changed by training configuration if args['turn_based_training'] and not args['observation']: