Skip to content

Commit

Permalink
Use AIProvider to extract token details from response
Browse files Browse the repository at this point in the history
  • Loading branch information
Tharsanan1 committed Sep 11, 2024
1 parent afc47b2 commit b0d9944
Show file tree
Hide file tree
Showing 6 changed files with 77 additions and 40 deletions.
1 change: 1 addition & 0 deletions adapter/internal/oasparser/envoyconf/internal_dtos.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ type routeCreateParams struct {
environment string
envType string
mirrorClusterNames map[string][]string
isAiAPI bool
}

// RatelimitCriteria criterias of rate limiting
Expand Down
23 changes: 12 additions & 11 deletions adapter/internal/oasparser/envoyconf/routes_with_clusters.go
Original file line number Diff line number Diff line change
Expand Up @@ -865,7 +865,7 @@ func createRoutes(params *routeCreateParams) (routes []*routev3.Route, err error
LuaLocal: luaFilter,
wellknown.CORS: corsFilter,
}
if !resource.GetEnableBackendBasedAIRatelimit() && !resource.GetEnableSubscriptionBasedAIRatelimit() {
if !params.isAiAPI || (!resource.GetEnableBackendBasedAIRatelimit() && !resource.GetEnableSubscriptionBasedAIRatelimit()) {
perFilterConfigExtProc := extProcessorv3.ExtProcPerRoute{
Override: &extProcessorv3.ExtProcPerRoute_Disabled{
Disabled: true,
Expand All @@ -878,7 +878,7 @@ func createRoutes(params *routeCreateParams) (routes []*routev3.Route, err error
}
perRouteFilterConfigs[HTTPExternalProcessor] = filterExtProc
}

logger.LoggerOasparser.Debugf("adding route : %s for API : %s", resourcePath, title)

rateLimitPolicyLevel := ""
Expand Down Expand Up @@ -931,7 +931,7 @@ func createRoutes(params *routeCreateParams) (routes []*routev3.Route, err error
routeConfig := resource.GetEndpoints().Config
metaData := &corev3.Metadata{}
logger.LoggerAPI.Infof("Is backend based rl enabled: %+v, Is subs based rl enabled: %+v", resource.GetEnableBackendBasedAIRatelimit(), resource.GetEnableSubscriptionBasedAIRatelimit())
if resource.GetEnableBackendBasedAIRatelimit() || resource.GetEnableSubscriptionBasedAIRatelimit() {
if params.isAiAPI && (resource.GetEnableBackendBasedAIRatelimit() || resource.GetEnableSubscriptionBasedAIRatelimit()) {
metaData = &corev3.Metadata{
FilterMetadata: map[string]*structpb.Struct{
"envoy.filters.http.ext_proc": &structpb.Struct{
Expand Down Expand Up @@ -1092,8 +1092,8 @@ func createRoutes(params *routeCreateParams) (routes []*routev3.Route, err error
metadataValue := operation.GetMethod() + "_to_" + newMethod
match2.DynamicMetadata = generateMetadataMatcherForInternalRoutes(metadataValue)

action1 := generateRouteAction(apiType, routeConfig, rateLimitPolicyCriteria, mirrorClusterNames[operation.GetID()], resource.GetEnableSubscriptionBasedAIRatelimit(), resource.GetEnableBackendBasedAIRatelimit(), resource.GetBackendBasedAIRatelimitDescriptorValue())
action2 := generateRouteAction(apiType, routeConfig, rateLimitPolicyCriteria, mirrorClusterNames[operation.GetID()], resource.GetEnableSubscriptionBasedAIRatelimit(), resource.GetEnableBackendBasedAIRatelimit(), resource.GetBackendBasedAIRatelimitDescriptorValue())
action1 := generateRouteAction(apiType, routeConfig, rateLimitPolicyCriteria, mirrorClusterNames[operation.GetID()], resource.GetEnableSubscriptionBasedAIRatelimit() && params.isAiAPI, resource.GetEnableBackendBasedAIRatelimit() && params.isAiAPI, resource.GetBackendBasedAIRatelimitDescriptorValue())
action2 := generateRouteAction(apiType, routeConfig, rateLimitPolicyCriteria, mirrorClusterNames[operation.GetID()], resource.GetEnableSubscriptionBasedAIRatelimit() && params.isAiAPI, resource.GetEnableBackendBasedAIRatelimit() && params.isAiAPI, resource.GetBackendBasedAIRatelimitDescriptorValue())

// Create route1 for current method.
// Do not add policies to route config. Send via enforcer
Expand All @@ -1116,7 +1116,7 @@ func createRoutes(params *routeCreateParams) (routes []*routev3.Route, err error
} else {
var action *routev3.Route_Route
if requestRedirectAction == nil {
action = generateRouteAction(apiType, routeConfig, rateLimitPolicyCriteria, mirrorClusterNames[operation.GetID()], resource.GetEnableSubscriptionBasedAIRatelimit(), resource.GetEnableBackendBasedAIRatelimit(), resource.GetBackendBasedAIRatelimitDescriptorValue())
action = generateRouteAction(apiType, routeConfig, rateLimitPolicyCriteria, mirrorClusterNames[operation.GetID()], resource.GetEnableSubscriptionBasedAIRatelimit() && params.isAiAPI, resource.GetEnableBackendBasedAIRatelimit() && params.isAiAPI, resource.GetBackendBasedAIRatelimitDescriptorValue())
}
logger.LoggerOasparser.Debug("Creating routes for resource with policies", resourcePath, operation.GetMethod())
// create route for current method. Add policies to route config. Send via enforcer
Expand Down Expand Up @@ -1145,7 +1145,7 @@ func createRoutes(params *routeCreateParams) (routes []*routev3.Route, err error
action := generateRouteAction(apiType, routeConfig, rateLimitPolicyCriteria, nil, resource.GetEnableSubscriptionBasedAIRatelimit(), resource.GetEnableBackendBasedAIRatelimit(), resource.GetBackendBasedAIRatelimitDescriptorValue())
rewritePath := generateRoutePathForReWrite(basePath, resourcePath, pathMatchType)
action.Route.RegexRewrite = generateRegexMatchAndSubstitute(rewritePath, resourcePath, pathMatchType)

route := generateRouteConfig(xWso2Basepath, match, action, nil, metaData, decorator, perRouteFilterConfigs,
nil, nil, nil, nil) // general headers to add and remove are included in this methods
routes = append(routes, route)
Expand Down Expand Up @@ -1284,7 +1284,7 @@ func CreateAPIDefinitionRoute(basePath string, vHost string, methods []string, i
Decorator: decorator,
TypedPerFilterConfig: map[string]*any.Any{
wellknown.HTTPExternalAuthorization: filter,
HTTPExternalProcessor : filterExtProc,
HTTPExternalProcessor: filterExtProc,
},
}
return &router
Expand Down Expand Up @@ -1378,7 +1378,7 @@ func CreateAPIDefinitionEndpoint(adapterInternalAPI *model.AdapterInternalAPI, v
Decorator: decorator,
TypedPerFilterConfig: map[string]*any.Any{
wellknown.HTTPExternalAuthorization: filter,
HTTPExternalProcessor : filterExtProc,
HTTPExternalProcessor: filterExtProc,
},
}
return router
Expand Down Expand Up @@ -1443,7 +1443,7 @@ func CreateHealthEndpoint() *routev3.Route {
Decorator: decorator,
TypedPerFilterConfig: map[string]*any.Any{
wellknown.HTTPExternalAuthorization: filter,
HTTPExternalProcessor : filterExtProc,
HTTPExternalProcessor: filterExtProc,
},
}
return &router
Expand Down Expand Up @@ -1493,7 +1493,7 @@ func CreateReadyEndpoint() *routev3.Route {
Metadata: nil,
Decorator: decorator,
TypedPerFilterConfig: map[string]*any.Any{
HTTPExternalProcessor : filterExtProc,
HTTPExternalProcessor: filterExtProc,
},
}
return &router
Expand Down Expand Up @@ -1691,6 +1691,7 @@ func genRouteCreateParams(swagger *model.AdapterInternalAPI, resource *model.Res
environment: swagger.GetEnvironment(),
envType: swagger.EnvType,
mirrorClusterNames: mirrorClusterNames,
isAiAPI: swagger.AIProvider.Enabled,
}
return params
}
Expand Down
6 changes: 2 additions & 4 deletions adapter/internal/operator/controllers/dp/api_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -869,14 +869,12 @@ func (apiReconciler *APIReconciler) resolveAiSubscriptionRatelimitPolicies(ctx c
Name: subscription.Spec.RatelimitRef.Name,
Namespace: subscription.GetNamespace(),
}
if err := apiReconciler.client.Get(ctx, nn, aiRatelimitPolicy, ); err != nil {
loggers.LoggerAPKOperator.Infof("No associated aiRatelimitPolicy found for Subscription: %s", utils.NamespacedName(&subscription))
continue
} else {
if err := apiReconciler.client.Get(ctx, nn, aiRatelimitPolicy, ); err == nil {
loggers.LoggerAPKOperator.Infof("API state set as AI subscription enabled")
apiState.IsAiSubscriptionRatelimitEnabled = true
break
}
loggers.LoggerAPKOperator.Infof("No associated aiRatelimitPolicy found for Subscription: %s", utils.NamespacedName(&subscription))
}
}

Expand Down
14 changes: 14 additions & 0 deletions adapter/internal/operator/synchronizer/data_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,20 @@ func (ods *OperatorDataStore) processAPIState(apiNamespacedName types.Namespaced
events = append(events, "Subscription based AI RatelimitPolicy")
}

if cachedAPI.AIProvider == nil && apiState.AIProvider != nil {
cachedAPI.AIProvider = apiState.AIProvider
updated = true
events = append(events, "API provider")
} else if cachedAPI.AIProvider != nil && apiState.AIProvider == nil{
cachedAPI.AIProvider = nil
updated = true
events = append(events, "API provider")
} else if cachedAPI.AIProvider.Generation != apiState.AIProvider.Generation {
cachedAPI.AIProvider = apiState.AIProvider
updated = true
events = append(events, "API provider")
}

if apiState.APIDefinition.Generation > cachedAPI.APIDefinition.Generation {
cachedAPI.APIDefinition = apiState.APIDefinition
updated = true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ public class ExternalProcessorService extends ExternalProcessorGrpc.ExternalProc
private static final String DESCRIPTOR_KEY_FOR_AI_SUBSCRIPTION = "subscription";
private static final String DYNAMIC_METADATA_KEY_FOR_ORGANIZATION_AND_AIRL_POLICY = "ratelimit:organization-and-rlpolicy";
private static final String DYNAMIC_METADATA_KEY_FOR_SUBSCRIPTION = "ratelimit:subscription";
private static final String DYNAMIC_METADATA_KEY_FOR_EXTRACT_TOKEN_FROM = "aitoken:extracttokenfrom";
private static final String DYNAMIC_METADATA_KEY_FOR_PROMPT_TOKEN_ID = "aitoken:prompttokenid";
private static final String DYNAMIC_METADATA_KEY_FOR_COMPLETION_TOKEN_ID = "aitoken:completiontokenid";
private static final String DYNAMIC_METADATA_KEY_FOR_TOTAL_TOKEN_ID = "aitoken:totaltokenid";
RatelimitClient ratelimitClient = new RatelimitClient();
@Override
public StreamObserver<ProcessingRequest> process(
Expand Down Expand Up @@ -92,34 +96,47 @@ public void onNext(ProcessingRequest request) {
System.out.println("In the response flow metadata descirtor:" + filterMetadata.backendBasedAIRatelimitDescriptorValue);
if (request.hasResponseBody()) {
String body = request.getResponseBody().getBody().toStringUtf8();
// System.out.println("Body: " + body);
Usage usage = extractUsageFromBody(body, "usage.completion_tokens", "usage.prompt_tokens", "usage.total_tokens");
if (usage == null) {
logger.error("Usage details not found..");
System.out.println("Usage details not found..");
responseObserver.onCompleted();
return;
}
System.out.println("body: " +request.getResponseBody().getBody().toStringUtf8());
List<RatelimitClient.KeyValueHitsAddend> configs = new ArrayList<>();
if (filterMetadata.enableBackendBasedAIRatelimit) {
configs.add(new RatelimitClient.KeyValueHitsAddend(DESCRIPTOR_KEY_FOR_AI_REQUEST_TOKEN_COUNT, filterMetadata.backendBasedAIRatelimitDescriptorValue, usage.getPrompt_tokens()));
configs.add(new RatelimitClient.KeyValueHitsAddend(DESCRIPTOR_KEY_FOR_AI_RESPONSE_TOKEN_COUNT, filterMetadata.backendBasedAIRatelimitDescriptorValue, usage.getCompletion_tokens()));
configs.add(new RatelimitClient.KeyValueHitsAddend(DESCRIPTOR_KEY_FOR_AI_TOTAL_TOKEN_COUNT, filterMetadata.backendBasedAIRatelimitDescriptorValue, usage.getTotal_tokens()));
}
if (filterMetadata.enableSubscriptionBasedAIRatelimit) {
if (request.hasMetadataContext()) {
Struct filterMetadataFromAuthZ = request.getMetadataContext().getFilterMetadataOrDefault("envoy.filters.http.ext_authz", null);
if (filterMetadataFromAuthZ != null) {
String orgAndAIRLPolicyValue = filterMetadataFromAuthZ.getFieldsMap().get(DYNAMIC_METADATA_KEY_FOR_ORGANIZATION_AND_AIRL_POLICY).getStringValue();
String aiRLSubsValue = filterMetadataFromAuthZ.getFieldsMap().get(DYNAMIC_METADATA_KEY_FOR_SUBSCRIPTION).getStringValue();
configs.add(new RatelimitClient.KeyValueHitsAddend(DESCRIPTOR_KEY_FOR_SUBSCRIPTION_BASED_AI_REQUEST_TOKEN_COUNT, orgAndAIRLPolicyValue, new RatelimitClient.KeyValueHitsAddend(DESCRIPTOR_KEY_FOR_AI_SUBSCRIPTION, aiRLSubsValue, usage.getPrompt_tokens())));
configs.add(new RatelimitClient.KeyValueHitsAddend(DESCRIPTOR_KEY_FOR_SUBSCRIPTION_BASED_AI_RESPONSE_TOKEN_COUNT, orgAndAIRLPolicyValue, new RatelimitClient.KeyValueHitsAddend(DESCRIPTOR_KEY_FOR_AI_SUBSCRIPTION, aiRLSubsValue, usage.getCompletion_tokens())));
configs.add(new RatelimitClient.KeyValueHitsAddend(DESCRIPTOR_KEY_FOR_SUBSCRIPTION_BASED_AI_TOTAL_TOKEN_COUNT, orgAndAIRLPolicyValue, new RatelimitClient.KeyValueHitsAddend(DESCRIPTOR_KEY_FOR_AI_SUBSCRIPTION, aiRLSubsValue, usage.getTotal_tokens())));
Struct filterMetadataFromAuthZ = request.getMetadataContext().getFilterMetadataOrDefault("envoy.filters.http.ext_authz", null);
if (filterMetadataFromAuthZ != null) {
String extractTokenFrom = filterMetadataFromAuthZ.getFieldsMap().get(DYNAMIC_METADATA_KEY_FOR_EXTRACT_TOKEN_FROM).getStringValue();
System.out.println("Extract Token From: " + extractTokenFrom);

String promptTokenID = filterMetadataFromAuthZ.getFieldsMap().get(DYNAMIC_METADATA_KEY_FOR_PROMPT_TOKEN_ID).getStringValue();
System.out.println("Prompt Token ID: " + promptTokenID);

String completionTokenID = filterMetadataFromAuthZ.getFieldsMap().get(DYNAMIC_METADATA_KEY_FOR_COMPLETION_TOKEN_ID).getStringValue();
System.out.println("Completion Token ID: " + completionTokenID);

String totalTokenID = filterMetadataFromAuthZ.getFieldsMap().get(DYNAMIC_METADATA_KEY_FOR_TOTAL_TOKEN_ID).getStringValue();
System.out.println("Total Token ID: " + totalTokenID);

Usage usage = extractUsageFromBody(body, completionTokenID, promptTokenID, totalTokenID);
if (usage == null) {
logger.error("Usage details not found..");
System.out.println("Usage details not found..");
responseObserver.onCompleted();
return;
}
System.out.println("body: " +request.getResponseBody().getBody().toStringUtf8());
List<RatelimitClient.KeyValueHitsAddend> configs = new ArrayList<>();
if (filterMetadata.enableBackendBasedAIRatelimit) {
configs.add(new RatelimitClient.KeyValueHitsAddend(DESCRIPTOR_KEY_FOR_AI_REQUEST_TOKEN_COUNT, filterMetadata.backendBasedAIRatelimitDescriptorValue, usage.getPrompt_tokens()));
configs.add(new RatelimitClient.KeyValueHitsAddend(DESCRIPTOR_KEY_FOR_AI_RESPONSE_TOKEN_COUNT, filterMetadata.backendBasedAIRatelimitDescriptorValue, usage.getCompletion_tokens()));
configs.add(new RatelimitClient.KeyValueHitsAddend(DESCRIPTOR_KEY_FOR_AI_TOTAL_TOKEN_COUNT, filterMetadata.backendBasedAIRatelimitDescriptorValue, usage.getTotal_tokens()));
}
if (filterMetadata.enableSubscriptionBasedAIRatelimit) {
if (request.hasMetadataContext()) {
if (filterMetadataFromAuthZ != null) {
String orgAndAIRLPolicyValue = filterMetadataFromAuthZ.getFieldsMap().get(DYNAMIC_METADATA_KEY_FOR_ORGANIZATION_AND_AIRL_POLICY).getStringValue();
String aiRLSubsValue = filterMetadataFromAuthZ.getFieldsMap().get(DYNAMIC_METADATA_KEY_FOR_SUBSCRIPTION).getStringValue();
configs.add(new RatelimitClient.KeyValueHitsAddend(DESCRIPTOR_KEY_FOR_SUBSCRIPTION_BASED_AI_REQUEST_TOKEN_COUNT, orgAndAIRLPolicyValue, new RatelimitClient.KeyValueHitsAddend(DESCRIPTOR_KEY_FOR_AI_SUBSCRIPTION, aiRLSubsValue, usage.getPrompt_tokens())));
configs.add(new RatelimitClient.KeyValueHitsAddend(DESCRIPTOR_KEY_FOR_SUBSCRIPTION_BASED_AI_RESPONSE_TOKEN_COUNT, orgAndAIRLPolicyValue, new RatelimitClient.KeyValueHitsAddend(DESCRIPTOR_KEY_FOR_AI_SUBSCRIPTION, aiRLSubsValue, usage.getCompletion_tokens())));
configs.add(new RatelimitClient.KeyValueHitsAddend(DESCRIPTOR_KEY_FOR_SUBSCRIPTION_BASED_AI_TOTAL_TOKEN_COUNT, orgAndAIRLPolicyValue, new RatelimitClient.KeyValueHitsAddend(DESCRIPTOR_KEY_FOR_AI_SUBSCRIPTION, aiRLSubsValue, usage.getTotal_tokens())));
}
}
}
ratelimitClient.shouldRatelimit(configs);
}
ratelimitClient.shouldRatelimit(configs);
responseObserver.onCompleted();
} else {
System.out.println("Request does not have response body");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,12 @@ public boolean handleRequest(RequestContext requestContext) {
boolean authenticated = false;
// Any auth token has been provided for application-level security or not
boolean canAuthenticated = false;
if (requestContext.getMatchedAPI() != null && requestContext.getMatchedAPI().getAiProvider() != null) {
requestContext.addMetadataToMap("aitoken:prompttokenid", requestContext.getMatchedAPI().getAiProvider().getPromptTokens().getValue());
requestContext.addMetadataToMap("aitoken:completiontokenid", requestContext.getMatchedAPI().getAiProvider().getCompletionToken().getValue());
requestContext.addMetadataToMap("aitoken:totaltokenid", requestContext.getMatchedAPI().getAiProvider().getTotalToken().getValue());
requestContext.addMetadataToMap("aitoken:extracttokenfrom", requestContext.getMatchedAPI().getAiProvider().getCompletionToken().getIn());
}
for (Authenticator authenticator : authenticators) {
if (authenticator.canAuthenticate(requestContext)) {
// For transport level securities (mTLS), canAuthenticated will not be applied
Expand Down

0 comments on commit b0d9944

Please sign in to comment.