Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update DataTree repr to indicate inheritance #9470

Merged
merged 5 commits into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 94 additions & 24 deletions xarray/core/formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import contextlib
import functools
import math
from collections import defaultdict
from collections import ChainMap, defaultdict
from collections.abc import Collection, Hashable, Sequence
from datetime import datetime, timedelta
from itertools import chain, zip_longest
Expand All @@ -29,6 +29,7 @@
if TYPE_CHECKING:
from xarray.core.coordinates import AbstractCoordinates
from xarray.core.datatree import DataTree
from xarray.core.variable import Variable

UNITS = ("B", "kB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB")

Expand Down Expand Up @@ -318,7 +319,7 @@ def inline_variable_array_repr(var, max_width):

def summarize_variable(
name: Hashable,
var,
var: Variable,
col_width: int,
max_width: int | None = None,
is_index: bool = False,
Expand Down Expand Up @@ -446,6 +447,21 @@ def coords_repr(coords: AbstractCoordinates, col_width=None, max_rows=None):
)


def inherited_coords_repr(node: DataTree, col_width=None, max_rows=None):
coords = _inherited_vars(node._coord_variables)
if col_width is None:
col_width = _calculate_col_width(coords)
return _mapping_repr(
coords,
title="Inherited coordinates",
summarizer=summarize_variable,
expand_option_name="display_expand_coords",
col_width=col_width,
indexes=node._indexes,
max_rows=max_rows,
)


def inline_index_repr(index: pd.Index, max_width=None):
if hasattr(index, "_repr_inline_"):
repr_ = index._repr_inline_(max_width=max_width)
Expand Down Expand Up @@ -498,12 +514,12 @@ def filter_nondefault_indexes(indexes, filter_indexes: bool):
}


def indexes_repr(indexes, max_rows: int | None = None) -> str:
def indexes_repr(indexes, max_rows: int | None = None, title: str = "Indexes") -> str:
col_width = _calculate_col_width(chain.from_iterable(indexes))

return _mapping_repr(
indexes,
"Indexes",
title,
summarize_index,
"display_expand_indexes",
col_width=col_width,
Expand Down Expand Up @@ -1048,39 +1064,93 @@ def diff_datatree_repr(a: DataTree, b: DataTree, compat):
return "\n".join(summary)


def _single_node_repr(node: DataTree) -> str:
"""Information about this node, not including its relationships to other nodes."""
if node.has_data or node.has_attrs:
# TODO: change this to inherited=False, in order to clarify what is
# inherited? https://github.com/pydata/xarray/issues/9463
node_view = node._to_dataset_view(rebuild_dims=False, inherited=True)
ds_info = "\n" + repr(node_view)
else:
ds_info = ""
return f"Group: {node.path}{ds_info}"
def _inherited_vars(mapping: ChainMap) -> dict:
return {k: v for k, v in mapping.parents.items() if k not in mapping.maps[0]}
Comment on lines +1073 to +1074
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if there is anywhere else that this function might be useful.



def _datatree_node_repr(node: DataTree, show_inherited: bool) -> str:
summary = [f"Group: {node.path}"]

col_width = _calculate_col_width(node.variables)
max_rows = OPTIONS["display_max_rows"]

inherited_coords = _inherited_vars(node._coord_variables)

# Only show dimensions if also showing a variable or coordinates section.
show_dims = (
node._node_coord_variables
or (show_inherited and inherited_coords)
or node._data_variables
)

if show_dims:
# Includes inherited dimensions.
dims_start = pretty_print("Dimensions:", col_width)
dims_values = dim_summary_limited(
node, col_width=col_width + 1, max_rows=max_rows
shoyer marked this conversation as resolved.
Show resolved Hide resolved
)
summary.append(f"{dims_start}({dims_values})")

def datatree_repr(dt: DataTree):
if node._node_coord_variables:
summary.append(coords_repr(node.coords, col_width=col_width, max_rows=max_rows))

if show_inherited and inherited_coords:
summary.append(
inherited_coords_repr(node, col_width=col_width, max_rows=max_rows)
)

if show_dims:
unindexed_dims_str = unindexed_dims_repr(
node.dims, node.coords, max_rows=max_rows
)
if unindexed_dims_str:
summary.append(unindexed_dims_str)

if node._data_variables:
summary.append(
data_vars_repr(node._data_variables, col_width=col_width, max_rows=max_rows)
)

# TODO: only show indexes defined at this node, with a separate section for
# inherited indexes (if show_inherited=True)
display_default_indexes = _get_boolean_with_default(
"display_default_indexes", False
)
xindexes = filter_nondefault_indexes(
_get_indexes_dict(node.xindexes), not display_default_indexes
)
if xindexes:
summary.append(indexes_repr(xindexes, max_rows=max_rows))

if node.attrs:
summary.append(attrs_repr(node.attrs, max_rows=max_rows))

return "\n".join(summary)


def datatree_repr(dt: DataTree) -> str:
"""A printable representation of the structure of this entire tree."""
renderer = RenderDataTree(dt)

name_info = "" if dt.name is None else f" {dt.name!r}"
header = f"<xarray.DataTree{name_info}>"

lines = [header]
show_inherited = True
for pre, fill, node in renderer:
node_repr = _single_node_repr(node)
node_repr = _datatree_node_repr(node, show_inherited=show_inherited)
show_inherited = False # only show inherited coords on the root

node_line = f"{pre}{node_repr.splitlines()[0]}"
raw_repr_lines = node_repr.splitlines()

node_line = f"{pre}{raw_repr_lines[0]}"
lines.append(node_line)

if node.has_data or node.has_attrs:
ds_repr = node_repr.splitlines()[2:]
for line in ds_repr:
if len(node.children) > 0:
lines.append(f"{fill}{renderer.style.vertical}{line}")
else:
lines.append(f"{fill}{' ' * len(renderer.style.vertical)}{line}")
for line in raw_repr_lines[1:]:
if len(node.children) > 0:
lines.append(f"{fill}{renderer.style.vertical}{line}")
else:
lines.append(f"{fill}{' ' * len(renderer.style.vertical)}{line}")

return "\n".join(lines)

Expand Down
10 changes: 2 additions & 8 deletions xarray/tests/test_datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -798,20 +798,16 @@ def test_repr(self):
│ e (x) float64 16B 1.0 2.0
└── Group: /b
│ Dimensions: (x: 2, y: 1)
│ Coordinates:
│ * x (x) float64 16B 2.0 3.0
│ Dimensions without coordinates: y
│ Data variables:
│ f (y) float64 8B 3.0
├── Group: /b/c
└── Group: /b/d
Dimensions: (x: 2, y: 1)
Coordinates:
* x (x) float64 16B 2.0 3.0
Dimensions without coordinates: y
Data variables:
g float64 8B 4.0
"""
"""
).strip()
assert result == expected

Expand All @@ -821,16 +817,14 @@ def test_repr(self):
<xarray.DataTree 'b'>
Group: /b
│ Dimensions: (x: 2, y: 1)
Coordinates:
Inherited coordinates:
│ * x (x) float64 16B 2.0 3.0
│ Dimensions without coordinates: y
│ Data variables:
│ f (y) float64 8B 3.0
├── Group: /b/c
└── Group: /b/d
Dimensions: (x: 2, y: 1)
Coordinates:
* x (x) float64 16B 2.0 3.0
Dimensions without coordinates: y
Data variables:
g float64 8B 4.0
Expand Down
Loading