Skip to content

Commit

Permalink
Added export location functionality, fixed some issues with parallel …
Browse files Browse the repository at this point in the history
…processing, and added more logging
  • Loading branch information
beatfactor committed Jul 12, 2024
1 parent 99b89e4 commit 01063bc
Show file tree
Hide file tree
Showing 13 changed files with 568 additions and 147 deletions.
126 changes: 101 additions & 25 deletions oceanstream/cli/main.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
from asyncio import CancelledError

import typer
import asyncio
import os
import logging
import sys
import warnings
import dask

from pathlib import Path
from rich import print
from rich.traceback import install, Traceback
from oceanstream.settings import load_config
from dask.distributed import LocalCluster, Client, Variable
from rich.console import Console

from oceanstream.process import compute_and_export_single_file, process_zarr_files

install(show_locals=False, width=120)

Expand All @@ -30,6 +33,11 @@

logging.basicConfig(level="ERROR", format='%(asctime)s - %(levelname)s - %(message)s')

dask.config.set({
'distributed.comm.timeouts.connect': '60s', # Increase the connection timeout
'distributed.comm.timeouts.tcp': '120s', # Increase the TCP timeout
'distributed.comm.retry.count': 0
})

def initialize(settings, file_path, log_level=None):
config_data = load_config(settings["config"])
Expand Down Expand Up @@ -115,7 +123,8 @@ def convert(
if filePath.is_file():
from oceanstream.process import convert_raw_file
convert_raw_file(filePath, configData)
print(f"[blue]✅ Converted raw file {source} to Zarr and wrote output to: {configData['output_folder']} [/blue]")
print(
f"[blue]✅ Converted raw file {source} to Zarr and wrote output to: {configData['output_folder']} [/blue]")
elif filePath.is_dir():
from oceanstream.process import convert_raw_files
convert_raw_files(configData, workers_count=workers_count)
Expand Down Expand Up @@ -193,7 +202,8 @@ def compute_sv(
sonar_model: str = typer.Option(DEFAULT_SONAR_MODEL, help="Sonar model used to collect the data",
show_choices=["AZFP", "EK60", "ES70", "EK80", "ES80", "EA640", "AD2CP"]),
plot_echogram: bool = typer.Option(False, help="Plot the echogram after processing"),
use_dask: bool = typer.Option(False, help="Start a Local Dask cluster for parallel processing (always enabled for multiple files)"),
use_dask: bool = typer.Option(False,
help="Start a Local Dask cluster for parallel processing (always enabled for multiple files)"),
depth_offset: float = typer.Option(0.0, help="Depth offset for the echogram plot"),
waveform_mode: str = typer.Option("CW", help="Waveform mode, can be either CW or BB",
show_choices=["CW", "BB"]),
Expand All @@ -209,6 +219,87 @@ def compute_sv(
"sonar_model": sonar_model,
"output_folder": output or DEFAULT_OUTPUT_FOLDER
}
dask.config.set({'distributed.comm.retry.count': 1})

file_path = Path(source)
config_data = initialize(settings_dict, file_path, log_level=log_level)
single_file = file_path.is_dir() and source.endswith(".zarr")

client = None
cluster = None

if use_dask or not single_file:
cluster = LocalCluster(n_workers=workers_count, threads_per_worker=1)
client = Client(cluster)

try:
if file_path.is_dir() and source.endswith(".zarr"):
console = Console()
with console.status("Processing...", spinner="dots") as status:
status.start()
status.update(
f"[blue] Computing Sv for {file_path}...[/blue]" + use_dask * "– navigate to "
"http://localhost:8787/status for "
"progress")

chunks = None
if use_dask:
chunks = config_data.get('base_chunk_sizes')

compute_and_export_single_file(config_data,
chunks=chunks,
plot_echogram=plot_echogram,
waveform_mode=waveform_mode,
depth_offset=depth_offset)

status.stop()
print("✅ The file have been processed successfully.")
elif file_path.is_dir():
print(f"Dashboard available at {client.dashboard_link}")
process_zarr_files(config_data,
client,
workers_count=workers_count,
chunks=config_data.get('base_chunk_sizes'),
plot_echogram=plot_echogram,
waveform_mode=waveform_mode,
depth_offset=depth_offset)
else:
print(f"[red]❌ The provided path '{source}' is not a valid Zarr root.[/red]")
sys.exit(1)
except KeyboardInterrupt:
logging.info("KeyboardInterrupt received, terminating processes...")
except Exception as e:
logging.exception("Error while processing %s", config_data['raw_path'])
print(Traceback())
finally:
if client is not None:
client.close()

if cluster is not None:
cluster.close()


@app.command()
def export_location(
source: str = typer.Option(..., help="Path to a Zarr root file or a directory containing Zarr files"),
output: str = typer.Option(None,
help="Destination path for saving the exported data. Defaults to a predefined "
"directory if not specified."),
workers_count: int = typer.Option(os.cpu_count(), help="Number of CPU workers to use for parallel processing"),
use_dask: bool = typer.Option(False,
help="Start a Local Dask cluster for parallel processing (always enabled for "
"multiple files)"),
config: str = typer.Option(None, help="Path to a configuration file"),
log_level: str = typer.Option("WARNING", help="Set the logging level",
show_choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"])
):
"""
Given a Zarr dataset containing Sv data, exports the GPS location data to a JSON file.
"""
settings_dict = {
"config": config,
"output_folder": output or DEFAULT_OUTPUT_FOLDER
}
file_path = Path(source)
config_data = initialize(settings_dict, file_path, log_level=log_level)

Expand All @@ -225,27 +316,17 @@ def compute_sv(
if file_path.is_dir() and source.endswith(".zarr"):
status.update(
f"[blue] Computing Sv for {file_path}...[/blue] – navigate to http://localhost:8787/status for progress")
from oceanstream.process import compute_single_file

compute_single_file(config_data,
chunks=config_data.get('base_chunk_sizes'),
plot_echogram=plot_echogram,
waveform_mode=waveform_mode,
depth_offset=depth_offset)
# TODO: Implement export_location_json
elif file_path.is_dir():
status.update(
f"[blue] Processing zarr files in {file_path}...[/blue] – navigate to "
f"http://localhost:8787/status for progress")
from oceanstream.process import process_zarr_files
processed_count_var = Variable('processed_count', client)
process_zarr_files(config_data,
workers_count=workers_count,
status=status,
chunks=config_data.get('base_chunk_sizes'),
plot_echogram=plot_echogram,
waveform_mode=waveform_mode,
processed_count_var=processed_count_var,
depth_offset=depth_offset)
from oceanstream.process import export_location_from_zarr_files

export_location_from_zarr_files(config_data,
workers_count=workers_count,
client=client,
chunks=config_data.get('base_chunk_sizes'))
else:
print(f"[red]❌ The provided path '{source}' is not a valid Zarr root.[/red]")
sys.exit(1)
Expand All @@ -261,11 +342,6 @@ def compute_sv(
status.stop()


@app.command()
def export():
typer.echo("Export data...")


def main():
print(BANNER)
warnings.filterwarnings("ignore", category=UserWarning)
Expand Down
10 changes: 9 additions & 1 deletion oceanstream/echodata/sv_computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def compute_sv(echodata: EchoData, **kwargs) -> xr.Dataset:
raise ValueError(str(e))
# Check if the sonar model is supported
sonar_model = echodata.sonar_model

try:
SupportedSonarModelsForSv(sonar_model)
except ValueError:
Expand All @@ -103,6 +104,7 @@ def compute_sv(echodata: EchoData, **kwargs) -> xr.Dataset:
{list(SupportedSonarModelsForSv)}."
)
# Compute Sv
print(f"Computing Sv for {sonar_model} sonar model...", kwargs)
Sv = ep.calibrate.compute_Sv(echodata, **kwargs)
# Check if the computed Sv is empty
if Sv["Sv"].values.size == 0:
Expand All @@ -113,6 +115,12 @@ def compute_sv(echodata: EchoData, **kwargs) -> xr.Dataset:
def compute_sv_with_encode_mode(
echodata: EchoData, waveform_mode: str = "CW", encode_mode: str = "power"
) -> xr.Dataset:
sv_dataset = compute_sv(echodata, waveform_mode=waveform_mode, encode_mode=encode_mode)
try:
sv_dataset = compute_sv(echodata, waveform_mode=waveform_mode, encode_mode=encode_mode)
except Exception as e:
print(f"\n--------Error computing Sv with encode mode: {e}--------\n")
import traceback
traceback.print_exc()
raise e

return sv_dataset
35 changes: 18 additions & 17 deletions oceanstream/echodata/sv_dataset_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@
"""
import warnings
import numpy as np
import xarray as xr

from echopype.consolidate import add_location, add_splitbeam_angle
from echopype.echodata.echodata import EchoData
from xarray import Dataset
from echopype.echodata import EchoData


def enrich_sv_dataset(sv: Dataset, echodata: EchoData, **kwargs) -> Dataset:
def enrich_sv_dataset(sv: xr.Dataset, echodata: EchoData, **kwargs) -> xr.Dataset:
"""
Enhances the input `sv` dataset by adding depth, location, and split-beam angle information.
Expand Down Expand Up @@ -58,28 +58,26 @@ def enrich_sv_dataset(sv: Dataset, echodata: EchoData, **kwargs) -> Dataset:
"return_dataset",
]
splitbeam_args = {k: kwargs[k] for k in splitbeam_keys if k in kwargs}
enriched_sv = sv

try:
add_depth(enriched_sv, **depth_args)
add_depth(sv, **depth_args)
except Exception as e:
warnings.warn(f"Failed to add depth due to error: {str(e)}")

# try:
# enriched_sv = add_location(enriched_sv, echodata, **location_args)
# except Exception as e:
# warnings.warn(f"Failed to add location due to error: {str(e)}")
try:
sv = add_location(sv, echodata, **location_args)
except Exception as e:
warnings.warn(f"Failed to add location due to error: {str(e)}")

# try:
# add_splitbeam_angle(enriched_sv, echodata, **splitbeam_args)
# except Exception as e:
# warnings.warn(f"Failed to add split-beam angle due to error: {str(e)}")
# traceback.print_exc()
try:
add_splitbeam_angle(sv, echodata, **splitbeam_args)
except Exception as e:
warnings.warn(f"Failed to add split-beam angle due to error: {str(e)}")

return enriched_sv
return sv


def add_depth(Sv: Dataset, depth_offset: float = 0, tilt: float = 0, downward: bool = True):
def add_depth(Sv: xr.Dataset, depth_offset: float = 0, tilt: float = 0, downward: bool = True):
"""
Given an existing Sv dataset, it adds a data variable called depth containing the depth of
each ping.
Expand All @@ -101,11 +99,14 @@ def add_depth(Sv: Dataset, depth_offset: float = 0, tilt: float = 0, downward: b
selected_echo_range = selected_echo_range.values.tolist()
selected_echo_range = [mult * value * np.cos(tilt / 180 * np.pi) + depth_offset for value in selected_echo_range]
Sv = Sv.assign_coords(range_sample=selected_echo_range)
min_val = np.nanmin(selected_echo_range)
max_val = np.nanmax(selected_echo_range)
Sv = Sv.sel(range_sample=slice(min_val, max_val))

return Sv


def add_seabed_depth(Sv: Dataset):
def add_seabed_depth(Sv: xr.Dataset):
"""
Given an existing Sv dataset with a seabed mask attached, it adds a
data variable called seabed depth containing the location of the seabed on
Expand Down
19 changes: 9 additions & 10 deletions oceanstream/echodata/sv_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def linear_to_db(linear: xr.DataArray) -> xr.DataArray:


def interpolate_sv(
sv: Union[xr.Dataset, str, Path], method: str = "linear", with_edge_fill: bool = False
ds_Sv: xr.Dataset, method: str = "linear", with_edge_fill: bool = False
) -> xr.Dataset:
"""
Apply masks to the Sv DataArray in the dataset and interpolate over the resulting NaN values.
Expand All @@ -43,23 +43,22 @@ def interpolate_sv(
>> interpolate_sv(Sv, method)
Expected Output
"""
# Load the dataset
if isinstance(sv, xr.Dataset):
dataset = sv

# Initialize an empty list to store the processed channels
processed_channels = []

# Loop over each channel
for channel in sv_dataarray["channel"]:
channel_data = sv_dataarray.sel(channel=channel)
for channel in ds_Sv.Sv["channel"]:
channel_data = ds_Sv.Sv.sel(channel=channel)

# Convert from dB to linear scale
channel_data_linear = db_to_linear(channel_data)

# Perform interpolation to fill NaN values in linear scale using Xarray's interpolate_na
interpolated_channel_data_linear = channel_data_linear.interpolate_na(
dim="ping_time", method=method, use_coordinate=True
dim="ping_time",
method=method,
use_coordinate=True,
dask_gufunc_kwargs={"allow_rechunk": True},
)

if with_edge_fill:
Expand All @@ -76,9 +75,9 @@ def interpolate_sv(
interpolated_sv = xr.concat(processed_channels, dim="channel")

# Update the Sv DataArray in the dataset with the interpolated values
dataset["Sv"] = interpolated_sv
ds_Sv["Sv"] = interpolated_sv

return dataset
return ds_Sv


def find_impacted_variables(
Expand Down
8 changes: 6 additions & 2 deletions oceanstream/exports/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,6 @@
from .nasc import compute_and_write_nasc
from .shoals.shoals_handler import get_shoals_list, write_csv
from .csv.csv_export_from_Sv import create_location, export_location_json

__all__ = [
"create_location",
"export_location_json"
]
Loading

0 comments on commit 01063bc

Please sign in to comment.