Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(Experiment) (Another Style) feature: generalized policy setting log #254

Open
wants to merge 38 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
ecf5112
experiment: generalized policy gradient
YuriCat Jan 16, 2022
f33745a
feature: remove action mask
YuriCat Jan 16, 2022
c239a43
feature: generalized policy for RNN
YuriCat Jan 17, 2022
30960aa
fix: generalized action dimension
YuriCat Jan 17, 2022
a4129a6
Merge experiment/generalized_policy
YuriCat Jan 17, 2022
cce3666
feature: remove action mask for RNN
YuriCat Jan 17, 2022
a53da19
feature: log_selected_prob based training
YuriCat Jan 17, 2022
f25b8ed
feature: action dimension
YuriCat Jan 17, 2022
61a272f
Merge feature/generalized_policy_setting
YuriCat Jan 17, 2022
ea55fe9
fix: padded log_selected_prob
YuriCat Jan 17, 2022
f05d45f
fix: prob in agent.py
YuriCat Jan 17, 2022
0487fa2
Merge branch 'feature/generalized_policy_setting' into feature/genera…
YuriCat Jan 17, 2022
d448cb5
fix: duplicate substitution
YuriCat Jan 21, 2022
93068d1
Merge experiment/generalized_policy
YuriCat Jan 21, 2022
9279a02
Merge branch 'feature/generalized_policy_setting' into feature/genera…
YuriCat Jan 21, 2022
3f47985
Merge develop
YuriCat Jan 25, 2022
bd9902f
Merge experiment/generalized_policy
YuriCat Jan 25, 2022
691bbfb
Merge feature/generalized_policy_setting
YuriCat Jan 25, 2022
4d5b38e
fix: small codefix
YuriCat Jan 25, 2022
e9b0589
Merge experiment/generalized_policy
YuriCat Jan 25, 2022
6e61054
Merge branch 'feature/generalized_policy_setting' into feature/genera…
YuriCat Jan 25, 2022
42298f7
Merge develop
YuriCat Jan 27, 2022
fb2a0b2
Merge experiment/generalized_policy
YuriCat Jan 29, 2022
f515d24
Merge feature/generalized_policy_setting
YuriCat Jan 29, 2022
8084510
Merge develop
YuriCat Jan 31, 2022
22cb711
Merge develop
YuriCat Jan 31, 2022
15b7789
Merge experiment/generalized_policy
YuriCat Jan 31, 2022
e92e4d5
fix: there is no action mask
YuriCat Jan 31, 2022
97c3694
Merge feature/generalized_policy_setting
YuriCat Jan 31, 2022
3bc5f86
chore: remove unused imports from model.py
YuriCat Feb 7, 2022
7018da4
Merge branch 'feature/generalized_policy_setting' into feature/genera…
YuriCat Feb 7, 2022
e7e30a9
Merge branch 'develop' into experiment/generalized_policy
YuriCat Feb 11, 2022
05ad675
Merge develo
YuriCat Mar 23, 2022
656bdfb
Merge develop
YuriCat Mar 23, 2022
4b88f3c
Merge develop
YuriCat Mar 23, 2022
9cfbb46
Merge develop
YuriCat Apr 27, 2022
88491d4
Merge branch 'experiment/generalized_policy' into feature/generalized…
YuriCat Apr 27, 2022
171c3ea
Merge branch 'feature/generalized_policy_setting' into feature/genera…
YuriCat Apr 27, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 11 additions & 17 deletions handyrl/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,18 +33,18 @@ def action(self, env, player, show=False):
return random.choice(env.legal_actions(player))


def print_outputs(env, prob, v):
def print_outputs(env, action, prob, v):
if hasattr(env, 'print_outputs'):
env.print_outputs(prob, v)
env.print_outputs(action, prob, v)
else:
if v is not None:
print('v = %f' % v)
if prob is not None:
print('p = %s' % (prob * 1000).astype(int))
if action is not None:
print('a = %d prob = %f' % (action, prob))


class Agent:
def __init__(self, model, temperature=0.0, observation=True):
def __init__(self, model, temperature=1e-6, observation=True):
# model might be a neural net, or some planning algorithm such as game tree search
self.model = model
self.hidden = None
Expand All @@ -55,28 +55,22 @@ def reset(self, env, show=False):
self.hidden = self.model.init_hidden()

