From 6ffa9d30eae39eb4658afbc07b8e91b9f524fb9f Mon Sep 17 00:00:00 2001 From: YuriCat Date: Wed, 6 Jul 2022 14:33:18 +0900 Subject: [PATCH] feature: compute rho, c by joint probability --- handyrl/train.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/handyrl/train.py b/handyrl/train.py index 75c0a31b..32fea68b 100755 --- a/handyrl/train.py +++ b/handyrl/train.py @@ -223,14 +223,13 @@ def compute_loss(batch, model, hidden, args): outputs = map_r(outputs, lambda v: v[:, args['burn_in_steps']:]) actions = batch['action'] - emasks = batch['episode_mask'] clip_rho_threshold, clip_c_threshold = 1.0, 1.0 - log_selected_b_policies = torch.log(torch.clamp(batch['selected_prob'], 1e-16, 1)) * emasks - log_selected_t_policies = F.log_softmax(outputs['policy'], dim=-1).gather(-1, actions) * emasks + log_selected_b_policies = torch.log(torch.clamp(batch['selected_prob'], 1e-16, 1)) * batch['turn_mask'] + log_selected_t_policies = F.log_softmax(outputs['policy'], dim=-1).gather(-1, actions) * batch['turn_mask'] # thresholds of importance sampling - log_rhos = log_selected_t_policies.detach() - log_selected_b_policies + log_rhos = (log_selected_t_policies.detach() - log_selected_b_policies).sum(-2, keepdim=True) rhos = torch.exp(log_rhos) clipped_rhos = torch.clamp(rhos, 0, clip_rho_threshold) cs = torch.clamp(rhos, 0, clip_c_threshold) @@ -241,7 +240,7 @@ def compute_loss(batch, model, hidden, args): if args['turn_based_training'] and values_nograd.size(2) == 2: # two player zerosum game values_nograd_opponent = -torch.stack([values_nograd[:, :, 1], values_nograd[:, :, 0]], dim=2) values_nograd = (values_nograd + values_nograd_opponent) / (batch['observation_mask'].sum(dim=2, keepdim=True) + 1e-8) - outputs_nograd['value'] = values_nograd * emasks + batch['outcome'] * (1 - emasks) + outputs_nograd['value'] = values_nograd * batch['episode_mask'] + batch['outcome'] * (1 - batch['episode_mask']) # compute targets and advantage targets = {}