Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

♻️ Let PyogrioReader return geodataframe only instead of tuple #33

Merged
merged 1 commit into from
Jul 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 11 additions & 17 deletions zen3geo/datapipes/pyogrio.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
DataPipes for :doc:`pyogrio <pyogrio:index>`.
"""
from typing import Any, Dict, Iterator, Optional, Tuple
from typing import Any, Dict, Iterator, Optional

try:
import pyogrio
Expand All @@ -13,15 +13,15 @@


@functional_datapipe("read_from_pyogrio")
class PyogrioReaderIterDataPipe(IterDataPipe[Tuple[str, StreamWrapper]]):
class PyogrioReaderIterDataPipe(IterDataPipe[StreamWrapper]):
"""
Takes vector files (e.g. FlatGeoBuf, GeoPackage, GeoJSON) from local disk
or URLs (as long as they can be read by pyogrio) and yields tuples of
filename and :py:class:`geopandas.GeoDataFrame` objects
or URLs (as long as they can be read by pyogrio) and yields
:py:class:`geopandas.GeoDataFrame` objects
(functional name: ``read_from_pyogrio``).

Based on
https://github.com/pytorch/data/blob/v0.3.0/torchdata/datapipes/iter/load/iopath.py#L37-L83
https://github.com/pytorch/data/blob/v0.4.0/torchdata/datapipes/iter/load/iopath.py#L42-L97

Parameters
----------
Expand All @@ -34,9 +34,8 @@ class PyogrioReaderIterDataPipe(IterDataPipe[Tuple[str, StreamWrapper]]):

Yields
------
stream_obj : Tuple[str, geopandas.GeoDataFrame]
A tuple consisting of the filename that was passed in, and a
:py:class:`geopandas.GeoDataFrame` object containing the vector data.
stream_obj : geopandas.GeoDataFrame
A :py:class:`geopandas.GeoDataFrame` object containing the vector data.

Raises
------
Expand All @@ -55,15 +54,13 @@ class PyogrioReaderIterDataPipe(IterDataPipe[Tuple[str, StreamWrapper]]):
>>> from zen3geo.datapipes import PyogrioReader
...
>>> # Read in GeoPackage data using DataPipe
>>> file_url: str = "https://github.com/geopandas/pyogrio/raw/v0.4.0a1/pyogrio/tests/fixtures/test_gpkg_nulls.gpkg"
>>> file_url: str = "https://github.com/geopandas/pyogrio/raw/v0.4.0/pyogrio/tests/fixtures/test_gpkg_nulls.gpkg"
>>> dp = IterableWrapper(iterable=[file_url])
>>> dp_pyogrio = dp.read_from_pyogrio()
...
>>> # Loop or iterate over the DataPipe stream
>>> it = iter(dp_pyogrio)
>>> filename, geodataframe = next(it)
>>> filename
'https://github.com/geopandas/pyogrio/raw/v0.4.0a1/pyogrio/tests/fixtures/test_gpkg_nulls.gpkg'
>>> geodataframe = next(it)
>>> geodataframe
StreamWrapper< col_bool col_int8 ... col_float64 geometry
0 1.0 1.0 ... 1.5 POINT (0.00000 0.00000)
Expand All @@ -87,12 +84,9 @@ def __init__(
self.source_datapipe: IterDataPipe[str] = source_datapipe
self.kwargs = kwargs

def __iter__(self) -> Iterator[Tuple]:
def __iter__(self) -> Iterator[StreamWrapper]:
for filename in self.source_datapipe:
yield (
filename,
StreamWrapper(pyogrio.read_dataframe(filename, **self.kwargs)),
)
yield StreamWrapper(pyogrio.read_dataframe(filename, **self.kwargs))

def __len__(self) -> int:
return len(self.source_datapipe)
5 changes: 2 additions & 3 deletions zen3geo/tests/test_datapipes_pyogrio.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def test_pyogrio_reader():
Ensure that PyogrioReader works to read in a GeoTIFF file and outputs a
tuple made up of a filename and an xarray.DataArray object.
"""
file_url: str = "https://github.com/geopandas/pyogrio/raw/v0.4.0a1/pyogrio/tests/fixtures/test_gpkg_nulls.gpkg"
file_url: str = "https://github.com/geopandas/pyogrio/raw/v0.4.0/pyogrio/tests/fixtures/test_gpkg_nulls.gpkg"
dp = IterableWrapper(iterable=[file_url])

# Using class constructors
Expand All @@ -24,9 +24,8 @@ def test_pyogrio_reader():

assert len(dp_pyogrio) == 1
it = iter(dp_pyogrio)
filename, geodataframe = next(it)
geodataframe = next(it)

assert isinstance(filename, str)
assert geodataframe.shape == (4, 12)
assert any(geodataframe.isna())
assert all(geodataframe.geom_type == "Point")