diff --git a/handyrl/agent.py b/handyrl/agent.py index 86d2c08e..c44ac302 100755 --- a/handyrl/agent.py +++ b/handyrl/agent.py @@ -33,18 +33,18 @@ def action(self, env, player, show=False): return random.choice(env.legal_actions(player)) -def print_outputs(env, prob, v): +def print_outputs(env, action, prob, v): if hasattr(env, 'print_outputs'): - env.print_outputs(prob, v) + env.print_outputs(action, prob, v) else: if v is not None: print('v = %f' % v) - if prob is not None: - print('p = %s' % (prob * 1000).astype(int)) + if action is not None: + print('a = %d prob = %f' % (action, prob)) class Agent: - def __init__(self, model, temperature=0.0, observation=True): + def __init__(self, model, temperature=1e-6, observation=True): # model might be a neural net, or some planning algorithm such as game tree search self.model = model self.hidden = None @@ -55,28 +55,22 @@ def reset(self, env, show=False): self.hidden = self.model.init_hidden() def plan(self, obs): - outputs = self.model.inference(obs, self.hidden) + outputs = self.model.inference(obs, self.hidden, temperature=self.temperature) self.hidden = outputs.pop('hidden', None) return outputs def action(self, env, player, show=False): obs = env.observation(player) outputs = self.plan(obs) - actions = env.legal_actions(player) - p = outputs['policy'] + + action = outputs['action'] + prob = np.exp(outputs['log_selected_prob']) v = outputs.get('value', None) - mask = np.ones_like(p) - mask[actions] = 0 - p = p - mask * 1e32 if show: - print_outputs(env, softmax(p), v) + print_outputs(env, action, prob, v) - if self.temperature == 0: - ap_list = sorted([(a, p[a]) for a in actions], key=lambda x: -x[1]) - return ap_list[0][0] - else: - return random.choices(np.arange(len(p)), weights=softmax(p / self.temperature))[0] + return action def observe(self, env, player, show=False): v = None diff --git a/handyrl/envs/geister.py b/handyrl/envs/geister.py index a82bd6af..0b8ca337 100755 --- a/handyrl/envs/geister.py +++ b/handyrl/envs/geister.py @@ -10,6 +10,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +import torch.distributions as dist from ..environment import BaseEnvironment @@ -147,7 +148,7 @@ def __init__(self): def init_hidden(self, batch_size=[]): return self.body.init_hidden(self.input_size[1:], batch_size) - def forward(self, x, hidden): + def forward(self, x, hidden, action=None, temperature=1.0): b, s = x['board'], x['scalar'] h_s = s.view(*s.size(), 1, 1).repeat(1, 1, 6, 6) h = torch.cat([h_s, b], -3) @@ -163,7 +164,16 @@ def forward(self, x, hidden): h_v = self.head_v(h) h_r = self.head_r(h) - return {'policy': h_p, 'value': torch.tanh(h_v), 'return': h_r, 'hidden': hidden} + log_prob = F.log_softmax(h_p / temperature, -1) + prob = torch.exp(log_prob) + entropy = dist.Categorical(logits=log_prob).entropy().unsqueeze(-1) + + if action is None: + prob = torch.exp(log_prob) + action = prob.multinomial(num_samples=1, replacement=True) + log_selected_prob = log_prob.gather(-1, action) + + return {'action': action, 'log_selected_prob': log_selected_prob, 'value': torch.tanh(h_v), 'return': h_r, 'hidden': hidden, 'entropy': entropy} class Environment(BaseEnvironment): @@ -357,6 +367,10 @@ def _set(self, layout): def play(self, action, _=None): # state transition + if not self.legal(action): + self.win_color = self.opponent(self.color) + return + if self.turn_count < 0: layout = action - 4 * 6 * 6 return self._set(layout) diff --git a/handyrl/envs/tictactoe.py b/handyrl/envs/tictactoe.py index 2c27809c..ceffc155 100755 --- a/handyrl/envs/tictactoe.py +++ b/handyrl/envs/tictactoe.py @@ -10,6 +10,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +import torch.distributions as dist from ..environment import BaseEnvironment @@ -59,14 +60,22 @@ def __init__(self): self.head_p = Head((filters, 3, 3), 2, 9) self.head_v = Head((filters, 3, 3), 1, 1) - def forward(self, x, hidden=None): + def forward(self, x, hidden=None, action=None, temperature=1.0): h = F.relu(self.conv(x)) for block in self.blocks: h = F.relu(block(h)) h_p = self.head_p(h) h_v = self.head_v(h) - return {'policy': h_p, 'value': torch.tanh(h_v)} + log_prob = F.log_softmax(h_p / temperature, -1) + entropy = dist.Categorical(logits=log_prob).entropy().unsqueeze(-1) + + if action is None: + prob = torch.exp(log_prob) + action = prob.multinomial(num_samples=1, replacement=True) + log_selected_prob = log_prob.gather(-1, action) + + return {'action': action, 'log_selected_prob': log_selected_prob, 'value': torch.tanh(h_v), 'entropy': entropy} class Environment(BaseEnvironment): @@ -104,7 +113,10 @@ def play(self, action, _=None): # state transition function # action is integer (0 ~ 8) x, y = action // 3, action % 3 - self.board[x, y] = self.color + if self.board[x, y] != 0: # illegal action + self.win_color = -self.color + else: + self.board[x, y] = self.color # check winning condition win = self.board[x, :].sum() == 3 * self.color \ diff --git a/handyrl/generation.py b/handyrl/generation.py index 8bca1c98..bb2f9fb3 100755 --- a/handyrl/generation.py +++ b/handyrl/generation.py @@ -29,7 +29,7 @@ def generate(self, models, args): return None while not self.env.terminal(): - moment_keys = ['observation', 'selected_prob', 'action_mask', 'action', 'value', 'reward', 'return'] + moment_keys = ['observation', 'log_selected_prob', 'action', 'value', 'reward', 'return'] moment = {key: {p: None for p in self.env.players()} for key in moment_keys} turn_players = self.env.turns() @@ -42,6 +42,7 @@ def generate(self, models, args): obs = self.env.observation(player) model = models[player] + outputs = model.inference(obs, hidden[player]) hidden[player] = outputs.get('hidden', None) v = outputs.get('value', None) @@ -50,16 +51,8 @@ def generate(self, models, args): moment['value'][player] = v if player in turn_players: - p_ = outputs['policy'] - legal_actions = self.env.legal_actions(player) - action_mask = np.ones_like(p_) * 1e32 - action_mask[legal_actions] = 0 - p = softmax(p_ - action_mask) - action = random.choices(legal_actions, weights=p[legal_actions])[0] - - moment['selected_prob'][player] = p[action] - moment['action_mask'][player] = action_mask - moment['action'][player] = action + moment['action'][player] = outputs['action'] + moment['log_selected_prob'][player] = outputs['log_selected_prob'] err = self.env.step(moment['action']) if err: diff --git a/handyrl/model.py b/handyrl/model.py index b4dd2a7a..44ab9c60 100755 --- a/handyrl/model.py +++ b/handyrl/model.py @@ -58,17 +58,3 @@ def inference(self, x, hidden, **kwargs): ht = map_r(hidden, lambda h: torch.from_numpy(np.array(h)).contiguous().unsqueeze(0) if h is not None else None) outputs = self.forward(xt, ht, **kwargs) return map_r(outputs, lambda o: o.detach().numpy().squeeze(0) if o is not None else None) - - -# simple model - -class RandomModel(nn.Module): - def __init__(self, model, x): - super().__init__() - wrapped_model = ModelWrapper(model) - hidden = wrapped_model.init_hidden() - outputs = wrapped_model.inference(x, hidden) - self.output_dict = {key: np.zeros_like(value) for key, value in outputs.items() if key != 'hidden'} - - def inference(self, *args): - return self.output_dict diff --git a/handyrl/train.py b/handyrl/train.py index 05fb2b9d..4b6e5280 100755 --- a/handyrl/train.py +++ b/handyrl/train.py @@ -18,7 +18,6 @@ import torch import torch.nn as nn import torch.nn.functional as F -import torch.distributions as dist import torch.optim as optim import psutil @@ -59,19 +58,18 @@ def replace_none(a, b): # template for padding obs_zeros = map_r(moments[0]['observation'][moments[0]['turn'][0]], lambda o: np.zeros_like(o)) - amask_zeros = np.zeros_like(moments[0]['action_mask'][moments[0]['turn'][0]]) + log_prob_zeros = np.zeros_like(moments[0]['log_selected_prob'][moments[0]['turn'][0]]) + act_zeros = np.zeros_like(moments[0]['action'][moments[0]['turn'][0]]) # data that is changed by training configuration if args['turn_based_training'] and not args['observation']: obs = [[m['observation'][m['turn'][0]]] for m in moments] - prob = np.array([[[m['selected_prob'][m['turn'][0]]]] for m in moments]) - act = np.array([[m['action'][m['turn'][0]]] for m in moments], dtype=np.int64)[..., np.newaxis] - amask = np.array([[m['action_mask'][m['turn'][0]]] for m in moments]) + log_prob = np.array([[m['log_selected_prob'][m['turn'][0]]] for m in moments]) + act = np.array([[m['action'][m['turn'][0]]] for m in moments]) else: obs = [[replace_none(m['observation'][player], obs_zeros) for player in players] for m in moments] - prob = np.array([[[replace_none(m['selected_prob'][player], 1.0)] for player in players] for m in moments]) - act = np.array([[replace_none(m['action'][player], 0) for player in players] for m in moments], dtype=np.int64)[..., np.newaxis] - amask = np.array([[replace_none(m['action_mask'][player], amask_zeros + 1e32) for player in players] for m in moments]) + log_prob = np.array([[replace_none(m['log_selected_prob'][player], log_prob_zeros) for player in players] for m in moments]) + act = np.array([[replace_none(m['action'][player], act_zeros) for player in players] for m in moments]) # reshape observation obs = rotate(rotate(obs)) # (T, P, ..., ...) -> (P, ..., T, ...) -> (..., T, P, ...) @@ -84,7 +82,7 @@ def replace_none(a, b): oc = np.array([ep['outcome'][player] for player in players], dtype=np.float32).reshape(1, len(players), -1) emask = np.ones((len(moments), 1, 1), dtype=np.float32) # episode mask - tmask = np.array([[[m['selected_prob'][player] is not None] for player in players] for m in moments], dtype=np.float32) + tmask = np.array([[[m['log_selected_prob'][player] is not None] for player in players] for m in moments], dtype=np.float32) omask = np.array([[[m['observation'][player] is not None] for player in players] for m in moments], dtype=np.float32) progress = np.arange(ep['start'], ep['end'], dtype=np.float32)[..., np.newaxis] / ep['total'] @@ -95,7 +93,7 @@ def replace_none(a, b): pad_len_b = args['burn_in_steps'] - (ep['train_start'] - ep['start']) pad_len_a = batch_steps - len(tmask) - pad_len_b obs = map_r(obs, lambda o: np.pad(o, [(pad_len_b, pad_len_a)] + [(0, 0)] * (len(o.shape) - 1), 'constant', constant_values=0)) - prob = np.pad(prob, [(pad_len_b, pad_len_a), (0, 0), (0, 0)], 'constant', constant_values=1) + log_prob = np.pad(log_prob, [(pad_len_b, pad_len_a), (0, 0), (0, 0)], 'constant', constant_values=0) v = np.concatenate([np.pad(v, [(pad_len_b, 0), (0, 0), (0, 0)], 'constant', constant_values=0), np.tile(oc, [pad_len_a, 1, 1])]) act = np.pad(act, [(pad_len_b, pad_len_a), (0, 0), (0, 0)], 'constant', constant_values=0) rew = np.pad(rew, [(pad_len_b, pad_len_a), (0, 0), (0, 0)], 'constant', constant_values=0) @@ -103,24 +101,22 @@ def replace_none(a, b): emask = np.pad(emask, [(pad_len_b, pad_len_a), (0, 0), (0, 0)], 'constant', constant_values=0) tmask = np.pad(tmask, [(pad_len_b, pad_len_a), (0, 0), (0, 0)], 'constant', constant_values=0) omask = np.pad(omask, [(pad_len_b, pad_len_a), (0, 0), (0, 0)], 'constant', constant_values=0) - amask = np.pad(amask, [(pad_len_b, pad_len_a), (0, 0), (0, 0)], 'constant', constant_values=1e32) progress = np.pad(progress, [(pad_len_b, pad_len_a), (0, 0)], 'constant', constant_values=1) obss.append(obs) - datum.append((prob, v, act, oc, rew, ret, emask, tmask, omask, amask, progress)) + datum.append((log_prob, v, act, oc, rew, ret, emask, tmask, omask, progress)) obs = to_torch(bimap_r(obs_zeros, rotate(obss), lambda _, o: np.array(o))) - prob, v, act, oc, rew, ret, emask, tmask, omask, amask, progress = [to_torch(np.array(val)) for val in zip(*datum)] + log_prob, v, act, oc, rew, ret, emask, tmask, omask, progress = [to_torch(np.array(val)) for val in zip(*datum)] return { 'observation': obs, - 'selected_prob': prob, + 'log_selected_prob': log_prob, 'value': v, 'action': act, 'outcome': oc, 'reward': rew, 'return': ret, 'episode_mask': emask, 'turn_mask': tmask, 'observation_mask': omask, - 'action_mask': amask, 'progress': progress, } @@ -143,13 +139,15 @@ def forward_prediction(model, hidden, batch, args): if hidden is None: # feed-forward neural network obs = map_r(observations, lambda o: o.flatten(0, 2)) # (..., B * T * P or 1, ...) - outputs = model(obs, None) + action = batch['action'].flatten(0, 2) + outputs = model(obs, None, action=action) outputs = map_r(outputs, lambda o: o.unflatten(0, batch_shape)) # (..., B, T, P or 1, ...) else: # sequential computation with RNN outputs = {} for t in range(batch_shape[1]): obs = map_r(observations, lambda o: o[:, t].flatten(0, 1)) # (..., B * P or 1, ...) + action = batch['action'][:, t].flatten(0, 1) omask_ = batch['observation_mask'][:, t] omask = map_r(hidden, lambda h: omask_.view(*h.size()[:2], *([1] * (h.dim() - 2)))) hidden_ = bimap_r(hidden, omask, lambda h, m: h * m) # (..., B, P, ...) @@ -160,11 +158,11 @@ def forward_prediction(model, hidden, batch, args): if t < args['burn_in_steps']: model.eval() with torch.no_grad(): - outputs_ = model(obs, hidden_) + outputs_ = model(obs, hidden_, action=action) else: if not model.training: model.train() - outputs_ = model(obs, hidden_) + outputs_ = model(obs, hidden_, action=action) outputs_ = map_r(outputs_, lambda o: o.unflatten(0, (batch_shape[0], batch_shape[2]))) # (..., B, P or 1, ...) for k, o in outputs_.items(): if k == 'hidden': @@ -175,11 +173,11 @@ def forward_prediction(model, hidden, batch, args): outputs = {k: torch.stack(o, dim=1) for k, o in outputs.items() if o[0] is not None} for k, o in outputs.items(): - if k == 'policy': + if k in ['action', 'log_selected_prob', 'entropy']: o = o.mul(batch['turn_mask']) if o.size(2) > 1 and batch_shape[2] == 1: # turn-alternating batch o = o.sum(2, keepdim=True) # gather turn player's policies - outputs[k] = o - batch['action_mask'] + outputs[k] = o else: # mask valid target values and cumulative rewards outputs[k] = o.mul(batch['observation_mask']) @@ -206,11 +204,11 @@ def compose_losses(outputs, log_selected_policies, total_advantages, targets, ba if 'return' in outputs: losses['r'] = F.smooth_l1_loss(outputs['return'], targets['return'], reduction='none').mul(omasks).sum() - entropy = dist.Categorical(logits=outputs['policy']).entropy().mul(tmasks.sum(-1)) + entropy = outputs['entropy'].mul(tmasks) losses['ent'] = entropy.sum() base_loss = losses['p'] + losses.get('v', 0) + losses.get('r', 0) - entropy_loss = entropy.mul(1 - batch['progress'] * (1 - args['entropy_regularization_decay'])).sum() * -args['entropy_regularization'] + entropy_loss = entropy.mul(1 - batch['progress'].unsqueeze(-2) * (1 - args['entropy_regularization_decay'])).sum() * -args['entropy_regularization'] losses['total'] = base_loss + entropy_loss return losses, dcnt @@ -226,8 +224,8 @@ def compute_loss(batch, model, hidden, args): emasks = batch['episode_mask'] clip_rho_threshold, clip_c_threshold = 1.0, 1.0 - log_selected_b_policies = torch.log(torch.clamp(batch['selected_prob'], 1e-16, 1)) * emasks - log_selected_t_policies = F.log_softmax(outputs['policy'], dim=-1).gather(-1, actions) * emasks + log_selected_b_policies = batch['log_selected_prob'] * emasks + log_selected_t_policies = outputs['log_selected_prob'] * emasks # thresholds of importance sampling log_rhos = log_selected_t_policies.detach() - log_selected_b_policies diff --git a/handyrl/worker.py b/handyrl/worker.py index 0cf47b63..0b2e08a0 100755 --- a/handyrl/worker.py +++ b/handyrl/worker.py @@ -20,7 +20,7 @@ from .connection import connect_socket_connection, accept_socket_connections from .evaluation import Evaluator from .generation import Generator -from .model import ModelWrapper, RandomModel +from .model import ModelWrapper class Worker: @@ -52,11 +52,6 @@ def _gather_models(self, model_ids): else: # get model from server model = pickle.loads(send_recv(self.conn, ('model', model_id))) - if model_id == 0: - # use random model - self.env.reset() - obs = self.env.observation(self.env.players()[0]) - model = RandomModel(model, obs) model_pool[model_id] = ModelWrapper(model) # update latest model if model_id > self.latest_model[0]: