Skip to content

Commit

Permalink
Added seeds to lineage.
Browse files Browse the repository at this point in the history
  • Loading branch information
elongl committed Sep 5, 2024
1 parent 917432c commit 103ea5e
Show file tree
Hide file tree
Showing 9 changed files with 112 additions and 30 deletions.
2 changes: 2 additions & 0 deletions elementary/monitor/api/groups/groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from elementary.monitor.api.models.schema import (
NormalizedExposureSchema,
NormalizedModelSchema,
NormalizedSeedSchema,
NormalizedSourceSchema,
)
from elementary.monitor.fetchers.tests.schema import NormalizedTestSchema
Expand All @@ -28,6 +29,7 @@
NormalizedSourceSchema,
NormalizedExposureSchema,
NormalizedTestSchema,
NormalizedSeedSchema,
]


Expand Down
13 changes: 1 addition & 12 deletions elementary/monitor/api/lineage/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from elementary.utils.pydantic_shim import BaseModel, validator

NodeUniqueIdType = str
NodeType = Literal["model", "source", "exposure"]
NodeType = Literal["seed", "model", "source", "exposure"]
NodeSubType = Literal["table", "view"]


Expand Down Expand Up @@ -51,15 +51,4 @@ class NodeDependsOnNodesSchema(BaseModel):
@validator("depends_on_nodes", pre=True, always=True)
def set_depends_on_nodes(cls, depends_on_nodes):
formatted_depends_on = depends_on_nodes or []
formatted_depends_on = [
cls._format_node_id(node_id) for node_id in formatted_depends_on
]
return [node_id for node_id in formatted_depends_on if node_id]

@classmethod
def _format_node_id(cls, node_id: str):
# Currently we don't save seeds in our artifacts.
# We remove seeds from the lineage graph (as long as we don't support them).
if re.search(_SEED_PATH_PATTERN, node_id):
return None
return node_id
45 changes: 34 additions & 11 deletions elementary/monitor/api/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
import statistics
from collections import defaultdict
from typing import Dict, List, Optional, Set, Union, overload
from typing import Dict, List, Optional, Set, Union, cast, overload

from elementary.clients.api.api_client import APIClient
from elementary.clients.dbt.base_dbt_runner import BaseDbtRunner
Expand All @@ -13,6 +13,7 @@
ModelRunsWithTotalsSchema,
NormalizedExposureSchema,
NormalizedModelSchema,
NormalizedSeedSchema,
NormalizedSourceSchema,
TotalsModelRunsSchema,
)
Expand All @@ -22,14 +23,19 @@
from elementary.monitor.fetchers.models.schema import (
ModelRunSchema as FetcherModelRunSchema,
)
from elementary.monitor.fetchers.models.schema import ModelSchema, SourceSchema
from elementary.monitor.fetchers.models.schema import (
ModelSchema,
SeedSchema,
SourceSchema,
)
from elementary.utils.log import get_logger

logger = get_logger(__name__)


