Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Chunkwise netcdf export #1092

Closed
wants to merge 5 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 52 additions & 13 deletions parcels/particlefile/particlefilesoa.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Module controlling the writing of ParticleSets to NetCDF file"""
import os
import psutil
from glob import glob
import numpy as np

Expand Down Expand Up @@ -59,7 +60,7 @@ def get_pset_info_attributes(self):
'file_list', 'file_list_once', 'maxid_written', 'parcels_mesh', 'metadata']
return attributes

def read_from_npy(self, file_list, time_steps, var):
def read_from_npy(self, file_list, time_steps, var, chunk_ids):
"""
Read NPY-files for one variable using a loop over all files.

Expand All @@ -70,8 +71,11 @@ def read_from_npy(self, file_list, time_steps, var):
:param time_steps: Number of time steps that were written in out directory
:param var: name of the variable to read
"""
data = np.nan * np.zeros((self.maxid_written+1, time_steps))
time_index = np.zeros(self.maxid_written+1, dtype=np.int64)

n_ids = len(chunk_ids)

data = np.nan * np.zeros((n_ids, time_steps))
time_index = np.zeros(n_ids, dtype=np.int64)
t_ind_used = np.zeros(time_steps, dtype=np.int64)

# loop over all files
Expand All @@ -84,16 +88,20 @@ def read_from_npy(self, file_list, time_steps, var):
'"parcels_convert_npydir_to_netcdf %s" to convert these to '
'a NetCDF file yourself.\nTo avoid this error, make sure you '
'close() your ParticleFile at the end of your script.' % self.tempwritedir)
id_ind = np.array(data_dict["id"], dtype=np.int64)
t_ind = time_index[id_ind] if 'once' not in file_list[0] else 0

id_avail = np.array(data_dict["id"], dtype=np.int64)
id_mask_full = np.in1d(id_avail, chunk_ids) # which ids in data are present in this chunk
id_mask_chunk = np.in1d(chunk_ids, id_avail) # which ids in this chunk are present in data
t_ind = time_index[id_mask_chunk] if 'once' not in file_list[0] else 0
t_ind_used[t_ind] = 1
data[id_ind, t_ind] = data_dict[var]
time_index[id_ind] = time_index[id_ind] + 1
data[id_mask_chunk, t_ind] = data_dict[var][id_mask_full]
time_index[id_mask_chunk] = time_index[id_mask_chunk] + 1

# remove rows and columns that are completely filled with nan values
tmp = data[time_index > 0, :]
return tmp[:, t_ind_used == 1]


def export(self):
"""
Exports outputs in temporary NPY-files to NetCDF file
Expand All @@ -117,6 +125,7 @@ def export(self):

global_maxid_written = -1
global_time_written = []
global_id = []
global_file_list = []
if len(self.var_names_once) > 0:
global_file_list_once = []
Expand All @@ -127,21 +136,51 @@ def export(self):
for npyfile in pset_info_local['file_list']:
tmp_dict = np.load(npyfile, allow_pickle=True).item()
global_time_written.append([t for t in tmp_dict['time']])
global_id.append([i for i in tmp_dict['id']])
global_file_list += pset_info_local['file_list']
if len(self.var_names_once) > 0:
global_file_list_once += pset_info_local['file_list_once']
self.maxid_written = global_maxid_written

# These steps seem to be quite expensive...
self.time_written = np.unique(global_time_written)
self.id_present = np.unique([pid for frame in global_id for pid in frame])

for var in self.var_names:
data = self.read_from_npy(global_file_list, len(self.time_written), var)
if var == self.var_names[0]:
self.open_netcdf_file(data.shape)
varout = 'z' if var == 'depth' else var
getattr(self, varout)[:, :] = data
# Find available memory to check if output file is too large
avail_mem = psutil.virtual_memory()[1]
req_mem = len(self.id_present)*len(self.time_written)*8*1.2
# avail_mem = req_mem/2 # ! HACK FOR TESTING !

if req_mem > avail_mem:
# Read id_per_chunk ids at a time to keep memory use down
total_chunks = int(np.ceil(req_mem/avail_mem))
id_per_chunk = int(np.ceil(len(self.id_present)/total_chunks))
else:
total_chunks = 1
id_per_chunk = len(self.id_present)

for chunk in range(total_chunks):
# Minimum and maximum particle indices for this chunk
idx_range = [0, 0]
idx_range[0] = int(chunk*id_per_chunk)
idx_range[1] = int(np.min(((chunk+1)*id_per_chunk,
len(self.id_present))))

# Read chunk-sized data from NPY-files
data = self.read_from_npy(global_file_list, len(self.time_written), var, self.id_present[idx_range[0]:idx_range[1]])
if (var == self.var_names[0]) & (chunk == 0):
# !! unacceptable assumption !!
# Assumes that the number of time-steps in the first chunk
# == number of time-steps across all chunks.
self.open_netcdf_file((len(self.id_present), data.shape[1]))

varout = 'z' if var == 'depth' else var
# Write to correct location in netcdf file
getattr(self, varout)[idx_range[0]:idx_range[1], :] = data

if len(self.var_names_once) > 0:
for var in self.var_names_once:
getattr(self, var)[:] = self.read_from_npy(global_file_list_once, 1, var)
getattr(self, var)[:] = self.read_from_npy(global_file_list_once, 1, var, self.id_present)

self.close_netcdf_file()