def plan(self, obs):
outputs = self.model.inference(obs, self.hidden)
outputs = self.model.inference(obs, self.hidden, temperature=self.temperature)
self.hidden = outputs.pop('hidden', None)
return outputs

def action(self, env, player, show=False):
obs = env.observation(player)
outputs = self.plan(obs)
actions = env.legal_actions(player)
p = outputs['policy']

action = outputs['action']
prob = np.exp(outputs['log_selected_prob'])
v = outputs.get('value', None)
mask = np.ones_like(p)
mask[actions] = 0
p = p - mask * 1e32

if show:
print_outputs(env, softmax(p), v)
print_outputs(env, action, prob, v)

if self.temperature == 0:
ap_list = sorted([(a, p[a]) for a in actions], key=lambda x: -x[1])
return ap_list[0][0]
else:
return random.choices(np.arange(len(p)), weights=softmax(p / self.temperature))[0]
return action

def observe(self, env, player, show=False):
v = None
Expand Down
18 changes: 16 additions & 2 deletions handyrl/envs/geister.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributions as dist

from ..environment import BaseEnvironment

Expand Down Expand Up @@ -147,7 +148,7 @@ def __init__(self):
def init_hidden(self, batch_size=[]):
return self.body.init_hidden(self.input_size[1:], batch_size)

def forward(self, x, hidden):
def forward(self, x, hidden, action=None, temperature=1.0):
b, s = x['board'], x['scalar']
h_s = s.view(*s.size(), 1, 1).repeat(1, 1, 6, 6)
h = torch.cat([h_s, b], -3)
Expand All @@ -163,7 +164,16 @@ def forward(self, x, hidden):
h_v = self.head_v(h)
h_r = self.head_r(h)

return {'policy': h_p, 'value': torch.tanh(h_v), 'return': h_r, 'hidden': hidden}
log_prob = F.log_softmax(h_p / temperature, -1)
prob = torch.exp(log_prob)
entropy = dist.Categorical(logits=log_prob).entropy().unsqueeze(-1)

if action is None:
prob = torch.exp(log_prob)
action = prob.multinomial(num_samples=1, replacement=True)
log_selected_prob = log_prob.gather(-1, action)

return {'action': action, 'log_selected_prob': log_selected_prob, 'value': torch.tanh(h_v), 'return': h_r, 'hidden': hidden, 'entropy': entropy}


class Environment(BaseEnvironment):
Expand Down Expand Up @@ -357,6 +367,10 @@ def _set(self, layout):

def play(self, action, _=None):
# state transition
if not self.legal(action):
self.win_color = self.opponent(self.color)
return

if self.turn_count < 0:
layout = action - 4 * 6 * 6
return self._set(layout)
Expand Down
18 changes: 15 additions & 3 deletions handyrl/envs/tictactoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributions as dist

from ..environment import BaseEnvironment

Expand Down Expand Up @@ -59,14 +60,22 @@ def __init__(self):
self.head_p = Head((filters, 3, 3), 2, 9)
self.head_v = Head((filters, 3, 3), 1, 1)

def forward(self, x, hidden=None):
def forward(self, x, hidden=None, action=None, temperature=1.0):
h = F.relu(self.conv(x))
for block in self.blocks:
h = F.relu(block(h))
h_p = self.head_p(h)
h_v = self.head_v(h)

return {'policy': h_p, 'value': torch.tanh(h_v)}
log_prob = F.log_softmax(h_p / temperature, -1)
entropy = dist.Categorical(logits=log_prob).entropy().unsqueeze(-1)

if action is None:
prob = torch.exp(log_prob)
action = prob.multinomial(num_samples=1, replacement=True)
log_selected_prob = log_prob.gather(-1, action)

return {'action': action, 'log_selected_prob': log_selected_prob, 'value': torch.tanh(h_v), 'entropy': entropy}


class Environment(BaseEnvironment):
Expand Down Expand Up @@ -104,7 +113,10 @@ def play(self, action, _=None):
# state transition function
# action is integer (0 ~ 8)
x, y = action // 3, action % 3
self.board[x, y] = self.color
if self.board[x, y] != 0: # illegal action
self.win_color = -self.color
else:
self.board[x, y] = self.color

