Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow Variable type as dim argument to concat #8384

Merged
merged 1 commit into from
Oct 29, 2023

Conversation

maresb
Copy link
Contributor

@maresb maresb commented Oct 27, 2023

Given the following excerpt from the xr.concat() source code

if isinstance(dim, DataArray):
dim_var = dim.variable
elif isinstance(dim, Variable):
dim_var = dim
else:
dim_var = None

it seems like it's explicitly intended that the dim= argument can be of type xr.Variable. However, it's not indicated as such in the type hints or the documentation. This leads to a type error in pyright when a Variable type is used for dim.

I'm submitting this PR to fix this apparent shortcoming. Or have I overlooked some reason why this should not be the case? Thanks for your consideration!

@welcome
Copy link

welcome bot commented Oct 27, 2023

Thank you for opening this pull request! It may take us a few days to respond here, so thank you for being patient.
If you have questions, some answers may be found in our contributing guidelines.

Copy link
Collaborator

@max-sixty max-sixty left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @maresb !

In order to test this, would you be up for finding a test which passes a Variable, and adding a -> None to its return type? Then mypy will check the types. (and mypy should raise an error on the old code).

@maresb
Copy link
Contributor Author

maresb commented Oct 28, 2023

Thanks @max-sixty for the response! If I understand what you mean, then this test already meets your criteria except for mypy erroring on old code:

[EDIT: See code snipped from the comment below because I selected the wrong function.]

So now the mystery is why mypy allows this? By replacing the annotations which I modified in this PR with pd.Index I infer that Variable is matching the pd.Index type, but I have no idea why. 🤔 To be sure it's not somehow secretly a subclass, I verified that assert isinstance(coord, pd.Index) fails.

In contrast, if I mamba install pyright and run pyright xarray/tests/test_concat.py on old code I get

  /home/mares/repos/xarray/xarray/tests/test_concat.py:902:18 - error: No overloads for "concat" match the provided arguments (reportGeneralTypeIssues)
  /home/mares/repos/xarray/xarray/tests/test_concat.py:902:31 - error: Argument of type "Variable" cannot be assigned to parameter "dim" of type "Hashable | T_DataArray@concat | Index" in function "concat"
    Type "Variable" cannot be assigned to type "Hashable | T_DataArray@concat | Index"
      "Variable" is incompatible with protocol "Hashable"
        "__hash__" is an incompatible type
          Type "None" cannot be assigned to type "() -> int"
      Type "Variable" cannot be assigned to type "DataArray"
        "Variable" is incompatible with "DataArray"
      "Variable" is incompatible with "Index" (reportGeneralTypeIssues)

among other errors, which is more logical to me. Any idea what's going on?

@max-sixty
Copy link
Collaborator

That is odd!

BTW I think the function that tests this is the one above:

def test_concat_dim_is_variable(self) -> None:
objs = [Dataset({"x": 0}), Dataset({"x": 1})]
coord = Variable("y", [3, 4], attrs={"foo": "bar"})
expected = Dataset({"x": ("y", [0, 1]), "y": coord})
actual = concat(objs, coord)
assert_identical(actual, expected)

I had a look through — it's possible that concat is overloaded and the overloaded version has is untyped. I'm not sure.

I tried removing the overload and using T_Xarray, but a) it raises an error I can't work out in dataarray_plot.py b) it doesn't highlight the error. Pasting the diff below in case @headtr1ck or others have any ideas.

diff --git a/xarray/core/concat.py b/xarray/core/concat.py
index a136480b..acbd8c80 100644
--- a/xarray/core/concat.py
+++ b/xarray/core/concat.py
@@ -1,7 +1,7 @@
 from __future__ import annotations
 
 from collections.abc import Hashable, Iterable
-from typing import TYPE_CHECKING, Any, Union, overload
+from typing import TYPE_CHECKING, Any, Iterator, Sequence, Union, overload
 
 import numpy as np
 import pandas as pd
@@ -16,7 +16,7 @@
     merge_attrs,
     merge_collected,
 )
