Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update DataTree repr to indicate inheritance #9470

Merged
merged 5 commits into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion asv_bench/benchmarks/datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ class Datatree:
def setup(self):
run1 = DataTree.from_dict({"run1": xr.Dataset({"a": 1})})
self.d_few = {"run1": run1}
self.d_many = {f"run{i}": run1.copy() for i in range(100)}
self.d_many = {f"run{i}": xr.Dataset({"a": 1}) for i in range(100)}

def time_from_dict_few(self):
DataTree.from_dict(self.d_few)
Expand Down
140 changes: 109 additions & 31 deletions xarray/core/formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
import contextlib
import functools
import math
from collections import defaultdict
from collections.abc import Collection, Hashable, Sequence
from collections import ChainMap, defaultdict
from collections.abc import Collection, Hashable, Mapping, Sequence
from datetime import datetime, timedelta
from itertools import chain, zip_longest
from reprlib import recursive_repr
from textwrap import dedent
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

import numpy as np
import pandas as pd
Expand All @@ -29,6 +29,7 @@
if TYPE_CHECKING:
from xarray.core.coordinates import AbstractCoordinates
from xarray.core.datatree import DataTree
from xarray.core.variable import Variable

UNITS = ("B", "kB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB")

Expand Down Expand Up @@ -318,7 +319,7 @@ def inline_variable_array_repr(var, max_width):

def summarize_variable(
name: Hashable,
var,
var: Variable,
col_width: int,
max_width: int | None = None,
is_index: bool = False,
Expand Down Expand Up @@ -446,6 +447,21 @@ def coords_repr(coords: AbstractCoordinates, col_width=None, max_rows=None):
)


def inherited_coords_repr(node: DataTree, col_width=None, max_rows=None):
coords = _inherited_vars(node._coord_variables)
if col_width is None:
col_width = _calculate_col_width(coords)
return _mapping_repr(
coords,
title="Inherited coordinates",
summarizer=summarize_variable,
expand_option_name="display_expand_coords",
col_width=col_width,
indexes=node._indexes,
max_rows=max_rows,
)


def inline_index_repr(index: pd.Index, max_width=None):
if hasattr(index, "_repr_inline_"):
repr_ = index._repr_inline_(max_width=max_width)
Expand Down Expand Up @@ -498,12 +514,12 @@ def filter_nondefault_indexes(indexes, filter_indexes: bool):
}


def indexes_repr(indexes, max_rows: int | None = None) -> str:
def indexes_repr(indexes, max_rows: int | None = None, title: str = "Indexes") -> str:
col_width = _calculate_col_width(chain.from_iterable(indexes))

