diff --git a/handyrl/train.py b/handyrl/train.py index 75c0a31b..ec021c5b 100755 --- a/handyrl/train.py +++ b/handyrl/train.py @@ -187,7 +187,7 @@ def forward_prediction(model, hidden, batch, args): return outputs -def compose_losses(outputs, log_selected_policies, total_advantages, targets, batch, args): +def compose_losses(outputs, advantages, targets, batch, args, outputs_reg=None): """Caluculate loss value Returns: @@ -200,7 +200,16 @@ def compose_losses(outputs, log_selected_policies, total_advantages, targets, ba losses = {} dcnt = tmasks.sum().item() - losses['p'] = (-log_selected_policies * total_advantages).mul(tmasks).sum() + total_advantages = sum(advantages.values()) + log_policies = F.log_softmax(outputs['policy'], dim=-1) + if args.get('nash', False): + eta, clip = 1, 1e4 + log_pi_ratio_reg = log_policies.detach() - F.log_softmax(outputs_reg['policy'].detach(), -1) + nash_advantages = -eta * log_pi_ratio_reg + F.one_hot(batch['action'], log_policies.size(-1)).squeeze(-2) / batch['selected_prob'] * outputs['rho'] * total_advantages + losses['p'] = -log_policies.mul(torch.clamp(nash_advantages, -clip, clip)).mul(tmasks).sum() + else: + losses['p'] = -log_policies.gather(-1, batch['action']).mul(outputs['clipped_rho']).mul(total_advantages).mul(tmasks).sum() + if 'value' in outputs: losses['v'] = ((outputs['value'] - targets['value']) ** 2).mul(omasks).sum() / 2 if 'return' in outputs: @@ -216,24 +225,34 @@ def compose_losses(outputs, log_selected_policies, total_advantages, targets, ba return losses, dcnt -def compute_loss(batch, model, hidden, args): +def compute_loss(batch, model, hidden, args, reg_model=None): outputs = forward_prediction(model, hidden, batch, args) + outputs_reg = None + if reg_model is not None: + with torch.no_grad(): + outputs_reg = forward_prediction(reg_model, hidden, batch, args) if args['burn_in_steps'] > 0: batch = map_r(batch, lambda v: v[:, args['burn_in_steps']:] if v.size(1) > 1 else v) outputs = map_r(outputs, lambda v: v[:, args['burn_in_steps']:]) + if outputs_reg is not None: + outputs_reg = map_r(outputs_reg, lambda v: v[:, args['burn_in_steps']:]) actions = batch['action'] emasks = batch['episode_mask'] + tmasks = batch['turn_mask'] clip_rho_threshold, clip_c_threshold = 1.0, 1.0 - log_selected_b_policies = torch.log(torch.clamp(batch['selected_prob'], 1e-16, 1)) * emasks - log_selected_t_policies = F.log_softmax(outputs['policy'], dim=-1).gather(-1, actions) * emasks + log_selected_b_policies = torch.log(torch.clamp(batch['selected_prob'], 1e-16, 1)).mul(tmasks) + log_selected_t_policies = F.log_softmax(outputs['policy'], dim=-1).gather(-1, actions).mul(tmasks) # thresholds of importance sampling - log_rhos = log_selected_t_policies.detach() - log_selected_b_policies + log_rhos = (log_selected_t_policies.detach() - log_selected_b_policies).sum(-2, keepdim=True) rhos = torch.exp(log_rhos) clipped_rhos = torch.clamp(rhos, 0, clip_rho_threshold) cs = torch.clamp(rhos, 0, clip_c_threshold) + outputs['rho'] = rhos + outputs['clipped_rho'] = clipped_rhos + outputs_nograd = {k: o.detach() for k, o in outputs.items()} if 'value' in outputs_nograd: @@ -257,10 +276,7 @@ def compute_loss(batch, model, hidden, args): _, advantages['value'] = compute_target(args['policy_target'], *value_args) _, advantages['return'] = compute_target(args['policy_target'], *return_args) - # compute policy advantage - total_advantages = clipped_rhos * sum(advantages.values()) - - return compose_losses(outputs, log_selected_t_policies, total_advantages, targets, batch, args) + return compose_losses(outputs, advantages, targets, batch, args, outputs_reg=outputs_reg) class Batcher: @@ -316,10 +332,11 @@ def __init__(self, args, model): self.args = args self.gpu = torch.cuda.device_count() self.model = model - self.default_lr = 3e-8 + self.reg_model = model if args.get('nash', False) else None + self.default_lr = 3e-6 self.data_cnt_ema = self.args['batch_size'] * self.args['forward_steps'] self.params = list(self.model.parameters()) - lr = self.default_lr * self.data_cnt_ema + lr = self.default_lr * (self.data_cnt_ema ** 0.5) self.optimizer = optim.Adam(self.params, lr=lr, weight_decay=1e-5) if len(self.params) > 0 else None self.steps = 0 self.batcher = Batcher(self.args, self.episodes) @@ -330,6 +347,7 @@ def __init__(self, args, model): self.trained_model = self.wrapped_model if self.gpu > 1: self.trained_model = nn.DataParallel(self.wrapped_model) + self.reg_model = nn.DataParallel(self.reg_model) if self.reg_model is not None else None def update(self): self.update_flag = True @@ -344,6 +362,8 @@ def train(self): batch_cnt, data_cnt, loss_sum = 0, 0, {} if self.gpu > 0: self.trained_model.cuda() + if self.reg_model is not None: + self.reg_model.cuda() self.trained_model.train() while data_cnt == 0 or not self.update_flag: @@ -355,11 +375,11 @@ def train(self): batch = to_gpu(batch) hidden = to_gpu(hidden) - losses, dcnt = compute_loss(batch, self.trained_model, hidden, self.args) + losses, dcnt = compute_loss(batch, self.trained_model, hidden, self.args, reg_model=self.reg_model) self.optimizer.zero_grad() losses['total'].backward() - nn.utils.clip_grad_norm_(self.params, 4.0) + nn.utils.clip_grad_norm_(self.params, 1e4) self.optimizer.step() batch_cnt += 1 @@ -373,7 +393,9 @@ def train(self): self.data_cnt_ema = self.data_cnt_ema * 0.8 + data_cnt / (1e-2 + batch_cnt) * 0.2 for param_group in self.optimizer.param_groups: - param_group['lr'] = self.default_lr * self.data_cnt_ema / (1 + self.steps * 1e-5) + param_group['lr'] = self.default_lr * (self.data_cnt_ema ** 0.5) / (1 + self.steps * 1e-5) + if self.reg_model is not None: + self.reg_model = copy.deepcopy(self.model) self.model.cpu() self.model.eval() return copy.deepcopy(self.model)