diff --git a/sdk/python/feast/infra/online_stores/redis.py b/sdk/python/feast/infra/online_stores/redis.py index bb85a8e853..a226c5cd18 100644 --- a/sdk/python/feast/infra/online_stores/redis.py +++ b/sdk/python/feast/infra/online_stores/redis.py @@ -102,7 +102,7 @@ def _parse_connection_string(connection_string: str): params = {} for c in connection_string.split(","): if "=" in c: - kv = c.split("=") + kv = c.split("=", 1) try: kv[1] = json.loads(kv[1]) except json.JSONDecodeError: diff --git a/sdk/python/tests/test_cli_redis.py b/sdk/python/tests/test_cli_redis.py index e948bffd25..7aeafad83f 100644 --- a/sdk/python/tests/test_cli_redis.py +++ b/sdk/python/tests/test_cli_redis.py @@ -5,6 +5,7 @@ from textwrap import dedent import pytest +import redis from feast.feature_store import FeatureStore from tests.cli_utils import CliRunner @@ -58,3 +59,47 @@ def test_basic() -> None: result = runner.run(["teardown"], cwd=repo_path) assert result.returncode == 0 + + +@pytest.mark.integration +def test_connection_error() -> None: + project_id = "".join( + random.choice(string.ascii_lowercase + string.digits) for _ in range(10) + ) + runner = CliRunner() + with tempfile.TemporaryDirectory() as repo_dir_name, tempfile.TemporaryDirectory() as data_dir_name: + + repo_path = Path(repo_dir_name) + data_path = Path(data_dir_name) + + repo_config = repo_path / "feature_store.yaml" + + repo_config.write_text( + dedent( + f""" + project: {project_id} + registry: {data_path / "registry.db"} + provider: local + offline_store: + type: file + online_store: + type: redis + connection_string: localhost:6379,db=0= + """ + ) + ) + + repo_example = repo_path / "example.py" + repo_example.write_text( + (Path(__file__).parent / "example_feature_repo_2.py").read_text() + ) + + result = runner.run(["apply"], cwd=repo_path) + assert result.returncode == 0 + + # Redis does not support names for its databases. + with pytest.raises(redis.exceptions.ResponseError): + basic_rw_test( + FeatureStore(repo_path=str(repo_path), config=None), + view_name="driver_hourly_stats", + )