Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate schema Tags to merlin.schema.Tags #632

Merged
merged 11 commits into from
Mar 9, 2023
2,114 changes: 1,057 additions & 1,057 deletions examples/getting-started-session-based/01-ETL-with-NVTabular.ipynb

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,13 @@
log_parameters,
)
from merlin.io import Dataset
from merlin.schema import Tags
from transf_exp_args import DataArguments, ModelArguments, TrainingArguments
from transformers import HfArgumentParser, set_seed
from transformers.trainer_utils import is_main_process

import transformers4rec.torch as t4r
from merlin_standard_lib import Schema, Tag
from merlin_standard_lib import Schema
from transformers4rec.torch import Trainer
from transformers4rec.torch.utils.data_utils import MerlinDataLoader
from transformers4rec.torch.utils.examples_utils import wipe_memory
Expand All @@ -54,9 +55,9 @@ def main():
# Loading the schema of the dataset
schema = Schema().from_proto_text(data_args.features_schema_path)
if not data_args.use_side_information_features:
schema = schema.select_by_tag(Tag.ITEM_ID)
schema = schema.select_by_tag([Tags.ITEM_ID])

item_id_col = schema.select_by_tag(Tag.ITEM_ID).column_names[0]
item_id_col = schema.select_by_tag([Tags.ITEM_ID]).column_names[0]
col_names = schema.column_names
logger.info("Column names: {}".format(col_names))

Expand Down
2 changes: 0 additions & 2 deletions merlin_standard_lib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

from .schema import schema
from .schema.schema import ColumnSchema, Schema, categorical_cardinalities
from .schema.tag import Tag

# Other monkey-patching
Message.HasField = proto_utils.has_field # type: ignore
Expand All @@ -29,7 +28,6 @@
"ColumnSchema",
"Schema",
"schema",
"Tag",
"Registry",
"RegistryMixin",
"categorical_cardinalities",
Expand Down
38 changes: 19 additions & 19 deletions merlin_standard_lib/schema/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,9 @@

from google.protobuf import json_format, text_format
from google.protobuf.message import Message as ProtoMessage
from merlin.schema import Tags, TagSet, TagsType
from merlin.schema.io import proto_utils

from .tag import TagsType

try:
from functools import cached_property # type: ignore
except ImportError:
Expand Down Expand Up @@ -79,11 +78,11 @@ def create_categorical(
tags: Optional[TagsType] = None,
**kwargs,
) -> "ColumnSchema":
_tags: List[str] = [str(t) for t in tags] if tags else []
_tags: List[str] = [t.value for t in TagSet(tags or [])]

extra = _parse_shape_and_value_count(shape, value_count)
int_domain = IntDomain(name=name, min=min_index, max=num_items, is_categorical=True)
_tags = list(set(_tags + ["categorical"]))
_tags = list(set(_tags + [Tags.CATEGORICAL.value]))
extra["type"] = FeatureType.INT

return cls(name=name, int_domain=int_domain, **extra, **kwargs).with_tags(_tags)
Expand All @@ -103,7 +102,7 @@ def create_continuous(
tags: Optional[TagsType] = None,
**kwargs,
) -> "ColumnSchema":
_tags: List[str] = [str(t) for t in tags] if tags else []
_tags: List[str] = [t.value for t in TagSet(tags or [])]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is something I don't love - _tags is a List[str] not a TagsType. We convert actual Tags to their string value, but I don't think we convert them back anywhere. I need to understand it a little more now that the tests are passing (hopefully)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. I think we converted to strings everywhere to make it easier for users to provide custom tags as strings. Not completely sure how we handle that currently across Merlin.


extra = _parse_shape_and_value_count(shape, value_count)
if min_value is not None and max_value is not None:
Expand All @@ -121,7 +120,7 @@ def create_continuous(
name=name, min=int(min_value), max=int(max_value), is_categorical=False
)
extra["type"] = FeatureType.FLOAT if is_float else FeatureType.INT
_tags = list(set(_tags + ["continuous"]))
_tags = list(set(_tags + [Tags.CONTINUOUS.value]))

return cls(name=name, **extra, **kwargs).with_tags(_tags)

