Skip to content

Commit

Permalink
Merge pull request #1680 from OceanParcels/v/small-changes
Browse files Browse the repository at this point in the history
Enable pyupgrade, add Grid repr, and other changes
  • Loading branch information
VeckoTheGecko committed Sep 12, 2024
2 parents d4d35c1 + 5c4e289 commit b007dcf
Show file tree
Hide file tree
Showing 28 changed files with 141 additions and 80 deletions.
1 change: 0 additions & 1 deletion .binder/environment.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
name: parcels_binder
channels:
- conda-forge
- defaults
dependencies:
- parcels
3 changes: 3 additions & 0 deletions docs/documentation/additional_examples.rst
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
Python Example Scripts
======================

example_brownian.py
-------------------

Expand Down
2 changes: 1 addition & 1 deletion docs/examples/tutorial_delaystart.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"The simplest way to delaye the start of a particle is to use the `time` argument for each particle\n"
"The simplest way to delay the start of a particle is to use the `time` argument for each particle\n"
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/tutorial_timestamps.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@
"outputs": [],
"source": [
"timestamps = np.expand_dims(\n",
" np.array([np.datetime64(\"2001-%.2d-15\" % m) for m in range(1, 13)]), axis=1\n",
" np.array([np.datetime64(f\"2001-{m:02d}-15\") for m in range(1, 13)]), axis=1\n",
")"
]
},
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/tutorial_unitconverters.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"The units for Brownian diffusion are in $m^2/s$. If (and only if!) the diffusion fields are called `kh_zonal` and `kh_meridional`, Parcels will automatically assign the correct Unitconverter objects to these fields.\n"
"The units for Brownian diffusion are in $m^2/s$. If (and only if!) the diffusion fields are called \"Kh_zonal\" and \"Kh_meridional\", Parcels will automatically assign the correct Unitconverter objects to these fields.\n"
]
},
{
Expand Down
3 changes: 2 additions & 1 deletion parcels/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
import ast
import datetime
import os
from typing import Callable, Literal
from collections.abc import Callable
from typing import Literal


class ParcelsAST(ast.AST):
Expand Down
2 changes: 1 addition & 1 deletion parcels/compilation/codecompiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def __init__(self, cppargs=None, ldargs=None, incdirs=None, libdirs=None, libs=N
self._ldargs += lflags
self._ldargs += ldargs
if len(Lflags) > 0:
self._ldargs += ["-Wl, -rpath=%s" % (":".join(libdirs))]
self._ldargs += [f"-Wl, -rpath={':'.join(libdirs)}"]
self._ldargs += arch_flag
self._incdirs = incdirs
self._libdirs = libdirs
Expand Down
32 changes: 16 additions & 16 deletions parcels/compilation/codegenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __getattr__(self, attr):
elif isinstance(getattr(self.obj, attr), VectorField):
return VectorFieldNode(getattr(self.obj, attr), ccode=f"{self.ccode}->{attr}")
else:
return ConstNode(getattr(self.obj, attr), ccode="%s" % (attr))
return ConstNode(getattr(self.obj, attr), ccode=f"{attr}")


class FieldNode(IntrinsicNode):
Expand Down Expand Up @@ -489,13 +489,13 @@ def visit_FunctionDef(self, node):
c.Value("double", "time"),
]
for field in self.field_args.values():
args += [c.Pointer(c.Value("CField", "%s" % field.ccode_name))]
args += [c.Pointer(c.Value("CField", f"{field.ccode_name}"))]
for field in self.vector_field_args.values():
for fcomponent in ["U", "V", "W"]:
try:
f = getattr(field, fcomponent)
if f.ccode_name not in self.field_args:
args += [c.Pointer(c.Value("CField", "%s" % f.ccode_name))]
args += [c.Pointer(c.Value("CField", f"{f.ccode_name}"))]
self.field_args[f.ccode_name] = f
except:
pass # field.W does not always exist
Expand Down Expand Up @@ -528,9 +528,9 @@ def visit_Call(self, node):
if isinstance(node.func, PrintNode):
# Write our own Print parser because Python3-AST does not seem to have one
if isinstance(node.args[0], ast.Str):
node.ccode = str(c.Statement('printf("%s\\n")' % (node.args[0].s)))
node.ccode = str(c.Statement(f'printf("{node.args[0].s}\\n")'))
elif isinstance(node.args[0], ast.Name):
node.ccode = str(c.Statement('printf("%%f\\n", %s)' % (node.args[0].id)))
node.ccode = str(c.Statement(f'printf("%f\\n", {node.args[0].id})'))
elif isinstance(node.args[0], ast.BinOp):
if hasattr(node.args[0].right, "ccode"):
args = node.args[0].right.ccode
Expand All @@ -545,12 +545,12 @@ def visit_Call(self, node):
args.append(a.id)
else:
args = []
s = 'printf("%s\\n"' % node.args[0].left.s
s = f'printf("{node.args[0].left.s}\\n"'
if isinstance(args, str):
s = s + (", %s)" % args)
s = s + f", {args})"
else:
for arg in args:
s = s + (", %s" % arg)
s = s + (f", {arg}")
s = s + ")"
node.ccode = str(c.Statement(s))
else:
Expand All @@ -568,7 +568,7 @@ def visit_Call(self, node):
elif isinstance(a, ParticleNode):
continue
elif pointer_args:
a.ccode = "&%s" % a.ccode
a.ccode = f"&{a.ccode}"
ccode_args = ", ".join([a.ccode for a in node.args[pointer_args:]])
try:
if isinstance(node.func, str):
Expand Down Expand Up @@ -742,7 +742,7 @@ def visit_BoolOp(self, node):
self.visit(node.op)
for v in node.values:
self.visit(v)
op_str = " %s " % node.op.ccode
op_str = f" {node.op.ccode} "
node.ccode = op_str.join([v.ccode for v in node.values])

def visit_Eq(self, node):
Expand Down Expand Up @@ -813,7 +813,7 @@ def visit_ConstNode(self, node):

def visit_Return(self, node):
self.visit(node.value)
node.ccode = c.Statement("return %s" % node.value.ccode)
node.ccode = c.Statement(f"return {node.value.ccode}")

def visit_FieldEvalNode(self, node):
self.visit(node.field)
Expand Down Expand Up @@ -909,16 +909,16 @@ def visit_Print(self, node):
for n in node.values:
self.visit(n)
if hasattr(node.values[0], "s"):
node.ccode = c.Statement('printf("%s\\n")' % (n.ccode))
node.ccode = c.Statement(f'printf("{n.ccode}\\n")')
return
if hasattr(node.values[0], "s_print"):
args = node.values[0].right.ccode
s = 'printf("%s\\n"' % node.values[0].left.ccode
s = f'printf("{node.values[0].left.ccode}\\n"'
if isinstance(args, str):
s = s + (", %s)" % args)
s = s + f", {args})"
else:
for arg in args:
s = s + (", %s" % arg)
s = s + (f", {arg}")
s = s + ")"
node.ccode = c.Statement(s)
return
Expand Down Expand Up @@ -973,7 +973,7 @@ def generate(self, funcname, field_args, const_args, kernel_ast, c_include):
c.Value("double", "dt"),
]
for field, _ in field_args.items():
args += [c.Pointer(c.Value("CField", "%s" % field))]
args += [c.Pointer(c.Value("CField", f"{field}"))]
for const, _ in const_args.items():
args += [c.Value("double", const)] # are we SURE those const's are double's ?
fargs_str = ", ".join(["particles->time_nextloop[pnum]"] + list(field_args.keys()) + list(const_args.keys()))
Expand Down
22 changes: 9 additions & 13 deletions parcels/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
import datetime
import math
import warnings
from collections.abc import Iterable
from ctypes import POINTER, Structure, c_float, c_int, pointer
from pathlib import Path
from typing import TYPE_CHECKING, Iterable, Type
from typing import TYPE_CHECKING

