From ce6709fc81548516a2fc99c6ad2c66241308753f Mon Sep 17 00:00:00 2001 From: YuriCat Date: Fri, 19 Nov 2021 00:03:16 +0900 Subject: [PATCH 1/8] experiment: multi-dimensional reward --- handyrl/environment.py | 6 ---- handyrl/envs/geister.py | 13 +++++---- handyrl/envs/tictactoe.py | 11 ++++---- handyrl/evaluation.py | 32 +++++++++++++-------- handyrl/generation.py | 7 +++-- handyrl/train.py | 58 +++++++++++++++++++-------------------- 6 files changed, 65 insertions(+), 62 deletions(-) diff --git a/handyrl/environment.py b/handyrl/environment.py index f470e816..5e1ba522 100755 --- a/handyrl/environment.py +++ b/handyrl/environment.py @@ -89,12 +89,6 @@ def terminal(self): def reward(self): return {} - # - # Should be defined in all games - # - def outcome(self): - raise NotImplementedError() - # # Should be defined in all games # diff --git a/handyrl/envs/geister.py b/handyrl/envs/geister.py index 86741444..9c3104c5 100755 --- a/handyrl/envs/geister.py +++ b/handyrl/envs/geister.py @@ -163,8 +163,9 @@ def forward(self, x, hidden): h_p = torch.cat([h_p_move, h_p_set], -1) h_v = self.head_v(h) h_r = self.head_r(h) + h_v = torch.cat([torch.tanh(h_v), h_r], -1) - return {'policy': h_p, 'value': torch.tanh(h_v), 'return': h_r, 'hidden': hidden} + return {'policy': h_p, 'value': h_v, 'hidden': hidden} class Environment(BaseEnvironment): @@ -432,17 +433,17 @@ def terminal(self): return self.win_color is not None def reward(self): - # return immediate rewards - return {p: -0.01 for p in self.players()} + # immediate rewards + turn_rewards = [-0.01, -0.01] - def outcome(self): - # return terminal outcomes + # terminal reward outcomes = [0, 0] if self.win_color == self.BLACK: outcomes = [1, -1] elif self.win_color == self.WHITE: outcomes = [-1, 1] - return {p: outcomes[idx] for idx, p in enumerate(self.players())} + + return {p: [outcomes[idx], turn_rewards[idx]] for idx, p in enumerate(self.players())} def legal(self, action): if self.turn_count < 0: diff --git a/handyrl/envs/tictactoe.py b/handyrl/envs/tictactoe.py index 8ad950d2..65f84e18 100755 --- a/handyrl/envs/tictactoe.py +++ b/handyrl/envs/tictactoe.py @@ -137,14 +137,13 @@ def terminal(self): # check whether the state is terminal return self.win_color != 0 or len(self.record) == 3 * 3 - def outcome(self): - # terminal outcome - outcomes = [0, 0] + def reward(self): + rewards = [0, 0] if self.win_color > 0: - outcomes = [1, -1] + rewards = [1, -1] if self.win_color < 0: - outcomes = [-1, 1] - return {p: outcomes[idx] for idx, p in enumerate(self.players())} + rewards = [-1, 1] + return {p: rewards[idx] for idx, p in enumerate(self.players())} def legal_actions(self, _=None): # legal action list diff --git a/handyrl/evaluation.py b/handyrl/evaluation.py index 51d3c6c5..dedf8463 100755 --- a/handyrl/evaluation.py +++ b/handyrl/evaluation.py @@ -7,6 +7,8 @@ import time import multiprocessing as mp +import numpy as np + from .environment import prepare_env, make_env from .connection import send_recv, accept_socket_connections, connect_socket_connection from .agent import RandomAgent, RuleBasedAgent, Agent, EnsembleAgent, SoftAgent @@ -40,8 +42,8 @@ def run(self): command, args = self.conn.recv() if command == 'quit': break - elif command == 'outcome': - print('outcome = %f' % args[0]) + elif command == 'reward': + print('reward = %s' % args[0]) elif hasattr(self.agent, command): if command == 'action' or command == 'observe': view(self.env) @@ -80,8 +82,10 @@ def exec_match(env, agents, critic, show=False, game_args={}): ''' match with shared game environment ''' if env.reset(game_args): return None - for agent in agents.values(): + total_rewards = {} + for p, agent in agents.items(): agent.reset(env, show=show) + total_rewards[p] = 0 while not env.terminal(): if show: view(env) @@ -98,19 +102,22 @@ def exec_match(env, agents, critic, show=False, game_args={}): return None if show: view_transition(env) - outcome = env.outcome() + for p, reward in env.reward().items(): + total_rewards[p] += np.array(reward).reshape(-1) if show: - print('final outcome = %s' % outcome) - return outcome + print('total rewards = %s' % total_rewards) + return total_rewards def exec_network_match(env, network_agents, critic, show=False, game_args={}): ''' match with divided game environment ''' if env.reset(game_args): return None + total_rewards = {} for p, agent in network_agents.items(): info = env.diff_info(p) agent.update(info, True) + total_rewards[p] = 0 while not env.terminal(): if show: view(env) @@ -126,13 +133,14 @@ def exec_network_match(env, network_agents, critic, show=False, game_args={}): agent.observe(p) if env.step(actions): return None + for p, reward in env.reward().items(): + total_rewards[p] += np.array(reward).reshape(-1) for p, agent in network_agents.items(): info = env.diff_info(p) agent.update(info, False) - outcome = env.outcome() for p, agent in network_agents.items(): - agent.outcome(outcome[p]) - return outcome + agent.reward(total_rewards[p]) + return total_rewards def build_agent(raw, env): @@ -163,11 +171,11 @@ def execute(self, models, args): else: agents[p] = Agent(model, self.args['observation']) - outcome = exec_match(self.env, agents, None) - if outcome is None: + total_rewards = exec_match(self.env, agents, None) + if total_rewards is None: print('None episode in evaluation!') return None - return {'args': args, 'result': outcome, 'opponent': opponent} + return {'args': args, 'total_reward': total_rewards, 'opponent': opponent} def wp_func(results): diff --git a/handyrl/generation.py b/handyrl/generation.py index 63b7e553..b40ea10a 100755 --- a/handyrl/generation.py +++ b/handyrl/generation.py @@ -21,8 +21,10 @@ def generate(self, models, args): # episode generation moments = [] hidden = {} + total_rewards = {} for player in self.env.players(): hidden[player] = models[player].init_hidden() + total_rewards[player] = 0 err = self.env.reset() if err: @@ -63,6 +65,7 @@ def generate(self, models, args): reward = self.env.reward() for player in self.env.players(): moment['reward'][player] = reward.get(player, None) + total_rewards[player] += np.array(reward.get(player, 0)).reshape(-1) moment['turn'] = turn_players moments.append(moment) @@ -73,12 +76,12 @@ def generate(self, models, args): for player in self.env.players(): ret = 0 for i, m in reversed(list(enumerate(moments))): - ret = (m['reward'][player] or 0) + self.args['gamma'] * ret + ret = np.array(m['reward'][player] or 0) + np.array(self.args['gamma']) * ret moments[i]['return'][player] = ret episode = { 'args': args, 'steps': len(moments), - 'outcome': self.env.outcome(), + 'total_reward': total_rewards, 'moment': [ bz2.compress(pickle.dumps(moments[i:i+self.args['compress_steps']])) for i in range(0, len(moments), self.args['compress_steps']) diff --git a/handyrl/train.py b/handyrl/train.py index 35815031..1fbcec64 100755 --- a/handyrl/train.py +++ b/handyrl/train.py @@ -57,8 +57,10 @@ def replace_none(a, b): if not args['turn_based_training']: # solo training players = [random.choice(players)] - obs_zeros = map_r(moments[0]['observation'][moments[0]['turn'][0]], lambda o: np.zeros_like(o)) # template for padding - p_zeros = np.zeros_like(moments[0]['policy'][moments[0]['turn'][0]]) # template for padding + # template for padding + obs_zeros = map_r(moments[0]['observation'][moments[0]['turn'][0]], lambda o: np.zeros_like(o)) + p_zeros = np.zeros_like(moments[0]['policy'][moments[0]['turn'][0]]) + v_zeros = np.zeros_like(moments[0]['value'][moments[0]['turn'][0]]) # data that is chainge by training configuration if args['turn_based_training'] and not args['observation']: @@ -77,10 +79,9 @@ def replace_none(a, b): obs = bimap_r(obs_zeros, obs, lambda _, o: np.array(o)) # datum that is not changed by training configuration - v = np.array([[replace_none(m['value'][player], [0]) for player in players] for m in moments], dtype=np.float32).reshape(len(moments), len(players), -1) - rew = np.array([[replace_none(m['reward'][player], [0]) for player in players] for m in moments], dtype=np.float32).reshape(len(moments), len(players), -1) - ret = np.array([[replace_none(m['return'][player], [0]) for player in players] for m in moments], dtype=np.float32).reshape(len(moments), len(players), -1) - oc = np.array([ep['outcome'][player] for player in players], dtype=np.float32).reshape(1, len(players), -1) + v = np.array([[replace_none(m['value'][player], v_zeros) for player in players] for m in moments], dtype=np.float32).reshape(len(moments), len(players), -1) + rew = np.array([[replace_none(m['reward'][player], v_zeros) for player in players] for m in moments], dtype=np.float32).reshape(len(moments), len(players), -1) + ret = np.array([[replace_none(m['return'][player], v_zeros) for player in players] for m in moments], dtype=np.float32).reshape(len(moments), len(players), -1) emask = np.ones((len(moments), 1, 1), dtype=np.float32) # episode mask tmask = np.array([[[m['policy'][player] is not None] for player in players] for m in moments], dtype=np.float32) @@ -93,7 +94,7 @@ def replace_none(a, b): pad_len = args['forward_steps'] - len(tmask) obs = map_r(obs, lambda o: np.pad(o, [(0, pad_len)] + [(0, 0)] * (len(o.shape) - 1), 'constant', constant_values=0)) p = np.pad(p, [(0, pad_len), (0, 0), (0, 0)], 'constant', constant_values=0) - v = np.concatenate([v, np.tile(oc, [pad_len, 1, 1])]) + v = np.pad(v, [(0, pad_len), (0, 0), (0, 0)], 'constant', constant_values=0) act = np.pad(act, [(0, pad_len), (0, 0), (0, 0)], 'constant', constant_values=0) rew = np.pad(rew, [(0, pad_len), (0, 0), (0, 0)], 'constant', constant_values=0) ret = np.pad(ret, [(0, pad_len), (0, 0), (0, 0)], 'constant', constant_values=0) @@ -104,31 +105,32 @@ def replace_none(a, b): progress = np.pad(progress, [(0, pad_len), (0, 0)], 'constant', constant_values=1) obss.append(obs) - datum.append((p, v, act, oc, rew, ret, emask, tmask, omask, amask, progress)) + datum.append((p, v, act, rew, ret, emask, tmask, omask, amask, progress)) - p, v, act, oc, rew, ret, emask, tmask, omask, amask, progress = zip(*datum) + p, v, act, rew, ret, emask, tmask, omask, amask, progress = zip(*datum) obs = to_torch(bimap_r(obs_zeros, rotate(obss), lambda _, o: np.array(o))) p = to_torch(np.array(p)) v = to_torch(np.array(v)) act = to_torch(np.array(act)) - oc = to_torch(np.array(oc)) rew = to_torch(np.array(rew)) ret = to_torch(np.array(ret)) emask = to_torch(np.array(emask)) tmask = to_torch(np.array(tmask)) omask = to_torch(np.array(omask)) amask = to_torch(np.array(amask)) + gamma = to_torch(np.array(args['gamma'])) progress = to_torch(np.array(progress)) return { 'observation': obs, 'policy': p, 'value': v, - 'action': act, 'outcome': oc, + 'action': act, 'reward': rew, 'return': ret, 'episode_mask': emask, 'turn_mask': tmask, 'observation_mask': omask, 'action_mask': amask, + 'gamma': gamma, 'progress': progress, } @@ -201,14 +203,12 @@ def compose_losses(outputs, log_selected_policies, total_advantages, targets, ba losses['p'] = (-log_selected_policies * turn_advantages).sum() if 'value' in outputs: - losses['v'] = ((outputs['value'] - targets['value']) ** 2).mul(omasks).sum() / 2 - if 'return' in outputs: - losses['r'] = F.smooth_l1_loss(outputs['return'], targets['return'], reduction='none').mul(omasks).sum() + losses['v'] = F.smooth_l1_loss(outputs['value'], targets['value'], reduction='none').mul(omasks).sum() entropy = dist.Categorical(logits=outputs['policy']).entropy().mul(tmasks.sum(-1)) losses['ent'] = entropy.sum() - base_loss = losses['p'] + losses.get('v', 0) + losses.get('r', 0) + base_loss = losses['p'] + losses.get('r', 0) entropy_loss = entropy.mul(1 - batch['progress'] * (1 - args['entropy_regularization_decay'])).sum() * -args['entropy_regularization'] losses['total'] = base_loss + entropy_loss @@ -236,21 +236,18 @@ def compute_loss(batch, model, hidden, args): if args['turn_based_training'] and values_nograd.size(2) == 2: # two player zerosum game values_nograd_opponent = -torch.stack([values_nograd[:, :, 1], values_nograd[:, :, 0]], dim=2) values_nograd = (values_nograd + values_nograd_opponent) / (batch['observation_mask'].sum(dim=2, keepdim=True) + 1e-8) - outputs_nograd['value'] = values_nograd * emasks + batch['outcome'] * (1 - emasks) + + outputs_nograd['value'] = values_nograd * emasks + batch['return'] * (1 - emasks) # compute targets and advantage targets = {} advantages = {} - value_args = outputs_nograd.get('value', None), batch['outcome'], None, args['lambda'], 1, clipped_rhos, cs - return_args = outputs_nograd.get('return', None), batch['return'], batch['reward'], args['lambda'], args['gamma'], clipped_rhos, cs - + value_args = outputs_nograd.get('value', None), batch['return'], batch['reward'], args['lambda'], batch['gamma'], clipped_rhos, cs targets['value'], advantages['value'] = compute_target(args['value_target'], *value_args) - targets['return'], advantages['return'] = compute_target(args['value_target'], *return_args) if args['policy_target'] != args['value_target']: _, advantages['value'] = compute_target(args['policy_target'], *value_args) - _, advantages['return'] = compute_target(args['policy_target'], *return_args) # compute policy advantage total_advantages = clipped_rhos * sum(advantages.values()) @@ -294,7 +291,7 @@ def select_episode(self): st_block = st // self.args['compress_steps'] ed_block = (ed - 1) // self.args['compress_steps'] + 1 ep_minimum = { - 'args': ep['args'], 'outcome': ep['outcome'], + 'args': ep['args'], 'moment': ep['moment'][st_block:ed_block], 'base': st_block * self.args['compress_steps'], 'start': st, 'end': ed, 'total': ep['steps'], @@ -480,9 +477,9 @@ def feed_episodes(self, episodes): continue for p in episode['args']['player']: model_id = episode['args']['model_id'][p] - outcome = episode['outcome'][p] + rewards = episode['total_reward'][p] n, r, r2 = self.generation_results.get(model_id, (0, 0, 0)) - self.generation_results[model_id] = n + 1, r + outcome, r2 + outcome ** 2 + self.generation_results[model_id] = n + 1, r + rewards, r2 + rewards ** 2 # store generated episodes self.trainer.episodes.extend([e for e in episodes if e is not None]) @@ -505,15 +502,15 @@ def feed_results(self, results): continue for p in result['args']['player']: model_id = result['args']['model_id'][p] - res = result['result'][p] + rewards = result['total_reward'][p] n, r, r2 = self.results.get(model_id, (0, 0, 0)) - self.results[model_id] = n + 1, r + res, r2 + res ** 2 + self.results[model_id] = n + 1, r + rewards, r2 + rewards ** 2 if model_id not in self.results_per_opponent: self.results_per_opponent[model_id] = {} opponent = result['opponent'] n, r, r2 = self.results_per_opponent[model_id].get(opponent, (0, 0, 0)) - self.results_per_opponent[model_id][opponent] = n + 1, r + res, r2 + res ** 2 + self.results_per_opponent[model_id][opponent] = n + 1, r + rewards, r2 + rewards ** 2 def update(self): # call update to every component @@ -526,8 +523,9 @@ def update(self): def output_wp(name, results): n, r, r2 = results mean = r / (n + 1e-6) + std = (r2 / (n + 1e-6) - mean ** 2) ** 0.5 name_tag = ' (%s)' % name if name != '' else '' - print('win rate%s = %.3f (%.1f / %d)' % (name_tag, (mean + 1) / 2, (r + n) / 2, n)) + print('eval reward%s = %.3f +- %.3f (/ %d)' % (name_tag, mean[0], std[0], n)) if len(self.args.get('eval', {}).get('opponent', [])) <= 1: output_wp('', self.results[self.model_epoch]) @@ -537,12 +535,12 @@ def output_wp(name, results): output_wp(key, self.results_per_opponent[self.model_epoch][key]) if self.model_epoch not in self.generation_results: - print('generation stats = Nan (0)') + print('gen reward = Nan (0)') else: n, r, r2 = self.generation_results[self.model_epoch] mean = r / (n + 1e-6) std = (r2 / (n + 1e-6) - mean ** 2) ** 0.5 - print('generation stats = %.3f +- %.3f' % (mean, std)) + print('gen reward = %.3f +- %.3f (/ %d)' % (mean[0], std[0], n)) model, steps = self.trainer.update() if model is None: From 4a2ab4baa561681fcf3ba162a03ee5b7cd971b96 Mon Sep 17 00:00:00 2001 From: YuriCat Date: Wed, 24 Nov 2021 00:10:13 +0900 Subject: [PATCH 2/8] fix: train value loss --- handyrl/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/handyrl/train.py b/handyrl/train.py index 1fbcec64..5597b65a 100755 --- a/handyrl/train.py +++ b/handyrl/train.py @@ -208,7 +208,7 @@ def compose_losses(outputs, log_selected_policies, total_advantages, targets, ba entropy = dist.Categorical(logits=outputs['policy']).entropy().mul(tmasks.sum(-1)) losses['ent'] = entropy.sum() - base_loss = losses['p'] + losses.get('r', 0) + base_loss = losses['p'] + losses.get('v', 0) entropy_loss = entropy.mul(1 - batch['progress'] * (1 - args['entropy_regularization_decay'])).sum() * -args['entropy_regularization'] losses['total'] = base_loss + entropy_loss From 380d4e2b7ec8465c61b3f3aaa0f7957cf1b69019 Mon Sep 17 00:00:00 2001 From: YuriCat Date: Sun, 13 Feb 2022 17:26:10 +0900 Subject: [PATCH 3/8] fix: environment sample tictactoe --- handyrl/envs/tictactoe.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/handyrl/envs/tictactoe.py b/handyrl/envs/tictactoe.py index 915a4d1d..3c6861ba 100755 --- a/handyrl/envs/tictactoe.py +++ b/handyrl/envs/tictactoe.py @@ -170,11 +170,14 @@ def observation(self, player=None): if __name__ == '__main__': e = Environment() for _ in range(100): + total_rewards = {} e.reset() while not e.terminal(): print(e) actions = e.legal_actions() print([e.action2str(a) for a in actions]) e.play(random.choice(actions)) + for p, r in e.reward().items(): + total_rewards[p] = total_rewards.get(p, 0) + r print(e) - print(e.outcome()) + print(total_rewards) From 074cdf071362f3505778a7f990f7fd093f2d94f7 Mon Sep 17 00:00:00 2001 From: YuriCat Date: Sun, 13 Feb 2022 17:28:14 +0900 Subject: [PATCH 4/8] fix: remove outcome() from test --- tests/test_environment.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/test_environment.py b/tests/test_environment.py index 14ff9f64..9e90ad4b 100644 --- a/tests/test_environment.py +++ b/tests/test_environment.py @@ -47,7 +47,6 @@ def test_environment_local(environment_path, env): actions[player] = random.choice(e.legal_actions(player)) e.step(actions) e.reward() - e.outcome() no_error_loop = True except Exception: traceback.print_exc() @@ -81,7 +80,6 @@ def test_environment_network(environment_path, env): info = e.diff_info(p) e_.update(info, False) e.reward() - e.outcome() no_error_loop = True except Exception: traceback.print_exc() From b3ae663f5396872c581bb62db867776eeaabfaed Mon Sep 17 00:00:00 2001 From: YuriCat Date: Sun, 13 Feb 2022 17:34:15 +0900 Subject: [PATCH 5/8] fix: parallel tic-tac-toe environment sample for md-reward style --- handyrl/envs/parallel_tictactoe.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/handyrl/envs/parallel_tictactoe.py b/handyrl/envs/parallel_tictactoe.py index b70e3aea..60a5941a 100755 --- a/handyrl/envs/parallel_tictactoe.py +++ b/handyrl/envs/parallel_tictactoe.py @@ -61,6 +61,7 @@ def turns(self): if __name__ == '__main__': e = Environment() for _ in range(100): + total_rewards = {} e.reset() while not e.terminal(): print(e) @@ -70,5 +71,7 @@ def turns(self): print([e.action2str(a) for a in actions]) action_map[p] = random.choice(actions) e.step(action_map) + for p, r in e.reward().items(): + total_rewards[p] = total_rewards.get(p, 0) + r print(e) - print(e.outcome()) + print(total_rewards) From c4500b1ec95dea8104866a692c3ae4045d91b7b2 Mon Sep 17 00:00:00 2001 From: YuriCat Date: Sun, 13 Feb 2022 17:39:28 +0900 Subject: [PATCH 6/8] chore: stop using calling terminal_rewards 'outcome' in Geister environment --- handyrl/envs/geister.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/handyrl/envs/geister.py b/handyrl/envs/geister.py index 8a83975c..54b322c6 100755 --- a/handyrl/envs/geister.py +++ b/handyrl/envs/geister.py @@ -436,13 +436,13 @@ def reward(self): turn_rewards = [-0.01, -0.01] # terminal reward - outcomes = [0, 0] + terminal_rewards = [0, 0] if self.win_color == self.BLACK: - outcomes = [1, -1] + terminal_rewards = [1, -1] elif self.win_color == self.WHITE: - outcomes = [-1, 1] + terminal_rewards = [-1, 1] - return {p: [outcomes[idx], turn_rewards[idx]] for idx, p in enumerate(self.players())} + return {p: [terminal_rewards[idx], turn_rewards[idx]] for idx, p in enumerate(self.players())} def legal(self, action): if self.turn_count < 0: From 3866c7433e52ec19c0c315c179609e0360f3eeb2 Mon Sep 17 00:00:00 2001 From: YuriCat Date: Sun, 13 Feb 2022 17:42:48 +0900 Subject: [PATCH 7/8] fix: geister environment sample for md-reward setting --- handyrl/envs/geister.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/handyrl/envs/geister.py b/handyrl/envs/geister.py index 54b322c6..aee01f3b 100755 --- a/handyrl/envs/geister.py +++ b/handyrl/envs/geister.py @@ -542,11 +542,14 @@ def net(self): if __name__ == '__main__': e = Environment() for _ in range(100): + total_rewards = {} e.reset() while not e.terminal(): print(e) actions = e.legal_actions() print([e.action2str(a, e.turn()) for a in actions]) e.play(random.choice(actions)) + for p, r in e.reward().items(): + total_rewards[p] = total_rewards.get(p, 0) + np.array(r) print(e) - print(e.outcome()) + print(total_rewards) From 461993db71526a7bf262c9e1c68a71bbbda77fef Mon Sep 17 00:00:00 2001 From: YuriCat Date: Sun, 13 Feb 2022 18:01:09 +0900 Subject: [PATCH 8/8] fix: gamma outside of batches --- handyrl/train.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/handyrl/train.py b/handyrl/train.py index 40523cc0..4a08f1ca 100755 --- a/handyrl/train.py +++ b/handyrl/train.py @@ -110,7 +110,6 @@ def replace_none(a, b): obs = to_torch(bimap_r(obs_zeros, rotate(obss), lambda _, o: np.array(o))) prob, v, act, rew, ret, emask, tmask, omask, amask, progress = [to_torch(np.array(val)) for val in zip(*datum)] - gamma = to_torch(np.array(args['gamma'])) return { 'observation': obs, @@ -120,7 +119,6 @@ def replace_none(a, b): 'episode_mask': emask, 'turn_mask': tmask, 'observation_mask': omask, 'action_mask': amask, - 'gamma': gamma, 'progress': progress, } @@ -247,7 +245,8 @@ def compute_loss(batch, model, hidden, args): targets = {} advantages = {} - value_args = outputs_nograd.get('value', None), batch['return'], batch['reward'], args['lambda'], batch['gamma'], clipped_rhos, cs + gamma = torch.from_numpy(np.array(args['gamma'], dtype=np.float32)).to(batch['reward'].device) + value_args = outputs_nograd.get('value', None), batch['return'], batch['reward'], args['lambda'], gamma, clipped_rhos, cs targets['value'], advantages['value'] = compute_target(args['value_target'], *value_args) if args['policy_target'] != args['value_target']: