From 50d1d67dc7259a90128fff06754a50590ed2cb21 Mon Sep 17 00:00:00 2001 From: Theophile du Laz Date: Thu, 29 Jun 2023 16:27:26 -0700 Subject: [PATCH] Catalog ingestion script: Parallel ingestion + multiple files at once (#222) * ingest multiple files in parallel --- kowalski/ingesters/ingest_catalog.py | 133 +++++++++++++++++---------- 1 file changed, 86 insertions(+), 47 deletions(-) diff --git a/kowalski/ingesters/ingest_catalog.py b/kowalski/ingesters/ingest_catalog.py index 30f92df7..2aba231b 100644 --- a/kowalski/ingesters/ingest_catalog.py +++ b/kowalski/ingesters/ingest_catalog.py @@ -2,14 +2,19 @@ This tool will catalogs with different formats (fits, csv, and parquet) to Kowalski """ +import multiprocessing +import os import pathlib +import random import time +from typing import Sequence import fire import numpy as np import pandas as pd import pyarrow.parquet as pq from astropy.io import fits +from tqdm import tqdm import kowalski.tools.istarmap as istarmap # noqa: F401 from kowalski.config import load_config @@ -23,33 +28,40 @@ init_db_sync(config=config) -def process_file( - file, - collection, - ra_col=None, - dec_col=None, - id_col=None, - batch_size=2048, - max_docs=None, - format="fits", -): +def process_file(argument_list: Sequence): + ( + file, + collection, + ra_col, + dec_col, + id_col, + batch_size, + max_docs, + rm, + format, + ) = argument_list + + log(f"Processing {file}") if format not in ("fits", "csv", "parquet"): log("Format not supported") return - # connect to MongoDB: - log("Connecting to DB") - mongo = Mongo( - host=config["database"]["host"], - port=config["database"]["port"], - replica_set=config["database"]["replica_set"], - username=config["database"]["username"], - password=config["database"]["password"], - db=config["database"]["db"], - srv=config["database"]["srv"], - verbose=0, - ) - log("Successfully connected") + mongo = None + while True: + try: + mongo = Mongo( + host=config["database"]["host"], + port=config["database"]["port"], + replica_set=config["database"]["replica_set"], + username=config["database"]["username"], + password=config["database"]["password"], + db=config["database"]["db"], + srv=config["database"]["srv"], + verbose=0, + ) + break + except Exception as e: + log(str(e)) # if the file is not an url if not file.startswith("http"): @@ -59,8 +71,6 @@ def process_file( log(f"File {file} not found") return - log(f"Processing {file}") - total_good_documents = 0 total_bad_documents = 0 @@ -117,8 +127,6 @@ def process_file( np.arange(len(dataframe)) // batch_size ): - log(f"{file}: processing batch # {chunk_index + 1}") - for col, dtype in dataframe_chunk.dtypes.items(): if dtype == object: dataframe_chunk[col] = dataframe_chunk[col].apply( @@ -169,7 +177,6 @@ def process_file( total_good_documents += len(batch) - len(bad_document_indexes) if len(bad_document_indexes) > 0: total_bad_documents += len(bad_document_indexes) - log("Removing bad docs") for index in sorted(bad_document_indexes, reverse=True): del batch[index] @@ -189,7 +196,6 @@ def process_file( last_chunk = True names = list(dataframe_chunk.columns) - log(f"{file}: processing batch # {chunk_index + 1}") if id_col is not None: if id_col not in names: @@ -268,7 +274,6 @@ def process_file( total_good_documents += len(batch) - len(bad_document_indexes) if len(bad_document_indexes) > 0: total_bad_documents += len(bad_document_indexes) - log("Removing bad docs") for index in sorted(bad_document_indexes, reverse=True): del batch[index] @@ -407,8 +412,8 @@ def convert_nparray_to_list(value): ) # flush: batch = [] - except Exception as exception: - log(str(exception)) + except Exception as e: + log(str(e)) while len(batch) > 0: try: @@ -416,7 +421,7 @@ def convert_nparray_to_list(value): # flush: batch = [] except Exception as e: - log(e) + log(str(e)) log("Failed, waiting 5 seconds to retry") time.sleep(5) @@ -426,11 +431,19 @@ def convert_nparray_to_list(value): # disconnect from db: try: mongo.client.close() - finally: - log("Successfully disconnected from db") + except Exception as e: + log(f"Failed to disconnect from db: {e}") - log(f"Total good documents: {total_good_documents}") - log(f"Total bad documents: {total_bad_documents}") + if total_good_documents + total_bad_documents == 0: + log("No documents ingested") + if total_bad_documents > 0: + log(f"Failed to ingest {total_bad_documents} documents") + + if total_bad_documents == 0 and rm is True: + try: + os.remove(pathlib.Path(file)) + except Exception as e: + log(f"Failed to remove original file: {e}") return total_good_documents, total_bad_documents @@ -440,8 +453,10 @@ def run( ra_col: str = None, dec_col: str = None, id_col: str = None, + num_proc: int = multiprocessing.cpu_count(), batch_size: int = 2048, max_docs: int = None, + rm: bool = False, format: str = "fits", ): """Pre-process and ingest catalog from fits files into Kowalski @@ -469,16 +484,40 @@ def run( [("coordinates.radec_geojson", "2dsphere"), ("_id", 1)], background=True ) - total_good_documents, total_bad_documents = process_file( - path, - catalog_name, - ra_col, - dec_col, - id_col, - batch_size, - max_docs, - format, - ) + # grab all the files in the directory and subdirectories: + files = [] + if os.path.isfile(path): + files.append(path) + else: + for root, dirnames, filenames in os.walk(path): + files += [os.path.join(root, f) for f in filenames if f.endswith(format)] + + input_list = [ + ( + file, + catalog_name, + ra_col, + dec_col, + id_col, + batch_size, + max_docs, + rm, + format, + ) + for file in sorted(files) + ] + random.shuffle(input_list) + + total_good_documents, total_bad_documents = 0, 0 + log(f"Processing {len(files)} files with {num_proc} processes") + with multiprocessing.Pool(processes=num_proc) as pool: + for result in tqdm(pool.imap(process_file, input_list), total=len(files)): + total_good_documents += result[0] + total_bad_documents += result[1] + + log(f"Successfully ingested {total_good_documents} documents") + log(f"Failed to ingest {total_bad_documents} documents") + return total_good_documents, total_bad_documents