diff --git a/mirar/database/base_model.py b/mirar/database/base_model.py index 7706f1c8b..87c2a97c0 100644 --- a/mirar/database/base_model.py +++ b/mirar/database/base_model.py @@ -74,8 +74,8 @@ def validate_sql(cls, value: Any, info: FieldValidationInfo) -> Any: def _insert_entry( self, + duplicate_protocol: str, returning_key_names: str | list[str] = None, - duplicate_protocol: str = "ignore", ) -> pd.DataFrame: """ Insert the pydantic-ified data into the corresponding sql database @@ -186,15 +186,21 @@ def get_available_unique_keys(self) -> list[Column]: return [x for x in self.get_unique_keys() if x.name in self.model_fields] def insert_entry( - self, returning_key_names: str | list[str] | None = None + self, + duplicate_protocol: str, + returning_key_names: str | list[str] | None = None, ) -> pd.DataFrame: """ Insert the pydantic-ified data into the corresponding sql database + :param duplicate_protocol: protocol to follow if duplicate entry is found :param returning_key_names: names of the keys to return :return: dataframe of the sequence keys """ - result = self._insert_entry(returning_key_names=returning_key_names) + result = self._insert_entry( + duplicate_protocol=duplicate_protocol, + returning_key_names=returning_key_names, + ) logger.debug(f"Return result {result}") return result @@ -218,6 +224,9 @@ def _update_entry(self, update_key_names: list[str] | str | None = None): full_dict = self.model_dump() + if update_key_names is None: + update_key_names = full_dict.keys() + update_dict = {key: full_dict[key] for key in update_key_names} _update_database_entry( diff --git a/mirar/pipelines/summer/models/_exposures.py b/mirar/pipelines/summer/models/_exposures.py index a2df633c2..d1a6c88f1 100644 --- a/mirar/pipelines/summer/models/_exposures.py +++ b/mirar/pipelines/summer/models/_exposures.py @@ -163,21 +163,28 @@ def validate_fid(cls, nightdate: str) -> datetime: """ return datetime.strptime(nightdate, SUMMER_NIGHT_FORMAT) - def insert_entry(self, returning_key_names=None) -> pd.DataFrame: + def insert_entry( + self, duplicate_protocol: str, returning_key_names=None + ) -> pd.DataFrame: """ Insert the pydantic-ified data into the corresponding sql database - :return: None + :param duplicate_protocol: protocol to follow if duplicate entry is found + :param returning_key_names: names of keys to return + :return: DataFrame of returning keys """ night = Night(nightdate=self.nightdate) if not night.exists(): - night.insert_entry() + night.insert_entry(duplicate_protocol="ignore") logger.debug(f"puid: {self.puid}") if not Program._exists(values=self.puid, keys="puid"): self.puid = 1 - return self._insert_entry() + return self._insert_entry( + duplicate_protocol=duplicate_protocol, + returning_key_names=returning_key_names, + ) def exists(self) -> bool: """ diff --git a/mirar/pipelines/summer/models/_filters.py b/mirar/pipelines/summer/models/_filters.py index 35ba394b1..d3b92618a 100644 --- a/mirar/pipelines/summer/models/_filters.py +++ b/mirar/pipelines/summer/models/_filters.py @@ -71,4 +71,4 @@ def populate_filters(filter_map: dict = None): for filter_name, fid in filter_map.items(): summer_filter = Filter(fid=fid, filtername=filter_name) if not summer_filter.exists(): - summer_filter.insert_entry() + summer_filter.insert_entry(duplicate_protocol="ignore") diff --git a/mirar/pipelines/summer/models/_img_type.py b/mirar/pipelines/summer/models/_img_type.py index 2946fbf36..93785aef1 100644 --- a/mirar/pipelines/summer/models/_img_type.py +++ b/mirar/pipelines/summer/models/_img_type.py @@ -57,4 +57,4 @@ def populate_itid(): for i, img_type in enumerate(ALL_ITID): itid = ImgType(itid=i + 1, imgtype=img_type) if not itid.exists(): - itid.insert_entry() + itid.insert_entry(duplicate_protocol="ignore") diff --git a/mirar/pipelines/summer/models/_programs.py b/mirar/pipelines/summer/models/_programs.py index 832334231..70221975b 100644 --- a/mirar/pipelines/summer/models/_programs.py +++ b/mirar/pipelines/summer/models/_programs.py @@ -132,4 +132,4 @@ def populate_programs(): """ if not default_program.exists(): - default_program.insert_entry() + default_program.insert_entry(duplicate_protocol="ignore") diff --git a/mirar/pipelines/summer/models/_subdets.py b/mirar/pipelines/summer/models/_subdets.py index b7a256a56..a2a4ecf5d 100644 --- a/mirar/pipelines/summer/models/_subdets.py +++ b/mirar/pipelines/summer/models/_subdets.py @@ -68,4 +68,4 @@ def populate_subdets(ndetectors: int = 1, nxtot: int = 1, nytot: int = 1): nxtot=nxtot, nytot=nytot, ) - new.insert_entry() + new.insert_entry(duplicate_protocol="fail") diff --git a/mirar/pipelines/winter/blocks.py b/mirar/pipelines/winter/blocks.py index 203f55c32..b60f0ba0c 100644 --- a/mirar/pipelines/winter/blocks.py +++ b/mirar/pipelines/winter/blocks.py @@ -13,6 +13,7 @@ MAX_DITHER_KEY, OBSCLASS_KEY, RAW_IMG_KEY, + SOURCE_HISTORY_KEY, SOURCE_NAME_KEY, TARGET_KEY, ZP_KEY, @@ -44,12 +45,14 @@ winter_candidate_quality_filterer, winter_fourier_filtered_image_generator, winter_history_deprecated_constraint, + winter_new_source_updater, winter_photometric_catalog_generator, winter_photometric_ref_catalog_namer, winter_reference_generator, winter_reference_image_resampler_for_zogy, winter_reference_psfex, winter_reference_sextractor, + winter_source_entry_updater, winter_stackid_annotator, ) from mirar.pipelines.winter.load_winter_image import ( @@ -61,14 +64,15 @@ load_winter_stack, ) from mirar.pipelines.winter.models import ( - CANDIDATE_PREFIX, DEFAULT_FIELD, NAME_START, + SOURCE_PREFIX, AstrometryStat, Candidate, Diff, Exposure, Raw, + Source, Stack, ) from mirar.pipelines.winter.validator import ( @@ -94,8 +98,8 @@ DatabaseSourceInserter, ) from mirar.processors.database.database_selector import ( - CrossmatchSourceWithDatabase, - DatabaseHistorySelector, + SelectSourcesWithMetadata, + SingleSpatialCrossmatchSource, ) from mirar.processors.database.database_updater import ImageDatabaseMultiEntryUpdater from mirar.processors.flat import FlatCalibrator @@ -550,28 +554,47 @@ XMatch(catalog=TMASS(num_sources=3, search_radius_arcmin=0.5)), XMatch(catalog=PS1(num_sources=3, search_radius_arcmin=0.5)), SourceWriter(output_dir_name="kowalski"), - CrossmatchSourceWithDatabase( - db_table=Candidate, - db_output_columns=[SOURCE_NAME_KEY], + # Check if the source is already in the source table + SingleSpatialCrossmatchSource( + db_table=Source, + db_output_columns=["sourceid", SOURCE_NAME_KEY], crossmatch_radius_arcsec=2.0, - max_num_results=1, + ra_field_name="average_ra", + dec_field_name="average_dec", ), + # Assign names to the new sources CandidateNamer( - db_table=Candidate, - base_name=CANDIDATE_PREFIX, + db_table=Source, + base_name=SOURCE_PREFIX, name_start=NAME_START, + db_name_field=SOURCE_NAME_KEY, + db_order_field="sourceid", ), - DatabaseHistorySelector( - crossmatch_radius_arcsec=2.0, - time_field_name="jd", - history_duration_days=500.0, + # Add the new sources to the source table + CustomSourceTableModifier(modifier_function=winter_new_source_updater), + DatabaseSourceInserter( + db_table=Source, + duplicate_protocol="ignore", + ), + # Get all candidates associated with source + SelectSourcesWithMetadata( + db_query_columns=["sourceid"], db_table=Candidate, db_output_columns=prv_candidate_cols + [SOURCE_NAME_KEY], + base_output_column=SOURCE_HISTORY_KEY, additional_query_constraints=winter_history_deprecated_constraint, ), CustomSourceTableModifier( modifier_function=winter_candidate_avro_fields_calculator ), + # Update average ra and dec for source + CustomSourceTableModifier(modifier_function=winter_source_entry_updater), + # Update sources in the source table + DatabaseSourceInserter( + db_table=Source, + duplicate_protocol="replace", + ), + # Add candidates in the candidate table DatabaseSourceInserter( db_table=Candidate, duplicate_protocol="fail", diff --git a/mirar/pipelines/winter/generator.py b/mirar/pipelines/winter/generator.py index 097024e4b..befabe82c 100644 --- a/mirar/pipelines/winter/generator.py +++ b/mirar/pipelines/winter/generator.py @@ -27,6 +27,7 @@ OBSCLASS_KEY, REF_CAT_PATH_KEY, SATURATE_KEY, + SOURCE_HISTORY_KEY, TIME_KEY, ZP_KEY, ZP_STD_KEY, @@ -419,6 +420,62 @@ def winter_candidate_annotator_filterer(source_batch: SourceBatch) -> SourceBatc return new_batch +def winter_new_source_updater(source_table: SourceBatch) -> SourceBatch: + """ + Function to add relevant fields for new sources + + :param source_table: Original source table + :return: Updated source table + """ + for source in source_table: + src_df = source.get_data() + + src_df["ndet"] = 1 + src_df["average_ra"] = src_df["ra"] + src_df["average_dec"] = src_df["dec"] + + source.set_data(src_df) + + return source_table + + +def winter_source_entry_updater(source_table: SourceBatch) -> SourceBatch: + """ + Function to update the source table with new source averages + + :param source_table: Original source table + :return: Updated source table + """ + for source in source_table: + src_df = source.get_data() + + hist_dfs = [ + pd.DataFrame(src_df[SOURCE_HISTORY_KEY].loc[x]) for x in range(len(src_df)) + ] + + src_df["ndet"] = [len(x) + 1 for x in hist_dfs] + + average_ras, average_decs = [], [] + + for i, hist_df in enumerate(hist_dfs): + if len(hist_df) == 0: + average_ras.append(src_df["ra"].iloc[i]) + average_decs.append(src_df["dec"].iloc[i]) + else: + average_ras.append( + np.mean(hist_df["ra"].tolist() + [src_df["ra"].iloc[i]]) + ) + average_decs.append( + np.mean(hist_df["dec"].tolist() + [src_df["dec"].iloc[i]]) + ) + + src_df["average_ra"] = average_ras + src_df["average_dec"] = average_decs + source.set_data(src_df) + + return source_table + + def winter_candidate_avro_fields_calculator(source_table: SourceBatch) -> SourceBatch: """ Function to calculate the AVRO fields for WINTER @@ -433,7 +490,7 @@ def winter_candidate_avro_fields_calculator(source_table: SourceBatch) -> Source src_df["magfromlim"] = source["diffmaglim"] - src_df["magpsf"] hist_dfs = [ - pd.DataFrame(src_df["prv_candidates"].loc[x]) for x in range(len(src_df)) + pd.DataFrame(src_df[SOURCE_HISTORY_KEY].loc[x]) for x in range(len(src_df)) ] jdstarthists, jdendhists = [], [] @@ -591,10 +648,11 @@ def winter_reference_generator(image: Image): stack_image_annotator=winter_reference_stack_annotator, ) - if filtername == "Y": - # Use PS1 references for Y-band - logger.debug("Will query reference image from PS1") - return PS1Ref(filter_name=filtername) + assert filtername == "Y", f"Filter {filtername} not recognized for WINTER" + + # Use PS1 references for Y-band + logger.debug("Will query reference image from PS1") + return PS1Ref(filter_name=filtername) winter_history_deprecated_constraint = DBQueryConstraints( diff --git a/mirar/pipelines/winter/models/__init__.py b/mirar/pipelines/winter/models/__init__.py index 439e90335..3f20467ae 100644 --- a/mirar/pipelines/winter/models/__init__.py +++ b/mirar/pipelines/winter/models/__init__.py @@ -12,12 +12,7 @@ AstrometryStat, AstrometryStatsTable, ) -from mirar.pipelines.winter.models._candidates import ( - CANDIDATE_PREFIX, - NAME_START, - Candidate, - CandidatesTable, -) +from mirar.pipelines.winter.models._candidates import Candidate, CandidatesTable from mirar.pipelines.winter.models._diff import Diff, DiffsTable from mirar.pipelines.winter.models._exposures import Exposure, ExposuresTable from mirar.pipelines.winter.models._fields import ( @@ -55,6 +50,13 @@ ) from mirar.pipelines.winter.models._ref_queries import RefQueriesTable, RefQuery from mirar.pipelines.winter.models._ref_stacks import RefStack, RefStacksTable +from mirar.pipelines.winter.models._sources import ( + MIN_NAME_LENGTH, + NAME_START, + SOURCE_PREFIX, + Source, + SourcesTable, +) from mirar.pipelines.winter.models._stack import Stack, StacksTable from mirar.pipelines.winter.models._subdets import ( Subdet, @@ -85,7 +87,14 @@ def set_up_q3c(db_name: str, db_table: BaseTable): if DB_USER is not None: setup_database(db_base=WinterBase) - for table in [ExposuresTable, CandidatesTable, RefQueriesTable, StacksTable]: + for table in [ + ExposuresTable, + AstrometryStatsTable, + CandidatesTable, + RefQueriesTable, + StacksTable, + SourcesTable, + ]: set_up_q3c(db_name=WinterBase.db_name, db_table=table) populate_fields() diff --git a/mirar/pipelines/winter/models/_candidates.py b/mirar/pipelines/winter/models/_candidates.py index 86b3f744d..c22ad7016 100644 --- a/mirar/pipelines/winter/models/_candidates.py +++ b/mirar/pipelines/winter/models/_candidates.py @@ -59,10 +59,13 @@ class CandidatesTable(WinterBase): # pylint: disable=too-few-public-methods # Image properties - diffid: Mapped[int] = mapped_column(ForeignKey("diffs.diffid")) # FIXME + sourceid: Mapped[int] = mapped_column(ForeignKey("sources.sourceid")) + source: Mapped["SourcesTable"] = relationship(back_populates="candidates") + + diffid: Mapped[int] = mapped_column(ForeignKey("diffs.diffid")) diff_id: Mapped["DiffsTable"] = relationship(back_populates="candidates") - stackid: Mapped[int] = mapped_column(ForeignKey("stacks.stackid")) # FIXME + stackid: Mapped[int] = mapped_column(ForeignKey("stacks.stackid")) stack_id: Mapped["StacksTable"] = relationship(back_populates="candidates") fid: Mapped[int] = mapped_column(ForeignKey("filters.fid")) @@ -212,6 +215,8 @@ class Candidate(BaseDB): objectid: str = Field(min_length=MIN_NAME_LENGTH) deprecated: bool = Field(default=False) + sourceid: int = Field(ge=0) + jd: float = Field(ge=0) diffid: int | None = Field(ge=0, default=None) @@ -327,10 +332,14 @@ class Candidate(BaseDB): maggaia: float | None = Field(default=None) maggaiabright: float | None = Field(default=None) - def insert_entry(self, returning_key_names=None) -> pd.DataFrame: + def insert_entry( + self, duplicate_protocol, returning_key_names=None + ) -> pd.DataFrame: """ Insert the pydantic-ified data into the corresponding sql database + :param duplicate_protocol: protocol to follow if duplicate entry is found + :param returning_key_names: names of the keys to return :return: None """ @@ -345,4 +354,7 @@ def insert_entry(self, returning_key_names=None) -> pd.DataFrame: ) self.progname = default_program.progname - return self._insert_entry(returning_key_names=returning_key_names) + return self._insert_entry( + duplicate_protocol=duplicate_protocol, + returning_key_names=returning_key_names, + ) diff --git a/mirar/pipelines/winter/models/_diff.py b/mirar/pipelines/winter/models/_diff.py index 0243ef6a4..75d239dfc 100644 --- a/mirar/pipelines/winter/models/_diff.py +++ b/mirar/pipelines/winter/models/_diff.py @@ -74,11 +74,15 @@ def validate_savepath(cls, savepath: str) -> str: return savepath def insert_entry( - self, returning_key_names: str | list[str] | None = None + self, + duplicate_protocol: str, + returning_key_names: str | list[str] | None = None, ) -> pd.DataFrame: """ Insert entry into database + :param returning_key_names: names of keys to return + :param duplicate_protocol: protocol to follow if duplicate entry is found :return: dataframe of inserted entries """ dbconstraints = DBQueryConstraints() @@ -92,4 +96,7 @@ def insert_entry( db_constraints=dbconstraints, ) - return self._insert_entry(returning_key_names=returning_key_names) + return self._insert_entry( + duplicate_protocol=duplicate_protocol, + returning_key_names=returning_key_names, + ) diff --git a/mirar/pipelines/winter/models/_exposures.py b/mirar/pipelines/winter/models/_exposures.py index d6f846f79..19d525d2c 100644 --- a/mirar/pipelines/winter/models/_exposures.py +++ b/mirar/pipelines/winter/models/_exposures.py @@ -136,16 +136,20 @@ class Exposure(BaseDB): altitude: float = alt_field azimuth: float = az_field - def insert_entry(self, returning_key_names=None) -> pd.DataFrame: + def insert_entry( + self, duplicate_protocol: str, returning_key_names=None + ) -> pd.DataFrame: """ Insert the pydantic-ified data into the corresponding sql database + :param duplicate_protocol: protocol to follow if duplicate entry is found + :param returning_key_names: names of the keys to return :return: None """ night = Night(nightdate=self.nightdate) logger.debug(f"Searched for night {self.nightdate}") if not night.exists(): - night.insert_entry() + night.insert_entry(duplicate_protocol="ignore") prog_match = select_from_table( DBQueryConstraints(columns="progname", accepted_values=self.progname), @@ -158,7 +162,10 @@ def insert_entry(self, returning_key_names=None) -> pd.DataFrame: ) self.progname = default_program.progname - return self._insert_entry(returning_key_names=returning_key_names) + return self._insert_entry( + duplicate_protocol=duplicate_protocol, + returning_key_names=returning_key_names, + ) def exists(self) -> bool: """ diff --git a/mirar/pipelines/winter/models/_filters.py b/mirar/pipelines/winter/models/_filters.py index 0e64d242b..202cb982a 100644 --- a/mirar/pipelines/winter/models/_filters.py +++ b/mirar/pipelines/winter/models/_filters.py @@ -70,4 +70,4 @@ def populate_filters(filter_map: dict = None): for filter_name, fid in filter_map.items(): winter_filter = Filter(fid=fid, filtername=filter_name) if not winter_filter.exists(): - winter_filter.insert_entry() + winter_filter.insert_entry(duplicate_protocol="ignore") diff --git a/mirar/pipelines/winter/models/_img_type.py b/mirar/pipelines/winter/models/_img_type.py index eec483fa2..1f7079ba5 100644 --- a/mirar/pipelines/winter/models/_img_type.py +++ b/mirar/pipelines/winter/models/_img_type.py @@ -59,4 +59,4 @@ def populate_itid(): for ind, imgtype in enumerate(ALL_ITID): itid = ImgType(itid=ind + 1, imgtype=imgtype) if not itid.exists(): - itid.insert_entry() + itid.insert_entry(duplicate_protocol="ignore") diff --git a/mirar/pipelines/winter/models/_programs.py b/mirar/pipelines/winter/models/_programs.py index 70118263b..baf7b08bd 100644 --- a/mirar/pipelines/winter/models/_programs.py +++ b/mirar/pipelines/winter/models/_programs.py @@ -134,4 +134,4 @@ def populate_programs(): """ if not default_program.exists(): - default_program.insert_entry() + default_program.insert_entry(duplicate_protocol="ignore") diff --git a/mirar/pipelines/winter/models/_sources.py b/mirar/pipelines/winter/models/_sources.py new file mode 100644 index 000000000..f042718d7 --- /dev/null +++ b/mirar/pipelines/winter/models/_sources.py @@ -0,0 +1,69 @@ +""" +Models for the 'sources' table +""" + +import logging +from typing import ClassVar, List + +from pydantic import Field +from sqlalchemy import VARCHAR, BigInteger, Boolean, Column, Float, Integer, Sequence +from sqlalchemy.orm import Mapped, relationship + +from mirar.database.base_model import BaseDB, dec_field, ra_field +from mirar.pipelines.winter.models.base_model import WinterBase + +logger = logging.getLogger(__name__) + +SOURCE_PREFIX = "WNTR" +NAME_START = "aaaaa" + +MIN_NAME_LENGTH = len(SOURCE_PREFIX) + len(NAME_START) + 2 + + +class SourcesTable(WinterBase): # pylint: disable=too-few-public-methods + """ + Sources table in database + """ + + __tablename__ = "sources" + __table_args__ = {"extend_existing": True} + + # extra avro_path, diff img foreign key etc + + # Core fields + sourceid = Column( + BigInteger, + Sequence(name="sources_candid_seq", start=1, increment=1), + unique=True, + autoincrement=True, + primary_key=True, + ) + objectid = Column(VARCHAR(40), nullable=False, unique=True) + deprecated = Column(Boolean, nullable=False, default=False) + + # Positional properties + + average_ra = Column(Float) + average_dec = Column(Float) + ra_column_name = "average_ra" + dec_column_name = "average_dec" + + ndet = Column(Integer, nullable=False, default=1) + + candidates: Mapped[List["CandidatesTable"]] = relationship(back_populates="source") + + +class Source(BaseDB): + """ + A pydantic model for a source database entry + """ + + sql_model: ClassVar = SourcesTable + + objectid: str = Field(min_length=MIN_NAME_LENGTH) + deprecated: bool = Field(default=False) + + average_ra: float = ra_field + average_dec: float = dec_field + + ndet: int = Field(ge=1, description="Number of detections", default=1) diff --git a/mirar/pipelines/winter/models/_subdets.py b/mirar/pipelines/winter/models/_subdets.py index 6af789067..19acd2823 100644 --- a/mirar/pipelines/winter/models/_subdets.py +++ b/mirar/pipelines/winter/models/_subdets.py @@ -60,4 +60,4 @@ def populate_subdets(): ny=row["ny"], nytot=row["nytot"], ) - new.insert_entry() + new.insert_entry(duplicate_protocol="ignore") diff --git a/mirar/pipelines/wirc/blocks.py b/mirar/pipelines/wirc/blocks.py index c6825ccd9..eb4af9105 100644 --- a/mirar/pipelines/wirc/blocks.py +++ b/mirar/pipelines/wirc/blocks.py @@ -55,8 +55,8 @@ from mirar.processors.dark import DarkCalibrator from mirar.processors.database.database_inserter import DatabaseSourceInserter from mirar.processors.database.database_selector import ( - CrossmatchSourceWithDatabase, DatabaseHistorySelector, + SpatialCrossmatchSourceWithDatabase, ) from mirar.processors.flat import SkyFlatCalibrator from mirar.processors.mask import ( @@ -252,7 +252,7 @@ XMatch(catalog=TMASS(num_sources=3, search_radius_arcmin=0.5)), XMatch(catalog=PS1(num_sources=3, search_radius_arcmin=0.5)), SourceWriter(output_dir_name="kowalski"), - CrossmatchSourceWithDatabase( + SpatialCrossmatchSourceWithDatabase( db_table=Candidate, db_output_columns=[SOURCE_NAME_KEY], crossmatch_radius_arcsec=2.0, diff --git a/mirar/processors/base_processor.py b/mirar/processors/base_processor.py index 7a557392d..fd382cd1d 100644 --- a/mirar/processors/base_processor.py +++ b/mirar/processors/base_processor.py @@ -555,4 +555,5 @@ def generate_super_dict(metadata: dict, source_row: pd.Series) -> dict: super_dict.update( {key.lower(): val for key, val in source_row.to_dict().items()} ) + super_dict.update({key.upper(): val for key, val in super_dict.items()}) return super_dict diff --git a/mirar/processors/database/__init__.py b/mirar/processors/database/__init__.py index 09221f111..1507475e2 100644 --- a/mirar/processors/database/__init__.py +++ b/mirar/processors/database/__init__.py @@ -8,7 +8,9 @@ ) from mirar.processors.database.database_selector import ( BaseDatabaseSourceSelector, - CrossmatchSourceWithDatabase, DatabaseHistorySelector, + SelectSourcesWithMetadata, + SingleSpatialCrossmatchSource, + SpatialCrossmatchSourceWithDatabase, ) from mirar.processors.database.database_updater import ImageDatabaseUpdater diff --git a/mirar/processors/database/base_database_processor.py b/mirar/processors/database/base_database_processor.py index bf2872d80..805792ae3 100644 --- a/mirar/processors/database/base_database_processor.py +++ b/mirar/processors/database/base_database_processor.py @@ -7,7 +7,6 @@ from typing import Type from mirar.database.base_model import BaseDB -from mirar.database.constants import POSTGRES_DUPLICATE_PROTOCOLS from mirar.database.user import PostgresAdmin, PostgresUser from mirar.processors.base_processor import BaseProcessor @@ -24,14 +23,10 @@ def __init__( db_table: Type[BaseDB], pg_user: PostgresUser = PostgresUser(), pg_admin: PostgresAdmin = PostgresAdmin(), - duplicate_protocol: str = "fail", ): super().__init__() self.db_table = db_table self.db_name = self.db_table.sql_model.db_name - self.duplicate_protocol = duplicate_protocol self.pg_user = pg_user self._pg_admin = pg_admin - - assert self.duplicate_protocol in POSTGRES_DUPLICATE_PROTOCOLS diff --git a/mirar/processors/database/database_inserter.py b/mirar/processors/database/database_inserter.py index c7405c49d..0f062c07b 100644 --- a/mirar/processors/database/database_inserter.py +++ b/mirar/processors/database/database_inserter.py @@ -9,6 +9,7 @@ import pandas as pd from mirar.data import Image, ImageBatch, SourceBatch +from mirar.database.constants import POSTGRES_DUPLICATE_PROTOCOLS from mirar.errors.exceptions import BaseProcessorError from mirar.processors.base_processor import ( BaseImageProcessor, @@ -35,6 +36,14 @@ class BaseDatabaseInserter(BaseDatabaseProcessor, ABC): base_key = "dbinserter" max_n_cpu = 1 + def __init__(self, *args, duplicate_protocol: str = "fail", **kwargs): + super().__init__(*args, **kwargs) + self.duplicate_protocol = duplicate_protocol + + assert ( + self.duplicate_protocol in POSTGRES_DUPLICATE_PROTOCOLS + ), f"Invalid duplicate protocol, must be one of {POSTGRES_DUPLICATE_PROTOCOLS}" + def __str__(self): return ( f"Processor to save " @@ -54,7 +63,7 @@ def _apply_to_images(self, batch: ImageBatch) -> ImageBatch: val_dict = self.generate_value_dict(image) new = self.db_table(**val_dict) - res = new.insert_entry() + res = new.insert_entry(duplicate_protocol=self.duplicate_protocol) assert len(res) == 1 @@ -88,7 +97,7 @@ def _apply_to_sources(self, batch: SourceBatch) -> SourceBatch: super_dict = self.generate_super_dict(metadata, source_row) new = self.db_table(**super_dict) - res = new.insert_entry() + res = new.insert_entry(duplicate_protocol=self.duplicate_protocol) assert len(res) == 1 @@ -138,7 +147,7 @@ def _apply_to_images(self, batch: ImageBatch) -> ImageBatch: val_dict = {key.lower(): image[key] for key in image.keys()} new = self.db_table(**val_dict) - res = new.insert_entry() + res = new.insert_entry(duplicate_protocol=self.duplicate_protocol) assert len(res) == 1 diff --git a/mirar/processors/database/database_selector.py b/mirar/processors/database/database_selector.py index d4f3a9f2d..1fadc759d 100644 --- a/mirar/processors/database/database_selector.py +++ b/mirar/processors/database/database_selector.py @@ -12,7 +12,7 @@ from mirar.data import DataBlock, Image, ImageBatch, SourceBatch from mirar.database.constraints import DBQueryConstraints from mirar.database.transactions import select_from_table -from mirar.paths import SOURCE_HISTORY_KEY, SOURCE_XMATCH_KEY +from mirar.paths import SOURCE_HISTORY_KEY from mirar.processors.base_processor import BaseImageProcessor, BaseSourceProcessor from mirar.processors.database.base_database_processor import BaseDatabaseProcessor @@ -101,7 +101,7 @@ def _apply_to_images( return batch -class CrossmatchDatabaseWithHeader(BaseImageDatabaseSelector): +class BaseValuesCrossmatch(BaseDatabaseSelector, ABC): """Processor to crossmatch to a database""" def __init__(self, db_query_columns: str | list[str], *args, **kwargs): @@ -136,6 +136,12 @@ def get_constraints(self, data: dict) -> DBQueryConstraints: return query_constraints +class CrossmatchDatabaseWithHeader(BaseImageDatabaseSelector, BaseValuesCrossmatch): + """ + Processor to crossmatch to a database using keys + """ + + class BaseDatabaseSourceSelector(BaseDatabaseSelector, BaseSourceProcessor, ABC): """ Base Class for dataframe DB importers @@ -194,14 +200,79 @@ def _apply_to_sources( ) results.append(res) + new_table = self.update_dataframe(candidate_table, results) source_table.set_data(new_table) return batch -class CrossmatchSourceWithDatabase(BaseDatabaseSourceSelector, BaseSourceProcessor): +class DatabaseSingleMatchSelector(BaseDatabaseSourceSelector, ABC): """ - Processor to crossmatch to sources in a database + Processor to import a single match from a database + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, max_num_results=1, **kwargs) + + def update_dataframe( + self, candidate_table: pd.DataFrame, results: list[pd.DataFrame] + ) -> pd.DataFrame: + """ + Update a dataframe with db results + + :param candidate_table: pandas table + :param results: results from db query + :return: updated dataframe + """ + assert len(results) == len(candidate_table) + + new_cols = [] + for res in results: + if len(res) > 0: + assert len(res) == 1 + new_row = {x: res.iloc[0][x] for x in self.db_output_columns} + + else: + new_row = {x: None for x in self.db_output_columns} + + new_cols.append(new_row) + + candidate_table = candidate_table.join(pd.DataFrame(new_cols)) + + return candidate_table + + +class DatabaseMultimatchSelector(BaseDatabaseSourceSelector, ABC): + """ + Processor to import multiple matches from a database + """ + + def __init__(self, *args, base_output_column: str = SOURCE_HISTORY_KEY, **kwargs): + self.base_output_column = base_output_column + super().__init__(*args, **kwargs) + + def update_dataframe( + self, + candidate_table: pd.DataFrame, + results: list[pd.DataFrame], + ) -> pd.DataFrame: + """ + Update a pandas dataframe with the number of matches + + :param candidate_table: Pandas dataframe + :param results: db query results + :return: updated pandas dataframe + """ + assert len(results) == len(candidate_table) + candidate_table[self.base_output_column] = [ + x.to_dict(orient="records") for x in results + ] + return candidate_table + + +class BaseSpatialCrossmatchSource(BaseDatabaseSourceSelector, ABC): + """ + Processor to crossmatch to sources in a database using spatial search """ def __init__( @@ -212,7 +283,6 @@ def __init__( order_field_name: Optional[str] = None, order_ascending: bool = False, query_dist: bool = False, - output_df_colname: str = SOURCE_XMATCH_KEY, **kwargs, ): super().__init__(**kwargs) @@ -222,40 +292,18 @@ def __init__( self.order_field_name = order_field_name self.order_ascending = order_ascending self.query_dist = query_dist - self.output_df_colname = output_df_colname - - def update_dataframe( - self, - candidate_table: pd.DataFrame, - results: list[pd.DataFrame], - ) -> pd.DataFrame: - """ - Update a pandas dataframe with the number of previous detections - - :param candidate_table: Pandas dataframe - :param results: db query results - :return: updated pandas dataframe - """ - assert len(results) == len(candidate_table) - candidate_table[self.output_df_colname] = [ - x.to_dict(orient="records") for x in results - ] - - return candidate_table - def get_source_crossmatch_constraints( - self, source: pd.Series - ) -> DBQueryConstraints: + def get_source_crossmatch_constraints(self, data: dict) -> DBQueryConstraints: """ Apply constraints to a single source, using q3c - :param source: Source + :param data: Dictionary containing source data :return: DBQueryConstraints """ query_constraints = DBQueryConstraints() query_constraints.add_q3c_constraint( - ra=source["ra"], - dec=source["dec"], + ra=data["ra"], + dec=data["dec"], ra_field_name=self.ra_field_name, dec_field_name=self.dec_field_name, crossmatch_radius_arcsec=self.xmatch_radius_arcsec, @@ -263,11 +311,33 @@ def get_source_crossmatch_constraints( return query_constraints - def get_constraints(self, source: pd.Series) -> DBQueryConstraints: - return self.get_source_crossmatch_constraints(source) + def get_constraints(self, data: dict) -> DBQueryConstraints: + return self.get_source_crossmatch_constraints(data) + + +class SingleSpatialCrossmatchSource( + BaseSpatialCrossmatchSource, DatabaseSingleMatchSelector +): + """ + Processor to import a single source from a database using spatial crossmatch + """ + + +class SpatialCrossmatchSourceWithDatabase( + BaseSpatialCrossmatchSource, DatabaseMultimatchSelector +): + """ + Processor to import multiple sources from a database using spatial crossmatch + """ + +class SelectSourcesWithMetadata(DatabaseMultimatchSelector, BaseValuesCrossmatch): + """ + Processor to import sources from a database using metadata values + """ -class DatabaseHistorySelector(CrossmatchSourceWithDatabase): + +class DatabaseHistorySelector(SpatialCrossmatchSourceWithDatabase): """ Processor to import previous detections of a source from a database """ @@ -284,16 +354,16 @@ def __init__( self.output_df_colname = SOURCE_HISTORY_KEY logger.info(f"Update db is {self.update_dataframe}") - def get_constraints(self, source: pd.Series) -> DBQueryConstraints: - query_constraints = self.get_source_crossmatch_constraints(source) + def get_constraints(self, data: dict) -> DBQueryConstraints: + query_constraints = self.get_source_crossmatch_constraints(data) query_constraints.add_constraint( column=self.time_field_name, comparison_type="<", - accepted_values=source[self.time_field_name], + accepted_values=data[self.time_field_name], ) query_constraints.add_constraint( column=self.time_field_name, comparison_type=">=", - accepted_values=source[self.time_field_name] - self.history_duration_days, + accepted_values=data[self.time_field_name] - self.history_duration_days, ) return query_constraints diff --git a/mirar/processors/sources/namer.py b/mirar/processors/sources/namer.py index 75d5d66e8..33bcda327 100644 --- a/mirar/processors/sources/namer.py +++ b/mirar/processors/sources/namer.py @@ -4,15 +4,13 @@ import logging -import numpy as np +import pandas as pd from astropy.time import Time from sqlalchemy import select, text from mirar.data import SourceBatch from mirar.database.transactions.select import run_select -from mirar.paths import SOURCE_NAME_KEY, SOURCE_XMATCH_KEY, TIME_KEY -from mirar.processors.base_processor import PrerequisiteError -from mirar.processors.database import CrossmatchSourceWithDatabase +from mirar.paths import SOURCE_NAME_KEY, TIME_KEY from mirar.processors.database.database_selector import BaseDatabaseSourceSelector logger = logging.getLogger(__name__) @@ -119,40 +117,27 @@ def _apply_to_sources( for source_table in batch: sources = source_table.get_data() - assert ( - SOURCE_XMATCH_KEY in sources.columns - ), "No candidate cross-match in source table" - names = [] detection_time = Time(source_table[TIME_KEY]) for ind, source in sources.iterrows(): - if len(source[SOURCE_XMATCH_KEY]) > 0: - source_name = source[SOURCE_XMATCH_KEY][0][self.db_name_field] - else: + + source_name = None + + if SOURCE_NAME_KEY in source: + source_name = source[SOURCE_NAME_KEY] + + if pd.isnull(source_name): source_name = self.get_next_name( detection_time, last_name=self.lastname ) self.lastname = source_name - logger.debug(f"Assigning name: {source_name} to source # {ind}.") + logger.debug(f"Assigning name: {source_name} to source # {ind}.") + else: + logger.debug(f"Source # {ind} already has a name: {source_name}.") names.append(source_name) sources[self.db_name_field] = names source_table.set_data(sources) return batch - - def check_prerequisites( - self, - ): - check = np.sum( - [isinstance(x, CrossmatchSourceWithDatabase) for x in self.preceding_steps] - ) - if check < 1: - err = ( - f"{self.__module__} requires {CrossmatchSourceWithDatabase} " - f"as a prerequisite. " - f"However, the following steps were found: {self.preceding_steps}." - ) - logger.error(err) - raise PrerequisiteError(err)