diff --git a/docs/user/interfaces/asyncqueryinterface.rst b/docs/user/interfaces/asyncqueryinterface.rst index 5c64499c6f..3fbc16d15f 100644 --- a/docs/user/interfaces/asyncqueryinterface.rst +++ b/docs/user/interfaces/asyncqueryinterface.rst @@ -90,12 +90,13 @@ User could reuse the session by using ``sessionId`` query parameters. Sample Request:: - curl --location 'http://localhost:9200/_plugins/_async_query?sessionId=1Giy65ZnzNlmsPAm' \ + curl --location 'http://localhost:9200/_plugins/_async_query' \ --header 'Content-Type: application/json' \ --data '{ "datasource" : "my_glue", "lang" : "sql", - "query" : "select * from my_glue.default.http_logs limit 10" + "query" : "select * from my_glue.default.http_logs limit 10", + "sessionId" : "1Giy65ZnzNlmsPAm" }' Sample Response:: diff --git a/spark/src/main/java/org/opensearch/sql/spark/rest/RestAsyncQueryManagementAction.java b/spark/src/main/java/org/opensearch/sql/spark/rest/RestAsyncQueryManagementAction.java index 1683ff7483..741501cd18 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/rest/RestAsyncQueryManagementAction.java +++ b/spark/src/main/java/org/opensearch/sql/spark/rest/RestAsyncQueryManagementAction.java @@ -43,8 +43,6 @@ public class RestAsyncQueryManagementAction extends BaseRestHandler { public static final String ASYNC_QUERY_ACTIONS = "async_query_actions"; public static final String BASE_ASYNC_QUERY_ACTION_URL = "/_plugins/_async_query"; - public static final String PARAMS_SESSION_ID = "sessionId"; - private static final Logger LOG = LogManager.getLogger(RestAsyncQueryManagementAction.class); @Override @@ -114,7 +112,6 @@ private RestChannelConsumer executePostRequest(RestRequest restRequest, NodeClie throws IOException { CreateAsyncQueryRequest submitJobRequest = CreateAsyncQueryRequest.fromXContentParser(restRequest.contentParser()); - submitJobRequest.setSessionId(restRequest.param(PARAMS_SESSION_ID, null)); return restChannel -> Scheduler.schedule( nodeClient, diff --git a/spark/src/main/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryRequest.java b/spark/src/main/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryRequest.java index ba2a0cf2ee..6acf6bc9a8 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryRequest.java +++ b/spark/src/main/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryRequest.java @@ -6,6 +6,7 @@ package org.opensearch.sql.spark.rest.model; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.sql.spark.execution.session.SessionModel.SESSION_ID; import java.io.IOException; import lombok.Data; @@ -27,11 +28,19 @@ public CreateAsyncQueryRequest(String query, String datasource, LangType lang) { this.lang = Validate.notNull(lang, "lang can't be null"); } + public CreateAsyncQueryRequest(String query, String datasource, LangType lang, String sessionId) { + this.query = Validate.notNull(query, "Query can't be null"); + this.datasource = Validate.notNull(datasource, "Datasource can't be null"); + this.lang = Validate.notNull(lang, "lang can't be null"); + this.sessionId = sessionId; + } + public static CreateAsyncQueryRequest fromXContentParser(XContentParser parser) throws IOException { String query = null; LangType lang = null; String datasource = null; + String sessionId = null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { String fieldName = parser.currentName(); @@ -43,10 +52,12 @@ public static CreateAsyncQueryRequest fromXContentParser(XContentParser parser) lang = LangType.fromString(langString); } else if (fieldName.equals("datasource")) { datasource = parser.textOrNull(); + } else if (fieldName.equals(SESSION_ID)) { + sessionId = parser.textOrNull(); } else { throw new IllegalArgumentException("Unknown field: " + fieldName); } } - return new CreateAsyncQueryRequest(query, datasource, lang); + return new CreateAsyncQueryRequest(query, datasource, lang, sessionId); } } diff --git a/spark/src/test/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryRequestTest.java b/spark/src/test/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryRequestTest.java new file mode 100644 index 0000000000..dd634d6055 --- /dev/null +++ b/spark/src/test/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryRequestTest.java @@ -0,0 +1,52 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.rest.model; + +import java.io.IOException; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; + +public class CreateAsyncQueryRequestTest { + + @Test + public void fromXContent() throws IOException { + String request = + "{\n" + + " \"datasource\": \"my_glue\",\n" + + " \"lang\": \"sql\",\n" + + " \"query\": \"select 1\"\n" + + "}"; + CreateAsyncQueryRequest queryRequest = + CreateAsyncQueryRequest.fromXContentParser(xContentParser(request)); + Assertions.assertEquals("my_glue", queryRequest.getDatasource()); + Assertions.assertEquals(LangType.SQL, queryRequest.getLang()); + Assertions.assertEquals("select 1", queryRequest.getQuery()); + } + + @Test + public void fromXContentWithSessionId() throws IOException { + String request = + "{\n" + + " \"datasource\": \"my_glue\",\n" + + " \"lang\": \"sql\",\n" + + " \"query\": \"select 1\",\n" + + " \"sessionId\": \"00fdjevgkf12s00q\"\n" + + "}"; + CreateAsyncQueryRequest queryRequest = + CreateAsyncQueryRequest.fromXContentParser(xContentParser(request)); + Assertions.assertEquals("00fdjevgkf12s00q", queryRequest.getSessionId()); + } + + private XContentParser xContentParser(String request) throws IOException { + return XContentType.JSON + .xContent() + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, request); + } +} diff --git a/spark/src/test/java/org/opensearch/sql/spark/transport/TransportCreateAsyncQueryRequestActionTest.java b/spark/src/test/java/org/opensearch/sql/spark/transport/TransportCreateAsyncQueryRequestActionTest.java index 2e48597593..36060d3850 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/transport/TransportCreateAsyncQueryRequestActionTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/transport/TransportCreateAsyncQueryRequestActionTest.java @@ -74,7 +74,8 @@ public void testDoExecute() { @Test public void testDoExecuteWithSessionId() { CreateAsyncQueryRequest createAsyncQueryRequest = - new CreateAsyncQueryRequest("source = my_glue.default.alb_logs", "my_glue", LangType.SQL); + new CreateAsyncQueryRequest( + "source = my_glue.default.alb_logs", "my_glue", LangType.SQL, MOCK_SESSION_ID); CreateAsyncQueryActionRequest request = new CreateAsyncQueryActionRequest(createAsyncQueryRequest); when(jobExecutorService.createAsyncQuery(createAsyncQueryRequest))