diff --git a/pygmt/src/project.py b/pygmt/src/project.py index 99738bfd9c8..833c58ce299 100644 --- a/pygmt/src/project.py +++ b/pygmt/src/project.py @@ -2,15 +2,18 @@ project - Project data onto lines or great circles, or generate tracks. """ +from typing import Literal + +import numpy as np import pandas as pd from pygmt.clib import Session from pygmt.exceptions import GMTInvalidInput from pygmt.helpers import ( - GMTTempFile, build_arg_string, fmt_docstring, kwargs_to_strings, use_alias, + validate_output_table_type, ) @@ -32,7 +35,15 @@ f="coltypes", ) @kwargs_to_strings(E="sequence", L="sequence", T="sequence", W="sequence", C="sequence") -def project(data=None, x=None, y=None, z=None, outfile=None, **kwargs): +def project( + data=None, + x=None, + y=None, + z=None, + output_type: Literal["pandas", "numpy", "file"] = "pandas", + outfile: str | None = None, + **kwargs, +) -> pd.DataFrame | np.ndarray | None: r""" Project data onto lines or great circles, or generate tracks. @@ -105,6 +116,8 @@ def project(data=None, x=None, y=None, z=None, outfile=None, **kwargs): Pass in (x, y, z) or (longitude, latitude, elevation) values by providing a file name to an ASCII data table, a 2-D {table-classes}. + {output_type} + {outfile} center : str or list *cx*/*cy*. @@ -196,22 +209,18 @@ def project(data=None, x=None, y=None, z=None, outfile=None, **kwargs): *direction* is counter-clockwise from the horizontal instead of an *azimuth*. - outfile : str - The file name for the output ASCII file. - {coltypes} Returns ------- - track: pandas.DataFrame or None - Return type depends on whether the ``outfile`` parameter is set: + ret + Return type depends on ``outfile`` and ``output_type``: - - :class:`pandas.DataFrame` table with (x, y, ..., newcolname) if - ``outfile`` is not set - - None if ``outfile`` is set (output will be stored in file set - by ``outfile``) + - ``None`` if ``outfile`` is set (output will be stored in file set by + ``outfile``) + - :class:`pandas.DataFrame` or :class:`numpy.ndarray` if ``outfile`` is not set + (depends on ``output_type``) """ - if kwargs.get("C") is None: raise GMTInvalidInput("The `center` parameter must be specified.") if kwargs.get("G") is None and data is None: @@ -223,29 +232,31 @@ def project(data=None, x=None, y=None, z=None, outfile=None, **kwargs): "The `convention` parameter is not allowed with `generate`." ) - with GMTTempFile(suffix=".csv") as tmpfile: - if outfile is None: # Output to tmpfile if outfile is not set - outfile = tmpfile.name - with Session() as lib: - if kwargs.get("G") is None: - with lib.virtualfile_in( - check_kind="vector", data=data, x=x, y=y, z=z, required_z=False - ) as vintbl: - # Run project on the temporary (csv) data table - arg_str = build_arg_string(kwargs, infile=vintbl, outfile=outfile) - else: - arg_str = build_arg_string(kwargs, outfile=outfile) - lib.call_module(module="project", args=arg_str) - - # if user did not set outfile, return pd.DataFrame - if outfile == tmpfile.name: - if kwargs.get("G") is not None: - column_names = list("rsp") - result = pd.read_csv(tmpfile.name, sep="\t", names=column_names) - else: - result = pd.read_csv(tmpfile.name, sep="\t", header=None, comment=">") - # return None if outfile set, output in outfile - elif outfile != tmpfile.name: - result = None - - return result + output_type = validate_output_table_type(output_type, outfile=outfile) + + column_names = None + if output_type == "pandas" and kwargs.get("G") is not None: + column_names = list("rsp") + + with Session() as lib: + with ( + lib.virtualfile_in( + check_kind="vector", + data=data, + x=x, + y=y, + z=z, + required_z=False, + required_data=False, + ) as vintbl, + lib.virtualfile_out(kind="dataset", fname=outfile) as vouttbl, + ): + lib.call_module( + module="project", + args=build_arg_string(kwargs, infile=vintbl, outfile=vouttbl), + ) + return lib.virtualfile_to_dataset( + output_type=output_type, + vfname=vouttbl, + column_names=column_names, + )