From 67cc75f50b133f7814368b47e92c48c43ac729ca Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Mon, 9 Sep 2024 14:42:08 -0700 Subject: [PATCH 1/5] Update DataTree repr to indicate inheritance Fixes https://github.com/pydata/xarray/issues/9463 --- xarray/core/formatting.py | 118 +++++++++++++++++++++++++++------- xarray/tests/test_datatree.py | 10 +-- 2 files changed, 96 insertions(+), 32 deletions(-) diff --git a/xarray/core/formatting.py b/xarray/core/formatting.py index 657c9a2dbfb..e4b9d928d5b 100644 --- a/xarray/core/formatting.py +++ b/xarray/core/formatting.py @@ -6,7 +6,7 @@ import contextlib import functools import math -from collections import defaultdict +from collections import ChainMap, defaultdict from collections.abc import Collection, Hashable, Sequence from datetime import datetime, timedelta from itertools import chain, zip_longest @@ -29,6 +29,7 @@ if TYPE_CHECKING: from xarray.core.coordinates import AbstractCoordinates from xarray.core.datatree import DataTree + from xarray.core.variable import Variable UNITS = ("B", "kB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB") @@ -318,7 +319,7 @@ def inline_variable_array_repr(var, max_width): def summarize_variable( name: Hashable, - var, + var: Variable, col_width: int, max_width: int | None = None, is_index: bool = False, @@ -446,6 +447,21 @@ def coords_repr(coords: AbstractCoordinates, col_width=None, max_rows=None): ) +def inherited_coords_repr(node: DataTree, col_width=None, max_rows=None): + coords = _inherited_vars(node._coord_variables) + if col_width is None: + col_width = _calculate_col_width(coords) + return _mapping_repr( + coords, + title="Inherited coordinates", + summarizer=summarize_variable, + expand_option_name="display_expand_coords", + col_width=col_width, + indexes=node._indexes, + max_rows=max_rows, + ) + + def inline_index_repr(index: pd.Index, max_width=None): if hasattr(index, "_repr_inline_"): repr_ = index._repr_inline_(max_width=max_width) @@ -498,12 +514,12 @@ def filter_nondefault_indexes(indexes, filter_indexes: bool): } -def indexes_repr(indexes, max_rows: int | None = None) -> str: +def indexes_repr(indexes, max_rows: int | None = None, title: str = "Indexes") -> str: col_width = _calculate_col_width(chain.from_iterable(indexes)) return _mapping_repr( indexes, - "Indexes", + title, summarize_index, "display_expand_indexes", col_width=col_width, @@ -1048,19 +1064,71 @@ def diff_datatree_repr(a: DataTree, b: DataTree, compat): return "\n".join(summary) -def _single_node_repr(node: DataTree) -> str: - """Information about this node, not including its relationships to other nodes.""" - if node.has_data or node.has_attrs: - # TODO: change this to inherited=False, in order to clarify what is - # inherited? https://github.com/pydata/xarray/issues/9463 - node_view = node._to_dataset_view(rebuild_dims=False, inherited=True) - ds_info = "\n" + repr(node_view) - else: - ds_info = "" - return f"Group: {node.path}{ds_info}" +def _inherited_vars(mapping: ChainMap) -> dict: + return {k: v for k, v in mapping.parents.items() if k not in mapping.maps[0]} + + +def _datatree_node_repr(node: DataTree, show_inherited: bool) -> str: + summary = [f"Group: {node.path}"] + + col_width = _calculate_col_width(node.variables) + max_rows = OPTIONS["display_max_rows"] + + inherited_coords = _inherited_vars(node._coord_variables) + # Only show dimensions if also showing a variable or coordinates section. + show_dims = ( + node._node_coord_variables + or (show_inherited and inherited_coords) + or node._data_variables + ) + + if show_dims: + # Includes inherited dimensions. + dims_start = pretty_print("Dimensions:", col_width) + dims_values = dim_summary_limited( + node, col_width=col_width + 1, max_rows=max_rows + ) + summary.append(f"{dims_start}({dims_values})") -def datatree_repr(dt: DataTree): + if node._node_coord_variables: + summary.append(coords_repr(node.coords, col_width=col_width, max_rows=max_rows)) + + if show_inherited and inherited_coords: + summary.append( + inherited_coords_repr(node, col_width=col_width, max_rows=max_rows) + ) + + if show_dims: + unindexed_dims_str = unindexed_dims_repr( + node.dims, node.coords, max_rows=max_rows + ) + if unindexed_dims_str: + summary.append(unindexed_dims_str) + + if node._data_variables: + summary.append( + data_vars_repr(node._data_variables, col_width=col_width, max_rows=max_rows) + ) + + # TODO: only show indexes defined at this node, with a separate section for + # inherited indexes (if show_inherited=True) + display_default_indexes = _get_boolean_with_default( + "display_default_indexes", False + ) + xindexes = filter_nondefault_indexes( + _get_indexes_dict(node.xindexes), not display_default_indexes + ) + if xindexes: + summary.append(indexes_repr(xindexes, max_rows=max_rows)) + + if node.attrs: + summary.append(attrs_repr(node.attrs, max_rows=max_rows)) + + return "\n".join(summary) + + +def datatree_repr(dt: DataTree) -> str: """A printable representation of the structure of this entire tree.""" renderer = RenderDataTree(dt) @@ -1068,19 +1136,21 @@ def datatree_repr(dt: DataTree): header = f"" lines = [header] + show_inherited = True for pre, fill, node in renderer: - node_repr = _single_node_repr(node) + node_repr = _datatree_node_repr(node, show_inherited=show_inherited) + show_inherited = False # only show inherited coords on the root - node_line = f"{pre}{node_repr.splitlines()[0]}" + raw_repr_lines = node_repr.splitlines() + + node_line = f"{pre}{raw_repr_lines[0]}" lines.append(node_line) - if node.has_data or node.has_attrs: - ds_repr = node_repr.splitlines()[2:] - for line in ds_repr: - if len(node.children) > 0: - lines.append(f"{fill}{renderer.style.vertical}{line}") - else: - lines.append(f"{fill}{' ' * len(renderer.style.vertical)}{line}") + for line in raw_repr_lines[1:]: + if len(node.children) > 0: + lines.append(f"{fill}{renderer.style.vertical}{line}") + else: + lines.append(f"{fill}{' ' * len(renderer.style.vertical)}{line}") return "\n".join(lines) diff --git a/xarray/tests/test_datatree.py b/xarray/tests/test_datatree.py index f1f74d240f0..cbdbd541fb0 100644 --- a/xarray/tests/test_datatree.py +++ b/xarray/tests/test_datatree.py @@ -798,20 +798,16 @@ def test_repr(self): │ e (x) float64 16B 1.0 2.0 └── Group: /b │ Dimensions: (x: 2, y: 1) - │ Coordinates: - │ * x (x) float64 16B 2.0 3.0 │ Dimensions without coordinates: y │ Data variables: │ f (y) float64 8B 3.0 ├── Group: /b/c └── Group: /b/d Dimensions: (x: 2, y: 1) - Coordinates: - * x (x) float64 16B 2.0 3.0 Dimensions without coordinates: y Data variables: g float64 8B 4.0 - """ + """ ).strip() assert result == expected @@ -821,7 +817,7 @@ def test_repr(self): Group: /b │ Dimensions: (x: 2, y: 1) - │ Coordinates: + │ Inherited coordinates: │ * x (x) float64 16B 2.0 3.0 │ Dimensions without coordinates: y │ Data variables: @@ -829,8 +825,6 @@ def test_repr(self): ├── Group: /b/c └── Group: /b/d Dimensions: (x: 2, y: 1) - Coordinates: - * x (x) float64 16B 2.0 3.0 Dimensions without coordinates: y Data variables: g float64 8B 4.0 From 3de69f73e75f7332ee7473c7a85f3d1b26c0e5a4 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Mon, 9 Sep 2024 15:02:52 -0700 Subject: [PATCH 2/5] fix whitespace --- xarray/tests/test_datatree.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/tests/test_datatree.py b/xarray/tests/test_datatree.py index cbdbd541fb0..83c0e728970 100644 --- a/xarray/tests/test_datatree.py +++ b/xarray/tests/test_datatree.py @@ -807,7 +807,7 @@ def test_repr(self): Dimensions without coordinates: y Data variables: g float64 8B 4.0 - """ + """ ).strip() assert result == expected From 83f524b4041651192373a9172a484b02bfbec84e Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Mon, 9 Sep 2024 17:15:05 -0700 Subject: [PATCH 3/5] add more repr tests, fix failure --- xarray/tests/test_datatree.py | 52 +++++++++++++++++++++++++++++++++ xarray/tests/test_formatting.py | 3 -- 2 files changed, 52 insertions(+), 3 deletions(-) diff --git a/xarray/tests/test_datatree.py b/xarray/tests/test_datatree.py index 83c0e728970..dfd4b493344 100644 --- a/xarray/tests/test_datatree.py +++ b/xarray/tests/test_datatree.py @@ -832,6 +832,58 @@ def test_repr(self): ).strip() assert result == expected + def test_repr2(self): + tree = DataTree.from_dict( + { + "/": Dataset(coords={"x": [1]}), + "/first_child": None, + "/second_child": Dataset({"foo": ("x", [0])}), + } + ) + + result = repr(tree) + expected = dedent( + """ + + Group: / + │ Dimensions: (x: 1) + │ Coordinates: + │ * x (x) int64 8B 1 + ├── Group: /first_child + └── Group: /second_child + Dimensions: (x: 1) + Data variables: + foo (x) int64 8B 0 + """ + ).strip() + assert result == expected + + result = repr(tree["first_child"]) + expected = dedent( + """ + + Group: /first_child + Dimensions: (x: 1) + Inherited coordinates: + * x (x) int64 8B 1 + """ + ).strip() + assert result == expected + + result = repr(tree["second_child"]) + expected = dedent( + """ + + Group: /second_child + Dimensions: (x: 1) + Inherited coordinates: + * x (x) int64 8B 1 + Data variables: + foo (x) int64 8B 0 + """ + ).strip() + assert result == expected + def _exact_match(message: str) -> str: return re.escape(dedent(message).strip()) diff --git a/xarray/tests/test_formatting.py b/xarray/tests/test_formatting.py index e7076151314..696c849cea1 100644 --- a/xarray/tests/test_formatting.py +++ b/xarray/tests/test_formatting.py @@ -632,9 +632,6 @@ def test_datatree_print_empty_node_with_attrs(self): """\ Group: / - Dimensions: () - Data variables: - *empty* Attributes: note: has attrs""" ) From 3b42219ce7209de88d95261ffc4e3b6825ca84aa Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Mon, 9 Sep 2024 18:13:07 -0700 Subject: [PATCH 4/5] fix failure on windows --- asv_bench/benchmarks/datatree.py | 2 +- xarray/tests/test_datatree.py | 14 +++++++------- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/asv_bench/benchmarks/datatree.py b/asv_bench/benchmarks/datatree.py index 13eedd0a518..9f1774f60ac 100644 --- a/asv_bench/benchmarks/datatree.py +++ b/asv_bench/benchmarks/datatree.py @@ -6,7 +6,7 @@ class Datatree: def setup(self): run1 = DataTree.from_dict({"run1": xr.Dataset({"a": 1})}) self.d_few = {"run1": run1} - self.d_many = {f"run{i}": run1.copy() for i in range(100)} + self.d_many = {f"run{i}": xr.Dataset({"a": 1}) for i in range(100)} def time_from_dict_few(self): DataTree.from_dict(self.d_few) diff --git a/xarray/tests/test_datatree.py b/xarray/tests/test_datatree.py index dfd4b493344..ba3041f271f 100644 --- a/xarray/tests/test_datatree.py +++ b/xarray/tests/test_datatree.py @@ -835,9 +835,9 @@ def test_repr(self): def test_repr2(self): tree = DataTree.from_dict( { - "/": Dataset(coords={"x": [1]}), + "/": Dataset(coords={"x": [1.0]}), "/first_child": None, - "/second_child": Dataset({"foo": ("x", [0])}), + "/second_child": Dataset({"foo": ("x", [0.0])}), } ) @@ -848,12 +848,12 @@ def test_repr2(self): Group: / │ Dimensions: (x: 1) │ Coordinates: - │ * x (x) int64 8B 1 + │ * x (x) float64 8B 1.0 ├── Group: /first_child └── Group: /second_child Dimensions: (x: 1) Data variables: - foo (x) int64 8B 0 + foo (x) float64 8B 0.0 """ ).strip() assert result == expected @@ -865,7 +865,7 @@ def test_repr2(self): Group: /first_child Dimensions: (x: 1) Inherited coordinates: - * x (x) int64 8B 1 + * x (x) float64 8B 1.0 """ ).strip() assert result == expected @@ -877,9 +877,9 @@ def test_repr2(self): Group: /second_child Dimensions: (x: 1) Inherited coordinates: - * x (x) int64 8B 1 + * x (x) float64 8B 1.0 Data variables: - foo (x) int64 8B 0 + foo (x) float64 8B 0.0 """ ).strip() assert result == expected From 70e661dc3d3ff952bc4a1f6453a7e4f7067dbf12 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Mon, 9 Sep 2024 18:34:40 -0700 Subject: [PATCH 5/5] fix repr for inherited dimensions --- xarray/core/formatting.py | 26 ++++++++----- xarray/tests/test_datatree.py | 67 +++++++++++++++++++++++++++++---- xarray/tests/test_formatting.py | 2 +- 3 files changed, 78 insertions(+), 17 deletions(-) diff --git a/xarray/core/formatting.py b/xarray/core/formatting.py index e4b9d928d5b..3f42d4828a3 100644 --- a/xarray/core/formatting.py +++ b/xarray/core/formatting.py @@ -7,12 +7,12 @@ import functools import math from collections import ChainMap, defaultdict -from collections.abc import Collection, Hashable, Sequence +from collections.abc import Collection, Hashable, Mapping, Sequence from datetime import datetime, timedelta from itertools import chain, zip_longest from reprlib import recursive_repr from textwrap import dedent -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any import numpy as np import pandas as pd @@ -587,8 +587,10 @@ def _element_formatter( return "".join(out) -def dim_summary_limited(obj, col_width: int, max_rows: int | None = None) -> str: - elements = [f"{k}: {v}" for k, v in obj.sizes.items()] +def dim_summary_limited( + sizes: Mapping[Any, int], col_width: int, max_rows: int | None = None +) -> str: + elements = [f"{k}: {v}" for k, v in sizes.items()] return _element_formatter(elements, col_width, max_rows) @@ -692,7 +694,7 @@ def array_repr(arr): data_repr = inline_variable_array_repr(arr.variable, OPTIONS["display_width"]) start = f" Size: {nbytes_str}", @@ -737,7 +739,9 @@ def dataset_repr(ds): max_rows = OPTIONS["display_max_rows"] dims_start = pretty_print("Dimensions:", col_width) - dims_values = dim_summary_limited(ds, col_width=col_width + 1, max_rows=max_rows) + dims_values = dim_summary_limited( + ds.sizes, col_width=col_width + 1, max_rows=max_rows + ) summary.append(f"{dims_start}({dims_values})") if ds.coords: @@ -772,7 +776,9 @@ def dims_and_coords_repr(ds) -> str: max_rows = OPTIONS["display_max_rows"] dims_start = pretty_print("Dimensions:", col_width) - dims_values = dim_summary_limited(ds, col_width=col_width + 1, max_rows=max_rows) + dims_values = dim_summary_limited( + ds.sizes, col_width=col_width + 1, max_rows=max_rows + ) summary.append(f"{dims_start}({dims_values})") if ds.coords: @@ -1083,11 +1089,13 @@ def _datatree_node_repr(node: DataTree, show_inherited: bool) -> str: or node._data_variables ) + dim_sizes = node.sizes if show_inherited else node._node_dims + if show_dims: # Includes inherited dimensions. dims_start = pretty_print("Dimensions:", col_width) dims_values = dim_summary_limited( - node, col_width=col_width + 1, max_rows=max_rows + dim_sizes, col_width=col_width + 1, max_rows=max_rows ) summary.append(f"{dims_start}({dims_values})") @@ -1101,7 +1109,7 @@ def _datatree_node_repr(node: DataTree, show_inherited: bool) -> str: if show_dims: unindexed_dims_str = unindexed_dims_repr( - node.dims, node.coords, max_rows=max_rows + dim_sizes, node.coords, max_rows=max_rows ) if unindexed_dims_str: summary.append(unindexed_dims_str) diff --git a/xarray/tests/test_datatree.py b/xarray/tests/test_datatree.py index ba3041f271f..ac074b90d62 100644 --- a/xarray/tests/test_datatree.py +++ b/xarray/tests/test_datatree.py @@ -773,7 +773,8 @@ def test_operation_with_attrs_but_no_data(self): class TestRepr: - def test_repr(self): + + def test_repr_four_nodes(self): dt = DataTree.from_dict( { "/": xr.Dataset( @@ -797,14 +798,13 @@ def test_repr(self): │ Data variables: │ e (x) float64 16B 1.0 2.0 └── Group: /b - │ Dimensions: (x: 2, y: 1) + │ Dimensions: (y: 1) │ Dimensions without coordinates: y │ Data variables: │ f (y) float64 8B 3.0 ├── Group: /b/c └── Group: /b/d - Dimensions: (x: 2, y: 1) - Dimensions without coordinates: y + Dimensions: () Data variables: g float64 8B 4.0 """ @@ -824,15 +824,29 @@ def test_repr(self): │ f (y) float64 8B 3.0 ├── Group: /b/c └── Group: /b/d - Dimensions: (x: 2, y: 1) - Dimensions without coordinates: y + Dimensions: () Data variables: g float64 8B 4.0 """ ).strip() assert result == expected - def test_repr2(self): + result = repr(dt.b.d) + expected = dedent( + """ + + Group: /b/d + Dimensions: (x: 2, y: 1) + Inherited coordinates: + * x (x) float64 16B 2.0 3.0 + Dimensions without coordinates: y + Data variables: + g float64 8B 4.0 + """ + ).strip() + assert result == expected + + def test_repr_two_children(self): tree = DataTree.from_dict( { "/": Dataset(coords={"x": [1.0]}), @@ -884,6 +898,45 @@ def test_repr2(self): ).strip() assert result == expected + def test_repr_inherited_dims(self): + tree = DataTree.from_dict( + { + "/": Dataset({"foo": ("x", [1.0])}), + "/child": Dataset({"bar": ("y", [2.0])}), + } + ) + + result = repr(tree) + expected = dedent( + """ + + Group: / + │ Dimensions: (x: 1) + │ Dimensions without coordinates: x + │ Data variables: + │ foo (x) float64 8B 1.0 + └── Group: /child + Dimensions: (y: 1) + Dimensions without coordinates: y + Data variables: + bar (y) float64 8B 2.0 + """ + ).strip() + assert result == expected + + result = repr(tree["child"]) + expected = dedent( + """ + + Group: /child + Dimensions: (x: 1, y: 1) + Dimensions without coordinates: x, y + Data variables: + bar (y) float64 8B 2.0 + """ + ).strip() + assert result == expected + def _exact_match(message: str) -> str: return re.escape(dedent(message).strip()) diff --git a/xarray/tests/test_formatting.py b/xarray/tests/test_formatting.py index 696c849cea1..039bbfb4606 100644 --- a/xarray/tests/test_formatting.py +++ b/xarray/tests/test_formatting.py @@ -883,7 +883,7 @@ def test__mapping_repr(display_max_rows, n_vars, n_attr) -> None: col_width = formatting._calculate_col_width(ds.variables) dims_start = formatting.pretty_print("Dimensions:", col_width) dims_values = formatting.dim_summary_limited( - ds, col_width=col_width + 1, max_rows=display_max_rows + ds.sizes, col_width=col_width + 1, max_rows=display_max_rows ) expected_size = "1kB" expected = f"""\