diff --git a/sdk/python/feast/infra/online_stores/dynamodb.py b/sdk/python/feast/infra/online_stores/dynamodb.py index 406bee525f..50709fa3d4 100644 --- a/sdk/python/feast/infra/online_stores/dynamodb.py +++ b/sdk/python/feast/infra/online_stores/dynamodb.py @@ -221,6 +221,9 @@ def online_read( entity_ids_iter = iter(entity_ids) while True: batch = list(itertools.islice(entity_ids_iter, batch_size)) + batch_result: List[ + Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]] + ] = [] # No more items to insert if len(batch) == 0: break @@ -243,20 +246,23 @@ def online_read( for tbl_res in table_responses: entity_id = tbl_res["entity_id"] while entity_id != batch[entity_idx]: - result.append((None, None)) + batch_result.append((None, None)) entity_idx += 1 res = {} for feature_name, value_bin in tbl_res["values"].items(): val = ValueProto() val.ParseFromString(value_bin.value) res[feature_name] = val - result.append((datetime.fromisoformat(tbl_res["event_ts"]), res)) + batch_result.append( + (datetime.fromisoformat(tbl_res["event_ts"]), res) + ) entity_idx += 1 # Not all entities in a batch may have responses # Pad with remaining values in batch that were not found - batch_size_nones = ((None, None),) * (len(batch) - len(result)) - result.extend(batch_size_nones) + batch_size_nones = ((None, None),) * (len(batch) - len(batch_result)) + batch_result.extend(batch_size_nones) + result.extend(batch_result) return result def _get_dynamodb_client(self, region: str, endpoint_url: Optional[str] = None): diff --git a/sdk/python/tests/unit/infra/online_store/test_dynamodb_online_store.py b/sdk/python/tests/unit/infra/online_store/test_dynamodb_online_store.py index 6275a177e0..25eb061930 100644 --- a/sdk/python/tests/unit/infra/online_store/test_dynamodb_online_store.py +++ b/sdk/python/tests/unit/infra/online_store/test_dynamodb_online_store.py @@ -318,3 +318,42 @@ def test_write_batch_non_duplicates(repo_config, dynamodb_online_store): returned_items = response.get("Items", None) assert returned_items is not None assert len(returned_items) == len(data) + + +@mock_dynamodb2 +def test_dynamodb_online_store_online_read_unknown_entity_end_of_batch( + repo_config, dynamodb_online_store +): + """ + Test DynamoDBOnlineStore online_read method with unknown entities at + the end of the batch. + """ + batch_size = repo_config.online_store.batch_size + n_samples = batch_size + _create_test_table(PROJECT, f"{TABLE_NAME}_unknown_entity_{n_samples}", REGION) + data = _create_n_customer_test_samples(n=n_samples) + _insert_data_test_table( + data, PROJECT, f"{TABLE_NAME}_unknown_entity_{n_samples}", REGION + ) + + entity_keys, features, *rest = zip(*data) + entity_keys = list(entity_keys) + features = list(features) + + # Append a nonsensical entity to search for as the only item in the 2nd batch + entity_keys.append( + EntityKeyProto( + join_keys=["customer"], entity_values=[ValueProto(string_val="12359")] + ) + ) + features.append(None) + + returned_items = dynamodb_online_store.online_read( + config=repo_config, + table=MockFeatureView(name=f"{TABLE_NAME}_unknown_entity_{n_samples}"), + entity_keys=entity_keys, + ) + + # ensure the entity is not dropped + assert len(returned_items) == len(entity_keys) + assert returned_items[-1] == (None, None)