-from xarray.core.types import T_DataArray, T_Dataset
+from xarray.core.types import T_DataArray, T_Dataset, T_Xarray
 from xarray.core.variable import Variable
 from xarray.core.variable import concat as concat_vars
 
@@ -31,47 +31,17 @@
     T_DataVars = Union[ConcatOptions, Iterable[Hashable]]
 
 
-@overload
 def concat(
-    objs: Iterable[T_Dataset],
+    objs: Iterable[T_Xarray],
     dim: Hashable | T_DataArray | pd.Index,
     data_vars: T_DataVars = "all",
-    coords: ConcatOptions | list[Hashable] = "different",
-    compat: CompatOptions = "equals",
-    positions: Iterable[Iterable[int]] | None = None,
-    fill_value: object = dtypes.NA,
-    join: JoinOptions = "outer",
-    combine_attrs: CombineAttrsOptions = "override",
-) -> T_Dataset:
-    ...
-
-
-@overload
-def concat(
-    objs: Iterable[T_DataArray],
-    dim: Hashable | T_DataArray | pd.Index,
-    data_vars: T_DataVars = "all",
-    coords: ConcatOptions | list[Hashable] = "different",
-    compat: CompatOptions = "equals",
-    positions: Iterable[Iterable[int]] | None = None,
-    fill_value: object = dtypes.NA,
-    join: JoinOptions = "outer",
-    combine_attrs: CombineAttrsOptions = "override",
-) -> T_DataArray:
-    ...
-
-
-def concat(
-    objs,
-    dim,
-    data_vars: T_DataVars = "all",
     coords="different",
     compat: CompatOptions = "equals",
     positions=None,
     fill_value=dtypes.NA,
     join: JoinOptions = "outer",
     combine_attrs: CombineAttrsOptions = "override",
-):
+) -> T_Xarray:
     """Concatenate xarray objects along a new or existing dimension.
 
     Parameters
@@ -449,7 +419,7 @@ def _parse_datasets(
 
 
 def _dataset_concat(
-    datasets: list[T_Dataset],
+    datasets: Iterable[T_Dataset],
     dim: str | T_DataArray | pd.Index,
     data_vars: T_DataVars,
     coords: str | list[str],
diff --git a/xarray/plot/dataarray_plot.py b/xarray/plot/dataarray_plot.py
index 61f2014f..bdb97cac 100644
--- a/xarray/plot/dataarray_plot.py
+++ b/xarray/plot/dataarray_plot.py
@@ -193,8 +193,9 @@ def _prepare_plot1d_data(
             for v in ["z", "x"]:
                 dim = coords_to_plot.get(v, None)
                 if (dim is not None) and (dim in darray.dims):
-                    darray_nan = np.nan * darray.isel({dim: -1})
-                    darray = concat([darray, darray_nan], dim=dim)
+                    darray_nan: T_DataArray = np.nan * darray.isel({dim: -1})
+                    arrays: list[T_DataArray] = [darray, darray_nan]
+                    darray = concat(arrays, dim=dim)
                     dims_T.append(coords_to_plot[v])
 
         # Lines should never connect to the same coordinate when stacked,

Regardless, let's merge, thanks for the change!

@max-sixty max-sixty merged commit 04eb342 into pydata:main Oct 29, 2023
29 checks passed
@welcome
Copy link

welcome bot commented Oct 29, 2023

Congratulations on completing your first pull request! Welcome to Xarray! We are proud of you, and hope to see you again! celebration gif

@maresb maresb deleted the add-Varaible-type-to-concat-dim branch October 29, 2023 09:36
@maresb
Copy link
Contributor Author

maresb commented Oct 29, 2023

Thanks @max-sixty for the merge!

I had a look at your diff and respective mypy error in xarray.plot.dataarray_plot:_prepare_plot1d_data, and I believe that if we want to use TypeVars wherever possible, then the overload is necessary due to a very subtle typing issue.

First of all, one easy fix for

darray_nan = np.nan * darray.isel({dim: -1})

which gives a DataArray instead of T_DataArray is to replace it with:

end_of_v = darray.isel({dim: -1})
darray_nan = end_of_v.copy(data=np.full_like(darray.data, np.nan))

Then the error is that concat(arrays, dim=dim) returns type DataArray instead of T_DataArray. This is the really subtle issue I mentioned above. Explanation:

In the TypeVar docs, it states:

Note that type variables can be bound, constrained, or neither, but cannot be both bound and constrained.

...

Bound type variables and constrained type variables have different semantics in several important ways. Using a bound type variable means that the TypeVar will be solved using the most specific type possible

...

Using a constrained type variable, however, means that the TypeVar can only ever be solved as being exactly one of the constraints given

Looking at the definition

T_Xarray = TypeVar("T_Xarray", "DataArray", "Dataset")

we see that T_Xarray is a constrained type since there's no TypeVar keyword. (While multiple constrains as above are supported, multiple bounds are unfortunately not currently supported.)

Thus according to the docs, T_Xarray must resolve to either DataArray or Dataset, and therefore it doesn't resolve to T_DataArray as is required by the return type of _prepare_plot1d_data.

One almost-solution is to switch from T_Xarray to

xarray/xarray/core/types.py

Lines 164 to 167 in f63ede9

# `T_DataArrayOrSet` is a type variable that is bounded to either "DataArray" or
# "Dataset". Use it for functions that might return either type, but where the exact
# type cannot be determined statically using the type system.
T_DataArrayOrSet = TypeVar("T_DataArrayOrSet", bound=Union["Dataset", "DataArray"])

But this would be bad since the (first) objs argument to concatenate is Iterable[T], and we should disallow T=Union[DataArray, Dataset].

In conclusion, the minimally-disruptive solution appears to be to keep the overloads.

The more obvious solution would be to not use T_DataArray but simply DataArray since AFAICT the only reason for T_DataArray would be to preserve subclasses, and subclasses in Xarray don't really seem well-supported, but that's a pretty deep design decision. (I am a bit curious about the rationale though.)

@max-sixty
Copy link
Collaborator

That's a very impressive explanation @maresb — thank you!

The more obvious solution would be to not use T_DataArray but simply DataArray since AFAICT the only reason for T_DataArray would be to preserve subclasses, and subclasses in Xarray don't really seem well-supported, but that's a pretty deep design decision. (I am a bit curious about the rationale though.)

We would like to support them better, so I we should try and move towards T_DataArray and work on allowing subclassing in more methods.

Thus according to the docs, T_Xarray must resolve to either DataArray or Dataset, and therefore it doesn't resolve to T_DataArray as is required by the return type of _prepare_plot1d_data.

So it sounds like this problem is intractable — that we can't use T_Xarray together with T_DataArray? And so we can't support:

  • some methods that only return DataArrays and their subclasses
  • some methods that return either DataArray or Dataset, and retain their concrete type through their methods

"retain their concrete type through their methods" is important — IIUC, if we use T_DataArrayOrSet, then the return type of ds.isel(x=2) is a union of both, rather than Dataset, which would be a shame.

@maresb
Copy link
Contributor Author

maresb commented Oct 29, 2023

So it sounds like this problem is intractable — that we can't use T_Xarray together with T_DataArray?

In my current understanding this is correct. (I'm tenacious but no expert in the internals of Python typing.)

some methods that only return DataArrays and their subclasses

Not with T_Xarray, but we can use T_DataArray in this case.

some methods that return either DataArray or Dataset, and retain their concrete type through their methods

Not with T_Xarray, but I think overload should cover most cases for this.

IIUC, if we use T_DataArrayOrSet, then the return type of ds.isel(x=2) is a union of both, rather than Dataset, which would be a shame

I don't understand exactly what you mean with your ds.isel(x=2) example. I honestly would have expected T_DataArrayOrSet to do the right thing in most cases, but if this is what you mean, then indeed the following example annoyingly fails

def sel_x(data: T_DataArrayOrSet) -> T_DataArrayOrSet:
    return data.sel(x=2)

It seems that there is some level of type "evaluation" in mypy since sel has return type Self for both Dataset and DataArray cases, resulting in Dataset | DataArray instead of T_DataArrayOrSet. I'm not sure if this is a failure of mypy or an inevitability of type systems. FWIW, pyright doesn't complain here.

Ugliness of overload aside, do you see any problems that can't be solved with overload?

@max-sixty
Copy link
Collaborator

I don't understand exactly what you mean with your ds.isel(x=2) example. I honestly would have expected T_DataArrayOrSet to do the right thing in most cases, but if this is what you mean, then indeed the following example annoyingly fails

def sel_x(data: T_DataArrayOrSet) -> T_DataArrayOrSet:
    return data.sel(x=2)

The disadvantage of T_DataArrayOrSet is that it's not possible to state "if we get a dataarray, we're returning a dataarray; if we get a dataset, we're returning a dataset", because it's a union rather than a generic type. So our API would always be returning "either a dataarray or a dataset"

That example is a good one — though if we typed .sel as T_DataArrayOrSet, then it would work IIUC. But if we typed .sel as T_DataArrayOrSet, then this wouldn't work:

ds = ds.sel(x=2)

...because the variable would be a T_Dataset, but the expression would be a union of T_Dataset & T_DataArray.

Ugliness of overload aside, do you see any problems that can't be solved with overload?

Yes, I think you're correct! i.e. in place of making things generic, we could just write out all the concrete types, almost like manual C++ templating..

@maresb
Copy link
Contributor Author

maresb commented Oct 29, 2023

The disadvantage of T_DataArrayOrSet is that it's not possible to state "if we get a dataarray, we're returning a dataarray; if we get a dataset, we're returning a dataset", because it's a union rather than a generic type.

But semantically T_DataArrayOrSet is supposed to do exactly what we want. For example,

def identity(data: T_DataArrayOrSet) -> T_DataArrayOrSet:
    return data

class DatasetSubclass(Dataset):
    pass

ds = DatasetSubclass({"x": 0})
ds2 = identity(ds)
reveal_type(ds2)  # DatasetSubclass

It's just that for whatever reason mypy is choking on the combination of T_DataArrayOrSet and sel. I think the reason we can't type sel with T_DataArrayOrSet is simply that sel has two independent implementations on DataArray and Dataset. Maybe it would work if we defined sel on a common parent class?

As far as I can tell, the only fundamental reason to avoid T_DataArrayOrSet is situations like the objs=Iterable[T] argument of concat where T needs to be "either Dataset or DataArray but not the union".

@max-sixty
Copy link
Collaborator

OK, that's v helpful — I was wrong in my assumptions about these two constructions — they're actually different:

T_DataArrayOrSet = TypeVar("T_DataArrayOrSet", bound=Union["Dataset", "DataArray"])
T_DataArrayOrSet = Union["Dataset", "DataArray"]

It's only the Union["Dataset", "DataArray"] construction (which we don't use) for which my claim above is correct:

The disadvantage of T_DataArrayOrSet Union["Dataset", "DataArray"] is that it's not possible to state "if we get a dataarray, we're returning a dataarray; if we get a dataset, we're returning a dataset", because it's a union rather than a generic type. So our API would always be returning "either a dataarray or a dataset"


As far as I can tell, the only fundamental reason to avoid T_DataArrayOrSet is situations like the objs=Iterable[T] argument of concat where T needs to be "either Dataset or DataArray but not the union".

So I think that's correct. And not exactly what the comments above the types in our code state (!)

I think the reason we can't type sel with T_DataArrayOrSet is simply that sel has two independent implementations on DataArray and Dataset.

Yes, I think that's also correct (which makes the somewhat misguided comments to use T_Xarray less bad, since we generally do need to use that type).

Maybe it would work if we defined sel on a common parent class?

I'd be very open to this if it worked out overall.

OTOH I would not favor having DataWithCoords.sel as an if / else based on the type, with no shared code — then we've added an extra layer of indirection without any other benefits.


Thanks a lot for the feedback @maresb ! Lmk if there are some code or doc / comment changes we could make from this...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants