From eb803f87023f7f7f64075a4d30f82538c889dd8c Mon Sep 17 00:00:00 2001 From: Austin Welch Date: Mon, 26 Dec 2022 04:07:22 -0500 Subject: [PATCH] [Enhance] Add progress argument in load_from_http (#770) --- mmengine/runner/checkpoint.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/mmengine/runner/checkpoint.py b/mmengine/runner/checkpoint.py index 3a786cfa21..91b471b368 100644 --- a/mmengine/runner/checkpoint.py +++ b/mmengine/runner/checkpoint.py @@ -320,7 +320,10 @@ def load_from_local(filename, map_location): @CheckpointLoader.register_scheme(prefixes=('http://', 'https://')) -def load_from_http(filename, map_location=None, model_dir=None): +def load_from_http(filename, + map_location=None, + model_dir=None, + progress=os.isatty(0)): """load checkpoint through HTTP or HTTPS scheme path. In distributed setting, this function only download checkpoint at local rank 0. @@ -337,12 +340,18 @@ def load_from_http(filename, map_location=None, model_dir=None): rank, world_size = get_dist_info() if rank == 0: checkpoint = load_url( - filename, model_dir=model_dir, map_location=map_location) + filename, + model_dir=model_dir, + map_location=map_location, + progress=progress) if world_size > 1: torch.distributed.barrier() if rank > 0: checkpoint = load_url( - filename, model_dir=model_dir, map_location=map_location) + filename, + model_dir=model_dir, + map_location=map_location, + progress=progress) return checkpoint