# check winning condition
win = self.board[x, :].sum() == 3 * self.color \
Expand Down
15 changes: 4 additions & 11 deletions handyrl/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def generate(self, models, args):
return None

while not self.env.terminal():
moment_keys = ['observation', 'selected_prob', 'action_mask', 'action', 'value', 'reward', 'return']
moment_keys = ['observation', 'log_selected_prob', 'action', 'value', 'reward', 'return']
moment = {key: {p: None for p in self.env.players()} for key in moment_keys}

turn_players = self.env.turns()
Expand All @@ -42,6 +42,7 @@ def generate(self, models, args):

obs = self.env.observation(player)
model = models[player]

outputs = model.inference(obs, hidden[player])
hidden[player] = outputs.get('hidden', None)
v = outputs.get('value', None)
Expand All @@ -50,16 +51,8 @@ def generate(self, models, args):
moment['value'][player] = v

if player in turn_players:
p_ = outputs['policy']
legal_actions = self.env.legal_actions(player)
action_mask = np.ones_like(p_) * 1e32
action_mask[legal_actions] = 0
p = softmax(p_ - action_mask)
action = random.choices(legal_actions, weights=p[legal_actions])[0]

moment['selected_prob'][player] = p[action]
moment['action_mask'][player] = action_mask
moment['action'][player] = action
moment['action'][player] = outputs['action']
moment['log_selected_prob'][player] = outputs['log_selected_prob']

err = self.env.step(moment['action'])
if err:
Expand Down
14 changes: 0 additions & 14 deletions handyrl/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,17 +58,3 @@ def inference(self, x, hidden, **kwargs):
ht = map_r(hidden, lambda h: torch.from_numpy(np.array(h)).contiguous().unsqueeze(0) if h is not None else None)
outputs = self.forward(xt, ht, **kwargs)
return map_r(outputs, lambda o: o.detach().numpy().squeeze(0) if o is not None else None)


# simple model

class RandomModel(nn.Module):
def __init__(self, model, x):
super().__init__()
wrapped_model = ModelWrapper(model)
hidden = wrapped_model.init_hidden()
outputs = wrapped_model.inference(x, hidden)
self.output_dict = {key: np.zeros_like(value) for key, value in outputs.items() if key != 'hidden'}

def inference(self, *args):
return self.output_dict
46 changes: 22 additions & 24 deletions handyrl/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributions as dist
import torch.optim as optim
import psutil

Expand Down Expand Up @@ -59,19 +58,18 @@ def replace_none(a, b):

# template for padding
obs_zeros = map_r(moments[0]['observation'][moments[0]['turn'][0]], lambda o: np.zeros_like(o))
amask_zeros = np.zeros_like(moments[0]['action_mask'][moments[0]['turn'][0]])
log_prob_zeros = np.zeros_like(moments[0]['log_selected_prob'][moments[0]['turn'][0]])
act_zeros = np.zeros_like(moments[0]['action'][moments[0]['turn'][0]])

# data that is changed by training configuration
if args['turn_based_training'] and not args['observation']:
obs = [[m['observation'][m['turn'][0]]] for m in moments]
prob = np.array([[[m['selected_prob'][m['turn'][0]]]] for m in moments])
act = np.array([[m['action'][m['turn'][0]]] for m in moments], dtype=np.int64)[..., np.newaxis]
amask = np.array([[m['action_mask'][m['turn'][0]]] for m in moments])
log_prob = np.array([[m['log_selected_prob'][m['turn'][0]]] for m in moments])
act = np.array([[m['action'][m['turn'][0]]] for m in moments])
else:
obs = [[replace_none(m['observation'][player], obs_zeros) for player in players] for m in moments]
prob = np.array([[[replace_none(m['selected_prob'][player], 1.0)] for player in players] for m in moments])
act = np.array([[replace_none(m['action'][player], 0) for player in players] for m in moments], dtype=np.int64)[..., np.newaxis]
amask = np.array([[replace_none(m['action_mask'][player], amask_zeros + 1e32) for player in players] for m in moments])
log_prob = np.array([[replace_none(m['log_selected_prob'][player], log_prob_zeros) for player in players] for m in moments])
act = np.array([[replace_none(m['action'][player], act_zeros) for player in players] for m in moments])

