diff --git a/tensorboard/plugins/beholder/BUILD b/tensorboard/plugins/beholder/BUILD new file mode 100644 index 00000000000..43fe5c8fe40 --- /dev/null +++ b/tensorboard/plugins/beholder/BUILD @@ -0,0 +1,54 @@ +# Description: +# TensorBoard plugin for tensors and tensor variance for an entire graph. + +package(default_visibility = ["//visibility:public"]) +licenses(["notice"]) # Apache 2.0 +exports_files(["LICENSE"]) + +py_library( + name = "file_system_tools", + data = ["resources"], + srcs = ["file_system_tools.py"], + srcs_version = "PY2AND3", +) + +py_library( + name = "im_util", + data = ["resources"], + srcs = ["im_util.py"], + deps = [":file_system_tools"], + srcs_version = "PY2AND3", +) + +py_library( + name = "visualizer", + srcs = ["visualizer.py", "shared_config.py"], + deps = [ + ":im_util", + ":file_system_tools", + ], + srcs_version = "PY2AND3", +) + +py_library( + name = "video_writing", + srcs = ["video_writing.py"], + deps = [ + ":im_util" + ], + srcs_version = "PY2AND3" +) + +py_library( + name = "beholder", + data = ["resources"], + srcs = ["beholder.py", "shared_config.py"], + deps = [ + ":im_util", + ":visualizer", + ":file_system_tools", + ":video_writing", + "//tensorboard/backend/event_processing:plugin_asset_util", + ], + srcs_version = "PY2AND3", +) diff --git a/tensorboard/plugins/beholder/__init__.py b/tensorboard/plugins/beholder/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tensorboard/plugins/beholder/beholder.py b/tensorboard/plugins/beholder/beholder.py new file mode 100644 index 00000000000..88a3eec3bad --- /dev/null +++ b/tensorboard/plugins/beholder/beholder.py @@ -0,0 +1,178 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import time + +import tensorflow as tf + +from tensorboard.plugins.beholder import im_util +from tensorboard.plugins.beholder.file_system_tools import read_pickle,\ + write_pickle, write_file +from tensorboard.plugins.beholder.shared_config import PLUGIN_NAME, TAG_NAME,\ + SUMMARY_FILENAME, DEFAULT_CONFIG, CONFIG_FILENAME +from tensorboard.plugins.beholder import video_writing +from tensorboard.plugins.beholder.visualizer import Visualizer + +class Beholder(object): + + def __init__(self, session, logdir): + self.video_writer = None + + self.PLUGIN_LOGDIR = logdir + '/plugins/' + PLUGIN_NAME + self.SESSION = session + + self.frame_placeholder = None + self.summary_op = None + + self.last_image_shape = [] + self.last_update_time = time.time() + self.config_last_modified_time = -1 + self.previous_config = dict(DEFAULT_CONFIG) + + if not tf.gfile.Exists(self.PLUGIN_LOGDIR + '/config.pkl'): + tf.gfile.MakeDirs(self.PLUGIN_LOGDIR) + write_pickle(DEFAULT_CONFIG, '{}/{}'.format(self.PLUGIN_LOGDIR, + CONFIG_FILENAME)) + + self.visualizer = Visualizer(self.PLUGIN_LOGDIR) + + + def _get_config(self): + '''Reads the config file from disk or creates a new one.''' + filename = '{}/{}'.format(self.PLUGIN_LOGDIR, CONFIG_FILENAME) + modified_time = os.path.getmtime(filename) + + if modified_time != self.config_last_modified_time: + config = read_pickle(filename, default=self.previous_config) + self.previous_config = config + else: + config = self.previous_config + + self.config_last_modified_time = modified_time + return config + + + def _write_summary(self, frame): + '''Writes the frame to disk as a tensor summary.''' + summary = self.SESSION.run(self.summary_op, feed_dict={ + self.frame_placeholder: frame + }) + path = '{}/{}'.format(self.PLUGIN_LOGDIR, SUMMARY_FILENAME) + write_file(summary, path) + + + + def _get_final_image(self, config, arrays=None, frame=None): + if config['values'] == 'frames': + if frame is None: + final_image = im_util.get_image_relative_to_script('frame-missing.png') + else: + frame = frame() if callable(frame) else frame + final_image = im_util.scale_image_for_display(frame) + + elif config['values'] == 'arrays': + if arrays is None: + final_image = im_util.get_image_relative_to_script('arrays-missing.png') + # TODO: hack to clear the info. Should be cleaner. + self.visualizer._save_section_info([], []) + else: + final_image = self.visualizer.build_frame(arrays) + + elif config['values'] == 'trainable_variables': + arrays = [self.SESSION.run(x) for x in tf.trainable_variables()] + final_image = self.visualizer.build_frame(arrays) + + return final_image + + + def _enough_time_has_passed(self, FPS): + '''For limiting how often frames are computed.''' + if FPS == 0: + return False + else: + earliest_time = self.last_update_time + (1.0 / FPS) + return time.time() >= earliest_time + + + def _update_frame(self, arrays, frame, config): + final_image = self._get_final_image(config, arrays, frame) + + if self.summary_op is None or self.last_image_shape != final_image.shape: + self.frame_placeholder = tf.placeholder(tf.uint8, final_image.shape) + self.summary_op = tf.summary.tensor_summary(TAG_NAME, + self.frame_placeholder) + self._write_summary(final_image) + self.last_image_shape = final_image.shape + + return final_image + + + def _update_recording(self, frame, config): + '''Adds a frame to the video using ffmpeg if possible. If not, writes + individual frames as png files in a directory. + ''' + # pylint: disable=redefined-variable-type + is_recording = config['is_recording'] + filename = self.PLUGIN_LOGDIR + '/video-{}.mp4'.format(time.time()) + + if is_recording: + if self.video_writer is None or frame.shape != self.video_writer.size: + try: + self.video_writer = video_writing.FFMPEG_VideoWriter(filename, + frame.shape, + 15) + except OSError: + message = ('Either ffmpeg is not installed, or something else went ' + 'wrong. Saving individual frames to disk instead.') + print(message) + self.video_writer = video_writing.PNGWriter(self.PLUGIN_LOGDIR, + frame.shape) + self.video_writer.write_frame(frame) + elif not is_recording and self.video_writer is not None: + self.video_writer.close() + self.video_writer = None + + + # TODO: blanket try and except for production? I don't someone's script to die + # after weeks of running because of a visualization. + def update(self, arrays=None, frame=None): + '''Creates a frame and writes it to disk. + + Args: + arrays: a list of np arrays. Use the "custom" option in the client. + frame: a 2D np array. This way the plugin can be used for video of any + kind, not just the visualization that comes with the plugin. + + frame can also be a function, which only is evaluated when the + "frame" option is selected by the client. + ''' + new_config = self._get_config() + + if self._enough_time_has_passed(self.previous_config['FPS']): + self.visualizer.update(new_config) + self.last_update_time = time.time() + final_image = self._update_frame(arrays, frame, new_config) + self._update_recording(final_image, new_config) + + + ############################################################################## + + @staticmethod + def gradient_helper(optimizer, loss, var_list=None): + '''A helper to get the gradients out at each step. + + Args: + optimizer: the optimizer op. + loss: the op that computes your loss value. + + Returns: the gradient tensors and the train_step op. + ''' + if var_list is None: + var_list = tf.trainable_variables() + + grads_and_vars = optimizer.compute_gradients(loss, var_list=var_list) + grads = [pair[0] for pair in grads_and_vars] + + return grads, optimizer.apply_gradients(grads_and_vars) diff --git a/tensorboard/plugins/beholder/client_side/BUILD b/tensorboard/plugins/beholder/client_side/BUILD new file mode 100644 index 00000000000..c03ce116832 --- /dev/null +++ b/tensorboard/plugins/beholder/client_side/BUILD @@ -0,0 +1,31 @@ +package(default_visibility = ["//visibility:public"]) + +load("@org_tensorflow_tensorboard//tensorboard/defs:web.bzl", "ts_web_library") + +ts_web_library( + name = "dashboard", + srcs = [ + "beholder-dashboard.html", + "beholder-video.html", + "beholder-info.html", + ], + path = "/beholder", + deps = [ + "@org_tensorflow_tensorboard//tensorboard/components/tf_backend", + "@org_tensorflow_tensorboard//tensorboard/components/tf_card_heading", + "@org_tensorflow_tensorboard//tensorboard/components/tf_categorization_utils", + "@org_tensorflow_tensorboard//tensorboard/components/tf_color_scale", + "@org_tensorflow_tensorboard//tensorboard/components/tf_dashboard_common", + "@org_tensorflow_tensorboard//tensorboard/components/tf_imports:d3", + "@org_tensorflow_tensorboard//tensorboard/components/tf_imports:lodash", + "@org_tensorflow_tensorboard//tensorboard/components/tf_imports:polymer", + "@org_tensorflow_tensorboard//tensorboard/components/tf_runs_selector", + "@org_polymer_paper_radio_group", + "@org_polymer_paper_button", + "@org_polymer_paper_dialog", + "@org_polymer_paper_icon_button", + "@org_polymer_paper_slider", + "@org_polymer_paper_spinner", + "@org_polymer_paper_tooltip", + ], +) diff --git a/tensorboard/plugins/beholder/client_side/beholder-dashboard.html b/tensorboard/plugins/beholder/client_side/beholder-dashboard.html new file mode 100644 index 00000000000..19dbb95a6cd --- /dev/null +++ b/tensorboard/plugins/beholder/client_side/beholder-dashboard.html @@ -0,0 +1,434 @@ + + + + + + + + + + + + + + + + + diff --git a/tensorboard/plugins/beholder/client_side/beholder-info.html b/tensorboard/plugins/beholder/client_side/beholder-info.html new file mode 100644 index 00000000000..64fa5b2b0eb --- /dev/null +++ b/tensorboard/plugins/beholder/client_side/beholder-info.html @@ -0,0 +1,89 @@ + + + + + + + + + + diff --git a/tensorboard/plugins/beholder/client_side/beholder-video.html b/tensorboard/plugins/beholder/client_side/beholder-video.html new file mode 100644 index 00000000000..2bf3a822387 --- /dev/null +++ b/tensorboard/plugins/beholder/client_side/beholder-video.html @@ -0,0 +1,90 @@ + + + + + + + + + + diff --git a/tensorboard/plugins/beholder/demos/demo/BUILD b/tensorboard/plugins/beholder/demos/demo/BUILD new file mode 100644 index 00000000000..022b7ba0f11 --- /dev/null +++ b/tensorboard/plugins/beholder/demos/demo/BUILD @@ -0,0 +1,6 @@ +py_binary( + name = "demo", + deps = ["//tensorboard/plugins/beholder"], + srcs = ["demo.py"], + srcs_version = "PY2AND3", +) diff --git a/tensorboard/plugins/beholder/demos/demo/__init__.py b/tensorboard/plugins/beholder/demos/demo/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tensorboard/plugins/beholder/demos/demo/demo.py b/tensorboard/plugins/beholder/demos/demo/demo.py new file mode 100644 index 00000000000..223e792c633 --- /dev/null +++ b/tensorboard/plugins/beholder/demos/demo/demo.py @@ -0,0 +1,225 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an 'AS IS' BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""A simple MNIST classifier which displays summaries in TensorBoard. + +This is an unimpressive MNIST model, but it is a good example of using +tf.name_scope to make a graph legible in the TensorBoard graph explorer, and of +naming summary tags so that they are grouped meaningfully in TensorBoard. + +It demonstrates the functionality of every TensorBoard dashboard. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import sys + +import numpy as np +import tensorflow as tf + +from tensorflow.examples.tutorials.mnist import input_data + +from beholder.beholder import Beholder + +FLAGS = None + +LOG_DIRECTORY = '/tmp/beholder-demo' + +def train(): + mnist = input_data.read_data_sets(FLAGS.data_dir, + one_hot=True, + fake_data=FLAGS.fake_data) + + sess = tf.InteractiveSession() + + with tf.name_scope('input'): + x = tf.placeholder(tf.float32, [None, 784], name='x-input') + y_ = tf.placeholder(tf.float32, [None, 10], name='y-input') + + with tf.name_scope('input_reshape'): + image_shaped_input = tf.reshape(x, [-1, 28, 28, 1]) + tf.summary.image('input', image_shaped_input, 10) + + def weight_variable(shape): + """Create a weight variable with appropriate initialization.""" + initial = tf.truncated_normal(shape, stddev=0.01) + return tf.Variable(initial) + + def bias_variable(shape): + """Create a bias variable with appropriate initialization.""" + initial = tf.constant(0.1, shape=shape) + return tf.Variable(initial) + + def variable_summaries(var): + """Attach a lot of summaries to a Tensor (for TensorBoard visualization).""" + with tf.name_scope('summaries'): + mean = tf.reduce_mean(var) + tf.summary.scalar('mean', mean) + with tf.name_scope('stddev'): + stddev = tf.sqrt(tf.reduce_mean(tf.square(var - mean))) + tf.summary.scalar('stddev', stddev) + tf.summary.scalar('max', tf.reduce_max(var)) + tf.summary.scalar('min', tf.reduce_min(var)) + tf.summary.histogram('histogram', var) + + def nn_layer(input_tensor, input_dim, output_dim, layer_name, act=tf.nn.relu): + """Reusable code for making a simple neural net layer. + + It does a matrix multiply, bias add, and then uses ReLU to nonlinearize. + It also sets up name scoping so that the resultant graph is easy to read, + and adds a number of summary ops. + """ + # Adding a name scope ensures logical grouping of the layers in the graph. + with tf.name_scope(layer_name): + # This Variable will hold the state of the weights for the layer + with tf.name_scope('weights'): + weights = weight_variable([input_dim, output_dim]) + variable_summaries(weights) + with tf.name_scope('biases'): + biases = bias_variable([output_dim]) + variable_summaries(biases) + with tf.name_scope('Wx_plus_b'): + preactivate = tf.matmul(input_tensor, weights) + biases + tf.summary.histogram('pre_activations', preactivate) + activations = act(preactivate, name='activation') + tf.summary.histogram('activations', activations) + return activations + + #conv1 + kernel = tf.Variable(tf.truncated_normal([5, 5, 1, 10], dtype=tf.float32, + stddev=1e-1), name='conv-weights') + conv = tf.nn.conv2d(image_shaped_input, kernel, [1, 1, 1, 1], padding='VALID') + biases = tf.Variable(tf.constant(0.0, shape=[kernel.get_shape().as_list()[-1]], dtype=tf.float32), + trainable=True, name='biases') + out = tf.nn.bias_add(conv, biases) + conv1 = tf.nn.relu(out, name='relu') + + #conv2 + kernel2 = tf.Variable(tf.truncated_normal([3, 3, 10, 20], dtype=tf.float32, + stddev=1e-1), name='conv-weights2') + conv2 = tf.nn.conv2d(conv1, kernel2, [1, 1, 1, 1], padding='VALID') + biases2 = tf.Variable(tf.constant(0.0, shape=[kernel2.get_shape().as_list()[-1]], dtype=tf.float32), + trainable=True, name='biases') + out2 = tf.nn.bias_add(conv2, biases2) + conv2 = tf.nn.relu(out2, name='relu') + + flattened = tf.contrib.layers.flatten(conv2) + + + # hidden1 = nn_layer(x, x.get_shape().as_list()[1], 10, 'layer1') + hidden1 = nn_layer(flattened, flattened.get_shape().as_list()[1], 10, 'layer1') + + with tf.name_scope('dropout'): + keep_prob = tf.placeholder(tf.float32) + tf.summary.scalar('dropout_keep_probability', keep_prob) + dropped = tf.nn.dropout(hidden1, keep_prob) + + y = nn_layer(dropped, 10, 10, 'layer2', act=tf.identity) + + with tf.name_scope('cross_entropy'): + diff = tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y) + with tf.name_scope('total'): + cross_entropy = tf.reduce_mean(diff) + tf.summary.scalar('cross_entropy', cross_entropy) + + with tf.name_scope('train'): + optimizer = tf.train.AdamOptimizer(FLAGS.learning_rate) + gradients, train_step = Beholder.gradient_helper(optimizer, cross_entropy) + + with tf.name_scope('accuracy'): + with tf.name_scope('correct_prediction'): + correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1)) + with tf.name_scope('accuracy'): + accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) + tf.summary.scalar('accuracy', accuracy) + + merged = tf.summary.merge_all() + train_writer = tf.summary.FileWriter(LOG_DIRECTORY + '/train', sess.graph) + test_writer = tf.summary.FileWriter(LOG_DIRECTORY + '/test') + tf.global_variables_initializer().run() + + visualizer = Beholder(session=sess, + logdir=LOG_DIRECTORY) + + + def feed_dict(train): + if train or FLAGS.fake_data: + xs, ys = mnist.train.next_batch(100, fake_data=FLAGS.fake_data) + k = FLAGS.dropout + else: + xs, ys = mnist.test.images, mnist.test.labels + k = 1.0 + return {x: xs, y_: ys, keep_prob: k} + + for i in range(FLAGS.max_steps): + # if i % 10 == 0: # Record summaries and test-set accuracy + summary, acc = sess.run([merged, accuracy], feed_dict=feed_dict(False)) + test_writer.add_summary(summary, i) + print('Accuracy at step %s: %s' % (i, acc)) + # else: # Record train set summaries, and train + # if i % 100 == 99: # Record execution stats + # run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) + # run_metadata = tf.RunMetadata() + # summary, _ = sess.run([merged, train_step], + # feed_dict=feed_dict(True), + # options=run_options, + # run_metadata=run_metadata) + # train_writer.add_run_metadata(run_metadata, 'step%03d' % i) + # train_writer.add_summary(summary, i) + # print('Adding run metadata for', i) + # else: # Record a summary + print('i', i) + feed_dictionary = feed_dict(True) + summary, gradient_arrays, activations, _ = sess.run([merged, gradients, [image_shaped_input, conv1, conv2, hidden1, y], train_step], feed_dict=feed_dictionary) + first_of_batch = sess.run(x, feed_dict=feed_dictionary)[0].reshape(28, 28) + + visualizer.update( + arrays=activations + [first_of_batch] + gradient_arrays, + frame=first_of_batch, + ) + train_writer.add_summary(summary, i) + + train_writer.close() + test_writer.close() + +def main(_): + if not tf.gfile.Exists(LOG_DIRECTORY): + tf.gfile.MakeDirs(LOG_DIRECTORY) + train() + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--fake_data', nargs='?', const=True, type=bool, + default=False, + help='If true, uses fake data for unit testing.') + parser.add_argument('--max_steps', type=int, default=1000000, + help='Number of steps to run trainer.') + parser.add_argument('--learning_rate', type=float, default=0.001, + help='Initial learning rate') + parser.add_argument('--dropout', type=float, default=0.9, + help='Keep probability for training dropout.') + parser.add_argument( + '--data_dir', + type=str, + default='/tmp/tensorflow/mnist/input_data', + help='Directory for storing input data') + parser.add_argument( + '--log_dir', + type=str, + default='/tmp/tensorflow/mnist/logs/mnist_with_summaries', + help='Summaries log directory') + FLAGS, unparsed = parser.parse_known_args() + tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/tensorboard/plugins/beholder/file_system_tools.py b/tensorboard/plugins/beholder/file_system_tools.py new file mode 100644 index 00000000000..2c7634eeea2 --- /dev/null +++ b/tensorboard/plugins/beholder/file_system_tools.py @@ -0,0 +1,54 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import pickle + +from google.protobuf import message +import tensorflow as tf + +def write_file(contents, path, mode='wb'): + with tf.gfile.Open(path, mode) as new_file: + new_file.write(contents) + + +def read_tensor_summary(path): + with tf.gfile.Open(path, 'rb') as summary_file: + summary_string = summary_file.read() + + if not summary_string: + raise message.DecodeError('Empty summary.') + + summary_proto = tf.Summary() + summary_proto.ParseFromString(summary_string) + tensor_proto = summary_proto.value[0].tensor + array = tf.make_ndarray(tensor_proto) + + return array + + +def write_pickle(obj, path): + with tf.gfile.Open(path, 'wb') as new_file: + pickle.dump(obj, new_file) + + +def read_pickle(path, default=None): + try: + with tf.gfile.Open(path, 'rb') as pickle_file: + result = pickle.load(pickle_file) + + except (IOError, EOFError, ValueError, tf.errors.NotFoundError) as e: + # TODO: log this somehow? Could swallow errors I don't intend. + if default is not None: + result = default + else: + raise e + + return result + + +def resources_path(): + script_directory = os.path.dirname(__file__) + filename = os.path.join(script_directory, 'resources') + return filename diff --git a/tensorboard/plugins/beholder/im_util.py b/tensorboard/plugins/beholder/im_util.py new file mode 100644 index 00000000000..b1d87542066 --- /dev/null +++ b/tensorboard/plugins/beholder/im_util.py @@ -0,0 +1,262 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import threading + +import numpy as np +import tensorflow as tf + +from tensorboard.plugins.beholder.file_system_tools import resources_path + +# pylint: disable=not-context-manager + +def global_extrema(arrays): + return min([x.min() for x in arrays]), max([x.max() for x in arrays]) + + +def scale_sections(sections, scaling_scope): + ''' + input: unscaled sections. + returns: sections scaled to [0, 255] + ''' + new_sections = [] + + if scaling_scope == 'layer': + for section in sections: + new_sections.append(scale_image_for_display(section)) + + elif scaling_scope == 'network': + global_min, global_max = global_extrema(sections) + + for section in sections: + new_sections.append(scale_image_for_display(section, + global_min, + global_max)) + return new_sections + + +def scale_image_for_display(image, minimum=None, maximum=None): + image = image.astype(float) + + minimum = image.min() if minimum is None else minimum + image -= minimum + + maximum = image.max() if maximum is None else maximum + + if maximum == 0: + return image + else: + image *= 255 / maximum + return image.astype(np.uint8) + + +def pad_to_shape(array, shape, constant=245): + padding = [] + + for actual_dim, target_dim in zip(array.shape, shape): + start_padding = 0 + end_padding = target_dim - actual_dim + + padding.append((start_padding, end_padding)) + + return np.pad(array, padding, mode='constant', constant_values=constant) + +# New matplotlib colormaps by Nathaniel J. Smith, Stefan van der Walt, +# and (in the case of viridis) Eric Firing. +# +# This file and the colormaps in it are released under the CC0 license / +# public domain dedication. We would appreciate credit if you use or +# redistribute these colormaps, but do not impose any legal restrictions. +# +# To the extent possible under law, the persons who associated CC0 with +# mpl-colormaps have waived all copyright and related or neighboring rights +# to mpl-colormaps. +# +# You should have received a copy of the CC0 legalcode along with this +# work. If not, see . +colormaps = np.load('{}/colormaps.npy'.format(resources_path())) +magma_data, inferno_data, plasma_data, viridis_data = colormaps + + +def apply_colormap(image, colormap='magma'): + if colormap == 'grayscale': + return image + + data_map = { + 'magma': magma_data, + 'inferno': inferno_data, + 'plasma': plasma_data, + 'viridis': viridis_data, + } + + colormap_data = data_map[colormap] + return (colormap_data[image]*255).astype(np.uint8) + +# Taken from https://github.com/tensorflow/tensorboard/blob/ +# /28f58888ebb22e2db0f4f1f60cd96138ef72b2ef/tensorboard/util.py + +# Modified by Chris Anderson to not use the GPU. +class PersistentOpEvaluator(object): + """Evaluate a fixed TensorFlow graph repeatedly, safely, efficiently. + Extend this class to create a particular kind of op evaluator, like an + image encoder. In `initialize_graph`, create an appropriate TensorFlow + graph with placeholder inputs. In `run`, evaluate this graph and + return its result. This class will manage a singleton graph and + session to preserve memory usage, and will ensure that this graph and + session do not interfere with other concurrent sessions. + A subclass of this class offers a threadsafe, highly parallel Python + entry point for evaluating a particular TensorFlow graph. + Example usage: + class FluxCapacitanceEvaluator(PersistentOpEvaluator): + \"\"\"Compute the flux capacitance required for a system. + Arguments: + x: Available power input, as a `float`, in jigawatts. + Returns: + A `float`, in nanofarads. + \"\"\" + def initialize_graph(self): + self._placeholder = tf.placeholder(some_dtype) + self._op = some_op(self._placeholder) + def run(self, x): + return self._op.eval(feed_dict: {self._placeholder: x}) + evaluate_flux_capacitance = FluxCapacitanceEvaluator() + for x in xs: + evaluate_flux_capacitance(x) + """ + + def __init__(self): + super(PersistentOpEvaluator, self).__init__() + self._session = None + self._initialization_lock = threading.Lock() + + + def _lazily_initialize(self): + """Initialize the graph and session, if this has not yet been done.""" + with self._initialization_lock: + if self._session: + return + graph = tf.Graph() + with graph.as_default(): + self.initialize_graph() + + config = tf.ConfigProto(device_count={'GPU': 0}) + self._session = tf.Session(graph=graph, config=config) + + + def initialize_graph(self): + """Create the TensorFlow graph needed to compute this operation. + This should write ops to the default graph and return `None`. + """ + raise NotImplementedError('Subclasses must implement "initialize_graph".') + + + def run(self, *args, **kwargs): + """Evaluate the ops with the given input. + When this function is called, the default session will have the + graph defined by a previous call to `initialize_graph`. This + function should evaluate any ops necessary to compute the result of + the query for the given *args and **kwargs, likely returning the + result of a call to `some_op.eval(...)`. + """ + raise NotImplementedError('Subclasses must implement "run".') + + + def __call__(self, *args, **kwargs): + self._lazily_initialize() + with self._session.as_default(): + return self.run(*args, **kwargs) + + +class PNGDecoder(PersistentOpEvaluator): + + def __init__(self): + super(PNGDecoder, self).__init__() + self._image_placeholder = None + self._decode_op = None + + + def initialize_graph(self): + self._image_placeholder = tf.placeholder(dtype=tf.string) + self._decode_op = tf.image.decode_png(self._image_placeholder) + + + # pylint: disable=arguments-differ + def run(self, image): + return self._decode_op.eval(feed_dict={ + self._image_placeholder: image, + }) + + +class PNGEncoder(PersistentOpEvaluator): + + def __init__(self): + super(PNGEncoder, self).__init__() + self._image_placeholder = None + self._encode_op = None + + + def initialize_graph(self): + self._image_placeholder = tf.placeholder(dtype=tf.uint8) + self._encode_op = tf.image.encode_png(self._image_placeholder) + + + # pylint: disable=arguments-differ + def run(self, image): + if len(image.shape) == 2: + image = image.reshape([image.shape[0], image.shape[1], 1]) + + return self._encode_op.eval(feed_dict={ + self._image_placeholder: image, + }) + + +class Resizer(PersistentOpEvaluator): + + def __init__(self): + super(Resizer, self).__init__() + self._image_placeholder = None + self._size_placeholder = None + self._resize_op = None + + + def initialize_graph(self): + self._image_placeholder = tf.placeholder(dtype=tf.float32) + self._size_placeholder = tf.placeholder(dtype=tf.int32) + self._resize_op = tf.image.resize_nearest_neighbor(self._image_placeholder, + self._size_placeholder) + + # pylint: disable=arguments-differ + def run(self, image, height, width): + if len(image.shape) == 2: + image = image.reshape([image.shape[0], image.shape[1], 1]) + + resized = np.squeeze(self._resize_op.eval(feed_dict={ + self._image_placeholder: [image], + self._size_placeholder: [height, width] + })) + + return resized + + +decode_png = PNGDecoder() +encode_png = PNGEncoder() +resize = Resizer() + + +def read_image(filename): + with tf.gfile.Open(filename, 'rb') as image_file: + return np.array(decode_png(image_file.read())) + + +def write_image(array, filename): + with tf.gfile.Open(filename, 'w') as image_file: + image_file.write(encode_png(array)) + + +def get_image_relative_to_script(filename): + script_directory = os.path.dirname(__file__) + filename = os.path.join(script_directory, 'resources', filename) + + return read_image(filename) diff --git a/tensorboard/plugins/beholder/resources/arrays-missing.png b/tensorboard/plugins/beholder/resources/arrays-missing.png new file mode 100644 index 00000000000..9474efc0339 Binary files /dev/null and b/tensorboard/plugins/beholder/resources/arrays-missing.png differ diff --git a/tensorboard/plugins/beholder/resources/colormaps.npy b/tensorboard/plugins/beholder/resources/colormaps.npy new file mode 100644 index 00000000000..b0549818313 Binary files /dev/null and b/tensorboard/plugins/beholder/resources/colormaps.npy differ diff --git a/tensorboard/plugins/beholder/resources/frame-missing.png b/tensorboard/plugins/beholder/resources/frame-missing.png new file mode 100644 index 00000000000..401ad75735b Binary files /dev/null and b/tensorboard/plugins/beholder/resources/frame-missing.png differ diff --git a/tensorboard/plugins/beholder/resources/no-data.png b/tensorboard/plugins/beholder/resources/no-data.png new file mode 100644 index 00000000000..ff2c4e95db2 Binary files /dev/null and b/tensorboard/plugins/beholder/resources/no-data.png differ diff --git a/tensorboard/plugins/beholder/server_side/BUILD b/tensorboard/plugins/beholder/server_side/BUILD new file mode 100644 index 00000000000..7fa4d53983f --- /dev/null +++ b/tensorboard/plugins/beholder/server_side/BUILD @@ -0,0 +1,20 @@ +package(default_visibility = ["//visibility:public"]) + +py_library( + name = "beholder_plugin", + srcs = [ + "beholder_plugin.py", + "//tensorboard/plugins/beholder:shared_config.py" + ], + srcs_version = "PY2AND3", + deps = [ + "//tensorboard/plugins/beholder", + "//tensorboard/plugins/beholder:im_util", + "//tensorboard/plugins/beholder:file_system_tools", + "@org_pocoo_werkzeug", + "@org_tensorflow_tensorboard//tensorboard/backend:http_util", + "@org_tensorflow_tensorboard//tensorboard/backend/event_processing:plugin_asset_util", + "@org_tensorflow_tensorboard//tensorboard/backend/event_processing:event_accumulator", + "@org_tensorflow_tensorboard//tensorboard/plugins:base_plugin", + ], +) diff --git a/tensorboard/plugins/beholder/server_side/beholder_plugin.py b/tensorboard/plugins/beholder/server_side/beholder_plugin.py new file mode 100644 index 00000000000..fbb4a48e5fc --- /dev/null +++ b/tensorboard/plugins/beholder/server_side/beholder_plugin.py @@ -0,0 +1,160 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import io +import time + +from google.protobuf import message +import numpy as np +import tensorboard +from tensorboard.backend import http_util +from tensorboard.backend.event_processing import plugin_asset_util as pau +from tensorboard.plugins import base_plugin +import tensorflow as tf +from werkzeug import wrappers + +from beholder.im_util import get_image_relative_to_script, encode_png +from beholder.shared_config import PLUGIN_NAME, SECTION_HEIGHT, IMAGE_WIDTH +from beholder.shared_config import SECTION_INFO_FILENAME, CONFIG_FILENAME,\ + TAG_NAME, SUMMARY_FILENAME, DEFAULT_CONFIG +from beholder.file_system_tools import read_tensor_summary, read_pickle,\ + write_pickle + +import sys +print(sys.version) + +class BeholderPlugin(base_plugin.TBPlugin): + + plugin_name = PLUGIN_NAME + + def __init__(self, context): + self._MULTIPLEXER = context.multiplexer + self.PLUGIN_LOGDIR = pau.PluginDirectory(context.logdir, PLUGIN_NAME) + self.FPS = 10 + self.most_recent_frame = get_image_relative_to_script('no-data.png') + self.most_recent_info = [{ + 'name': 'Waiting for data...', + }] + + if not tf.gfile.Exists(self.PLUGIN_LOGDIR): + tf.gfile.MakeDirs(self.PLUGIN_LOGDIR) + write_pickle(DEFAULT_CONFIG, '{}/{}'.format(self.PLUGIN_LOGDIR, + CONFIG_FILENAME)) + + + def get_plugin_apps(self): + return { + '/change-config': self._serve_change_config, + '/beholder-frame': self._serve_beholder_frame, + '/section-info': self._serve_section_info, + '/ping': self._serve_ping, + '/tags': self._serve_tags, + '/is-active': self._serve_is_active, + } + + + def is_active(self): + summary_filename = '{}/{}'.format(self.PLUGIN_LOGDIR, SUMMARY_FILENAME) + info_filename = '{}/{}'.format(self.PLUGIN_LOGDIR, SECTION_INFO_FILENAME) + return tf.gfile.Exists(summary_filename) and\ + tf.gfile.Exists(info_filename) + + + @wrappers.Request.application + def _serve_is_active(self, request): + return http_util.Respond(request, + {'is_active': self.is_active()}, + 'application/json') + + + def _fetch_current_frame(self): + path = '{}/{}'.format(self.PLUGIN_LOGDIR, SUMMARY_FILENAME) + + try: + frame = read_tensor_summary(path).astype(np.uint8) + self.most_recent_frame = frame + return frame + + except (message.DecodeError, IOError, tf.errors.NotFoundError): + return self.most_recent_frame + + + @wrappers.Request.application + def _serve_tags(self, request): + if self.is_active: + runs_and_tags = { + 'plugins/{}'.format(PLUGIN_NAME): {'tensors': [TAG_NAME]} + } + else: + runs_and_tags = {} + + return http_util.Respond(request, + runs_and_tags, + 'application/json') + + + @wrappers.Request.application + def _serve_change_config(self, request): + config = {} + + for key, value in request.form.items(): + try: + config[key] = int(value) + except ValueError: + if value == 'false': + config[key] = False + elif value == 'true': + config[key] = True + else: + config[key] = value + + self.FPS = config['FPS'] + + write_pickle(config, '{}/{}'.format(self.PLUGIN_LOGDIR, CONFIG_FILENAME)) + return http_util.Respond(request, {'config': config}, 'application/json') + + + @wrappers.Request.application + def _serve_section_info(self, request): + path = '{}/{}'.format(self.PLUGIN_LOGDIR, SECTION_INFO_FILENAME) + info = read_pickle(path, default=self.most_recent_info) + self.most_recent_info = info + return http_util.Respond(request, info, 'application/json') + + + def _frame_generator(self): + + while True: + last_duration = 0 + + if self.FPS == 0: + continue + else: + time.sleep(max(0, 1/(self.FPS) - last_duration)) + + start_time = time.time() + array = self._fetch_current_frame() + image_bytes = encode_png(array) + + frame_text = b'--frame\r\n' + content_type = b'Content-Type: image/png\r\n\r\n' + + response_content = frame_text + content_type + image_bytes + b'\r\n\r\n' + + last_duration = time.time() - start_time + yield response_content + + + @wrappers.Request.application + def _serve_beholder_frame(self, request): # pylint: disable=unused-argument + # Thanks to Miguel Grinberg for this technique: + # https://blog.miguelgrinberg.com/post/video-streaming-with-flask + mimetype = 'multipart/x-mixed-replace; boundary=frame' + return wrappers.Response(response=self._frame_generator(), + status=200, + mimetype=mimetype) + + @wrappers.Request.application + def _serve_ping(self, request): # pylint: disable=unused-argument + return http_util.Respond(request, {'status': 'alive'}, 'application/json') diff --git a/tensorboard/plugins/beholder/shared_config.py b/tensorboard/plugins/beholder/shared_config.py new file mode 100644 index 00000000000..6f4695a4549 --- /dev/null +++ b/tensorboard/plugins/beholder/shared_config.py @@ -0,0 +1,25 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +PLUGIN_NAME = 'beholder' +TAG_NAME = 'beholder-frame' +SUMMARY_FILENAME = 'frame.summary' +CONFIG_FILENAME = 'config.pkl' +SECTION_INFO_FILENAME = 'section-info.pkl' + +DEFAULT_CONFIG = { + 'values': 'trainable_variables', + 'mode': 'variance', + 'scaling': 'layer', + 'window_size': 15, + 'FPS': 10, + 'is_recording': False, + 'show_all': False, + 'colormap': 'magma' +} + +SECTION_HEIGHT = 128 +IMAGE_WIDTH = 512 + 256 + +TB_WHITE = 245 diff --git a/tensorboard/plugins/beholder/video_writing.py b/tensorboard/plugins/beholder/video_writing.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tensorboard/plugins/beholder/visualizer.py b/tensorboard/plugins/beholder/visualizer.py new file mode 100644 index 00000000000..22d9b5a7fd3 --- /dev/null +++ b/tensorboard/plugins/beholder/visualizer.py @@ -0,0 +1,293 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from collections import deque +from math import floor, sqrt + +import numpy as np +import tensorflow as tf + +from tensorboard.plugins.beholder import im_util +from tensorboard.plugins.beholder.shared_config import SECTION_HEIGHT,\ + IMAGE_WIDTH, DEFAULT_CONFIG, SECTION_INFO_FILENAME +from tensorboard.plugins.beholder.file_system_tools import write_pickle + +MIN_SQUARE_SIZE = 3 + +class Visualizer(object): + + def __init__(self, logdir): + self.logdir = logdir + self.sections_over_time = deque([], DEFAULT_CONFIG['window_size']) + self.config = dict(DEFAULT_CONFIG) + self.old_config = dict(DEFAULT_CONFIG) + + + def _reshape_conv_array(self, array, section_height, image_width): + '''Reshape a rank 4 array to be rank 2, where each column of block_width is + a filter, and each row of block height is an input channel. For example: + + [[[[ 11, 21, 31, 41], + [ 51, 61, 71, 81], + [ 91, 101, 111, 121]], + [[ 12, 22, 32, 42], + [ 52, 62, 72, 82], + [ 92, 102, 112, 122]], + [[ 13, 23, 33, 43], + [ 53, 63, 73, 83], + [ 93, 103, 113, 123]]], + [[[ 14, 24, 34, 44], + [ 54, 64, 74, 84], + [ 94, 104, 114, 124]], + [[ 15, 25, 35, 45], + [ 55, 65, 75, 85], + [ 95, 105, 115, 125]], + [[ 16, 26, 36, 46], + [ 56, 66, 76, 86], + [ 96, 106, 116, 126]]], + [[[ 17, 27, 37, 47], + [ 57, 67, 77, 87], + [ 97, 107, 117, 127]], + [[ 18, 28, 38, 48], + [ 58, 68, 78, 88], + [ 98, 108, 118, 128]], + [[ 19, 29, 39, 49], + [ 59, 69, 79, 89], + [ 99, 109, 119, 129]]]] + + should be reshaped to: + + [[ 11, 12, 13, 21, 22, 23, 31, 32, 33, 41, 42, 43], + [ 14, 15, 16, 24, 25, 26, 34, 35, 36, 44, 45, 46], + [ 17, 18, 19, 27, 28, 29, 37, 38, 39, 47, 48, 49], + [ 51, 52, 53, 61, 62, 63, 71, 72, 73, 81, 82, 83], + [ 54, 55, 56, 64, 65, 66, 74, 75, 76, 84, 85, 86], + [ 57, 58, 59, 67, 68, 69, 77, 78, 79, 87, 88, 89], + [ 91, 92, 93, 101, 102, 103, 111, 112, 113, 121, 122, 123], + [ 94, 95, 96, 104, 105, 106, 114, 115, 116, 124, 125, 126], + [ 97, 98, 99, 107, 108, 109, 117, 118, 119, 127, 128, 129]] + ''' + + # E.g. [100, 24, 24, 10]: this shouldn't be reshaped like normal. + if array.shape[1] == array.shape[2] and array.shape[0] != array.shape[1]: + array = np.rollaxis(np.rollaxis(array, 2), 2) + + block_height, block_width, in_channels = array.shape[:3] + rows = [] + + max_element_count = section_height * int(image_width / MIN_SQUARE_SIZE) + element_count = 0 + + for i in range(in_channels): + rows.append(array[:, :, i, :].reshape(block_height, -1, order='F')) + + # This line should be left in this position. Gives it one extra row. + if element_count >= max_element_count and not self.config['show_all']: + break + + element_count += block_height * in_channels * block_width + + return np.vstack(rows) + + + def _reshape_irregular_array(self, array, section_height, image_width): + '''Reshapes arrays of ranks not in {1, 2, 4} + ''' + section_area = section_height * image_width + flattened_array = np.ravel(array) + + if not self.config['show_all']: + flattened_array = flattened_array[:int(section_area/MIN_SQUARE_SIZE)] + + cell_count = np.prod(flattened_array.shape) + cell_area = section_area / cell_count + + cell_side_length = max(1, floor(sqrt(cell_area))) + row_count = max(1, int(section_height / cell_side_length)) + col_count = int(cell_count / row_count) + + # Reshape the truncated array so that it has the same aspect ratio as + # the section. + + # Truncate whatever remaining values there are that don't fit. Hopefully + # it doesn't matter that the last few (< section count) aren't there. + section = np.reshape(flattened_array[:row_count * col_count], + (row_count, col_count)) + + return section + + + def _determine_image_width(self, arrays, show_all): + final_width = IMAGE_WIDTH + + if show_all: + for array in arrays: + rank = len(array.shape) + + if rank == 1: + width = len(array) + elif rank == 2: + width = array.shape[1] + elif rank == 4: + width = array.shape[1] * array.shape[3] + else: + width = IMAGE_WIDTH + + if width > final_width: + final_width = width + + return final_width + + + def _determine_section_height(self, array, show_all): + rank = len(array.shape) + height = SECTION_HEIGHT + + if show_all: + if rank == 1: + height = SECTION_HEIGHT + if rank == 2: + height = max(SECTION_HEIGHT, array.shape[0]) + elif rank == 4: + height = max(SECTION_HEIGHT, array.shape[0] * array.shape[2]) + else: + height = max(SECTION_HEIGHT, np.prod(array.shape) // IMAGE_WIDTH) + + return height + + + def _arrays_to_sections(self, arrays): + ''' + input: unprocessed numpy arrays. + returns: columns of the size that they will appear in the image, not scaled + for display. That needs to wait until after variance is computed. + ''' + sections = [] + sections_to_resize_later = {} + show_all = self.config['show_all'] + image_width = self._determine_image_width(arrays, show_all) + + for array_number, array in enumerate(arrays): + rank = len(array.shape) + section_height = self._determine_section_height(array, show_all) + + if rank == 1: + section = np.atleast_2d(array) + elif rank == 2: + section = array + elif rank == 4: + section = self._reshape_conv_array(array, section_height, image_width) + else: + section = self._reshape_irregular_array(array, + section_height, + image_width) + # Only calculate variance for what we have to. In some cases (biases), + # the section is larger than the array, so we don't want to calculate + # variance for the same value over and over - better to resize later. + # About a 6-7x speedup for a big network with a big variance window. + section_size = section_height * image_width + array_size = np.prod(array.shape) + + if section_size > array_size: + sections.append(section) + sections_to_resize_later[array_number] = section_height + else: + sections.append(im_util.resize(section, section_height, image_width)) + + self.sections_over_time.append(sections) + + if self.config['mode'] == 'variance': + sections = self._sections_to_variance_sections(self.sections_over_time) + + for array_number, height in sections_to_resize_later.items(): + sections[array_number] = im_util.resize(sections[array_number], + height, + image_width) + return sections + + + def _sections_to_variance_sections(self, sections_over_time): + '''Computes the variance of corresponding sections over time. + + Returns: + a list of np arrays. + ''' + variance_sections = [] + + for i in range(len(sections_over_time[0])): + time_sections = [sections[i] for sections in sections_over_time] + variance = np.var(time_sections, axis=0) + variance_sections.append(variance) + + return variance_sections + + + def _sections_to_image(self, sections): + padding_size = 5 + + sections = im_util.scale_sections(sections, self.config['scaling']) + + final_stack = [sections[0]] + padding = np.zeros((padding_size, sections[0].shape[1])) + + for section in sections[1:]: + final_stack.append(padding) + final_stack.append(section) + + return np.vstack(final_stack).astype(np.uint8) + + + def _maybe_clear_deque(self): + '''Clears the deque if certain parts of the config have changed.''' + + for config_item in ['values', 'mode', 'show_all']: + if self.config[config_item] != self.old_config[config_item]: + self.sections_over_time.clear() + break + + self.old_config = self.config + + window_size = self.config['window_size'] + if window_size != self.sections_over_time.maxlen: + self.sections_over_time = deque(self.sections_over_time, window_size) + + + def _save_section_info(self, arrays, sections): + infos = [] + + if self.config['values'] == 'trainable_variables': + names = [x.name for x in tf.trainable_variables()] + else: + names = range(len(arrays)) + + for array, section, name in zip(arrays, sections, names): + info = {} + + info['name'] = name + info['shape'] = str(array.shape) + info['min'] = '{:.3e}'.format(section.min()) + info['mean'] = '{:.3e}'.format(section.mean()) + info['max'] = '{:.3e}'.format(section.max()) + info['range'] = '{:.3e}'.format(section.max() - section.min()) + info['height'] = section.shape[0] + + infos.append(info) + + write_pickle(infos, '{}/{}'.format(self.logdir, SECTION_INFO_FILENAME)) + + + def build_frame(self, arrays): + self._maybe_clear_deque() + + arrays = arrays if isinstance(arrays, list) else [arrays] + + sections = self._arrays_to_sections(arrays) + self._save_section_info(arrays, sections) + final_image = self._sections_to_image(sections) + final_image = im_util.apply_colormap(final_image, self.config['colormap']) + + return final_image + + def update(self, config): + self.config = config