-
Notifications
You must be signed in to change notification settings - Fork 1.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Copy Beholder into TensorBoard repo (#613)
- Loading branch information
Showing
21 changed files
with
1,921 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
# Description: | ||
# TensorBoard plugin for tensors and tensor variance for an entire graph. | ||
|
||
package(default_visibility = ["//visibility:public"]) | ||
licenses(["notice"]) # Apache 2.0 | ||
exports_files(["LICENSE"]) | ||
|
||
py_library( | ||
name = "file_system_tools", | ||
data = ["resources"], | ||
srcs = ["file_system_tools.py"], | ||
srcs_version = "PY2AND3", | ||
) | ||
|
||
py_library( | ||
name = "im_util", | ||
data = ["resources"], | ||
srcs = ["im_util.py"], | ||
deps = [":file_system_tools"], | ||
srcs_version = "PY2AND3", | ||
) | ||
|
||
py_library( | ||
name = "visualizer", | ||
srcs = ["visualizer.py", "shared_config.py"], | ||
deps = [ | ||
":im_util", | ||
":file_system_tools", | ||
], | ||
srcs_version = "PY2AND3", | ||
) | ||
|
||
py_library( | ||
name = "video_writing", | ||
srcs = ["video_writing.py"], | ||
deps = [ | ||
":im_util" | ||
], | ||
srcs_version = "PY2AND3" | ||
) | ||
|
||
py_library( | ||
name = "beholder", | ||
data = ["resources"], | ||
srcs = ["beholder.py", "shared_config.py"], | ||
deps = [ | ||
":im_util", | ||
":visualizer", | ||
":file_system_tools", | ||
":video_writing", | ||
"//tensorboard/backend/event_processing:plugin_asset_util", | ||
], | ||
srcs_version = "PY2AND3", | ||
) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,178 @@ | ||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import os | ||
import time | ||
|
||
import tensorflow as tf | ||
|
||
from tensorboard.plugins.beholder import im_util | ||
from tensorboard.plugins.beholder.file_system_tools import read_pickle,\ | ||
write_pickle, write_file | ||
from tensorboard.plugins.beholder.shared_config import PLUGIN_NAME, TAG_NAME,\ | ||
SUMMARY_FILENAME, DEFAULT_CONFIG, CONFIG_FILENAME | ||
from tensorboard.plugins.beholder import video_writing | ||
from tensorboard.plugins.beholder.visualizer import Visualizer | ||
|
||
class Beholder(object): | ||
|
||
def __init__(self, session, logdir): | ||
self.video_writer = None | ||
|
||
self.PLUGIN_LOGDIR = logdir + '/plugins/' + PLUGIN_NAME | ||
self.SESSION = session | ||
|
||
self.frame_placeholder = None | ||
self.summary_op = None | ||
|
||
self.last_image_shape = [] | ||
self.last_update_time = time.time() | ||
self.config_last_modified_time = -1 | ||
self.previous_config = dict(DEFAULT_CONFIG) | ||
|
||
if not tf.gfile.Exists(self.PLUGIN_LOGDIR + '/config.pkl'): | ||
tf.gfile.MakeDirs(self.PLUGIN_LOGDIR) | ||
write_pickle(DEFAULT_CONFIG, '{}/{}'.format(self.PLUGIN_LOGDIR, | ||
CONFIG_FILENAME)) | ||
|
||
self.visualizer = Visualizer(self.PLUGIN_LOGDIR) | ||
|
||
|
||
def _get_config(self): | ||
'''Reads the config file from disk or creates a new one.''' | ||
filename = '{}/{}'.format(self.PLUGIN_LOGDIR, CONFIG_FILENAME) | ||
modified_time = os.path.getmtime(filename) | ||
|
||
if modified_time != self.config_last_modified_time: | ||
config = read_pickle(filename, default=self.previous_config) | ||
self.previous_config = config | ||
else: | ||
config = self.previous_config | ||
|
||
self.config_last_modified_time = modified_time | ||
return config | ||
|
||
|
||
def _write_summary(self, frame): | ||
'''Writes the frame to disk as a tensor summary.''' | ||
summary = self.SESSION.run(self.summary_op, feed_dict={ | ||
self.frame_placeholder: frame | ||
}) | ||
path = '{}/{}'.format(self.PLUGIN_LOGDIR, SUMMARY_FILENAME) | ||
write_file(summary, path) | ||
|
||
|
||
|
||
def _get_final_image(self, config, arrays=None, frame=None): | ||
if config['values'] == 'frames': | ||
if frame is None: | ||
final_image = im_util.get_image_relative_to_script('frame-missing.png') | ||
else: | ||
frame = frame() if callable(frame) else frame | ||
final_image = im_util.scale_image_for_display(frame) | ||
|
||
elif config['values'] == 'arrays': | ||
if arrays is None: | ||
final_image = im_util.get_image_relative_to_script('arrays-missing.png') | ||
# TODO: hack to clear the info. Should be cleaner. | ||
self.visualizer._save_section_info([], []) | ||
else: | ||
final_image = self.visualizer.build_frame(arrays) | ||
|
||
elif config['values'] == 'trainable_variables': | ||
arrays = [self.SESSION.run(x) for x in tf.trainable_variables()] | ||
final_image = self.visualizer.build_frame(arrays) | ||
|
||
return final_image | ||
|
||
|
||
def _enough_time_has_passed(self, FPS): | ||
'''For limiting how often frames are computed.''' | ||
if FPS == 0: | ||
return False | ||
else: | ||
earliest_time = self.last_update_time + (1.0 / FPS) | ||
return time.time() >= earliest_time | ||
|
||
|
||
def _update_frame(self, arrays, frame, config): | ||
final_image = self._get_final_image(config, arrays, frame) | ||
|
||
if self.summary_op is None or self.last_image_shape != final_image.shape: | ||
self.frame_placeholder = tf.placeholder(tf.uint8, final_image.shape) | ||
self.summary_op = tf.summary.tensor_summary(TAG_NAME, | ||
self.frame_placeholder) | ||
self._write_summary(final_image) | ||
self.last_image_shape = final_image.shape | ||
|
||
return final_image | ||
|
||
|
||
def _update_recording(self, frame, config): | ||
'''Adds a frame to the video using ffmpeg if possible. If not, writes | ||
individual frames as png files in a directory. | ||
''' | ||
# pylint: disable=redefined-variable-type | ||
is_recording = config['is_recording'] | ||
filename = self.PLUGIN_LOGDIR + '/video-{}.mp4'.format(time.time()) | ||
|
||
if is_recording: | ||
if self.video_writer is None or frame.shape != self.video_writer.size: | ||
try: | ||
self.video_writer = video_writing.FFMPEG_VideoWriter(filename, | ||
frame.shape, | ||
15) | ||
except OSError: | ||
message = ('Either ffmpeg is not installed, or something else went ' | ||
'wrong. Saving individual frames to disk instead.') | ||
print(message) | ||
self.video_writer = video_writing.PNGWriter(self.PLUGIN_LOGDIR, | ||
frame.shape) | ||
self.video_writer.write_frame(frame) | ||
elif not is_recording and self.video_writer is not None: | ||
self.video_writer.close() | ||
self.video_writer = None | ||
|
||
|
||
# TODO: blanket try and except for production? I don't someone's script to die | ||
# after weeks of running because of a visualization. | ||
def update(self, arrays=None, frame=None): | ||
'''Creates a frame and writes it to disk. | ||
Args: | ||
arrays: a list of np arrays. Use the "custom" option in the client. | ||
frame: a 2D np array. This way the plugin can be used for video of any | ||
kind, not just the visualization that comes with the plugin. | ||
frame can also be a function, which only is evaluated when the | ||
"frame" option is selected by the client. | ||
''' | ||
new_config = self._get_config() | ||
|
||
if self._enough_time_has_passed(self.previous_config['FPS']): | ||
self.visualizer.update(new_config) | ||
self.last_update_time = time.time() | ||
final_image = self._update_frame(arrays, frame, new_config) | ||
self._update_recording(final_image, new_config) | ||
|
||
|
||
############################################################################## | ||
|
||
@staticmethod | ||
def gradient_helper(optimizer, loss, var_list=None): | ||
'''A helper to get the gradients out at each step. | ||
Args: | ||
optimizer: the optimizer op. | ||
loss: the op that computes your loss value. | ||
Returns: the gradient tensors and the train_step op. | ||
''' | ||
if var_list is None: | ||
var_list = tf.trainable_variables() | ||
|
||
grads_and_vars = optimizer.compute_gradients(loss, var_list=var_list) | ||
grads = [pair[0] for pair in grads_and_vars] | ||
|
||
return grads, optimizer.apply_gradients(grads_and_vars) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
package(default_visibility = ["//visibility:public"]) | ||
|
||
load("@org_tensorflow_tensorboard//tensorboard/defs:web.bzl", "ts_web_library") | ||
|
||
ts_web_library( | ||
name = "dashboard", | ||
srcs = [ | ||
"beholder-dashboard.html", | ||
"beholder-video.html", | ||
"beholder-info.html", | ||
], | ||
path = "/beholder", | ||
deps = [ | ||
"@org_tensorflow_tensorboard//tensorboard/components/tf_backend", | ||
"@org_tensorflow_tensorboard//tensorboard/components/tf_card_heading", | ||
"@org_tensorflow_tensorboard//tensorboard/components/tf_categorization_utils", | ||
"@org_tensorflow_tensorboard//tensorboard/components/tf_color_scale", | ||
"@org_tensorflow_tensorboard//tensorboard/components/tf_dashboard_common", | ||
"@org_tensorflow_tensorboard//tensorboard/components/tf_imports:d3", | ||
"@org_tensorflow_tensorboard//tensorboard/components/tf_imports:lodash", | ||
"@org_tensorflow_tensorboard//tensorboard/components/tf_imports:polymer", | ||
"@org_tensorflow_tensorboard//tensorboard/components/tf_runs_selector", | ||
"@org_polymer_paper_radio_group", | ||
"@org_polymer_paper_button", | ||
"@org_polymer_paper_dialog", | ||
"@org_polymer_paper_icon_button", | ||
"@org_polymer_paper_slider", | ||
"@org_polymer_paper_spinner", | ||
"@org_polymer_paper_tooltip", | ||
], | ||
) |
Oops, something went wrong.