# reshape observation
obs = rotate(rotate(obs)) # (T, P, ..., ...) -> (P, ..., T, ...) -> (..., T, P, ...)
Expand All @@ -84,7 +82,7 @@ def replace_none(a, b):
oc = np.array([ep['outcome'][player] for player in players], dtype=np.float32).reshape(1, len(players), -1)

emask = np.ones((len(moments), 1, 1), dtype=np.float32) # episode mask
tmask = np.array([[[m['selected_prob'][player] is not None] for player in players] for m in moments], dtype=np.float32)
tmask = np.array([[[m['log_selected_prob'][player] is not None] for player in players] for m in moments], dtype=np.float32)
omask = np.array([[[m['observation'][player] is not None] for player in players] for m in moments], dtype=np.float32)

progress = np.arange(ep['start'], ep['end'], dtype=np.float32)[..., np.newaxis] / ep['total']
Expand All @@ -95,32 +93,30 @@ def replace_none(a, b):
pad_len_b = args['burn_in_steps'] - (ep['train_start'] - ep['start'])
pad_len_a = batch_steps - len(tmask) - pad_len_b
obs = map_r(obs, lambda o: np.pad(o, [(pad_len_b, pad_len_a)] + [(0, 0)] * (len(o.shape) - 1), 'constant', constant_values=0))
prob = np.pad(prob, [(pad_len_b, pad_len_a), (0, 0), (0, 0)], 'constant', constant_values=1)
log_prob = np.pad(log_prob, [(pad_len_b, pad_len_a), (0, 0), (0, 0)], 'constant', constant_values=0)
v = np.concatenate([np.pad(v, [(pad_len_b, 0), (0, 0), (0, 0)], 'constant', constant_values=0), np.tile(oc, [pad_len_a, 1, 1])])
act = np.pad(act, [(pad_len_b, pad_len_a), (0, 0), (0, 0)], 'constant', constant_values=0)
rew = np.pad(rew, [(pad_len_b, pad_len_a), (0, 0), (0, 0)], 'constant', constant_values=0)
ret = np.pad(ret, [(pad_len_b, pad_len_a), (0, 0), (0, 0)], 'constant', constant_values=0)
emask = np.pad(emask, [(pad_len_b, pad_len_a), (0, 0), (0, 0)], 'constant', constant_values=0)
tmask = np.pad(tmask, [(pad_len_b, pad_len_a), (0, 0), (0, 0)], 'constant', constant_values=0)
omask = np.pad(omask, [(pad_len_b, pad_len_a), (0, 0), (0, 0)], 'constant', constant_values=0)
amask = np.pad(amask, [(pad_len_b, pad_len_a), (0, 0), (0, 0)], 'constant', constant_values=1e32)
progress = np.pad(progress, [(pad_len_b, pad_len_a), (0, 0)], 'constant', constant_values=1)

obss.append(obs)
datum.append((prob, v, act, oc, rew, ret, emask, tmask, omask, amask, progress))
datum.append((log_prob, v, act, oc, rew, ret, emask, tmask, omask, progress))

obs = to_torch(bimap_r(obs_zeros, rotate(obss), lambda _, o: np.array(o)))
prob, v, act, oc, rew, ret, emask, tmask, omask, amask, progress = [to_torch(np.array(val)) for val in zip(*datum)]
log_prob, v, act, oc, rew, ret, emask, tmask, omask, progress = [to_torch(np.array(val)) for val in zip(*datum)]

return {
'observation': obs,
'selected_prob': prob,
'log_selected_prob': log_prob,
'value': v,
'action': act, 'outcome': oc,
'reward': rew, 'return': ret,
'episode_mask': emask,
'turn_mask': tmask, 'observation_mask': omask,
'action_mask': amask,
'progress': progress,
}

