diff --git a/parcels/collection/collectionaos.py b/parcels/collection/collectionaos.py index cb947b075..adddb8fe7 100644 --- a/parcels/collection/collectionaos.py +++ b/parcels/collection/collectionaos.py @@ -32,7 +32,9 @@ def _to_write_particles(pd, time): """We don't want to write a particle that is not started yet. Particle will be written if particle.time is between time-dt/2 and time+dt (/2) """ - return [i for i, p in enumerate(pd) if time - np.abs(p.dt/2) <= p.time < time + np.abs(p.dt) and np.isfinite(p.id)] + return [i for i, p in enumerate(pd) if (((time - np.abs(p.dt/2) <= p.time < time + np.abs(p.dt)) + or (np.isnan(p.dt) and np.equal(time, p.time))) + and np.isfinite(p.id))] def _is_particle_started_yet(particle, time): @@ -928,7 +930,6 @@ def toDictionary(self, pfile, time, deleted_only=False): data_dict[var] = np.array([np.int64(getattr(p, var)) for p in self._data[indices_to_write]]) else: data_dict[var] = np.array([getattr(p, var) for p in self._data[indices_to_write]]) - pfile.maxid_written = np.maximum(pfile.maxid_written, np.max(data_dict['id'])) pset_errs = [p for p in self._data[indices_to_write] if p.state != OperationCode.Delete and abs(time-p.time) > 1e-3 and np.isfinite(p.time)] for p in pset_errs: diff --git a/parcels/collection/collectionsoa.py b/parcels/collection/collectionsoa.py index 2997a3f75..3d24c0867 100644 --- a/parcels/collection/collectionsoa.py +++ b/parcels/collection/collectionsoa.py @@ -30,8 +30,9 @@ def _to_write_particles(pd, time): """We don't want to write a particle that is not started yet. Particle will be written if particle.time is between time-dt/2 and time+dt (/2) """ - return (np.less_equal(time - np.abs(pd['dt']/2), pd['time'], where=np.isfinite(pd['time'])) - & np.greater_equal(time + np.abs(pd['dt'] / 2), pd['time'], where=np.isfinite(pd['time'])) + return ((np.less_equal(time - np.abs(pd['dt']/2), pd['time'], where=np.isfinite(pd['time'])) + & np.greater_equal(time + np.abs(pd['dt'] / 2), pd['time'], where=np.isfinite(pd['time'])) + | ((np.isnan(pd['dt'])) & np.equal(time, pd['time'], where=np.isfinite(pd['time'])))) & (np.isfinite(pd['id'])) & (np.isfinite(pd['time']))) @@ -849,7 +850,6 @@ def toDictionary(self, pfile, time, deleted_only=False): if np.any(indices_to_write): for var in pfile.var_names: data_dict[var] = self._data[var][indices_to_write] - pfile.maxid_written = np.maximum(pfile.maxid_written, np.max(data_dict['id'])) pset_errs = ((self._data['state'][indices_to_write] != OperationCode.Delete) & np.greater(np.abs(time - self._data['time'][indices_to_write]), 1e-3, where=np.isfinite(self._data['time'][indices_to_write]))) if np.count_nonzero(pset_errs) > 0: diff --git a/parcels/particlefile/particlefileaos.py b/parcels/particlefile/particlefileaos.py index cc67be939..ca6839996 100644 --- a/parcels/particlefile/particlefileaos.py +++ b/parcels/particlefile/particlefileaos.py @@ -56,10 +56,10 @@ def get_pset_info_attributes(self): For ParticleSet structures other than SoA, and structures where ID != index, this has to be overridden. """ attributes = ['name', 'var_names', 'var_names_once', 'time_origin', 'lonlatdepth_dtype', - 'file_list', 'file_list_once', 'maxid_written', 'parcels_mesh', 'metadata'] + 'file_list', 'file_list_once', 'parcels_mesh', 'metadata'] return attributes - def read_from_npy(self, file_list, time_steps, var): + def read_from_npy(self, file_list, n_timesteps, var): """ Read NPY-files for one variable using a loop over all files. @@ -67,12 +67,17 @@ def read_from_npy(self, file_list, time_steps, var): For ParticleSet structures other than SoA, and structures where ID != index, this has to be overridden. :param file_list: List that contains all file names in the output directory - :param time_steps: Number of time steps that were written in out directory + :param n_timesteps: Dictionary with (for each particle) number of time steps that were written in out directory :param var: name of the variable to read """ - data = np.nan * np.zeros((self.maxid_written+1, time_steps)) - time_index = np.zeros(self.maxid_written+1, dtype=np.int64) - t_ind_used = np.zeros(time_steps, dtype=np.int64) + max_timesteps = max(n_timesteps.values()) if n_timesteps.keys() else 0 + data = np.nan * np.zeros((len(n_timesteps), max_timesteps)) + time_index = np.zeros(len(n_timesteps)) + id_index = {} + count = 0 + for i in sorted(n_timesteps.keys()): + id_index[i] = count + count += 1 # loop over all files for npyfile in file_list: @@ -84,15 +89,14 @@ def read_from_npy(self, file_list, time_steps, var): '"parcels_convert_npydir_to_netcdf %s" to convert these to ' 'a NetCDF file yourself.\nTo avoid this error, make sure you ' 'close() your ParticleFile at the end of your script.' % self.tempwritedir) - id_ind = np.array(data_dict["id"], dtype=np.int64) - t_ind = time_index[id_ind] if 'once' not in file_list[0] else 0 - t_ind_used[t_ind] = 1 - data[id_ind, t_ind] = data_dict[var] - time_index[id_ind] = time_index[id_ind] + 1 + for ii, i in enumerate(data_dict["id"]): + id_ind = id_index[i] + t_ind = int(time_index[id_ind]) if 'once' not in file_list[0] else 0 + data[id_ind, t_ind] = data_dict[var][ii] + time_index[id_ind] = time_index[id_ind] + 1 # remove rows and columns that are completely filled with nan values - tmp = data[time_index > 0, :] - return tmp[:, t_ind_used == 1] + return data[time_index > 0, :] def export(self): """ @@ -114,34 +118,36 @@ def export(self): if len(temp_names) == 0: raise RuntimeError("No npy files found in %s" % self.tempwritedir_base) - global_maxid_written = -1 - global_time_written = [] + n_timesteps = {} global_file_list = [] - global_file_list_once = None if len(self.var_names_once) > 0: global_file_list_once = [] for tempwritedir in temp_names: if os.path.exists(tempwritedir): pset_info_local = np.load(os.path.join(tempwritedir, 'pset_info.npy'), allow_pickle=True).item() - global_maxid_written = np.max([global_maxid_written, pset_info_local['maxid_written']]) for npyfile in pset_info_local['file_list']: tmp_dict = np.load(npyfile, allow_pickle=True).item() - global_time_written.append([t for t in tmp_dict['time']]) + for i in tmp_dict['id']: + if i in n_timesteps: + n_timesteps[i] += 1 + else: + n_timesteps[i] = 1 global_file_list += pset_info_local['file_list'] if len(self.var_names_once) > 0: global_file_list_once += pset_info_local['file_list_once'] - self.maxid_written = global_maxid_written - self.time_written = np.unique(global_time_written) for var in self.var_names: - data = self.read_from_npy(global_file_list, len(self.time_written), var) + data = self.read_from_npy(global_file_list, n_timesteps, var) if var == self.var_names[0]: self.open_netcdf_file(data.shape) varout = 'z' if var == 'depth' else var getattr(self, varout)[:, :] = data if len(self.var_names_once) > 0: + n_timesteps_once = {} + for i in n_timesteps: + n_timesteps_once[i] = 1 for var in self.var_names_once: - getattr(self, var)[:] = self.read_from_npy(global_file_list_once, 1, var) + getattr(self, var)[:] = self.read_from_npy(global_file_list_once, n_timesteps_once, var) self.close_netcdf_file() diff --git a/parcels/particlefile/particlefilesoa.py b/parcels/particlefile/particlefilesoa.py index 690f400d6..2621b8d58 100644 --- a/parcels/particlefile/particlefilesoa.py +++ b/parcels/particlefile/particlefilesoa.py @@ -56,10 +56,10 @@ def get_pset_info_attributes(self): For ParticleSet structures other than SoA, and structures where ID != index, this has to be overridden. """ attributes = ['name', 'var_names', 'var_names_once', 'time_origin', 'lonlatdepth_dtype', - 'file_list', 'file_list_once', 'maxid_written', 'parcels_mesh', 'metadata'] + 'file_list', 'file_list_once', 'parcels_mesh', 'metadata'] return attributes - def read_from_npy(self, file_list, time_steps, var): + def read_from_npy(self, file_list, n_timesteps, var): """ Read NPY-files for one variable using a loop over all files. @@ -67,12 +67,17 @@ def read_from_npy(self, file_list, time_steps, var): For ParticleSet structures other than SoA, and structures where ID != index, this has to be overridden. :param file_list: List that contains all file names in the output directory - :param time_steps: Number of time steps that were written in out directory + :param n_timesteps: Dictionary with (for each particle) number of time steps that were written in out directory :param var: name of the variable to read """ - data = np.nan * np.zeros((self.maxid_written+1, time_steps)) - time_index = np.zeros(self.maxid_written+1, dtype=np.int64) - t_ind_used = np.zeros(time_steps, dtype=np.int64) + max_timesteps = max(n_timesteps.values()) if n_timesteps.keys() else 0 + data = np.nan * np.zeros((len(n_timesteps), max_timesteps)) + time_index = np.zeros(len(n_timesteps)) + id_index = {} + count = 0 + for i in sorted(n_timesteps.keys()): + id_index[i] = count + count += 1 # loop over all files for npyfile in file_list: @@ -84,15 +89,14 @@ def read_from_npy(self, file_list, time_steps, var): '"parcels_convert_npydir_to_netcdf %s" to convert these to ' 'a NetCDF file yourself.\nTo avoid this error, make sure you ' 'close() your ParticleFile at the end of your script.' % self.tempwritedir) - id_ind = np.array(data_dict["id"], dtype=np.int64) - t_ind = time_index[id_ind] if 'once' not in file_list[0] else 0 - t_ind_used[t_ind] = 1 - data[id_ind, t_ind] = data_dict[var] - time_index[id_ind] = time_index[id_ind] + 1 + for ii, i in enumerate(data_dict["id"]): + id_ind = id_index[i] + t_ind = int(time_index[id_ind]) if 'once' not in file_list[0] else 0 + data[id_ind, t_ind] = data_dict[var][ii] + time_index[id_ind] = time_index[id_ind] + 1 # remove rows and columns that are completely filled with nan values - tmp = data[time_index > 0, :] - return tmp[:, t_ind_used == 1] + return data[time_index > 0, :] def export(self): """ @@ -115,33 +119,36 @@ def export(self): if len(temp_names) == 0: raise RuntimeError("No npy files found in %s" % self.tempwritedir_base) - global_maxid_written = -1 - global_time_written = [] + n_timesteps = {} global_file_list = [] if len(self.var_names_once) > 0: global_file_list_once = [] for tempwritedir in temp_names: if os.path.exists(tempwritedir): pset_info_local = np.load(os.path.join(tempwritedir, 'pset_info.npy'), allow_pickle=True).item() - global_maxid_written = np.max([global_maxid_written, pset_info_local['maxid_written']]) for npyfile in pset_info_local['file_list']: tmp_dict = np.load(npyfile, allow_pickle=True).item() - global_time_written.append([t for t in tmp_dict['time']]) + for i in tmp_dict['id']: + if i in n_timesteps: + n_timesteps[i] += 1 + else: + n_timesteps[i] = 1 global_file_list += pset_info_local['file_list'] if len(self.var_names_once) > 0: global_file_list_once += pset_info_local['file_list_once'] - self.maxid_written = global_maxid_written - self.time_written = np.unique(global_time_written) for var in self.var_names: - data = self.read_from_npy(global_file_list, len(self.time_written), var) + data = self.read_from_npy(global_file_list, n_timesteps, var) if var == self.var_names[0]: self.open_netcdf_file(data.shape) varout = 'z' if var == 'depth' else var getattr(self, varout)[:, :] = data if len(self.var_names_once) > 0: + n_timesteps_once = {} + for i in n_timesteps: + n_timesteps_once[i] = 1 for var in self.var_names_once: - getattr(self, var)[:] = self.read_from_npy(global_file_list_once, 1, var) + getattr(self, var)[:] = self.read_from_npy(global_file_list_once, n_timesteps_once, var) self.close_netcdf_file() diff --git a/tests/test_particle_file.py b/tests/test_particle_file.py index 17d9e13dc..211e7cf78 100644 --- a/tests/test_particle_file.py +++ b/tests/test_particle_file.py @@ -62,12 +62,16 @@ def test_pfile_array_remove_particles(fieldset, pset_mode, mode, tmpdir, npart=1 filepath = tmpdir.join("pfile_array_remove_particles.nc") pset = pset_type[pset_mode]['pset'](fieldset, pclass=ptype[mode], lon=np.linspace(0, 1, npart), - lat=0.5*np.ones(npart)) + lat=0.5*np.ones(npart), time=0) pfile = pset.ParticleFile(filepath) pfile.write(pset, 0) pset.remove_indices(3) + for p in pset: + p.time = 1 pfile.write(pset, 1) ncfile = close_and_compare_netcdffiles(filepath, pfile) + timearr = ncfile.variables['time'][:] + assert type(timearr[3, 1]) is not type(timearr[3, 0]) # noqa ncfile.close() @@ -104,7 +108,7 @@ def test_pfile_array_remove_all_particles(fieldset, pset_mode, mode, tmpdir, npa filepath = tmpdir.join("pfile_array_remove_particles.nc") pset = pset_type[pset_mode]['pset'](fieldset, pclass=ptype[mode], lon=np.linspace(0, 1, npart), - lat=0.5*np.ones(npart)) + lat=0.5*np.ones(npart), time=0) pfile = pset.ParticleFile(filepath) pfile.write(pset, 0) for _ in range(npart): @@ -112,6 +116,7 @@ def test_pfile_array_remove_all_particles(fieldset, pset_mode, mode, tmpdir, npa pfile.write(pset, 1) pfile.write(pset, 2) ncfile = close_and_compare_netcdffiles(filepath, pfile) + assert ncfile.variables['time'][:].shape == (npart, 1) ncfile.close()