Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Cassandra online store, concurrent fetching for multiple entities #3356

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/reference/online-stores/cassandra.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ online_store:
load_balancing: # optional
local_dc: 'datacenter1' # optional
load_balancing_policy: 'TokenAwarePolicy(DCAwareRoundRobinPolicy)' # optional
read_concurrency: 100 # optional
```
{% endcode %}

Expand All @@ -52,7 +53,7 @@ online_store:
load_balancing: # optional
local_dc: 'eu-central-1' # optional
load_balancing_policy: 'TokenAwarePolicy(DCAwareRoundRobinPolicy)' # optional

read_concurrency: 100 # optional
```
{% endcode %}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ online_store:
load_balancing: # optional
local_dc: 'datacenter1' # optional
load_balancing_policy: 'TokenAwarePolicy(DCAwareRoundRobinPolicy)' # optional
read_concurrency: 100 # optional
```

#### Astra DB setup:
Expand All @@ -84,6 +85,7 @@ online_store:
load_balancing: # optional
local_dc: 'eu-central-1' # optional
load_balancing_policy: 'TokenAwarePolicy(DCAwareRoundRobinPolicy)' # optional
read_concurrency: 100 # optional
```

#### Protocol version and load-balancing settings
Expand Down Expand Up @@ -111,6 +113,14 @@ The former parameter is a region name for Astra DB instances (as can be verified
See the source code of the online store integration for the allowed values of
the latter parameter.

#### Read concurrency value

You can optionally specify the value of `read_concurrency`, which will be
passed to the Cassandra driver function handling
[concurrent reading of multiple entities](https://docs.datastax.com/en/developer/python-driver/3.25/api/cassandra/concurrent/#module-cassandra.concurrent).
Consult the reference for guidance on this parameter (which in most cases can be left to its default value of 100).
This is relevant only for retrieval of several entities at once.

### More info

For a more detailed walkthrough, please see the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
ResultSet,
Session,
)
from cassandra.concurrent import execute_concurrent_with_args
from cassandra.policies import DCAwareRoundRobinPolicy, TokenAwarePolicy
from cassandra.query import PreparedStatement
from pydantic import StrictFloat, StrictInt, StrictStr
Expand Down Expand Up @@ -166,6 +167,14 @@ class CassandraLoadBalancingPolicy(FeastConfigBaseModel):
wrapped into an execution profile if present.
"""

read_concurrency: Optional[StrictInt] = 100
"""
Value of the `concurrency` parameter internally passed to Cassandra driver's
`execute_concurrent_with_args ` call.
See https://docs.datastax.com/en/developer/python-driver/3.25/api/cassandra/concurrent/#module-cassandra.concurrent .
Default: 100.
"""


class CassandraOnlineStore(OnlineStore):
"""
Expand Down Expand Up @@ -358,32 +367,36 @@ def online_read(

result: List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]] = []

