-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
119 lines (94 loc) · 3.82 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import gym
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils import tensorboard
import random
import wimblepong
from policy import Policy
from agent import Agent
# Learning parameters
GAMMA = 0.99
EPS_PER_ITERATION = 10
EPOCHS = 5
MAX_TIMESTEPS = 500
LEARNING_RATE = 1e-4
def collect_data():
global global_n
state_history, action_history, action_prob_history, reward_history = [], [], [], []
for _ in range(EPS_PER_ITERATION):
obs, opp_obs = env.reset()
prev_obs, prev2_obs, prev3_obs = None, None, None
for t in range(MAX_TIMESTEPS):
state = policy.pre_process(obs, prev_obs, prev2_obs, prev3_obs)
prev_obs, prev2_obs, prev3_obs = obs, prev_obs, prev2_obs
with torch.no_grad():
action, action_prob = policy.get_action(state)
opp_act = opponent.get_action()
# opp_act = opponent.get_action(opp_obs)
(obs, opp_obs), (reward, _), done, _ = env.step((action, opp_act))
state_history.append(state)
action_history.append(action)
action_prob_history.append(action_prob)
reward_history.append(reward)
if done:
break
writer.add_scalar('reward/episode_reward', sum(reward_history[-t:]), global_n)
writer.add_scalar('reward/episode_length', t, global_n)
global_n += 1
return [state_history, action_history, action_prob_history, reward_history]
def compute_advantages(reward_history):
R = 0
discounted_rewards = torch.zeros(len(reward_history))
for i, r in enumerate(reward_history[::-1]):
# scored a point so reset the cumulation
if r != 0:
R = 0
R = r + GAMMA * R
discounted_rewards[-i] = R
discounted_rewards -= discounted_rewards.mean()
discounted_rewards /= discounted_rewards.std() + 1.0e-10
return discounted_rewards
def update_policy(data):
state_history = data.pop(0)
action_history = data.pop(0)
action_prob_history = data.pop(0)
advantage_history = data.pop(0)
# update policy
for _ in range(EPOCHS):
n_batch = len(action_history)
idxs = random.sample(range(len(action_history)), n_batch)
state_batch = torch.cat([state_history[idx] for idx in idxs], 0)
action_batch = torch.LongTensor([action_history[idx] for idx in idxs])
action_prob_batch = torch.FloatTensor([action_prob_history[idx] for idx in idxs])
advantage_batch = torch.FloatTensor([advantage_history[idx] for idx in idxs])
opt.zero_grad()
loss = policy.get_loss(state_batch, action_batch, action_prob_batch, advantage_batch)
writer.add_scalar('loss/loss', -loss, global_n)
loss.backward()
opt.step()
if __name__ == "__main__":
writer = tensorboard.SummaryWriter()
global_n = 0
# env = gym.make("CartPole-v0")
env = gym.make("WimblepongVisualMultiplayer-v0")
policy = Policy()
opponent = wimblepong.SimpleAi(env, 2)
# opponent = Agent(env)
# opponent.load_model('results/model.mdl')
opt = torch.optim.Adam(policy.parameters(), lr=LEARNING_RATE)
policy.load_state_dict(torch.load(f'results/model_best_400.mdl'))
i = 0
while True:
data = collect_data()
reward_history = data.pop(3)
advantage_history = compute_advantages(reward_history)
data.append(advantage_history)
writer.add_scalar('reward/average_reward', sum(reward_history) / EPS_PER_ITERATION, i)
writer.add_scalar('reward/average_ep_len', len(reward_history) / EPS_PER_ITERATION, i)
writer.add_scalar('loss/advantage', advantage_history.mean(), global_n)
update_policy(data)
if i % 100 == 0:
torch.save(policy.state_dict(), f'results/model_simple_{i}.mdl')
i += 1
env.close()