Expand All @@ -143,13 +139,15 @@ def forward_prediction(model, hidden, batch, args):
if hidden is None:
# feed-forward neural network
obs = map_r(observations, lambda o: o.flatten(0, 2)) # (..., B * T * P or 1, ...)
outputs = model(obs, None)
action = batch['action'].flatten(0, 2)
outputs = model(obs, None, action=action)
outputs = map_r(outputs, lambda o: o.unflatten(0, batch_shape)) # (..., B, T, P or 1, ...)
else:
# sequential computation with RNN
outputs = {}
for t in range(batch_shape[1]):
obs = map_r(observations, lambda o: o[:, t].flatten(0, 1)) # (..., B * P or 1, ...)
action = batch['action'][:, t].flatten(0, 1)
omask_ = batch['observation_mask'][:, t]
omask = map_r(hidden, lambda h: omask_.view(*h.size()[:2], *([1] * (h.dim() - 2))))
hidden_ = bimap_r(hidden, omask, lambda h, m: h * m) # (..., B, P, ...)
Expand All @@ -160,11 +158,11 @@ def forward_prediction(model, hidden, batch, args):
if t < args['burn_in_steps']:
model.eval()
with torch.no_grad():
outputs_ = model(obs, hidden_)
outputs_ = model(obs, hidden_, action=action)
else:
if not model.training:
model.train()
outputs_ = model(obs, hidden_)
outputs_ = model(obs, hidden_, action=action)
outputs_ = map_r(outputs_, lambda o: o.unflatten(0, (batch_shape[0], batch_shape[2]))) # (..., B, P or 1, ...)
for k, o in outputs_.items():
if k == 'hidden':
Expand All @@ -175,11 +173,11 @@ def forward_prediction(model, hidden, batch, args):
outputs = {k: torch.stack(o, dim=1) for k, o in outputs.items() if o[0] is not None}

for k, o in outputs.items():
if k == 'policy':
if k in ['action', 'log_selected_prob', 'entropy']:
o = o.mul(batch['turn_mask'])
if o.size(2) > 1 and batch_shape[2] == 1: # turn-alternating batch
o = o.sum(2, keepdim=True) # gather turn player's policies
outputs[k] = o - batch['action_mask']
outputs[k] = o
else:
# mask valid target values and cumulative rewards
outputs[k] = o.mul(batch['observation_mask'])
Expand All @@ -206,11 +204,11 @@ def compose_losses(outputs, log_selected_policies, total_advantages, targets, ba
if 'return' in outputs:
losses['r'] = F.smooth_l1_loss(outputs['return'], targets['return'], reduction='none').mul(omasks).sum()

entropy = dist.Categorical(logits=outputs['policy']).entropy().mul(tmasks.sum(-1))
entropy = outputs['entropy'].mul(tmasks)
losses['ent'] = entropy.sum()

base_loss = losses['p'] + losses.get('v', 0) + losses.get('r', 0)
entropy_loss = entropy.mul(1 - batch['progress'] * (1 - args['entropy_regularization_decay'])).sum() * -args['entropy_regularization']
entropy_loss = entropy.mul(1 - batch['progress'].unsqueeze(-2) * (1 - args['entropy_regularization_decay'])).sum() * -args['entropy_regularization']
losses['total'] = base_loss + entropy_loss

return losses, dcnt
Expand All @@ -226,8 +224,8 @@ def compute_loss(batch, model, hidden, args):
emasks = batch['episode_mask']
clip_rho_threshold, clip_c_threshold = 1.0, 1.0

log_selected_b_policies = torch.log(torch.clamp(batch['selected_prob'], 1e-16, 1)) * emasks
log_selected_t_policies = F.log_softmax(outputs['policy'], dim=-1).gather(-1, actions) * emasks
log_selected_b_policies = batch['log_selected_prob'] * emasks
log_selected_t_policies = outputs['log_selected_prob'] * emasks

# thresholds of importance sampling
log_rhos = log_selected_t_policies.detach() - log_selected_b_policies
Expand Down
7 changes: 1 addition & 6 deletions handyrl/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from .connection import connect_socket_connection, accept_socket_connections
from .evaluation import Evaluator
from .generation import Generator
from .model import ModelWrapper, RandomModel
from .model import ModelWrapper


class Worker:
Expand Down Expand Up @@ -52,11 +52,6 @@ def _gather_models(self, model_ids):
else:
# get model from server
model = pickle.loads(send_recv(self.conn, ('model', model_id)))
if model_id == 0:
# use random model
self.env.reset()
obs = self.env.observation(self.env.players()[0])
model = RandomModel(model, obs)
model_pool[model_id] = ModelWrapper(model)
# update latest model
if model_id > self.latest_model[0]:
Expand Down