for entity_key in entity_keys:
entity_key_bin = serialize_entity_key(
entity_key_bins = [
serialize_entity_key(
entity_key,
entity_key_serialization_version=config.entity_key_serialization_version,
).hex()
for entity_key in entity_keys
]

with tracing_span(name="remote_call"):
feature_rows_sequence = self._read_rows_by_entity_keys(
config,
project,
table,
entity_key_bins,
columns=["feature_name", "value", "event_ts"],
)

with tracing_span(name="remote_call"):
feature_rows = self._read_rows_by_entity_key(
config,
project,
table,
entity_key_bin,
columns=["feature_name", "value", "event_ts"],
)

for entity_key_bin, feature_rows in zip(entity_key_bins, feature_rows_sequence):
res = {}
res_ts = None
for feature_row in feature_rows:
if (
requested_features is None
or feature_row.feature_name in requested_features
):
val = ValueProto()
val.ParseFromString(feature_row.value)
res[feature_row.feature_name] = val
res_ts = feature_row.event_ts
if feature_rows:
for feature_row in feature_rows:
if (
requested_features is None
or feature_row.feature_name in requested_features
):
val = ValueProto()
val.ParseFromString(feature_row.value)
res[feature_row.feature_name] = val
res_ts = feature_row.event_ts
if not res:
result.append((None, None))
else:
Expand Down Expand Up @@ -479,12 +492,12 @@ def _write_rows(
params,
)

def _read_rows_by_entity_key(
def _read_rows_by_entity_keys(
self,
config: RepoConfig,
project: str,
table: FeatureView,
entity_key_bin: str,
entity_key_bins: List[str],
columns: Optional[List[str]] = None,
) -> ResultSet:
"""
Expand All @@ -500,7 +513,25 @@ def _read_rows_by_entity_key(
fqtable=fqtable,
columns=projection_columns,
)
return session.execute(select_cql, [entity_key_bin])
retrieval_results = execute_concurrent_with_args(
session,
select_cql,
((entity_key_bin,) for entity_key_bin in entity_key_bins),
concurrency=config.online_store.read_concurrency,
)
# execute_concurrent_with_args return a sequence
# of (success, result_or_exception) pairs:
returned_sequence = []
for success, result_or_exception in retrieval_results:
if success:
returned_sequence.append(result_or_exception)
else:
# an exception
logger.error(
f"Cassandra online store exception during concurrent fetching: {str(result_or_exception)}"
)
returned_sequence.append(None)
return returned_sequence

def _drop_table(
self,
Expand Down
27 changes: 21 additions & 6 deletions sdk/python/feast/templates/cassandra/bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,16 +70,16 @@ def collect_cassandra_store_settings():
sys.exit(1)
needs_port = click.confirm("Need to specify port?", default=False)
if needs_port:
c_port = click.prompt("Port to use", default=9042, type=int)
c_port = click.prompt(" Port to use", default=9042, type=int)
else:
c_port = None
use_auth = click.confirm(
"Do you need username/password?",
default=False,
)
if use_auth:
c_username = click.prompt("Database username")
c_password = click.prompt("Database password", hide_input=True)
c_username = click.prompt(" Database username")
c_password = click.prompt(" Database password", hide_input=True)
else:
c_username = None
c_password = None
Expand All @@ -95,7 +95,7 @@ def collect_cassandra_store_settings():
)
if specify_protocol_version:
c_protocol_version = click.prompt(
"Protocol version",
" Protocol version",
default={"A": 4, "C": 5}.get(db_type, 5),
type=int,
)
Expand All @@ -105,11 +105,11 @@ def collect_cassandra_store_settings():
specify_lb = click.confirm("Specify load-balancing?", default=False)
if specify_lb:
c_local_dc = click.prompt(
"Local datacenter (for load-balancing)",
" Local datacenter (for load-balancing)",
default="datacenter1" if db_type == "C" else None,
)
c_load_balancing_policy = click.prompt(
"Load-balancing policy",
" Load-balancing policy",
type=click.Choice(
[
"TokenAwarePolicy(DCAwareRoundRobinPolicy)",
Expand All @@ -122,6 +122,12 @@ def collect_cassandra_store_settings():
c_local_dc = None
c_load_balancing_policy = None

needs_concurrency = click.confirm("Specify read concurrency level?", default=False)
if needs_concurrency:
c_concurrency = click.prompt(" Concurrency level?", default=100, type=int)
else:
c_concurrency = None

return {
"c_secure_bundle_path": c_secure_bundle_path,
"c_hosts": c_hosts,
Expand All @@ -132,6 +138,7 @@ def collect_cassandra_store_settings():
"c_protocol_version": c_protocol_version,
"c_local_dc": c_local_dc,
"c_load_balancing_policy": c_load_balancing_policy,
"c_concurrency": c_concurrency,
}


Expand All @@ -149,6 +156,7 @@ def apply_cassandra_store_settings(config_file, settings):
'c_protocol_version'
'c_local_dc'
'c_load_balancing_policy'
'c_concurrency'
"""
write_setting_or_remove(
config_file,
Expand Down Expand Up @@ -216,6 +224,13 @@ def apply_cassandra_store_settings(config_file, settings):
remove_lines_from_file(config_file, "load_balancing:")
remove_lines_from_file(config_file, "local_dc:")
remove_lines_from_file(config_file, "load_balancing_policy:")
#
write_setting_or_remove(
config_file,
settings["c_concurrency"],
"read_concurrency",
"100",
)


def bootstrap():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,5 @@ online_store:
load_balancing:
local_dc: c_local_dc
load_balancing_policy: c_load_balancing_policy
read_concurrency: 100
entity_key_serialization_version: 2