Expand All @@ -144,21 +143,19 @@ def with_tags(self, tags: TagsType) -> "ColumnSchema":
def with_tags_based_on_properties(
self, using_value_count=True, using_domain=True
) -> "ColumnSchema":
from .tag import Tag

extra_tags = []

if using_value_count and proto_utils.has_field(self, "value_count"):
extra_tags.append(str(Tag.LIST))
extra_tags.append(str(Tags.LIST))

if using_domain and proto_utils.has_field(self, "int_domain"):
if self.int_domain.is_categorical:
extra_tags.append(str(Tag.CATEGORICAL))
extra_tags.append(str(Tags.CATEGORICAL))
else:
extra_tags.append(str(Tag.CONTINUOUS))
extra_tags.append(str(Tags.CONTINUOUS))

if using_domain and proto_utils.has_field(self, "float_domain"):
extra_tags.append(str(Tag.CONTINUOUS))
extra_tags.append(str(Tags.CONTINUOUS))

return self.with_tags(extra_tags) if extra_tags else self.copy()

Expand Down Expand Up @@ -306,14 +303,17 @@ def select_by_tag(self, to_select) -> "Schema":
if not isinstance(to_select, (list, tuple)) and not callable(to_select):
to_select = [to_select]

def collection_filter_fn(column_tags):
return all(x in column_tags for x in to_select)
if callable(to_select):
return self._filter_column_schemas(to_select, lambda x: False, lambda x: x.tags)
else:
# Schema.tags always returns a List[str] with the tag values, so if the user wants to
# filter using the Tags Enum, we need to convert those to their string value
to_select = [tag.value if isinstance(tag, Tags) else tag for tag in to_select]

output: Schema = self._filter_column_schemas(
to_select, collection_filter_fn, lambda x: x.tags
)
def collection_filter_fn(column_names: List[str]):
return all(x in column_names for x in to_select)

return output
return self._filter_column_schemas(to_select, collection_filter_fn, lambda x: x.tags)

def remove_by_tag(self, to_remove) -> "Schema":
if not isinstance(to_remove, (list, tuple)) and not callable(to_remove):
Expand Down Expand Up @@ -377,7 +377,7 @@ def column_schemas(self) -> Sequence[ColumnSchema]:

@cached_property
def item_id_column_name(self):
item_id_col = self.select_by_tag("item_id")
item_id_col = self.select_by_tag(Tags.ITEM_ID)
if len(item_id_col.column_names) == 0:
raise ValueError("There is no column tagged as item id.")

Expand Down
52 changes: 0 additions & 52 deletions merlin_standard_lib/schema/tag.py

This file was deleted.

7 changes: 4 additions & 3 deletions tests/config/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,16 @@
# limitations under the License.
#
import pytest
from merlin.schema import Tags

from merlin_standard_lib import Tag, categorical_cardinalities
from merlin_standard_lib import categorical_cardinalities
from merlin_standard_lib.utils.embedding_utils import get_embedding_sizes_from_schema


def test_schema_from_yoochoose_schema(yoochoose_schema):
assert len(yoochoose_schema.column_names) == 22
assert len(yoochoose_schema.select_by_tag(Tag.CONTINUOUS).column_schemas) == 11
assert len(yoochoose_schema.select_by_tag(Tag.CATEGORICAL).column_schemas) == 3
assert len(yoochoose_schema.select_by_tag(Tags.CONTINUOUS).column_schemas) == 11
assert len(yoochoose_schema.select_by_tag(Tags.CATEGORICAL).column_schemas) == 3


def test_schema_cardinalities(yoochoose_schema):
Expand Down
26 changes: 16 additions & 10 deletions tests/data/test_synthetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,18 @@ def test_generate_item_interactions():

assert isinstance(data, pd.DataFrame)
assert len(data) == 500
assert list(data.columns) == [
"session_id",
"item_id",
"day",
"purchase",
"price",
"category",
"item_recency",
]
print(sorted(list(data.columns)))
assert sorted(list(data.columns)) == sorted(
[
"session_id",
"item_id",
"day",
"purchase",
"price",
"category",
"item_recency",
]
)
expected_dtypes = {
"session_id": "int64",
"item_id": "int64",
Expand All @@ -31,5 +34,8 @@ def test_generate_item_interactions():
"category": "int64",
"item_recency": "float64",
}
for key, val in dict(data.dtypes).items():
print((key, val))
marcromeyn marked this conversation as resolved.
Show resolved Hide resolved
assert val == expected_dtypes[key]

