-
Notifications
You must be signed in to change notification settings - Fork 34
/
gobigger_env.py
54 lines (44 loc) · 1.68 KB
/
gobigger_env.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import os
import sys
import gym
import time
from gobigger.server import Server
from gobigger.render import EnvRender
import copy
class GoBiggerEnv(gym.Env):
def __init__(self, server_cfg=None, step_mul=2, **kwargs):
self.server_cfg = server_cfg
self.step_mul = step_mul
self.init_server()
def step(self, actions):
for i in range(self.step_mul):
if i==0:
done = self.server.step(actions=actions)
else:
done = self.server.step(actions=None)
obs_raw = self.server.obs()
global_state, player_states, info = obs_raw
obs = [global_state, player_states]
total_score = [global_state['leaderboard'][i] \
for i in range(len(global_state['leaderboard']))]
assert len(self.last_total_score) == len(total_score)
reward = [total_score[i] - self.last_total_score[i] for i in range(len(total_score))]
self.last_total_score = total_score
return obs, reward, done, info
def reset(self):
self.server.reset()
obs_raw = self.server.obs()
global_state, player_states, info = obs_raw
obs = [global_state, player_states]
self.last_total_score = [global_state['leaderboard'][i] \
for i in range(len(global_state['leaderboard']))]
return obs
def close(self):
self.server.close()
def seed(self, seed):
self.server.seed(seed)
def get_team_infos(self):
assert hasattr(self, 'server'), "Please call `reset()` first"
return self.server.get_team_infos()
def init_server(self):
self.server = Server(cfg=self.server_cfg)