Skip to content

Commit

Permalink
feat: Support retrieval from multiple feature views with different jo…
Browse files Browse the repository at this point in the history
…in keys (#2835)

* feat: Support retrieving from multiple feature views

Signed-off-by: Yongheng Lin <yongheng.lin@gmail.com>

* group by join keys instead of feature view

Signed-off-by: Yongheng Lin <yongheng.lin@gmail.com>

* tolerate insufficient entities

Signed-off-by: Yongheng Lin <yongheng.lin@gmail.com>

* mock registry.getEntityJoinKey

Signed-off-by: Yongheng Lin <yongheng.lin@gmail.com>

* add integration test

Signed-off-by: Yongheng Lin <yongheng.lin@gmail.com>
  • Loading branch information
yongheng authored Jun 30, 2022
1 parent 86e9efd commit 056cfa1
Show file tree
Hide file tree
Showing 5 changed files with 140 additions and 16 deletions.
15 changes: 15 additions & 0 deletions java/serving/src/main/java/feast/serving/registry/Registry.java
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ public class Registry {
private Map<String, OnDemandFeatureViewProto.OnDemandFeatureViewSpec>
onDemandFeatureViewNameToSpec;
private final Map<String, FeatureServiceProto.FeatureServiceSpec> featureServiceNameToSpec;
private final Map<String, String> entityNameToJoinKey;

Registry(RegistryProto.Registry registry) {
this.registry = registry;
Expand Down Expand Up @@ -60,6 +61,12 @@ public class Registry {
.collect(
Collectors.toMap(
FeatureServiceProto.FeatureServiceSpec::getName, Function.identity()));
this.entityNameToJoinKey =
registry.getEntitiesList().stream()
.map(EntityProto.Entity::getSpec)
.collect(
Collectors.toMap(
EntityProto.EntitySpecV2::getName, EntityProto.EntitySpecV2::getJoinKey));
}

public RegistryProto.Registry getRegistry() {
Expand Down Expand Up @@ -115,4 +122,12 @@ public FeatureServiceProto.FeatureServiceSpec getFeatureServiceSpec(String name)
}
return spec;
}

public String getEntityJoinKey(String name) {
String joinKey = entityNameToJoinKey.get(name);
if (joinKey == null) {
throw new SpecRetrievalException(String.format("Unable to find entity with name: %s", name));
}
return joinKey;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -102,4 +102,8 @@ public Duration getMaxAge(ServingAPIProto.FeatureReferenceV2 featureReference) {
public List<String> getEntitiesList(ServingAPIProto.FeatureReferenceV2 featureReference) {
return getFeatureViewSpec(featureReference).getEntitiesList();
}

public String getEntityJoinKey(String name) {
return this.registry.getEntityJoinKey(name);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
import feast.serving.registry.RegistryRepository;
import feast.serving.util.Metrics;
import feast.storage.api.retriever.OnlineRetrieverV2;
import io.grpc.Status;
import io.opentracing.Span;
import io.opentracing.Tracer;
import java.util.*;
Expand All @@ -51,6 +50,11 @@ public class OnlineServingServiceV2 implements ServingServiceV2 {
private final OnlineTransformationService onlineTransformationService;
private final String project;

public static final String DUMMY_ENTITY_ID = "__dummy_id";
public static final String DUMMY_ENTITY_VAL = "";
public static final ValueProto.Value DUMMY_ENTITY_VALUE =
ValueProto.Value.newBuilder().setStringVal(DUMMY_ENTITY_VAL).build();

public OnlineServingServiceV2(
OnlineRetrieverV2 retriever,
Tracer tracer,
Expand Down Expand Up @@ -103,31 +107,18 @@ public ServingAPIProto.GetOnlineFeaturesResponse getOnlineFeatures(

List<Map<String, ValueProto.Value>> entityRows = getEntityRows(request);

List<String> entityNames;
if (retrievedFeatureReferences.size() > 0) {
entityNames = this.registryRepository.getEntitiesList(retrievedFeatureReferences.get(0));
} else {
throw new RuntimeException("Requested features list must not be empty");
}

Span storageRetrievalSpan = tracer.buildSpan("storageRetrieval").start();
if (storageRetrievalSpan != null) {
storageRetrievalSpan.setTag("entities", entityRows.size());
storageRetrievalSpan.setTag("features", retrievedFeatureReferences.size());
}

List<List<feast.storage.api.retriever.Feature>> features =
retriever.getOnlineFeatures(entityRows, retrievedFeatureReferences, entityNames);
retrieveFeatures(retrievedFeatureReferences, entityRows);

if (storageRetrievalSpan != null) {
storageRetrievalSpan.finish();
}
if (features.size() != entityRows.size()) {
throw Status.INTERNAL
.withDescription(
"The no. of FeatureRow obtained from OnlineRetriever"
+ "does not match no. of entityRow passed.")
.asRuntimeException();
}

Span postProcessingSpan = tracer.buildSpan("postProcessing").start();

Expand Down Expand Up @@ -255,6 +246,84 @@ private List<Map<String, ValueProto.Value>> getEntityRows(
return entityRows;
}

private List<List<feast.storage.api.retriever.Feature>> retrieveFeatures(
List<FeatureReferenceV2> featureReferences, List<Map<String, ValueProto.Value>> entityRows) {
// Prepare feature reference to index mapping. This mapping will be used to arrange the
// retrieved features to the same order as in the input.
if (featureReferences.isEmpty()) {
throw new RuntimeException("Requested features list must not be empty.");
}
Map<FeatureReferenceV2, Integer> featureReferenceToIndexMap =
new HashMap<>(featureReferences.size());
for (int i = 0; i < featureReferences.size(); i++) {
FeatureReferenceV2 featureReference = featureReferences.get(i);
if (featureReferenceToIndexMap.containsKey(featureReference)) {
throw new RuntimeException(
String.format(
"Found duplicate features %s:%s.",
featureReference.getFeatureViewName(), featureReference.getFeatureName()));
}
featureReferenceToIndexMap.put(featureReference, i);
}

// Create placeholders for retrieved features.
List<List<feast.storage.api.retriever.Feature>> features = new ArrayList<>(entityRows.size());
for (int i = 0; i < entityRows.size(); i++) {
List<feast.storage.api.retriever.Feature> featuresPerEntity =
new ArrayList<>(featureReferences.size());
for (int j = 0; j < featureReferences.size(); j++) {
featuresPerEntity.add(null);
}
features.add(featuresPerEntity);
}

// Group feature references by join keys.
Map<String, List<FeatureReferenceV2>> groupNameToFeatureReferencesMap =
featureReferences.stream()
.collect(
Collectors.groupingBy(
featureReference ->
this.registryRepository.getEntitiesList(featureReference).stream()
.map(this.registryRepository::getEntityJoinKey)
.sorted()
.collect(Collectors.joining(","))));

// Retrieve features one group at a time.
for (List<FeatureReferenceV2> featureReferencesPerGroup :
groupNameToFeatureReferencesMap.values()) {
List<String> entityNames =
this.registryRepository.getEntitiesList(featureReferencesPerGroup.get(0));
List<Map<String, ValueProto.Value>> entityRowsPerGroup = new ArrayList<>(entityRows.size());
for (Map<String, ValueProto.Value> entityRow : entityRows) {
Map<String, ValueProto.Value> entityRowPerGroup = new HashMap<>();
entityNames.stream()
.map(this.registryRepository::getEntityJoinKey)
.forEach(
joinKey -> {
if (joinKey.equals(DUMMY_ENTITY_ID)) {
entityRowPerGroup.put(joinKey, DUMMY_ENTITY_VALUE);
} else {
ValueProto.Value value = entityRow.get(joinKey);
if (value != null) {
entityRowPerGroup.put(joinKey, value);
}
}
});
entityRowsPerGroup.add(entityRowPerGroup);
}
List<List<feast.storage.api.retriever.Feature>> featuresPerGroup =
retriever.getOnlineFeatures(entityRowsPerGroup, featureReferencesPerGroup, entityNames);
for (int i = 0; i < featuresPerGroup.size(); i++) {
for (int j = 0; j < featureReferencesPerGroup.size(); j++) {
int k = featureReferenceToIndexMap.get(featureReferencesPerGroup.get(j));
features.get(i).set(k, featuresPerGroup.get(i).get(j));
}
}
}

return features;
}

private void populateOnDemandFeatures(
List<FeatureReferenceV2> onDemandFeatureReferences,
List<FeatureReferenceV2> onDemandFeatureSources,
Expand Down
30 changes: 30 additions & 0 deletions java/serving/src/test/java/feast/serving/it/ServingBaseTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -172,5 +172,35 @@ public void shouldGetOnlineFeaturesWithStringEntity() {
}
}

@Test
public void shouldGetOnlineFeaturesFromAllFeatureViews() {
Map<String, ValueProto.RepeatedValue> entityRows =
ImmutableMap.of(
"entity",
ValueProto.RepeatedValue.newBuilder()
.addVal(DataGenerator.createStrValue("key-1"))
.build(),
"driver_id",
ValueProto.RepeatedValue.newBuilder()
.addVal(DataGenerator.createInt64Value(1005))
.build());

ImmutableList<String> featureReferences =
ImmutableList.of(
"feature_view_0:feature_0",
"feature_view_0:feature_1",
"driver_hourly_stats:conv_rate",
"driver_hourly_stats:avg_daily_trips");

ServingAPIProto.GetOnlineFeaturesRequest req =
TestUtils.createOnlineFeatureRequest(featureReferences, entityRows);

ServingAPIProto.GetOnlineFeaturesResponse resp = servingStub.getOnlineFeatures(req);

for (final int featureIdx : List.of(0, 1, 2, 3)) {
assertEquals(FieldStatus.PRESENT, resp.getResults(featureIdx).getStatuses(0));
}
}

abstract void updateRegistryFile(RegistryProto.Registry registry);
}
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,8 @@ public void shouldReturnResponseWithValuesAndMetadataIfKeysPresent() {
.thenReturn(featureSpecs.get(0));
when(registry.getFeatureSpec(mockedFeatureRows.get(3).getFeatureReference()))
.thenReturn(featureSpecs.get(1));
when(registry.getEntityJoinKey("entity1")).thenReturn("entity1");
when(registry.getEntityJoinKey("entity2")).thenReturn("entity2");

when(tracer.buildSpan(ArgumentMatchers.any())).thenReturn(Mockito.mock(SpanBuilder.class));

Expand Down Expand Up @@ -237,6 +239,8 @@ public void shouldReturnResponseWithUnsetValuesAndMetadataIfKeysNotPresent() {
.thenReturn(featureSpecs.get(0));
when(registry.getFeatureSpec(mockedFeatureRows.get(1).getFeatureReference()))
.thenReturn(featureSpecs.get(1));
when(registry.getEntityJoinKey("entity1")).thenReturn("entity1");
when(registry.getEntityJoinKey("entity2")).thenReturn("entity2");

when(tracer.buildSpan(ArgumentMatchers.any())).thenReturn(Mockito.mock(SpanBuilder.class));

Expand Down Expand Up @@ -314,6 +318,8 @@ public void shouldReturnResponseWithValuesAndMetadataIfMaxAgeIsExceeded() {
.thenReturn(featureSpecs.get(1));
when(registry.getFeatureSpec(mockedFeatureRows.get(5).getFeatureReference()))
.thenReturn(featureSpecs.get(0));
when(registry.getEntityJoinKey("entity1")).thenReturn("entity1");
when(registry.getEntityJoinKey("entity2")).thenReturn("entity2");

when(tracer.buildSpan(ArgumentMatchers.any())).thenReturn(Mockito.mock(SpanBuilder.class));

Expand Down

0 comments on commit 056cfa1

Please sign in to comment.