class ModelsAPI(APIClient):
_ARTIFACT_TYPE_DIR_MAP = {
SeedSchema: "seeds",
SourceSchema: "sources",
ModelSchema: "models",
ExposureSchema: "exposures",
Expand Down Expand Up @@ -117,6 +123,16 @@ def _get_model_runs_totals(
success_runs = len([run for run in runs if run.status == "success"])
return TotalsModelRunsSchema(errors=error_runs, success=success_runs)

def get_seeds(self) -> Dict[str, NormalizedSeedSchema]:
seed_results = self.models_fetcher.get_seeds()
seeds = dict()
if seed_results:
for seed_result in seed_results:
normalized_seed = self._normalize_dbt_artifact_dict(seed_result)
seed_unique_id = cast(str, normalized_seed.unique_id)
seeds[seed_unique_id] = normalized_seed
return seeds

def get_models(
self, exclude_elementary_models: bool = False
) -> Dict[str, NormalizedModelSchema]:
Expand All @@ -127,12 +143,7 @@ def get_models(
if models_results:
for model_result in models_results:
normalized_model = self._normalize_dbt_artifact_dict(model_result)

model_unique_id = normalized_model.unique_id
if model_unique_id is None:
# Shouldn't happen, but handling this case for mypy
continue

model_unique_id = cast(str, normalized_model.unique_id)
models[model_unique_id] = normalized_model
return models

Expand Down Expand Up @@ -222,6 +233,12 @@ def _exposure_has_upstream_node(
for dep in exposure.depends_on_nodes
)

@overload
def _normalize_dbt_artifact_dict(
self, artifact: SeedSchema
) -> NormalizedSeedSchema:
...

@overload
def _normalize_dbt_artifact_dict(
self, artifact: ModelSchema
Expand All @@ -241,9 +258,15 @@ def _normalize_dbt_artifact_dict(
...

def _normalize_dbt_artifact_dict(
self, artifact: Union[ModelSchema, ExposureSchema, SourceSchema]
) -> Union[NormalizedModelSchema, NormalizedExposureSchema, NormalizedSourceSchema]:
self, artifact: Union[SeedSchema, ModelSchema, ExposureSchema, SourceSchema]
) -> Union[
NormalizedSeedSchema,
NormalizedModelSchema,
NormalizedExposureSchema,
NormalizedSourceSchema,
]:
schema_to_normalized_schema_map = {
SeedSchema: NormalizedSeedSchema,
ExposureSchema: NormalizedExposureSchema,
ModelSchema: NormalizedModelSchema,
SourceSchema: NormalizedSourceSchema,
Expand Down Expand Up @@ -285,7 +308,7 @@ def _normalize_artifact_path(cls, artifact: ArtifactSchemaType, fqn: str) -> str
@classmethod
def _fqn(
cls,
artifact: Union[ModelSchema, ExposureSchema, SourceSchema],
artifact: Union[ModelSchema, ExposureSchema, SourceSchema, SeedSchema],
) -> str:
if isinstance(artifact, ExposureSchema):
path = (artifact.meta or {}).get("path")
Expand Down
6 changes: 6 additions & 0 deletions elementary/monitor/api/models/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from elementary.monitor.fetchers.models.schema import (
ExposureSchema,
ModelSchema,
SeedSchema,
SourceSchema,
)
from elementary.utils.pydantic_shim import BaseModel, Field, validator
Expand Down Expand Up @@ -35,6 +36,11 @@ def format_normalized_full_path_sep(cls, normalized_full_path: str) -> str:
return posixpath.sep.join(normalized_full_path.split(os.path.sep))


# NormalizedArtifactSchema must be first in the inheritance order
class NormalizedSeedSchema(NormalizedArtifactSchema, SeedSchema):
artifact_type: str = Field("seed", const=True) # type: ignore # noqa


# NormalizedArtifactSchema must be first in the inheritance order
class NormalizedModelSchema(NormalizedArtifactSchema, ModelSchema):
artifact_type: str = Field("model", const=True) # type: ignore # noqa
Expand Down
19 changes: 15 additions & 4 deletions elementary/monitor/api/report/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
ModelRunsSchema,
NormalizedExposureSchema,
NormalizedModelSchema,
NormalizedSeedSchema,
NormalizedSourceSchema,
)
from elementary.monitor.api.report.schema import ReportDataEnvSchema, ReportDataSchema
Expand Down Expand Up @@ -41,11 +42,12 @@ def _get_groups(
models: Iterable[NormalizedModelSchema],
sources: Iterable[NormalizedSourceSchema],
exposures: Iterable[NormalizedExposureSchema],
seeds: Iterable[NormalizedSeedSchema],
singular_tests: Iterable[NormalizedTestSchema],
) -> GroupsSchema:
groups_api = GroupsAPI(self.dbt_runner)
return groups_api.get_groups(
artifacts=[*models, *sources, *exposures, *singular_tests]
artifacts=[*models, *sources, *exposures, *seeds, *singular_tests]
)

def get_report_data(
Expand Down Expand Up @@ -78,6 +80,8 @@ def get_report_data(
invocations_api = InvocationsAPI(dbt_runner=self.dbt_runner)

lineage_node_ids: List[str] = []
seeds = models_api.get_seeds()
lineage_node_ids.extend(seeds.keys())
models = models_api.get_models(exclude_elementary_models)
lineage_node_ids.extend(models.keys())
sources = models_api.get_sources()
Expand All @@ -87,7 +91,11 @@ def get_report_data(
singular_tests = tests_api.get_singular_tests()

groups = self._get_groups(
models.values(), sources.values(), exposures.values(), singular_tests
models.values(),
sources.values(),
exposures.values(),
seeds.values(),
singular_tests,
)

models_runs = models_api.get_models_runs(
Expand Down Expand Up @@ -133,7 +141,9 @@ def get_report_data(
)

serializable_groups = groups.dict()
serializable_models = self._serialize_models(models, sources, exposures)
serializable_models = self._serialize_models(
models, sources, exposures, seeds
)
serializable_model_runs = self._serialize_models_runs(models_runs.runs)
serializable_model_runs_totals = models_runs.dict(include={"totals"})[
"totals"
Expand Down Expand Up @@ -191,8 +201,9 @@ def _serialize_models(
models: Dict[str, NormalizedModelSchema],
sources: Dict[str, NormalizedSourceSchema],
exposures: Dict[str, NormalizedExposureSchema],
seeds: Dict[str, NormalizedSeedSchema],
) -> Dict[str, dict]:
nodes = dict(**models, **sources, **exposures)
nodes = dict(**models, **sources, **exposures, **seeds)
serializable_nodes = dict()
for key in nodes.keys():
serializable_nodes[key] = dict(nodes[key])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,13 @@
{% endif %}

{% set models_depends_on_nodes_query %}
select
unique_id,
null as depends_on_nodes,
null as materialization,
'seed' as type
from {{ ref('elementary', 'dbt_seeds') }}
union all
select
unique_id,
depends_on_nodes,
Expand Down
27 changes: 27 additions & 0 deletions elementary/monitor/dbt_project/macros/get_seeds.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
{% macro get_seeds(exclude_elementary=false) %}
{% set dbt_seeds_relation = ref('elementary', 'dbt_seeds') %}
{%- if elementary.relation_exists(dbt_seeds_relation) -%}

{% set get_seeds_query %}
with dbt_artifacts_seeds as (
select
name,
unique_id,
database_name,
schema_name,
case when alias is not null then alias
else name end as table_name,
owner as owners,
tags,
package_name,
description,
original_path as full_path
from {{ dbt_seeds_relation }}
)

select * from dbt_artifacts_seeds
{% endset %}
{% set seeds_agate = run_query(get_seeds_query) %}
{% do return(elementary.agate_to_dicts(seeds_agate)) %}
{%- endif -%}
{% endmacro %}
17 changes: 14 additions & 3 deletions elementary/monitor/fetchers/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
ModelRunSchema,
ModelSchema,
ModelTestCoverage,
SeedSchema,
SourceSchema,
)
from elementary.utils.log import get_logger
Expand Down Expand Up @@ -34,6 +35,14 @@ def get_models_runs(
model_runs = [ModelRunSchema(**model_run) for model_run in model_run_dicts]
return model_runs

def get_seeds(self) -> List[SeedSchema]:
run_operation_response = self.dbt_runner.run_operation(
macro_name="elementary_cli.get_seeds",
)
seeds = json.loads(run_operation_response[0]) if run_operation_response else []
seeds = [SeedSchema(**seed) for seed in seeds]
return seeds

def get_models(self, exclude_elementary_models: bool = False) -> List[ModelSchema]:
run_operation_response = self.dbt_runner.run_operation(
macro_name="elementary_cli.get_models",
Expand Down Expand Up @@ -63,9 +72,11 @@ def get_exposures(self) -> List[ExposureSchema]:
exposures = [
{
**exposure,
"raw_queries": json.loads(exposure["raw_queries"])
if exposure.get("raw_queries")
else None,
"raw_queries": (
json.loads(exposure["raw_queries"])
if exposure.get("raw_queries")
else None
),
}
for exposure in exposures
]
Expand Down
6 changes: 6 additions & 0 deletions elementary/monitor/fetchers/models/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,12 @@ def load_meta(cls, meta):
ArtifactSchemaType = TypeVar("ArtifactSchemaType", bound=ArtifactSchema)


class SeedSchema(ArtifactSchema):
database_name: Optional[str] = None
schema_name: str
table_name: str


class ModelSchema(ArtifactSchema):
database_name: Optional[str] = None
schema_name: str
Expand Down

0 comments on commit 103ea5e

Please sign in to comment.