Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 2.x] Adds REST APIs for creating and provisioning a workflow (#63) #96

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,28 +9,53 @@
package org.opensearch.flowframework;

import com.google.common.collect.ImmutableList;
import org.opensearch.action.ActionRequest;
import org.opensearch.client.Client;
import org.opensearch.cluster.metadata.IndexNameExpressionResolver;
import org.opensearch.cluster.node.DiscoveryNodes;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.ClusterSettings;
import org.opensearch.common.settings.IndexScopedSettings;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.settings.SettingsFilter;
import org.opensearch.common.util.concurrent.OpenSearchExecutors;
import org.opensearch.core.action.ActionResponse;
import org.opensearch.core.common.io.stream.NamedWriteableRegistry;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.env.Environment;
import org.opensearch.env.NodeEnvironment;
import org.opensearch.flowframework.indices.GlobalContextHandler;
import org.opensearch.flowframework.rest.RestCreateWorkflowAction;
import org.opensearch.flowframework.rest.RestProvisionWorkflowAction;
import org.opensearch.flowframework.transport.CreateWorkflowAction;
import org.opensearch.flowframework.transport.CreateWorkflowTransportAction;
import org.opensearch.flowframework.transport.ProvisionWorkflowAction;
import org.opensearch.flowframework.transport.ProvisionWorkflowTransportAction;
import org.opensearch.flowframework.workflow.CreateIndexStep;
import org.opensearch.flowframework.workflow.WorkflowProcessSorter;
import org.opensearch.flowframework.workflow.WorkflowStepFactory;
import org.opensearch.plugins.ActionPlugin;
import org.opensearch.plugins.Plugin;
import org.opensearch.repositories.RepositoriesService;
import org.opensearch.rest.RestController;
import org.opensearch.rest.RestHandler;
import org.opensearch.script.ScriptService;
import org.opensearch.threadpool.ExecutorBuilder;
import org.opensearch.threadpool.FixedExecutorBuilder;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.watcher.ResourceWatcherService;

import java.util.Collection;
import java.util.List;
import java.util.function.Supplier;

import static org.opensearch.flowframework.common.CommonValue.FLOW_FRAMEWORK_THREAD_POOL_PREFIX;
import static org.opensearch.flowframework.common.CommonValue.PROVISION_THREAD_POOL;

/**
* An OpenSearch plugin that enables builders to innovate AI apps on OpenSearch.
*/
public class FlowFrameworkPlugin extends Plugin {
public class FlowFrameworkPlugin extends Plugin implements ActionPlugin {

/**
* Instantiate this plugin.
Expand All @@ -54,6 +79,45 @@ public Collection<Object> createComponents(
WorkflowStepFactory workflowStepFactory = new WorkflowStepFactory(clusterService, client);
WorkflowProcessSorter workflowProcessSorter = new WorkflowProcessSorter(workflowStepFactory, threadPool);

return ImmutableList.of(workflowStepFactory, workflowProcessSorter);
// TODO : Refactor, move system index creation/associated methods outside of the CreateIndexStep
GlobalContextHandler globalContextHandler = new GlobalContextHandler(client, new CreateIndexStep(clusterService, client));

return ImmutableList.of(workflowStepFactory, workflowProcessSorter, globalContextHandler);
}

@Override
public List<RestHandler> getRestHandlers(
Settings settings,
RestController restController,
ClusterSettings clusterSettings,
IndexScopedSettings indexScopedSettings,
SettingsFilter settingsFilter,
IndexNameExpressionResolver indexNameExpressionResolver,
Supplier<DiscoveryNodes> nodesInCluster
) {
return ImmutableList.of(new RestCreateWorkflowAction(), new RestProvisionWorkflowAction());
}

@Override
public List<ActionHandler<? extends ActionRequest, ? extends ActionResponse>> getActions() {
return ImmutableList.of(
new ActionHandler<>(CreateWorkflowAction.INSTANCE, CreateWorkflowTransportAction.class),
new ActionHandler<>(ProvisionWorkflowAction.INSTANCE, ProvisionWorkflowTransportAction.class)
);
}

@Override
public List<ExecutorBuilder<?>> getExecutorBuilders(Settings settings) {
// TODO : Determine final size/queueSize values for the provision thread pool
return ImmutableList.of(
new FixedExecutorBuilder(
settings,
PROVISION_THREAD_POOL,
OpenSearchExecutors.allocatedProcessors(settings),
10,
FLOW_FRAMEWORK_THREAD_POOL_PREFIX + PROVISION_THREAD_POOL
)
);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,51 @@
*/
public class CommonValue {

/** Default value for no schema version */
public static Integer NO_SCHEMA_VERSION = 0;
/** Index mapping meta field name*/
public static final String META = "_meta";
/** Schema Version field name */
public static final String SCHEMA_VERSION_FIELD = "schema_version";
/** Global Context Index Name */
public static final String GLOBAL_CONTEXT_INDEX = ".plugins-ai-global-context";
/** Global Context index mapping file path */
public static final String GLOBAL_CONTEXT_INDEX_MAPPING = "mappings/global-context.json";
/** Global Context index mapping version */
public static final Integer GLOBAL_CONTEXT_INDEX_VERSION = 1;

/** The transport action name prefix */
public static final String TRANSPORT_ACION_NAME_PREFIX = "cluster:admin/opensearch/flow_framework/";
/** The base URI for this plugin's rest actions */
public static final String FLOW_FRAMEWORK_BASE_URI = "/_plugins/_flow_framework";
/** The URI for this plugin's workflow rest actions */
public static final String WORKFLOW_URI = FLOW_FRAMEWORK_BASE_URI + "/workflow";
/** Field name for workflow Id, the document Id of the indexed use case template */
public static final String WORKFLOW_ID = "workflow_id";
/** The field name for provision workflow within a use case template*/
public static final String PROVISION_WORKFLOW = "provision";

/** Flow Framework plugin thread pool name prefix */
public static final String FLOW_FRAMEWORK_THREAD_POOL_PREFIX = "thread_pool.flow_framework.";
/** The provision workflow thread pool name */
public static final String PROVISION_THREAD_POOL = "opensearch_workflow_provision";

/** Model Id field */
public static final String MODEL_ID = "model_id";
/** Function Name field */
public static final String FUNCTION_NAME = "function_name";
/** Model Name field */
public static final String MODEL_NAME = "name";
/** Model Version field */
public static final String MODEL_VERSION = "model_version";
/** Model Group Id field */
public static final String MODEL_GROUP_ID = "model_group_id";
/** Description field */
public static final String DESCRIPTION = "description";
/** Connector Id field */
public static final String CONNECTOR_ID = "connector_id";
/** Model format field */
public static final String MODEL_FORMAT = "model_format";
/** Model config field */
public static final String MODEL_CONFIG = "model_config";
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/
package org.opensearch.flowframework.common;

import org.opensearch.common.xcontent.LoggingDeprecationHandler;
import org.opensearch.common.xcontent.json.JsonXContent;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;

import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.Map.Entry;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;

/**
* Utility methods for Template parsing
*/
public class TemplateUtil {

/**
* Converts a JSON string into an XContentParser
*
* @param json the json string
* @return The XContent parser for the json string
* @throws IOException on failure to create the parser
*/
public static XContentParser jsonToParser(String json) throws IOException {
XContentParser parser = JsonXContent.jsonXContent.createParser(
NamedXContentRegistry.EMPTY,
LoggingDeprecationHandler.INSTANCE,
json
);
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);
return parser;
}

/**
* Builds an XContent object representing a map of String keys to String values.
*
* @param xContentBuilder An XContent builder whose position is at the start of the map object to build
* @param map A map as key-value String pairs.
* @throws IOException on a build failure
*/
public static void buildStringToStringMap(XContentBuilder xContentBuilder, Map<?, ?> map) throws IOException {
xContentBuilder.startObject();
for (Entry<?, ?> e : map.entrySet()) {
xContentBuilder.field((String) e.getKey(), (String) e.getValue());
}
xContentBuilder.endObject();
}

/**
* Parses an XContent object representing a map of String keys to String values.
*
* @param parser An XContent parser whose position is at the start of the map object to parse
* @return A map as identified by the key-value pairs in the XContent
* @throws IOException on a parse failure
*/
public static Map<String, String> parseStringToStringMap(XContentParser parser) throws IOException {
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
Map<String, String> map = new HashMap<>();
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
String fieldName = parser.currentName();
parser.nextToken();
map.put(fieldName, parser.text());
}
return map;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ public class FlowFrameworkException extends RuntimeException {

private static final long serialVersionUID = 1L;

/** The rest status code of this exception */
private final RestStatus restStatus;

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
* An enumeration of Flow Framework indices
*/
public enum FlowFrameworkIndex {
/**
* Global Context Index
*/
GLOBAL_CONTEXT(
GLOBAL_CONTEXT_INDEX,
ThrowingSupplierWrapper.throwingSupplierWrapper(GlobalContextHandler::getGlobalContextMappings),
Expand All @@ -35,14 +38,26 @@ public enum FlowFrameworkIndex {
this.version = version;
}

/**
* Retrieves the index name
* @return the index name
*/
public String getIndexName() {
return indexName;
}

/**
* Retrieves the index mapping
* @return the index mapping
*/
public String getMapping() {
return mapping;
}

/**
* Retrieves the index version
* @return the index version
*/
public Integer getVersion() {
return version;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ public void putTemplateToGlobalContext(Template template, ActionListener<IndexRe
XContentBuilder builder = XContentFactory.jsonBuilder();
ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()
) {
request.source(template.toXContent(builder, ToXContent.EMPTY_PARAMS))
request.source(template.toDocumentSource(builder, ToXContent.EMPTY_PARAMS))
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
client.index(request, ActionListener.runBefore(listener, () -> context.restore()));
} catch (Exception e) {
Expand All @@ -94,6 +94,35 @@ public void putTemplateToGlobalContext(Template template, ActionListener<IndexRe
}));
}

/**
* Replaces a document in the global context index
* @param documentId the document Id
* @param template the use-case template
* @param listener action listener
*/
public void updateTemplateInGlobalContext(String documentId, Template template, ActionListener<IndexResponse> listener) {
if (!createIndexStep.doesIndexExist(GLOBAL_CONTEXT_INDEX)) {
String exceptionMessage = "Failed to update template for workflow_id : "
+ documentId
+ ", global_context index does not exist.";
logger.error(exceptionMessage);
listener.onFailure(new Exception(exceptionMessage));
} else {
IndexRequest request = new IndexRequest(GLOBAL_CONTEXT_INDEX).id(documentId);
try (
XContentBuilder builder = XContentFactory.jsonBuilder();
ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()
) {
request.source(template.toDocumentSource(builder, ToXContent.EMPTY_PARAMS))
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
client.index(request, ActionListener.runBefore(listener, () -> context.restore()));
} catch (Exception e) {
logger.error("Failed to update global_context entry : {}. {}", documentId, e.getMessage());
listener.onFailure(e);
}
}
}

/**
* Update global context index for specific fields
* @param documentId global context index document id
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import java.util.Map;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.flowframework.common.TemplateUtil.buildStringToStringMap;
import static org.opensearch.flowframework.common.TemplateUtil.parseStringToStringMap;

/**
* This represents a processor associated with search and ingest pipelines in the {@link Template}.
Expand Down Expand Up @@ -46,7 +48,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
XContentBuilder xContentBuilder = builder.startObject();
xContentBuilder.field(TYPE_FIELD, this.type);
xContentBuilder.field(PARAMS_FIELD);
Template.buildStringToStringMap(xContentBuilder, this.params);
buildStringToStringMap(xContentBuilder, this.params);
return xContentBuilder.endObject();
}

Expand All @@ -70,7 +72,7 @@ public static PipelineProcessor parse(XContentParser parser) throws IOException
type = parser.text();
break;
case PARAMS_FIELD:
params = Template.parseStringToStringMap(parser);
params = parseStringToStringMap(parser);
break;
default:
throw new IOException("Unable to parse field [" + fieldName + "] in a pipeline processor object.");
Expand Down
Loading
Loading