Skip to content

Commit

Permalink
Catalog ingestion script: Parallel ingestion + multiple files at once (
Browse files Browse the repository at this point in the history
…#222)

* ingest multiple files in parallel
  • Loading branch information
Theodlz committed Jun 29, 2023
1 parent 8166f56 commit 50d1d67
Showing 1 changed file with 86 additions and 47 deletions.
133 changes: 86 additions & 47 deletions kowalski/ingesters/ingest_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,19 @@
This tool will catalogs with different formats (fits, csv, and parquet) to Kowalski
"""

import multiprocessing
import os
import pathlib
import random
import time
from typing import Sequence

import fire
import numpy as np
import pandas as pd
import pyarrow.parquet as pq
from astropy.io import fits
from tqdm import tqdm

import kowalski.tools.istarmap as istarmap # noqa: F401
from kowalski.config import load_config
Expand All @@ -23,33 +28,40 @@
init_db_sync(config=config)


def process_file(
file,
collection,
ra_col=None,
dec_col=None,
id_col=None,
batch_size=2048,
max_docs=None,
format="fits",
):
def process_file(argument_list: Sequence):
(
file,
collection,
ra_col,
dec_col,
id_col,
batch_size,
max_docs,
rm,
format,
) = argument_list

log(f"Processing {file}")
if format not in ("fits", "csv", "parquet"):
log("Format not supported")
return

# connect to MongoDB:
log("Connecting to DB")
mongo = Mongo(
host=config["database"]["host"],
port=config["database"]["port"],
replica_set=config["database"]["replica_set"],
username=config["database"]["username"],
password=config["database"]["password"],
db=config["database"]["db"],
srv=config["database"]["srv"],
verbose=0,
)
log("Successfully connected")
mongo = None
while True:
try:
mongo = Mongo(
host=config["database"]["host"],
port=config["database"]["port"],
replica_set=config["database"]["replica_set"],
username=config["database"]["username"],
password=config["database"]["password"],
db=config["database"]["db"],
srv=config["database"]["srv"],
verbose=0,
)
break
except Exception as e:
log(str(e))

# if the file is not an url
if not file.startswith("http"):
Expand All @@ -59,8 +71,6 @@ def process_file(
log(f"File {file} not found")
return

log(f"Processing {file}")

total_good_documents = 0
total_bad_documents = 0

Expand Down Expand Up @@ -117,8 +127,6 @@ def process_file(
np.arange(len(dataframe)) // batch_size
):

log(f"{file}: processing batch # {chunk_index + 1}")

for col, dtype in dataframe_chunk.dtypes.items():
if dtype == object:
dataframe_chunk[col] = dataframe_chunk[col].apply(
Expand Down Expand Up @@ -169,7 +177,6 @@ def process_file(
total_good_documents += len(batch) - len(bad_document_indexes)
if len(bad_document_indexes) > 0:
total_bad_documents += len(bad_document_indexes)
log("Removing bad docs")
for index in sorted(bad_document_indexes, reverse=True):
del batch[index]

Expand All @@ -189,7 +196,6 @@ def process_file(
last_chunk = True

names = list(dataframe_chunk.columns)
log(f"{file}: processing batch # {chunk_index + 1}")

if id_col is not None:
if id_col not in names:
Expand Down Expand Up @@ -268,7 +274,6 @@ def process_file(
total_good_documents += len(batch) - len(bad_document_indexes)
if len(bad_document_indexes) > 0:
total_bad_documents += len(bad_document_indexes)
log("Removing bad docs")
for index in sorted(bad_document_indexes, reverse=True):
del batch[index]

Expand Down Expand Up @@ -407,16 +412,16 @@ def convert_nparray_to_list(value):
)
# flush:
batch = []
except Exception as exception:
log(str(exception))
except Exception as e:
log(str(e))

while len(batch) > 0:
try:
mongo.insert_many(collection=collection, documents=batch)
# flush:
batch = []
except Exception as e:
log(e)
log(str(e))
log("Failed, waiting 5 seconds to retry")
time.sleep(5)

Expand All @@ -426,11 +431,19 @@ def convert_nparray_to_list(value):
# disconnect from db:
try:
mongo.client.close()
finally:
log("Successfully disconnected from db")
except Exception as e:
log(f"Failed to disconnect from db: {e}")

log(f"Total good documents: {total_good_documents}")
log(f"Total bad documents: {total_bad_documents}")
if total_good_documents + total_bad_documents == 0:
log("No documents ingested")
if total_bad_documents > 0:
log(f"Failed to ingest {total_bad_documents} documents")

if total_bad_documents == 0 and rm is True:
try:
os.remove(pathlib.Path(file))
except Exception as e:
log(f"Failed to remove original file: {e}")
return total_good_documents, total_bad_documents


Expand All @@ -440,8 +453,10 @@ def run(
ra_col: str = None,
dec_col: str = None,
id_col: str = None,
num_proc: int = multiprocessing.cpu_count(),
batch_size: int = 2048,
max_docs: int = None,
rm: bool = False,
format: str = "fits",
):
"""Pre-process and ingest catalog from fits files into Kowalski
Expand Down Expand Up @@ -469,16 +484,40 @@ def run(
[("coordinates.radec_geojson", "2dsphere"), ("_id", 1)], background=True
)

total_good_documents, total_bad_documents = process_file(
path,
catalog_name,
ra_col,
dec_col,
id_col,
batch_size,
max_docs,
format,
)
# grab all the files in the directory and subdirectories:
files = []
if os.path.isfile(path):
files.append(path)
else:
for root, dirnames, filenames in os.walk(path):
files += [os.path.join(root, f) for f in filenames if f.endswith(format)]

input_list = [
(
file,
catalog_name,
ra_col,
dec_col,
id_col,
batch_size,
max_docs,
rm,
format,
)
for file in sorted(files)
]
random.shuffle(input_list)

total_good_documents, total_bad_documents = 0, 0
log(f"Processing {len(files)} files with {num_proc} processes")
with multiprocessing.Pool(processes=num_proc) as pool:
for result in tqdm(pool.imap(process_file, input_list), total=len(files)):
total_good_documents += result[0]
total_bad_documents += result[1]

log(f"Successfully ingested {total_good_documents} documents")
log(f"Failed to ingest {total_bad_documents} documents")

return total_good_documents, total_bad_documents


Expand Down

0 comments on commit 50d1d67

Please sign in to comment.