import dask.array as da
import numpy as np
Expand Down Expand Up @@ -222,7 +223,7 @@ def __init__(
stacklevel=2,
)

self.fieldset: "FieldSet" | None = None
self.fieldset: FieldSet | None = None
if allow_time_extrapolation is None:
self.allow_time_extrapolation = True if len(self.grid.time) == 1 else False
else:
Expand Down Expand Up @@ -299,7 +300,7 @@ def __init__(
# since some datasets do not provide the deeper level of data (which is ignored by the interpolation).
self.data_full_zdim = kwargs.pop("data_full_zdim", None)
self.data_chunks = [] # type: ignore # the data buffer of the FileBuffer raw loaded data - shall be a list of C-contiguous arrays
self.c_data_chunks: list["PointerType" | None] = [] # C-pointers to the data_chunks array
self.c_data_chunks: list[PointerType | None] = [] # C-pointers to the data_chunks array
self.nchunks: tuple[int, ...] = ()
self.chunk_set: bool = False
self.filebuffers = [None] * 2
Expand Down Expand Up @@ -565,13 +566,10 @@ def from_netcdf(
"time dimension in indices is not necessary anymore. It is then ignored.", FieldSetWarning, stacklevel=2
)

if "full_load" in kwargs: # for backward compatibility with Parcels < v2.0.0
deferred_load = not kwargs["full_load"]

if grid.time.size <= 2 or deferred_load is False:
if grid.time.size <= 2:
deferred_load = False

_field_fb_class: Type[DeferredDaskFileBuffer | DaskFileBuffer | DeferredNetcdfFileBuffer | NetcdfFileBuffer]
_field_fb_class: type[DeferredDaskFileBuffer | DaskFileBuffer | DeferredNetcdfFileBuffer | NetcdfFileBuffer]
if chunksize not in [False, None]:
if deferred_load:
_field_fb_class = DeferredDaskFileBuffer
Expand Down Expand Up @@ -828,11 +826,9 @@ def calc_cell_edge_sizes(self):
self.cell_edge_sizes = self.grid.cell_edge_sizes
else:
raise ValueError(
(
f"Field.cell_edge_sizes() not implemented for {self.grid.gtype} grids. "
"You can provide Field.grid.cell_edge_sizes yourself by in, e.g., "
"NEMO using the e1u fields etc from the mesh_mask.nc file."
)
f"Field.cell_edge_sizes() not implemented for {self.grid.gtype} grids. "
"You can provide Field.grid.cell_edge_sizes yourself by in, e.g., "
"NEMO using the e1u fields etc from the mesh_mask.nc file."
)

def cell_areas(self):
Expand Down
8 changes: 4 additions & 4 deletions parcels/fieldset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1228,7 +1228,7 @@ def from_parcels(
extra_fields.update({"U": uvar, "V": vvar})
for vars in extra_fields:
dimensions[vars] = deepcopy(default_dims)
dimensions[vars]["depth"] = "depth%s" % vars.lower()
dimensions[vars]["depth"] = f"depth{vars.lower()}"
filenames = {v: str(f"{basename}{v}.nc") for v in extra_fields.keys()}
return cls.from_netcdf(
filenames,
Expand Down Expand Up @@ -1317,7 +1317,7 @@ def from_modulefile(cls, filename, modulename="create_fieldset", **kwargs):
"""
# check if filename exists
if not os.path.exists(filename):
raise IOError(f"FieldSet module file {filename} does not exist")
raise OSError(f"FieldSet module file {filename} does not exist")

# Importing the source file directly (following https://docs.python.org/3/library/importlib.html#importing-a-source-file-directly)
spec = importlib.util.spec_from_file_location(modulename, filename)
Expand All @@ -1326,10 +1326,10 @@ def from_modulefile(cls, filename, modulename="create_fieldset", **kwargs):
spec.loader.exec_module(fieldset_module)

if not hasattr(fieldset_module, modulename):
raise IOError(f"{filename} does not contain a {modulename} function")
raise OSError(f"{filename} does not contain a {modulename} function")
fieldset = getattr(fieldset_module, modulename)(**kwargs)
if not isinstance(fieldset, FieldSet):
raise IOError(f"Module {filename}.{modulename} does not return a FieldSet object")
raise OSError(f"Module {filename}.{modulename} does not return a FieldSet object")
return fieldset

def get_fields(self):
Expand Down
10 changes: 9 additions & 1 deletion parcels/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,14 @@ def __init__(
self._add_last_periodic_data_timestep = False
self.depth_field = None

def __repr__(self):
with np.printoptions(threshold=5, suppress=True, linewidth=120, formatter={"float": "{: 0.2f}".format}):
return (
f"{type(self).__name__}("
f"lon={self.lon!r}, lat={self.lat!r}, time={self.time!r}, "
f"time_origin={self.time_origin!r}, mesh={self.mesh!r})"
)

@staticmethod
def create_grid(
lon: npt.ArrayLike,
Expand Down Expand Up @@ -352,7 +360,7 @@ def __init__(self, lon, lat, time, time_origin, mesh: Mesh):
stacklevel=2,
)

def add_periodic_halo(self, zonal, meridional, halosize=5):
def add_periodic_halo(self, zonal: bool, meridional: bool, halosize: int = 5):
"""Add a 'halo' to the Grid, through extending the Grid (and lon/lat)
similarly to the halo created for the Fields
Expand Down
6 changes: 1 addition & 5 deletions parcels/interaction/interactionkernel.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import inspect
import sys
import warnings
from collections import defaultdict

Expand Down Expand Up @@ -109,10 +108,7 @@ def check_kernel_signature_on_version(self):
numkernelargs = []
if self._pyfunc is not None and isinstance(self._pyfunc, list):
for func in self._pyfunc:
if sys.version_info[0] < 3:
numkernelargs.append(len(inspect.getargspec(func).args))
else:
numkernelargs.append(len(inspect.getfullargspec(func).args))
numkernelargs.append(len(inspect.getfullargspec(func).args))
return numkernelargs

def remove_lib(self):
Expand Down
26 changes: 12 additions & 14 deletions parcels/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def _cache_key(self):
field_keys = "-".join(
[f"{name}:{field.units.__class__.__name__}" for name, field in self.field_args.items()]
)
key = self.name + self.ptype._cache_key + field_keys + ("TIME:%f" % ostime())
key = self.name + self.ptype._cache_key + field_keys + (f"TIME:{ostime():f}")
return hashlib.md5(key.encode("utf-8")).hexdigest()

def remove_deleted(self, pset):
Expand Down Expand Up @@ -239,9 +239,10 @@ def __init__(

numkernelargs = self.check_kernel_signature_on_version()

assert (
numkernelargs == 3
), "Since Parcels v2.0, kernels do only take 3 arguments: particle, fieldset, time !! AND !! Argument order in field interpolation is time, depth, lat, lon."
if numkernelargs != 3:
raise ValueError(
"Since Parcels v2.0, kernels do only take 3 arguments: particle, fieldset, time !! AND !! Argument order in field interpolation is time, depth, lat, lon."
)

self.name = f"{ptype.name}{self.funcname}"

Expand Down Expand Up @@ -310,7 +311,7 @@ def _cache_key(self):
field_keys = "-".join(
[f"{name}:{field.units.__class__.__name__}" for name, field in self.field_args.items()]
)
key = self.name + self.ptype._cache_key + field_keys + ("TIME:%f" % ostime())
key = self.name + self.ptype._cache_key + field_keys + (f"TIME:{ostime():f}")
return hashlib.md5(key.encode("utf-8")).hexdigest()

def add_scipy_positionupdate_kernels(self):
Expand All @@ -330,7 +331,7 @@ def Updatecoords(particle, fieldset, time):
particle.depth_nextloop = particle.depth + particle_ddepth # noqa
particle.time_nextloop = particle.time + particle.dt

self._pyfunc = self.__radd__(Setcoords).__add__(Updatecoords)._pyfunc
self._pyfunc = (Setcoords + self + Updatecoords)._pyfunc

def check_fieldsets_in_kernels(self, pyfunc):
"""
Expand Down Expand Up @@ -396,13 +397,10 @@ def check_fieldsets_in_kernels(self, pyfunc):
self.fieldset.add_constant("RK45_max_dt", 60 * 60 * 24)

def check_kernel_signature_on_version(self):
numkernelargs = 0
if self._pyfunc is not None:
if sys.version_info[0] < 3:
numkernelargs = len(inspect.getargspec(self._pyfunc).args)
else:
numkernelargs = len(inspect.getfullargspec(self._pyfunc).args)
return numkernelargs
"""Returns number of arguments in a Python function."""
if self._pyfunc is None:
return 0
return len(inspect.getfullargspec(self._pyfunc).args)

def remove_lib(self):
if self._lib is not None:
Expand Down Expand Up @@ -449,7 +447,7 @@ def get_kernel_compile_files(self):
self._cache_key
) # only required here because loading is done by Kernel class instead of Compiler class
dyn_dir = get_cache_dir()
basename = "%s_0" % cache_name
basename = f"{cache_name}_0"
lib_path = "lib" + basename
src_file_or_files = None
if type(basename) in (list, dict, tuple, ndarray):
Expand Down
Loading

0 comments on commit b007dcf

Please sign in to comment.