-
Notifications
You must be signed in to change notification settings - Fork 2
/
visualizer.py
146 lines (118 loc) · 3.58 KB
/
visualizer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
"""
Training progress visualizer
"""
import time
from abc import ABC, abstractmethod
def get_training_visualizer(visualizer_type):
"""
A factory wrapper to generate training progress
visualizers.
:param visualizer_type: (str) the type of the visualizer to create
:return: TrainingVisualizer
"""
if visualizer_type == 'none':
return DummyTrainingVisualizer()
if visualizer_type == 'streamlit':
return StreamlitTrainingVisualizer()
raise RuntimeError('Visualizer of type {0} not found'.format(type))
class TrainingVisualizer(ABC):
"""
Base training visualizer
"""
@abstractmethod
def log_loss(self, loss):
"""
Logs a loss history to the desired visualization
:param loss: a list of loss history
:return: None
"""
return
@abstractmethod
def log_reward(self, reward):
"""
Logs a reward history to the desired visualization
:param reward: a list of reward history
:return: None
"""
return
@abstractmethod
def get_ui_feedback(self):
"""
Gets the configuration from UI
:return: None
"""
return
class DummyTrainingVisualizer(TrainingVisualizer):
"""
Used when no logging is required
"""
def log_loss(self, loss):
"""
A dummy logger that does nothing
:param loss: a list of loss history
:return: None
"""
return None
def log_reward(self, reward):
"""
A dummy logger that does nothing
:param reward: a list of reward history
:return: None
"""
return None
def get_ui_feedback(self):
"""
A dummy logger that does nothing
:return: None
"""
return None
class StreamlitTrainingVisualizer(TrainingVisualizer):
"""
Used when runs with stream lit
"""
def __init__(self):
"""
Initializes the streamlit dashboard with elements
"""
print('Initializing stream lit visualizer')
# pylint: disable=import-outside-toplevel)
import streamlit as st
self.loss_history = []
self.reward_history = []
# pylint: disable=no-value-for-parameter
st.sidebar.text('Cart Pole DQN Training (TensorFlow 2.0)')
start_bar = st.sidebar.progress(0)
for percent_complete in range(100):
time.sleep(0.02)
start_bar.progress(percent_complete + 1)
self.update_freq = st.sidebar.slider('Update frequency', 0, 500, 120)
self.epsilon = float(st.sidebar.slider('Epsilon', 0, 100, 10)) / 100.0
self.eval_eps = st.sidebar.slider('Eval episodes', 0, 100, 10)
st.text('Training loss history')
self.loss_chart = st.line_chart(self.loss_history)
st.text('Average reward history')
self.reward_chart = st.line_chart(self.reward_history)
def log_loss(self, loss):
"""
Adds a loss history to the chart
:param loss: a list of loss history
:return: None
"""
self.loss_chart.add_rows(loss)
def log_reward(self, reward):
"""
Adds a reward history to the chart
:param reward: a list of reward history
:return:
"""
self.reward_chart.add_rows(reward)
def get_ui_feedback(self):
"""
Gets the user defined config from the UI
:return: config
"""
return {
'update_freq': self.update_freq,
'epsilon': self.epsilon,
'eval_eps': self.eval_eps,
}