diff --git a/sdk/python/feast/infra/registry_stores/sql.py b/sdk/python/feast/infra/registry_stores/sql.py index d34bb2fa8b..503aaf8688 100644 --- a/sdk/python/feast/infra/registry_stores/sql.py +++ b/sdk/python/feast/infra/registry_stores/sql.py @@ -472,7 +472,7 @@ def update_infra(self, infra: Infra, project: str, commit: bool = True): pass def get_infra(self, project: str, allow_cache: bool = False) -> Infra: - pass + return Infra() def apply_user_metadata( self, @@ -550,7 +550,8 @@ def proto(self) -> RegistryProto: (self.list_validation_references, r.validation_references), ]: objs: List[Any] = lister(project) # type: ignore - registry_proto_field.extend([obj.to_proto() for obj in objs]) + if objs: + registry_proto_field.extend([obj.to_proto() for obj in objs]) return r diff --git a/sdk/python/feast/registry.py b/sdk/python/feast/registry.py index e993533c8b..c8b00befc6 100644 --- a/sdk/python/feast/registry.py +++ b/sdk/python/feast/registry.py @@ -663,6 +663,75 @@ def commit(self): def refresh(self): """Refreshes the state of the registry cache by fetching the registry state from the remote registry store.""" + @staticmethod + def _message_to_sorted_dict(message: Message) -> Dict[str, Any]: + return json.loads(MessageToJson(message, sort_keys=True)) + + def to_dict(self, project: str) -> Dict[str, List[Any]]: + """Returns a dictionary representation of the registry contents for the specified project. + + For each list in the dictionary, the elements are sorted by name, so this + method can be used to compare two registries. + + Args: + project: Feast project to convert to a dict + """ + registry_dict: Dict[str, Any] = defaultdict(list) + registry_dict["project"] = project + for data_source in sorted( + self.list_data_sources(project=project), key=lambda ds: ds.name + ): + registry_dict["dataSources"].append( + self._message_to_sorted_dict(data_source.to_proto()) + ) + for entity in sorted( + self.list_entities(project=project), key=lambda entity: entity.name + ): + registry_dict["entities"].append( + self._message_to_sorted_dict(entity.to_proto()) + ) + for feature_view in sorted( + self.list_feature_views(project=project), + key=lambda feature_view: feature_view.name, + ): + registry_dict["featureViews"].append( + self._message_to_sorted_dict(feature_view.to_proto()) + ) + for feature_service in sorted( + self.list_feature_services(project=project), + key=lambda feature_service: feature_service.name, + ): + registry_dict["featureServices"].append( + self._message_to_sorted_dict(feature_service.to_proto()) + ) + for on_demand_feature_view in sorted( + self.list_on_demand_feature_views(project=project), + key=lambda on_demand_feature_view: on_demand_feature_view.name, + ): + odfv_dict = self._message_to_sorted_dict(on_demand_feature_view.to_proto()) + odfv_dict["spec"]["userDefinedFunction"]["body"] = dill.source.getsource( + on_demand_feature_view.udf + ) + registry_dict["onDemandFeatureViews"].append(odfv_dict) + for request_feature_view in sorted( + self.list_request_feature_views(project=project), + key=lambda request_feature_view: request_feature_view.name, + ): + registry_dict["requestFeatureViews"].append( + self._message_to_sorted_dict(request_feature_view.to_proto()) + ) + for saved_dataset in sorted( + self.list_saved_datasets(project=project), key=lambda item: item.name + ): + registry_dict["savedDatasets"].append( + self._message_to_sorted_dict(saved_dataset.to_proto()) + ) + for infra_object in sorted(self.get_infra(project=project).infra_objects): + registry_dict["infra"].append( + self._message_to_sorted_dict(infra_object.to_proto()) + ) + return registry_dict + class Registry(BaseRegistry): """ @@ -689,6 +758,18 @@ def get_user_metadata( cached_registry_proto_created: Optional[datetime] = None cached_registry_proto_ttl: timedelta + def __new__( + cls, registry_config: Optional[RegistryConfig], repo_path: Optional[Path] + ): + # We override __new__ so that we can inspect registry_config and create a SqlRegistry without callers + # needing to make any changes. + if registry_config and registry_config.registry_type == "sql": + from feast.infra.registry_stores.sql import SqlRegistry + + return SqlRegistry(registry_config, repo_path) + else: + return super(Registry, cls).__new__(cls) + def __init__( self, registry_config: Optional[RegistryConfig], repo_path: Optional[Path] ): @@ -1587,75 +1668,6 @@ def teardown(self): def proto(self) -> RegistryProto: return self.cached_registry_proto or RegistryProto() - def to_dict(self, project: str) -> Dict[str, List[Any]]: - """Returns a dictionary representation of the registry contents for the specified project. - - For each list in the dictionary, the elements are sorted by name, so this - method can be used to compare two registries. - - Args: - project: Feast project to convert to a dict - """ - registry_dict: Dict[str, Any] = defaultdict(list) - registry_dict["project"] = project - for data_source in sorted( - self.list_data_sources(project=project), key=lambda ds: ds.name - ): - registry_dict["dataSources"].append( - self._message_to_sorted_dict(data_source.to_proto()) - ) - for entity in sorted( - self.list_entities(project=project), key=lambda entity: entity.name - ): - registry_dict["entities"].append( - self._message_to_sorted_dict(entity.to_proto()) - ) - for feature_view in sorted( - self.list_feature_views(project=project), - key=lambda feature_view: feature_view.name, - ): - registry_dict["featureViews"].append( - self._message_to_sorted_dict(feature_view.to_proto()) - ) - for feature_service in sorted( - self.list_feature_services(project=project), - key=lambda feature_service: feature_service.name, - ): - registry_dict["featureServices"].append( - self._message_to_sorted_dict(feature_service.to_proto()) - ) - for on_demand_feature_view in sorted( - self.list_on_demand_feature_views(project=project), - key=lambda on_demand_feature_view: on_demand_feature_view.name, - ): - odfv_dict = self._message_to_sorted_dict(on_demand_feature_view.to_proto()) - odfv_dict["spec"]["userDefinedFunction"]["body"] = dill.source.getsource( - on_demand_feature_view.udf - ) - registry_dict["onDemandFeatureViews"].append(odfv_dict) - for request_feature_view in sorted( - self.list_request_feature_views(project=project), - key=lambda request_feature_view: request_feature_view.name, - ): - registry_dict["requestFeatureViews"].append( - self._message_to_sorted_dict(request_feature_view.to_proto()) - ) - for saved_dataset in sorted( - self.list_saved_datasets(project=project), key=lambda item: item.name - ): - registry_dict["savedDatasets"].append( - self._message_to_sorted_dict(saved_dataset.to_proto()) - ) - for infra_object in sorted(self.get_infra(project=project).infra_objects): - registry_dict["infra"].append( - self._message_to_sorted_dict(infra_object.to_proto()) - ) - return registry_dict - - @staticmethod - def _message_to_sorted_dict(message: Message) -> Dict[str, Any]: - return json.loads(MessageToJson(message, sort_keys=True)) - def _prepare_registry_for_changes(self): """Prepares the Registry for changes by refreshing the cache if necessary.""" try: diff --git a/sdk/python/tests/integration/registration/test_sql_registry.py b/sdk/python/tests/integration/registration/test_sql_registry.py index c96d83ce0a..1fe9ff5cec 100644 --- a/sdk/python/tests/integration/registration/test_sql_registry.py +++ b/sdk/python/tests/integration/registration/test_sql_registry.py @@ -85,7 +85,7 @@ def mysql_registry(): log_string_to_wait_for = "/usr/sbin/mysqld: ready for connections. Version: '8.0.29' socket: '/var/run/mysqld/mysqld.sock' port: 3306" waited = wait_for_logs( - container=container, predicate=log_string_to_wait_for, timeout=30, interval=10, + container=container, predicate=log_string_to_wait_for, timeout=60, interval=10, ) logger.info("Waited for %s seconds until mysql container was up", waited) container_port = container.get_exposed_port(3306)