Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix field plotting #1247

Closed
wants to merge 14 commits into from
Closed
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions environment_py3_linux.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,5 @@ dependencies:
- nbval
- scikit-learn
- pykdtree
- cartopy
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is cartopy now need? That was not the case before, and I suggest not to keep it

Suggested change
- cartopy

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that will make the new test fail that actually tests platecarree plotting

- zarr
1 change: 1 addition & 0 deletions environment_py3_win.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,5 @@ dependencies:
- pytest
- nbval
- pykdtree
- cartopy
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See above, I suggest not to make cartopy an explicit dependency (as we had not done before)

Suggested change
- cartopy

- zarr
17 changes: 13 additions & 4 deletions parcels/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from datetime import timedelta as delta

import numpy as np
import dask.array as da
import copy

from parcels.field import Field
Expand Down Expand Up @@ -101,7 +102,7 @@ def plotparticles(particles, with_particles=True, show_time=None, field=None, do


def plotfield(field, show_time=None, domain=None, depth_level=0, projection='PlateCarree', land=True,
vmin=None, vmax=None, savefile=None, **kwargs):
vmin=None, vmax=None, savefile=None, use3D=False, **kwargs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this variable still needed now? I am a bit confused what this does

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

possibly not, indeed

"""Function to plot a Parcels Field

:param show_time: Time in seconds from start after which to show the Field
Expand All @@ -113,6 +114,7 @@ def plotfield(field, show_time=None, domain=None, depth_level=0, projection='Pla
:param vmax: maximum colour scale (only in single-plot mode)
:param savefile: Name of a file to save the plot to
:param animation: Boolean whether result is a single plot, or an animation
:param use3D: tells if requested projection is 2D or 3D
"""

if type(field) is VectorField:
Expand All @@ -130,7 +132,7 @@ def plotfield(field, show_time=None, domain=None, depth_level=0, projection='Pla
logger.warning('Field.show() does not always correctly determine the domain for curvilinear grids. '
'Use plotting with caution and perhaps use domain argument as in the NEMO 3D tutorial')

plt, fig, ax, cartopy = create_parcelsfig_axis(spherical, land, projection=projection, cartopy_features=kwargs.pop('cartopy_features', []))
plt, fig, ax, cartopy = create_parcelsfig_axis(spherical, land, projection=projection, cartopy_features=kwargs.pop('cartopy_features', []), use3D=use3D)
if plt is None:
return None, None, None, None # creating axes was not possible

Expand Down Expand Up @@ -165,6 +167,8 @@ def plotfield(field, show_time=None, domain=None, depth_level=0, projection='Pla
data[i] = np.squeeze(fld.data)[depth_level, latS:latN, lonW:lonE]
else:
data[i] = np.squeeze(fld.data)[latS:latN, lonW:lonE]
if isinstance(data[i], da.Array):
data[i] = np.array(data[i])

if plottype == 'vector':
if field[0].interp_method == 'cgrid_velocity':
Expand Down Expand Up @@ -264,21 +268,26 @@ def plotfield(field, show_time=None, domain=None, depth_level=0, projection='Pla
return plt, fig, ax, cartopy


def create_parcelsfig_axis(spherical, land=True, projection='PlateCarree', central_longitude=0, cartopy_features=[]):
def create_parcelsfig_axis(spherical, land=True, projection='PlateCarree', central_longitude=0, cartopy_features=[], use3D=False):
try:
import matplotlib.pyplot as plt
except:
logger.info("Visualisation is not possible. Matplotlib not found.")
return None, None, None, None # creating axes was not possible

if spherical and projection:
if use3D:
cartopy = None
fig, ax = plt.subplots(1, 1, subplot_kw={'projection': '3d'})
ax.grid()
elif spherical and projection:
try:
import cartopy
except:
logger.info("Visualisation of field with geographic coordinates is not possible. Cartopy not found.")
return None, None, None, None # creating axes was not possible

projection = cartopy.crs.PlateCarree(central_longitude) if projection == 'PlateCarree' else projection
# projection = '3d' if use3D else projection
CKehl marked this conversation as resolved.
Show resolved Hide resolved
fig, ax = plt.subplots(1, 1, subplot_kw={'projection': projection})
try: # gridlines not supported for all projections
if isinstance(projection, cartopy.crs.PlateCarree) and central_longitude != 0:
Expand Down
13 changes: 9 additions & 4 deletions parcels/scripts/plottrajectoriesfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,13 @@ def plotTrajectoriesFile(filename, mode='2d', tracerfile=None, tracerfield='P',

if tracerfile is not None and mode != 'hist2d':
tracerfld = Field.from_netcdf(tracerfile, tracerfield, {'lon': tracerlon, 'lat': tracerlat})
plt, fig, ax, cartopy = plotfield(tracerfld)
plt, fig, ax, cartopy = plotfield(tracerfld, use3D=(mode == '3d'))
if plt is None:
return # creating axes was not possible
titlestr = ' and ' + tracerfield
else:
spherical = False if mode == '3d' or mesh == 'flat' else True
plt, fig, ax, cartopy = create_parcelsfig_axis(spherical=spherical, central_longitude=central_longitude)
plt, fig, ax, cartopy = create_parcelsfig_axis(spherical=spherical, central_longitude=central_longitude, use3D=(mode == '3d'))
if plt is None:
return # creating axes was not possible
titlestr = ''
Expand All @@ -71,7 +71,8 @@ def plotTrajectoriesFile(filename, mode='2d', tracerfile=None, tracerfield='P',
if mode == '3d':
from mpl_toolkits.mplot3d import Axes3D # noqa
plt.clf() # clear the figure
ax = fig.gca(projection='3d')
# ax = fig.gca(projection='3d')
ax = fig.gca()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand this change. I think this is what breaks the unit-tests. The ax is now not set as 3D anymore, so there is no set_zlevel method. Revert back to original?

Copy link
Contributor Author

@CKehl CKehl Oct 3, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The change was necessary cause that line broke bccb214. In short: in Windows, the workflow obtains a newer version of matplotlib, and in matplotlib 3.6.0, the function figure.gca() has no projection parameters. Unexpected change, but therefore I needed to make this projection option part of the figure creation procedure. I guess help would be welcome from bccb214 onwards.

for p in range(len(lon)):
ax.plot(lon[p, :], lat[p, :], z[p, :], '.-')
ax.set_xlabel('Longitude')
Expand Down Expand Up @@ -116,7 +117,11 @@ def timestr(plottimes, index):
return str(plottimes[index])

if cartopy:
scat = ax.scatter(lon[b], lat[b], s=20, color='k', transform=cartopy.crs.Geodetic())
scat = None
try:
scat = ax.scatter(lon[b], lat[b], s=20, color='k', transform=cartopy.crs.Geodetic())
except (ValueError,):
scat = ax.scatter(lon[b], lat[b], s=20, color='k', transform=cartopy.crs.PlateCarree())
else:
scat = ax.scatter(lon[b], lat[b], s=20, color='k')
ttl = ax.set_title('Particles' + titlestr + ' at time ' + timestr(plottimes, 0))
Expand Down
54 changes: 54 additions & 0 deletions tests/test_plotting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from parcels import FieldSet, ParticleSet, JITParticle, AdvectionRK4, ErrorCode
from datetime import timedelta as delta
import numpy as np
import dask.array as da # NOQA
from os import path
from matplotlib import pyplot as plt # NOQA
import pytest


def periodicBC(particle, fieldset, time):
if particle.lon > 180:
particle.lon -= 360


def test_field_from_netcdf():
data_path = path.join(path.dirname(__file__), 'test_data/')

filenames = {'U': {'lon': data_path + 'mask_nemo_cross_180lon.nc',
'lat': data_path + 'mask_nemo_cross_180lon.nc',
'data': data_path + 'Uu_eastward_nemo_cross_180lon.nc'},
'V': {'lon': data_path + 'mask_nemo_cross_180lon.nc',
'lat': data_path + 'mask_nemo_cross_180lon.nc',
'data': data_path + 'Vv_eastward_nemo_cross_180lon.nc'}
}
variables = {'U': 'U',
'V': 'V'}
dimensions = {'lon': 'glamf', 'lat': 'gphif'}
return FieldSet.from_netcdf(filenames, variables, dimensions, interp_method='cgrid_velocity', chunksize='auto', allow_time_extrapolation=True)


@pytest.fixture(name="fieldset")
def fieldset_fixture():
return test_field_from_netcdf()


def test_pset_create_field(fieldset, npart=100):
lonp = -180 * np.ones(npart)
latp = [i for i in np.linspace(-70, 88, npart)]
pset = ParticleSet.from_list(fieldset, JITParticle, lon=lonp, lat=latp)
return pset


def DeleteParticle(particle, fieldset, time):
particle.delete()


if __name__ == '__main__':
fset = test_field_from_netcdf()
print(fset)
pset = test_pset_create_field(fset)
kernels = pset.Kernel(AdvectionRK4) + periodicBC
pset.execute(kernels, dt=delta(hours=1), output_file=None,
recovery={ErrorCode.ErrorOutOfBounds: DeleteParticle})
pset.show(field=fset.U)
Comment on lines +1 to +54
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does this new set of functions test that was not tested before? Why is it needed?

4 changes: 2 additions & 2 deletions tests/test_scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,15 @@ def create_outputfiles(dir, pset_mode):
npart = 10
delaytime = delta(hours=1)
endtime = delta(hours=24)
x = 3. * (1. / 1.852 / 60)
x = 3. * (1. / 1.852 / 60.)
y = (fieldset.U.lat[0] + x, fieldset.U.lat[-1] - x)
lat = np.linspace(y[0], y[1], npart)

fp = dir.join("DelayParticle.nc")
output_file = pset.ParticleFile(name=fp, outputdt=delaytime)

for t in range(npart):
time = 0 if len(pset) == 0 else pset[0].time
time = 0. if len(pset) == 0 else pset[0].time
pset.add(pset_type[pset_mode]['pset'](pclass=JITParticle, lon=x, lat=lat[t], fieldset=fieldset, time=time))
pset.execute(AdvectionRK4, runtime=delaytime, dt=delta(minutes=5), output_file=output_file)

Expand Down