From 439be2b6b85c417e1dfb62ffd38cf5bdfea6d97c Mon Sep 17 00:00:00 2001 From: YuriCat Date: Sat, 22 Jan 2022 09:13:55 +0900 Subject: [PATCH 1/5] feature: update configuration for turn-based --- config.yaml | 3 ++- handyrl/train.py | 5 ++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/config.yaml b/config.yaml index c141869e..d795c334 100755 --- a/config.yaml +++ b/config.yaml @@ -6,7 +6,8 @@ env_args: #env: 'handyrl.envs.parallel_tictactoe' # specify by path train_args: - turn_based_training: True + turn_based_training: True # for turn-based games + zero_sum_debiasing: True # for 2p zero-sum games observation: False gamma: 0.8 forward_steps: 16 diff --git a/handyrl/train.py b/handyrl/train.py index b933ebaa..541f745a 100755 --- a/handyrl/train.py +++ b/handyrl/train.py @@ -165,7 +165,10 @@ def forward_prediction(model, hidden, batch, args): o = o.view(*batch['turn_mask'].size()[:2], -1, o.size(-1)) if k == 'policy': # gather turn player's policies - outputs[k] = o.mul(batch['turn_mask']).sum(2, keepdim=True) - batch['action_mask'] + outputs[k] = o.mul(batch['turn_mask']) + if args['turn_based_training']: + outputs[k] = outputs[k].sum(2, keepdim=True) # gather turn player's policies + outputs[k] = outputs[k] - batch['action_mask'] else: # mask valid target values and cumulative rewards outputs[k] = o.mul(batch['observation_mask']) From 1c38ca6aedc54da92644c12c10a0ace8aa33eb84 Mon Sep 17 00:00:00 2001 From: YuriCat Date: Sat, 22 Jan 2022 09:44:53 +0900 Subject: [PATCH 2/5] fix: use zero_sum_averaging --- handyrl/train.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/handyrl/train.py b/handyrl/train.py index 541f745a..50e4c09e 100755 --- a/handyrl/train.py +++ b/handyrl/train.py @@ -224,7 +224,8 @@ def compute_loss(batch, model, hidden, args): if 'value' in outputs_nograd: values_nograd = outputs_nograd['value'] - if args['turn_based_training'] and values_nograd.size(2) == 2: # two player zerosum game + if args['zero_sum_averaging']: # two player zerosum game + assert values_nograd.size(2) == 2 values_nograd_opponent = -torch.stack([values_nograd[:, :, 1], values_nograd[:, :, 0]], dim=2) values_nograd = (values_nograd + values_nograd_opponent) / (batch['observation_mask'].sum(dim=2, keepdim=True) + 1e-8) outputs_nograd['value'] = values_nograd * emasks + batch['outcome'] * (1 - emasks) From 0c66a5c7f853eddd498ddca90e633175778c48cf Mon Sep 17 00:00:00 2001 From: YuriCat Date: Sat, 22 Jan 2022 09:46:17 +0900 Subject: [PATCH 3/5] feature: 2p zero-sum setting should be False in default config --- config.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/config.yaml b/config.yaml index d795c334..ba03957c 100755 --- a/config.yaml +++ b/config.yaml @@ -6,8 +6,8 @@ env_args: #env: 'handyrl.envs.parallel_tictactoe' # specify by path train_args: - turn_based_training: True # for turn-based games - zero_sum_debiasing: True # for 2p zero-sum games + turn_based_training: False # for turn-based games + zero_sum_averaging: False # for 2p zero-sum games observation: False gamma: 0.8 forward_steps: 16 From 203c63b2d47dd9215668e05e975f5a64eb6bff6a Mon Sep 17 00:00:00 2001 From: YuriCat Date: Sat, 22 Jan 2022 13:30:38 +0900 Subject: [PATCH 4/5] fix: create dual-player batch when zero_sum_averaging is True --- handyrl/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/handyrl/train.py b/handyrl/train.py index 50e4c09e..2d299383 100755 --- a/handyrl/train.py +++ b/handyrl/train.py @@ -54,7 +54,7 @@ def replace_none(a, b): moments_ = sum([pickle.loads(bz2.decompress(ms)) for ms in ep['moment']], []) moments = moments_[ep['start'] - ep['base']:ep['end'] - ep['base']] players = list(moments[0]['observation'].keys()) - if not args['turn_based_training']: # solo training + if not (args['turn_based_training'] or args['zero_sum_averaging']): # solo training players = [random.choice(players)] obs_zeros = map_r(moments[0]['observation'][moments[0]['turn'][0]], lambda o: np.zeros_like(o)) # template for padding From c246d3b1156190c2c2c63f826c87e092da8f34cc Mon Sep 17 00:00:00 2001 From: YuriCat Date: Tue, 1 Feb 2022 20:05:10 +0900 Subject: [PATCH 5/5] fix: debug comment --- handyrl/train.py | 1 - 1 file changed, 1 deletion(-) diff --git a/handyrl/train.py b/handyrl/train.py index 844891c7..594fe3bf 100755 --- a/handyrl/train.py +++ b/handyrl/train.py @@ -174,7 +174,6 @@ def forward_prediction(model, hidden, batch, args): for k, o in outputs.items(): if k == 'policy': o = o.mul(batch['turn_mask']) - print(o.shape, batch['turn_mask'].shape, batch['action_mask'].shape) if o.size(2) > 1 and batch_shape[2] == 1: # turn-alternating batch o = o.sum(2, keepdim=True) # gather turn player's policies outputs[k] = o - batch['action_mask']