Skip to content

Commit

Permalink
Merge branch 'feature/cardinality_callback_nn' into 'main'
Browse files Browse the repository at this point in the history
Add cardinality callback for data nn interfaces

See merge request ai-lab-pmo/mltools/recsys/RePlay!140
  • Loading branch information
OnlyDeniko committed Nov 30, 2023
2 parents a7e1af6 + dcfbd1b commit 96b0055
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 2 deletions.
13 changes: 13 additions & 0 deletions replay/data/nn/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
Set,
Union,
ValuesView,
Callable
)

import torch
Expand Down Expand Up @@ -201,11 +202,23 @@ def cardinality(self) -> Optional[int]:
raise RuntimeError(
f"Can not get cardinality because feature type of {self.name} column is not categorical."
)
if hasattr(self, "_cardinality_callback") and self._cardinality is None:
self._set_cardinality(self._cardinality_callback(self._name))
return self._cardinality

# pylint: disable=attribute-defined-outside-init
def _set_cardinality_callback(self, callback: Callable) -> None:
self._cardinality_callback = callback

def _set_cardinality(self, cardinality: int) -> None:
self._cardinality = cardinality

def reset_cardinality(self) -> None:
"""
Reset cardinality of the feature to None.
"""
self._cardinality = None

@property
def tensor_dim(self) -> Optional[int]:
"""
Expand Down
11 changes: 11 additions & 0 deletions replay/data/nn/sequential_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np
from pandas import DataFrame as PandasDataFrame

from replay.data.schema import FeatureType
from replay.data.nn.schema import TensorSchema


Expand Down Expand Up @@ -87,9 +88,19 @@ def __init__(

self._sequences = sequences

for feature in tensor_schema.all_features:
if feature.feature_type == FeatureType.CATEGORICAL:
# pylint: disable=protected-access
feature._set_cardinality_callback(self.cardinality_callback)

def __len__(self) -> int:
return len(self._sequences)

def cardinality_callback(self, column: str) -> int:
if self._query_id_column == column:
return self._sequences.index.nunique()
return len({x for seq in self._sequences[column] for x in seq})

def get_query_id(self, index: int) -> int:
return self._sequences.index[index]

Expand Down
26 changes: 24 additions & 2 deletions tests/data/nn/test_sequential_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@

pytest.importorskip("torch")

from replay.data import FeatureHint
from replay.data import FeatureHint, FeatureType
from replay.utils import TORCH_AVAILABLE

if TORCH_AVAILABLE:
from replay.data.nn import PandasSequentialDataset
from replay.data.nn import PandasSequentialDataset, TensorSchema, TensorFeatureInfo
from replay.experimental.nn.data.schema_builder import TensorSchemaBuilder


Expand All @@ -19,6 +19,28 @@ def test_can_create_sequential_dataset_with_valid_schema(sequential_info):
PandasSequentialDataset(**sequential_info)


@pytest.mark.torch
def test_callback_for_cardinality(sequential_info):
schema = TensorSchema(
[
TensorFeatureInfo("user_id", feature_type=FeatureType.CATEGORICAL, is_seq=True),
TensorFeatureInfo("item_id", feature_type=FeatureType.CATEGORICAL, is_seq=True),
TensorFeatureInfo("some_user_feature", feature_type=FeatureType.CATEGORICAL),
TensorFeatureInfo("some_item_feature", feature_type=FeatureType.CATEGORICAL, is_seq=True),
]
)

for f in schema.all_features:
assert f.cardinality is None

PandasSequentialDataset(schema, "user_id", "item_id", sequential_info["sequences"])

assert schema.all_features[0].cardinality == 4
assert schema.all_features[1].cardinality == 6
assert schema.all_features[2].cardinality == 4
assert schema.all_features[3].cardinality == 6


@pytest.mark.torch
def test_cannot_create_sequential_dataset_with_invalid_schema(sequential_info):
corrupted_sequences = sequential_info["sequences"].drop(columns=["some_item_feature"])
Expand Down

0 comments on commit 96b0055

Please sign in to comment.