diff --git a/config.yaml b/config.yaml index 2bf65d85..029b1bb1 100755 --- a/config.yaml +++ b/config.yaml @@ -6,7 +6,8 @@ env_args: #env: 'handyrl.envs.parallel_tictactoe' # specify by path train_args: - turn_based_training: True + turn_based_training: False # for turn-based games + zero_sum_averaging: False # for 2p zero-sum games observation: False gamma: 0.8 forward_steps: 16 diff --git a/handyrl/train.py b/handyrl/train.py index 07dbf81f..594fe3bf 100755 --- a/handyrl/train.py +++ b/handyrl/train.py @@ -53,7 +53,7 @@ def replace_none(a, b): moments_ = sum([pickle.loads(bz2.decompress(ms)) for ms in ep['moment']], []) moments = moments_[ep['start'] - ep['base']:ep['end'] - ep['base']] players = list(moments[0]['observation'].keys()) - if not args['turn_based_training']: # solo training + if not (args['turn_based_training'] or args['zero_sum_averaging']): # solo training players = [random.choice(players)] obs_zeros = map_r(moments[0]['observation'][moments[0]['turn'][0]], lambda o: np.zeros_like(o)) # template for padding @@ -236,7 +236,8 @@ def compute_loss(batch, model, hidden, args): if 'value' in outputs_nograd: values_nograd = outputs_nograd['value'] - if args['turn_based_training'] and values_nograd.size(2) == 2: # two player zerosum game + if args['zero_sum_averaging']: # two player zerosum game + assert values_nograd.size(2) == 2 values_nograd_opponent = -torch.stack([values_nograd[:, :, 1], values_nograd[:, :, 0]], dim=2) values_nograd = (values_nograd + values_nograd_opponent) / (batch['observation_mask'].sum(dim=2, keepdim=True) + 1e-8) outputs_nograd['value'] = values_nograd * emasks + batch['outcome'] * (1 - emasks)