assert all(val == expected_dtypes[key] for key, val in dict(data.dtypes).items())
# assert all(val == expected_dtypes[key] for key, val in dict(data.dtypes).items())
marcromeyn marked this conversation as resolved.
Show resolved Hide resolved
26 changes: 0 additions & 26 deletions tests/merlin_standard_lib/schema/test_tag.py

This file was deleted.

5 changes: 3 additions & 2 deletions tests/torch/features/test_continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@
# limitations under the License.
#

from merlin.schema import Tags

import transformers4rec.torch as tr
from merlin_standard_lib import Tag


def test_continuous_features(torch_con_features):
Expand All @@ -27,7 +28,7 @@ def test_continuous_features(torch_con_features):

def test_continuous_features_yoochoose(yoochoose_schema, torch_yoochoose_like):
schema = yoochoose_schema
cont_cols = schema.select_by_tag(Tag.CONTINUOUS)
cont_cols = schema.select_by_tag(Tags.CONTINUOUS)

con = tr.ContinuousFeatures.from_schema(cont_cols)
outputs = con(torch_yoochoose_like)
Expand Down
13 changes: 6 additions & 7 deletions tests/torch/features/test_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@
import numpy as np
import pytest
import torch
from merlin.schema import Tags

import transformers4rec.torch as tr
from merlin_standard_lib import Tag


def test_embedding_features(torch_cat_features):
Expand Down Expand Up @@ -75,8 +75,7 @@ def test_table_config_invalid_embedding_initializer():


def test_embedding_features_yoochoose(yoochoose_schema, torch_yoochoose_like):
schema = yoochoose_schema.select_by_tag(Tag.CATEGORICAL)

schema = yoochoose_schema.select_by_tag(Tags.CATEGORICAL)
emb_module = tr.EmbeddingFeatures.from_schema(schema)
embeddings = emb_module(torch_yoochoose_like)

Expand All @@ -89,7 +88,7 @@ def test_embedding_features_yoochoose(yoochoose_schema, torch_yoochoose_like):


def test_embedding_features_yoochoose_custom_dims(yoochoose_schema, torch_yoochoose_like):
schema = yoochoose_schema.select_by_tag(Tag.CATEGORICAL)
schema = yoochoose_schema.select_by_tag(Tags.CATEGORICAL)

emb_module = tr.EmbeddingFeatures.from_schema(
schema, embedding_dims={"item_id/list": 100}, embedding_dim_default=64
Expand All @@ -105,7 +104,7 @@ def test_embedding_features_yoochoose_custom_dims(yoochoose_schema, torch_yoocho


def test_embedding_features_yoochoose_infer_embedding_sizes(yoochoose_schema, torch_yoochoose_like):
schema = yoochoose_schema.select_by_tag(Tag.CATEGORICAL)
schema = yoochoose_schema.select_by_tag(Tags.CATEGORICAL)

emb_module = tr.EmbeddingFeatures.from_schema(
schema, infer_embedding_sizes=True, infer_embedding_sizes_multiplier=3.0
Expand All @@ -127,7 +126,7 @@ def test_embedding_features_yoochoose_custom_initializers(yoochoose_schema, torc
CATEGORY_MEAN = 2.0
CATEGORY_STD = 0.1

schema = yoochoose_schema.select_by_tag(Tag.CATEGORICAL)
schema = yoochoose_schema.select_by_tag(Tags.CATEGORICAL)
emb_module = tr.EmbeddingFeatures.from_schema(
schema,
layer_norm=False,
Expand Down Expand Up @@ -158,7 +157,7 @@ def test_pre_trained_embeddings_initializer(yoochoose_schema, torch_yoochoose_li
embedding_dim = 64
pre_trained_item_embeddings = np.random.rand(item_id_cardinality, embedding_dim)

schema = yoochoose_schema.select_by_tag(Tag.CATEGORICAL)
schema = yoochoose_schema.select_by_tag(Tags.CATEGORICAL)
emb_module = tr.EmbeddingFeatures.from_schema(
schema,
embedding_dims={"item_id/list": embedding_dim},
Expand Down
Loading