Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(IMPORTANT) Deep Nash algorithm #323

Open
wants to merge 3 commits into
base: develop
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 37 additions & 15 deletions handyrl/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def forward_prediction(model, hidden, batch, args):
return outputs


def compose_losses(outputs, log_selected_policies, total_advantages, targets, batch, args):
def compose_losses(outputs, advantages, targets, batch, args, outputs_reg=None):
"""Caluculate loss value

Returns:
Expand All @@ -200,7 +200,16 @@ def compose_losses(outputs, log_selected_policies, total_advantages, targets, ba
losses = {}
dcnt = tmasks.sum().item()

losses['p'] = (-log_selected_policies * total_advantages).mul(tmasks).sum()
total_advantages = sum(advantages.values())
log_policies = F.log_softmax(outputs['policy'], dim=-1)
if args.get('nash', False):
eta, clip = 1, 1e4
log_pi_ratio_reg = log_policies.detach() - F.log_softmax(outputs_reg['policy'].detach(), -1)
nash_advantages = -eta * log_pi_ratio_reg + F.one_hot(batch['action'], log_policies.size(-1)).squeeze(-2) / batch['selected_prob'] * outputs['rho'] * total_advantages
losses['p'] = -log_policies.mul(torch.clamp(nash_advantages, -clip, clip)).mul(tmasks).sum()
else:
losses['p'] = -log_policies.gather(-1, batch['action']).mul(outputs['clipped_rho']).mul(total_advantages).mul(tmasks).sum()

if 'value' in outputs:
losses['v'] = ((outputs['value'] - targets['value']) ** 2).mul(omasks).sum() / 2
if 'return' in outputs:
Expand All @@ -216,24 +225,34 @@ def compose_losses(outputs, log_selected_policies, total_advantages, targets, ba
return losses, dcnt


def compute_loss(batch, model, hidden, args):
def compute_loss(batch, model, hidden, args, reg_model=None):
outputs = forward_prediction(model, hidden, batch, args)
outputs_reg = None
if reg_model is not None:
with torch.no_grad():
outputs_reg = forward_prediction(reg_model, hidden, batch, args)
if args['burn_in_steps'] > 0:
batch = map_r(batch, lambda v: v[:, args['burn_in_steps']:] if v.size(1) > 1 else v)
outputs = map_r(outputs, lambda v: v[:, args['burn_in_steps']:])
if outputs_reg is not None:
outputs_reg = map_r(outputs_reg, lambda v: v[:, args['burn_in_steps']:])

actions = batch['action']
emasks = batch['episode_mask']
tmasks = batch['turn_mask']
clip_rho_threshold, clip_c_threshold = 1.0, 1.0

log_selected_b_policies = torch.log(torch.clamp(batch['selected_prob'], 1e-16, 1)) * emasks
log_selected_t_policies = F.log_softmax(outputs['policy'], dim=-1).gather(-1, actions) * emasks
log_selected_b_policies = torch.log(torch.clamp(batch['selected_prob'], 1e-16, 1)).mul(tmasks)
log_selected_t_policies = F.log_softmax(outputs['policy'], dim=-1).gather(-1, actions).mul(tmasks)

# thresholds of importance sampling
log_rhos = log_selected_t_policies.detach() - log_selected_b_policies
log_rhos = (log_selected_t_policies.detach() - log_selected_b_policies).sum(-2, keepdim=True)
rhos = torch.exp(log_rhos)
clipped_rhos = torch.clamp(rhos, 0, clip_rho_threshold)
cs = torch.clamp(rhos, 0, clip_c_threshold)
outputs['rho'] = rhos
outputs['clipped_rho'] = clipped_rhos

outputs_nograd = {k: o.detach() for k, o in outputs.items()}

if 'value' in outputs_nograd:
Expand All @@ -257,10 +276,7 @@ def compute_loss(batch, model, hidden, args):
_, advantages['value'] = compute_target(args['policy_target'], *value_args)
_, advantages['return'] = compute_target(args['policy_target'], *return_args)

# compute policy advantage
total_advantages = clipped_rhos * sum(advantages.values())

return compose_losses(outputs, log_selected_t_policies, total_advantages, targets, batch, args)
return compose_losses(outputs, advantages, targets, batch, args, outputs_reg=outputs_reg)


class Batcher:
Expand Down Expand Up @@ -316,10 +332,11 @@ def __init__(self, args, model):
self.args = args
self.gpu = torch.cuda.device_count()
self.model = model
self.default_lr = 3e-8
self.reg_model = model if args.get('nash', False) else None
self.default_lr = 3e-6
self.data_cnt_ema = self.args['batch_size'] * self.args['forward_steps']
self.params = list(self.model.parameters())
lr = self.default_lr * self.data_cnt_ema
lr = self.default_lr * (self.data_cnt_ema ** 0.5)
self.optimizer = optim.Adam(self.params, lr=lr, weight_decay=1e-5) if len(self.params) > 0 else None
self.steps = 0
self.batcher = Batcher(self.args, self.episodes)
Expand All @@ -330,6 +347,7 @@ def __init__(self, args, model):
self.trained_model = self.wrapped_model
if self.gpu > 1:
self.trained_model = nn.DataParallel(self.wrapped_model)
self.reg_model = nn.DataParallel(self.reg_model) if self.reg_model is not None else None

def update(self):
self.update_flag = True
Expand All @@ -344,6 +362,8 @@ def train(self):
batch_cnt, data_cnt, loss_sum = 0, 0, {}
if self.gpu > 0:
self.trained_model.cuda()
if self.reg_model is not None:
self.reg_model.cuda()
self.trained_model.train()

while data_cnt == 0 or not self.update_flag:
Expand All @@ -355,11 +375,11 @@ def train(self):
batch = to_gpu(batch)
hidden = to_gpu(hidden)

losses, dcnt = compute_loss(batch, self.trained_model, hidden, self.args)
losses, dcnt = compute_loss(batch, self.trained_model, hidden, self.args, reg_model=self.reg_model)

self.optimizer.zero_grad()
losses['total'].backward()
nn.utils.clip_grad_norm_(self.params, 4.0)
nn.utils.clip_grad_norm_(self.params, 1e4)
self.optimizer.step()

batch_cnt += 1
Expand All @@ -373,7 +393,9 @@ def train(self):

self.data_cnt_ema = self.data_cnt_ema * 0.8 + data_cnt / (1e-2 + batch_cnt) * 0.2
for param_group in self.optimizer.param_groups:
param_group['lr'] = self.default_lr * self.data_cnt_ema / (1 + self.steps * 1e-5)
param_group['lr'] = self.default_lr * (self.data_cnt_ema ** 0.5) / (1 + self.steps * 1e-5)
if self.reg_model is not None:
self.reg_model = copy.deepcopy(self.model)
self.model.cpu()
self.model.eval()
return copy.deepcopy(self.model)
Expand Down