Skip to content

Commit

Permalink
Merge pull request #345 from YuriCat/fix/reward_return_fillna
Browse files Browse the repository at this point in the history
fix: fill 0 for reward, return, value in make_batch()
  • Loading branch information
YuriCat committed Sep 22, 2023
2 parents 826776e + 98b595d commit bf8587d
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions handyrl/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,9 @@ def replace_none(a, b):
obs = bimap_r(obs_zeros, obs, lambda _, o: np.array(o))

# datum that is not changed by training configuration
v = np.array([[replace_none(m['value'][player], [0]) for player in players] for m in moments], dtype=np.float32).reshape(len(moments), len(players), -1)
rew = np.array([[replace_none(m['reward'][player], [0]) for player in players] for m in moments], dtype=np.float32).reshape(len(moments), len(players), -1)
ret = np.array([[replace_none(m['return'][player], [0]) for player in players] for m in moments], dtype=np.float32).reshape(len(moments), len(players), -1)
v = np.array([[replace_none(m['value'][player], 0) for player in players] for m in moments], dtype=np.float32).reshape(len(moments), len(players), -1)
rew = np.array([[replace_none(m['reward'][player], 0) for player in players] for m in moments], dtype=np.float32).reshape(len(moments), len(players), -1)
ret = np.array([[replace_none(m['return'][player], 0) for player in players] for m in moments], dtype=np.float32).reshape(len(moments), len(players), -1)
oc = np.array([ep['outcome'][player] for player in players], dtype=np.float32).reshape(1, len(players), -1)

emask = np.ones((len(moments), 1, 1), dtype=np.float32) # episode mask
Expand Down

0 comments on commit bf8587d

Please sign in to comment.