From d08fe6d45a537d6d57d704c37d33ae01359a029a Mon Sep 17 00:00:00 2001 From: Virendra Kumar Pathak Date: Sun, 14 May 2023 01:24:31 -0700 Subject: [PATCH] Softly deprecate the get_str=False flag. Summary: We don't want to use print directly in stats.print() method. Instead this method will return the output string to the caller. Reviewed By: shapovalov Differential Revision: D45356240 fbshipit-source-id: 2cabe3cdfb9206bf09aa7b3cdd2263148a5ba145 --- .../implicitron_trainer/impl/training_loop.py | 4 +- pytorch3d/implicitron/tools/stats.py | 88 ++++++++++++------- 2 files changed, 58 insertions(+), 34 deletions(-) diff --git a/projects/implicitron_trainer/impl/training_loop.py b/projects/implicitron_trainer/impl/training_loop.py index 6c4ae1b80..3eeec4c1c 100644 --- a/projects/implicitron_trainer/impl/training_loop.py +++ b/projects/implicitron_trainer/impl/training_loop.py @@ -256,7 +256,6 @@ def load_stats( list(log_vars), plot_file=os.path.join(exp_dir, "train_stats.pdf"), visdom_env=visdom_env_charts, - verbose=False, visdom_server=self.visdom_server, visdom_port=self.visdom_port, ) @@ -382,7 +381,8 @@ def _training_or_validation_epoch( # print textual status update if it % self.metric_print_interval == 0 or last_iter: - stats.print(stat_set=trainmode, max_it=n_batches) + std_out = stats.get_status_string(stat_set=trainmode, max_it=n_batches) + logger.info(std_out) # visualize results if ( diff --git a/pytorch3d/implicitron/tools/stats.py b/pytorch3d/implicitron/tools/stats.py index a2826a093..32e7d703d 100644 --- a/pytorch3d/implicitron/tools/stats.py +++ b/pytorch3d/implicitron/tools/stats.py @@ -6,6 +6,7 @@ import gzip import json +import logging import time import warnings from collections.abc import Iterable @@ -17,6 +18,8 @@ from matplotlib import colors as mcolors from pytorch3d.implicitron.tools.vis_utils import get_visdom_connection +logger = logging.getLogger(__name__) + class AverageMeter(object): """Computes and stores the average and current value""" @@ -91,7 +94,9 @@ class Stats(object): # stats.update() automatically parses the 'objective' and 'top1e' from # the "output" dict and stores this into the db stats.update(output) - stats.print() # prints the averages over given epoch + # prints the metric averages over given epoch + std_out = stats.get_status_string() + logger.info(str_out) # stores the training plots into '/tmp/epoch_stats.pdf' # and plots into a visdom server running at localhost (if running) stats.plot_stats(plot_file='/tmp/epoch_stats.pdf') @@ -101,7 +106,6 @@ class Stats(object): def __init__( self, log_vars, - verbose=False, epoch=-1, visdom_env="main", do_plot=True, @@ -110,7 +114,6 @@ def __init__( visdom_port=8097, ): - self.verbose = verbose self.log_vars = log_vars self.visdom_env = visdom_env self.visdom_server = visdom_server @@ -156,15 +159,14 @@ def __exit__(self, type, value, traceback): iserr = type is not None and issubclass(type, Exception) iserr = iserr or (type is KeyboardInterrupt) if iserr: - print("error inside 'with' block") + logger.error("error inside 'with' block") return if self.do_plot: self.plot_stats(self.visdom_env) def reset(self): # to be called after each epoch stat_sets = list(self.stats.keys()) - if self.verbose: - print("stats: epoch %d - reset" % self.epoch) + logger.debug(f"stats: epoch {self.epoch} - reset") self.it = {k: -1 for k in stat_sets} for stat_set in stat_sets: for stat in self.stats[stat_set]: @@ -172,16 +174,14 @@ def reset(self): # to be called after each epoch def hard_reset(self, epoch=-1): # to be called during object __init__ self.epoch = epoch - if self.verbose: - print("stats: epoch %d - hard reset" % self.epoch) + logger.debug(f"stats: epoch {self.epoch} - hard reset") self.stats = {} # reset self.reset() def new_epoch(self): - if self.verbose: - print("stats: new epoch %d" % (self.epoch + 1)) + logger.debug(f"stats: new epoch {(self.epoch + 1)}") self.epoch += 1 self.reset() # zero the stats + increase epoch counter @@ -193,18 +193,17 @@ def gather_value(self, val): val = float(val.sum()) return val - def add_log_vars(self, added_log_vars, verbose=True): + def add_log_vars(self, added_log_vars): for add_log_var in added_log_vars: if add_log_var not in self.stats: - if verbose: - print(f"Adding {add_log_var}") + logger.debug(f"Adding {add_log_var}") self.log_vars.append(add_log_var) def update(self, preds, time_start=None, freeze_iter=False, stat_set="train"): if self.epoch == -1: # uninitialized - print( - "warning: epoch==-1 means uninitialized stats structure -> new_epoch() called" + logger.warning( + "epoch==-1 means uninitialized stats structure -> new_epoch() called" ) self.new_epoch() @@ -284,6 +283,12 @@ def print( skip_nan=False, stat_format=lambda s: s.replace("loss_", "").replace("prev_stage_", "ps_"), ): + """ + stats.print() is deprecated. Please use get_status_string() instead. + example: + std_out = stats.get_status_string() + logger.info(str_out) + """ epoch = self.epoch stats = self.stats @@ -311,8 +316,30 @@ def print( if get_str: return str_out else: + warnings.warn( + "get_str=False is deprecated." + "Please enable this flag to get receive the output string.", + DeprecationWarning, + ) print(str_out) + def get_status_string( + self, + max_it=None, + stat_set="train", + vars_print=None, + skip_nan=False, + stat_format=lambda s: s.replace("loss_", "").replace("prev_stage_", "ps_"), + ): + return self.print( + max_it=max_it, + stat_set=stat_set, + vars_print=vars_print, + get_str=True, + skip_nan=skip_nan, + stat_format=stat_format, + ) + def plot_stats( self, visdom_env=None, plot_file=None, visdom_server=None, visdom_port=None ): @@ -329,16 +356,15 @@ def plot_stats( stat_sets = list(self.stats.keys()) - print( - "printing charts to visdom env '%s' (%s:%d)" - % (visdom_env, visdom_server, visdom_port) + logger.debug( + f"printing charts to visdom env '{visdom_env}' ({visdom_server}:{visdom_port})" ) novisdom = False viz = get_visdom_connection(server=visdom_server, port=visdom_port) if viz is None or not viz.check_connection(): - print("no visdom server! -> skipping visdom plots") + logger.info("no visdom server! -> skipping visdom plots") novisdom = True lines = [] @@ -385,7 +411,7 @@ def plot_stats( ) if plot_file: - print("exporting stats to %s" % plot_file) + logger.info(f"plotting stats to {plot_file}") ncol = 3 nrow = int(np.ceil(float(len(lines)) / ncol)) matplotlib.rcParams.update({"font.size": 5}) @@ -423,7 +449,7 @@ def plot_stats( except PermissionError: warnings.warn("Cant dump stats due to insufficient permissions!") - def synchronize_logged_vars(self, log_vars, default_val=float("NaN"), verbose=True): + def synchronize_logged_vars(self, log_vars, default_val=float("NaN")): stat_sets = list(self.stats.keys()) @@ -431,7 +457,7 @@ def synchronize_logged_vars(self, log_vars, default_val=float("NaN"), verbose=Tr for stat_set in stat_sets: for stat in self.stats[stat_set].keys(): if stat not in log_vars: - print("additional stat %s:%s -> removing" % (stat_set, stat)) + logger.warning(f"additional stat {stat_set}:{stat} -> removing") self.stats[stat_set] = { stat: v for stat, v in self.stats[stat_set].items() if stat in log_vars @@ -442,21 +468,19 @@ def synchronize_logged_vars(self, log_vars, default_val=float("NaN"), verbose=Tr for stat_set in stat_sets: for stat in log_vars: if stat not in self.stats[stat_set]: - if verbose: - print( - "missing stat %s:%s -> filling with default values (%1.2f)" - % (stat_set, stat, default_val) - ) + logger.info( + "missing stat %s:%s -> filling with default values (%1.2f)" + % (stat_set, stat, default_val) + ) elif len(self.stats[stat_set][stat].history) != self.epoch + 1: h = self.stats[stat_set][stat].history if len(h) == 0: # just never updated stat ... skip continue else: - if verbose: - print( - "incomplete stat %s:%s -> reseting with default values (%1.2f)" - % (stat_set, stat, default_val) - ) + logger.info( + "incomplete stat %s:%s -> reseting with default values (%1.2f)" + % (stat_set, stat, default_val) + ) else: continue