diff --git a/parcels/particlefile/particlefilesoa.py b/parcels/particlefile/particlefilesoa.py index 690f400d6..a2299fed8 100644 --- a/parcels/particlefile/particlefilesoa.py +++ b/parcels/particlefile/particlefilesoa.py @@ -1,5 +1,6 @@ """Module controlling the writing of ParticleSets to NetCDF file""" import os +import psutil from glob import glob import numpy as np @@ -59,7 +60,7 @@ def get_pset_info_attributes(self): 'file_list', 'file_list_once', 'maxid_written', 'parcels_mesh', 'metadata'] return attributes - def read_from_npy(self, file_list, time_steps, var): + def read_from_npy(self, file_list, time_steps, var, chunk_ids): """ Read NPY-files for one variable using a loop over all files. @@ -70,8 +71,11 @@ def read_from_npy(self, file_list, time_steps, var): :param time_steps: Number of time steps that were written in out directory :param var: name of the variable to read """ - data = np.nan * np.zeros((self.maxid_written+1, time_steps)) - time_index = np.zeros(self.maxid_written+1, dtype=np.int64) + + n_ids = len(chunk_ids) + + data = np.nan * np.zeros((n_ids, time_steps)) + time_index = np.zeros(n_ids, dtype=np.int64) t_ind_used = np.zeros(time_steps, dtype=np.int64) # loop over all files @@ -84,16 +88,20 @@ def read_from_npy(self, file_list, time_steps, var): '"parcels_convert_npydir_to_netcdf %s" to convert these to ' 'a NetCDF file yourself.\nTo avoid this error, make sure you ' 'close() your ParticleFile at the end of your script.' % self.tempwritedir) - id_ind = np.array(data_dict["id"], dtype=np.int64) - t_ind = time_index[id_ind] if 'once' not in file_list[0] else 0 + + id_avail = np.array(data_dict["id"], dtype=np.int64) + id_mask_full = np.in1d(id_avail, chunk_ids) # which ids in data are present in this chunk + id_mask_chunk = np.in1d(chunk_ids, id_avail) # which ids in this chunk are present in data + t_ind = time_index[id_mask_chunk] if 'once' not in file_list[0] else 0 t_ind_used[t_ind] = 1 - data[id_ind, t_ind] = data_dict[var] - time_index[id_ind] = time_index[id_ind] + 1 + data[id_mask_chunk, t_ind] = data_dict[var][id_mask_full] + time_index[id_mask_chunk] = time_index[id_mask_chunk] + 1 # remove rows and columns that are completely filled with nan values tmp = data[time_index > 0, :] return tmp[:, t_ind_used == 1] + def export(self): """ Exports outputs in temporary NPY-files to NetCDF file @@ -117,6 +125,7 @@ def export(self): global_maxid_written = -1 global_time_written = [] + global_id = [] global_file_list = [] if len(self.var_names_once) > 0: global_file_list_once = [] @@ -127,21 +136,51 @@ def export(self): for npyfile in pset_info_local['file_list']: tmp_dict = np.load(npyfile, allow_pickle=True).item() global_time_written.append([t for t in tmp_dict['time']]) + global_id.append([i for i in tmp_dict['id']]) global_file_list += pset_info_local['file_list'] if len(self.var_names_once) > 0: global_file_list_once += pset_info_local['file_list_once'] self.maxid_written = global_maxid_written + + # These steps seem to be quite expensive... self.time_written = np.unique(global_time_written) + self.id_present = np.unique([pid for frame in global_id for pid in frame]) for var in self.var_names: - data = self.read_from_npy(global_file_list, len(self.time_written), var) - if var == self.var_names[0]: - self.open_netcdf_file(data.shape) - varout = 'z' if var == 'depth' else var - getattr(self, varout)[:, :] = data + # Find available memory to check if output file is too large + avail_mem = psutil.virtual_memory()[1] + req_mem = len(self.id_present)*len(self.time_written)*8*1.2 + # avail_mem = req_mem/2 # ! HACK FOR TESTING ! + + if req_mem > avail_mem: + # Read id_per_chunk ids at a time to keep memory use down + total_chunks = int(np.ceil(req_mem/avail_mem)) + id_per_chunk = int(np.ceil(len(self.id_present)/total_chunks)) + else: + total_chunks = 1 + id_per_chunk = len(self.id_present) + + for chunk in range(total_chunks): + # Minimum and maximum particle indices for this chunk + idx_range = [0, 0] + idx_range[0] = int(chunk*id_per_chunk) + idx_range[1] = int(np.min(((chunk+1)*id_per_chunk, + len(self.id_present)))) + + # Read chunk-sized data from NPY-files + data = self.read_from_npy(global_file_list, len(self.time_written), var, self.id_present[idx_range[0]:idx_range[1]]) + if (var == self.var_names[0]) & (chunk == 0): + # !! unacceptable assumption !! + # Assumes that the number of time-steps in the first chunk + # == number of time-steps across all chunks. + self.open_netcdf_file((len(self.id_present), data.shape[1])) + + varout = 'z' if var == 'depth' else var + # Write to correct location in netcdf file + getattr(self, varout)[idx_range[0]:idx_range[1], :] = data if len(self.var_names_once) > 0: for var in self.var_names_once: - getattr(self, var)[:] = self.read_from_npy(global_file_list_once, 1, var) + getattr(self, var)[:] = self.read_from_npy(global_file_list_once, 1, var, self.id_present) self.close_netcdf_file()