Skip to content

Commit

Permalink
Loosening the checks in eval script for CO3Dv2 style eval
Browse files Browse the repository at this point in the history
Summary:
V2 dataset does not have the concept of known/unseen frames. Test-time conditining is done with train-set frames, which violates the previous check.

Also fixing a corner case in VideoWriter.

Reviewed By: bottler

Differential Revision: D42706976

fbshipit-source-id: d43be3dd3060d18cb9f46d5dcf6252d9f084110f
  • Loading branch information
shapovalov authored and facebook-github-bot committed Jan 26, 2023
1 parent 9dc28f5 commit 54eb76d
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 12 deletions.
15 changes: 4 additions & 11 deletions pytorch3d/implicitron/evaluation/evaluate_new_view_synthesis.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,17 +219,10 @@ def eval_batch(
frame_type = [frame_type]

is_train = is_train_frame(frame_type)
if not (is_train[0] == is_train).all():
raise ValueError("All frames in the eval batch have to be either train/test.")

# pyre-fixme[16]: `Optional` has no attribute `device`.
is_known = is_known_frame(frame_type, device=frame_data.image_rgb.device)

if not ((is_known[1:] == 1).all() and (is_known[0] == 0).all()):
if len(is_train) > 1 and (is_train[1] != is_train[1:]).any():
raise ValueError(
"For evaluation the first element of the batch has to be"
+ " a target view while the rest should be source views."
) # TODO: do we need to enforce this?
"All (conditioning) frames in the eval batch have to be either train/test."
)

for k in [
"depth_map",
Expand Down Expand Up @@ -362,7 +355,7 @@ def eval_batch(

results["meta"] = {
# store the size of the batch (corresponds to n_src_views+1)
"batch_size": int(is_known.numel()),
"batch_size": len(frame_type),
# store the type of the target frame
# pyre-fixme[16]: `None` has no attribute `__getitem__`.
"frame_type": str(frame_data.frame_type[0]),
Expand Down
5 changes: 4 additions & 1 deletion pytorch3d/implicitron/tools/video_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,11 @@ def get_video(self, quiet: bool = True) -> str:
quiet: If `True`, suppresses logging messages.
Returns:
video_path: The path to the generated video.
video_path: The path to the generated video if any frames were added.
Otherwise returns an empty string.
"""
if self.frame_num == 0:
return ""

regexp = os.path.join(self.cache_dir, self.regexp)

Expand Down

0 comments on commit 54eb76d

Please sign in to comment.