return _mapping_repr(
indexes,
"Indexes",
title,
summarize_index,
"display_expand_indexes",
col_width=col_width,
Expand Down Expand Up @@ -571,8 +587,10 @@ def _element_formatter(
return "".join(out)


def dim_summary_limited(obj, col_width: int, max_rows: int | None = None) -> str:
elements = [f"{k}: {v}" for k, v in obj.sizes.items()]
def dim_summary_limited(
sizes: Mapping[Any, int], col_width: int, max_rows: int | None = None
) -> str:
elements = [f"{k}: {v}" for k, v in sizes.items()]
return _element_formatter(elements, col_width, max_rows)


Expand Down Expand Up @@ -676,7 +694,7 @@ def array_repr(arr):
data_repr = inline_variable_array_repr(arr.variable, OPTIONS["display_width"])

start = f"<xarray.{type(arr).__name__} {name_str}"
dims = dim_summary_limited(arr, col_width=len(start) + 1, max_rows=max_rows)
dims = dim_summary_limited(arr.sizes, col_width=len(start) + 1, max_rows=max_rows)
nbytes_str = render_human_readable_nbytes(arr.nbytes)
summary = [
f"{start}({dims})> Size: {nbytes_str}",
Expand Down Expand Up @@ -721,7 +739,9 @@ def dataset_repr(ds):
max_rows = OPTIONS["display_max_rows"]

dims_start = pretty_print("Dimensions:", col_width)
dims_values = dim_summary_limited(ds, col_width=col_width + 1, max_rows=max_rows)
dims_values = dim_summary_limited(
ds.sizes, col_width=col_width + 1, max_rows=max_rows
)
summary.append(f"{dims_start}({dims_values})")

if ds.coords:
Expand Down Expand Up @@ -756,7 +776,9 @@ def dims_and_coords_repr(ds) -> str:
max_rows = OPTIONS["display_max_rows"]

dims_start = pretty_print("Dimensions:", col_width)
dims_values = dim_summary_limited(ds, col_width=col_width + 1, max_rows=max_rows)
dims_values = dim_summary_limited(
ds.sizes, col_width=col_width + 1, max_rows=max_rows
)
summary.append(f"{dims_start}({dims_values})")

if ds.coords:
Expand Down Expand Up @@ -1048,39 +1070,95 @@ def diff_datatree_repr(a: DataTree, b: DataTree, compat):
return "\n".join(summary)


def _single_node_repr(node: DataTree) -> str:
"""Information about this node, not including its relationships to other nodes."""
if node.has_data or node.has_attrs:
# TODO: change this to inherited=False, in order to clarify what is
# inherited? https://github.com/pydata/xarray/issues/9463
node_view = node._to_dataset_view(rebuild_dims=False, inherited=True)
ds_info = "\n" + repr(node_view)
else:
ds_info = ""
return f"Group: {node.path}{ds_info}"
def _inherited_vars(mapping: ChainMap) -> dict:
return {k: v for k, v in mapping.parents.items() if k not in mapping.maps[0]}
Comment on lines +1073 to +1074
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if there is anywhere else that this function might be useful.



def _datatree_node_repr(node: DataTree, show_inherited: bool) -> str:
summary = [f"Group: {node.path}"]

col_width = _calculate_col_width(node.variables)
max_rows = OPTIONS["display_max_rows"]

inherited_coords = _inherited_vars(node._coord_variables)

# Only show dimensions if also showing a variable or coordinates section.
show_dims = (
node._node_coord_variables
or (show_inherited and inherited_coords)
or node._data_variables
)

dim_sizes = node.sizes if show_inherited else node._node_dims

if show_dims:
# Includes inherited dimensions.
dims_start = pretty_print("Dimensions:", col_width)
dims_values = dim_summary_limited(
dim_sizes, col_width=col_width + 1, max_rows=max_rows
)
summary.append(f"{dims_start}({dims_values})")

if node._node_coord_variables:
summary.append(coords_repr(node.coords, col_width=col_width, max_rows=max_rows))

if show_inherited and inherited_coords:
summary.append(
inherited_coords_repr(node, col_width=col_width, max_rows=max_rows)
)

if show_dims:
unindexed_dims_str = unindexed_dims_repr(
dim_sizes, node.coords, max_rows=max_rows
)
if unindexed_dims_str:
summary.append(unindexed_dims_str)

if node._data_variables:
summary.append(
data_vars_repr(node._data_variables, col_width=col_width, max_rows=max_rows)
)

# TODO: only show indexes defined at this node, with a separate section for
# inherited indexes (if show_inherited=True)
display_default_indexes = _get_boolean_with_default(
"display_default_indexes", False
)
xindexes = filter_nondefault_indexes(
_get_indexes_dict(node.xindexes), not display_default_indexes
)
if xindexes:
summary.append(indexes_repr(xindexes, max_rows=max_rows))

def datatree_repr(dt: DataTree):
if node.attrs:
summary.append(attrs_repr(node.attrs, max_rows=max_rows))

return "\n".join(summary)


def datatree_repr(dt: DataTree) -> str:
"""A printable representation of the structure of this entire tree."""
renderer = RenderDataTree(dt)

name_info = "" if dt.name is None else f" {dt.name!r}"
header = f"<xarray.DataTree{name_info}>"

lines = [header]
show_inherited = True
for pre, fill, node in renderer:
node_repr = _single_node_repr(node)
node_repr = _datatree_node_repr(node, show_inherited=show_inherited)
show_inherited = False # only show inherited coords on the root

raw_repr_lines = node_repr.splitlines()

node_line = f"{pre}{node_repr.splitlines()[0]}"
node_line = f"{pre}{raw_repr_lines[0]}"
lines.append(node_line)

if node.has_data or node.has_attrs:
ds_repr = node_repr.splitlines()[2:]
for line in ds_repr:
if len(node.children) > 0:
lines.append(f"{fill}{renderer.style.vertical}{line}")
else:
lines.append(f"{fill}{' ' * len(renderer.style.vertical)}{line}")
for line in raw_repr_lines[1:]:
if len(node.children) > 0:
lines.append(f"{fill}{renderer.style.vertical}{line}")
else:
lines.append(f"{fill}{' ' * len(renderer.style.vertical)}{line}")

return "\n".join(lines)

Expand Down
Loading
Loading