diff --git a/src/main/java/org/aksw/iguana/cc/query/QueryData.java b/src/main/java/org/aksw/iguana/cc/query/QueryData.java new file mode 100644 index 00000000..02858cb8 --- /dev/null +++ b/src/main/java/org/aksw/iguana/cc/query/QueryData.java @@ -0,0 +1,32 @@ +package org.aksw.iguana.cc.query; + +import org.apache.jena.update.UpdateFactory; + +import java.io.InputStream; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; + +/** + * This class stores extra information about a query. + * At the moment, it only stores if the query is an update query or not. + * + * @param queryId The id of the query + * @param update If the query is an update query + */ +public record QueryData(int queryId, boolean update) { + public static List generate(Collection queries) { + final var queryData = new ArrayList(); + int i = 0; + for (InputStream query : queries) { + boolean update = true; + try { + UpdateFactory.read(query); // Throws an exception if the query is not an update query + } catch (Exception e) { + update = false; + } + queryData.add(new QueryData(i++, update)); + } + return queryData; + } +} diff --git a/src/main/java/org/aksw/iguana/cc/query/handler/QueryHandler.java b/src/main/java/org/aksw/iguana/cc/query/handler/QueryHandler.java index 6930d3f1..be30268f 100644 --- a/src/main/java/org/aksw/iguana/cc/query/handler/QueryHandler.java +++ b/src/main/java/org/aksw/iguana/cc/query/handler/QueryHandler.java @@ -5,6 +5,7 @@ import com.fasterxml.jackson.databind.DeserializationContext; import com.fasterxml.jackson.databind.annotation.JsonDeserialize; import com.fasterxml.jackson.databind.deser.std.StdDeserializer; +import org.aksw.iguana.cc.query.QueryData; import org.aksw.iguana.cc.query.list.impl.StringListQueryList; import org.aksw.iguana.cc.query.selector.QuerySelector; import org.aksw.iguana.cc.query.selector.impl.LinearQuerySelector; @@ -145,8 +146,9 @@ public Template(URI endpoint, Long limit, Boolean save) { } } - public record QueryStringWrapper(int index, String query) {} - public record QueryStreamWrapper(int index, boolean cached, Supplier queryInputStreamSupplier) {} + public record QueryStringWrapper(int index, String query, boolean update) {} + + public record QueryStreamWrapper(int index, boolean cached, Supplier queryInputStreamSupplier, boolean update) {} protected static final Logger LOGGER = LoggerFactory.getLogger(QueryHandler.class); @@ -155,6 +157,7 @@ public record QueryStreamWrapper(int index, boolean cached, Supplier queryData; private int workerCount = 0; // give every worker inside the same worker config an offset seed @@ -168,6 +171,7 @@ protected QueryHandler() { config = null; queryList = null; hashCode = 0; + queryData = null; } @JsonCreator @@ -184,6 +188,13 @@ public QueryHandler(Config config) throws IOException { new FileReadingQueryList(querySource); } this.hashCode = queryList.hashCode(); + this.queryData = QueryData.generate(IntStream.range(0, queryList.size()).mapToObj(i -> { + try { + return queryList.getQueryStream(i); + } catch (IOException e) { + throw new RuntimeException("Couldn't read query stream", e); + } + }).collect(Collectors.toList())); } private QueryList initializeTemplateQueryHandler(QuerySource templateSource) throws IOException { @@ -247,7 +258,7 @@ public QuerySelector getQuerySelectorInstance() { public QueryStringWrapper getNextQuery(QuerySelector querySelector) throws IOException { final var queryIndex = querySelector.getNextIndex(); - return new QueryStringWrapper(queryIndex, queryList.getQuery(queryIndex)); + return new QueryStringWrapper(queryIndex, queryList.getQuery(queryIndex), queryData.get(queryIndex).update()); } public QueryStreamWrapper getNextQueryStream(QuerySelector querySelector) { @@ -258,7 +269,7 @@ public QueryStreamWrapper getNextQueryStream(QuerySelector querySelector) { } catch (IOException e) { throw new RuntimeException(e); } - }); + }, queryData.get(queryIndex).update()); } @Override diff --git a/src/main/java/org/aksw/iguana/cc/query/list/QueryList.java b/src/main/java/org/aksw/iguana/cc/query/list/QueryList.java index 623a8c67..3f9f2a78 100644 --- a/src/main/java/org/aksw/iguana/cc/query/list/QueryList.java +++ b/src/main/java/org/aksw/iguana/cc/query/list/QueryList.java @@ -1,7 +1,5 @@ package org.aksw.iguana.cc.query.list; -import org.aksw.iguana.cc.query.source.QuerySource; - import java.io.IOException; import java.io.InputStream; diff --git a/src/main/java/org/aksw/iguana/cc/utils/http/RequestFactory.java b/src/main/java/org/aksw/iguana/cc/utils/http/RequestFactory.java index e29fc533..e0853166 100644 --- a/src/main/java/org/aksw/iguana/cc/utils/http/RequestFactory.java +++ b/src/main/java/org/aksw/iguana/cc/utils/http/RequestFactory.java @@ -17,6 +17,7 @@ import java.io.IOException; import java.io.InputStream; +import java.net.URI; import java.net.URISyntaxException; import java.net.URLEncoder; import java.nio.charset.StandardCharsets; @@ -113,9 +114,18 @@ public AsyncRequestProducer buildHttpRequest(QueryHandler.QueryStreamWrapper que throw new IOException(e); } + // check if the query is an update query, if yes, change the request type to similar update request type + RequestType actualRequestType = requestType; + if (requestType == RequestType.GET_QUERY || requestType == RequestType.POST_QUERY) + actualRequestType = queryHandle.update() ? RequestType.POST_UPDATE : requestType; + if (requestType == RequestType.POST_URL_ENC_QUERY) + actualRequestType = queryHandle.update() ? RequestType.POST_URL_ENC_UPDATE : requestType; + // if only one endpoint is set, use it for both queries and updates + URI updateEndpoint = connectionConfig.updateEndpoint() != null ? connectionConfig.updateEndpoint() : connectionConfig.endpoint(); + // If the query is bigger than 2^31 bytes (2GB) and the request type is set to GET_QUERY, POST_URL_ENC_QUERY or // POST_URL_ENC_UPDATE, the following code will throw an exception. - switch (requestType) { + switch (actualRequestType) { case GET_QUERY -> asyncRequestBuilder = AsyncRequestBuilder.get(new URIBuilder(connectionConfig.endpoint()) .addParameter("query", new String(queryStream.readAllBytes(), StandardCharsets.UTF_8)) .build() @@ -127,10 +137,10 @@ public AsyncRequestProducer buildHttpRequest(QueryHandler.QueryStreamWrapper que .setEntity(new BasicAsyncEntityProducer(urlEncode("query", new String(queryStream.readAllBytes(), StandardCharsets.UTF_8)), null, false)); case POST_QUERY -> asyncRequestBuilder = AsyncRequestBuilder.post(connectionConfig.endpoint()) .setEntity(new StreamEntityProducer(queryStreamSupplier, !caching, "application/sparql-query")); - case POST_URL_ENC_UPDATE -> asyncRequestBuilder = AsyncRequestBuilder.post(connectionConfig.endpoint()) + case POST_URL_ENC_UPDATE -> asyncRequestBuilder = AsyncRequestBuilder.post(updateEndpoint) .setHeader(HttpHeaders.CONTENT_TYPE, "application/x-www-form-urlencoded") .setEntity(new BasicAsyncEntityProducer(urlEncode("update", new String(queryStream.readAllBytes(), StandardCharsets.UTF_8)), null, false)); - case POST_UPDATE -> asyncRequestBuilder = AsyncRequestBuilder.post(connectionConfig.endpoint()) + case POST_UPDATE -> asyncRequestBuilder = AsyncRequestBuilder.post(updateEndpoint) .setEntity(new StreamEntityProducer(queryStreamSupplier, !caching, "application/sparql-update")); default -> throw new IllegalStateException("Unexpected value: " + requestType); } @@ -138,10 +148,15 @@ public AsyncRequestProducer buildHttpRequest(QueryHandler.QueryStreamWrapper que // set additional headers if (acceptHeader != null) asyncRequestBuilder.addHeader("Accept", acceptHeader); - if (connectionConfig.authentication() != null && connectionConfig.authentication().user() != null) + if (queryHandle.update() && connectionConfig.updateAuthentication() != null && connectionConfig.updateAuthentication().user() != null) { asyncRequestBuilder.addHeader("Authorization", - HttpWorker.basicAuth(connectionConfig.authentication().user(), - Optional.ofNullable(connectionConfig.authentication().password()).orElse(""))); + HttpWorker.basicAuth(connectionConfig.updateAuthentication().user(), + Optional.ofNullable(connectionConfig.updateAuthentication().password()).orElse(""))); + } else if (connectionConfig.authentication() != null && connectionConfig.authentication().user() != null) { + asyncRequestBuilder.addHeader("Authorization", + HttpWorker.basicAuth(connectionConfig.authentication().user(), + Optional.ofNullable(connectionConfig.authentication().password()).orElse(""))); + } // cache request if (caching) diff --git a/src/test/java/org/aksw/iguana/cc/query/QueryDataTest.java b/src/test/java/org/aksw/iguana/cc/query/QueryDataTest.java new file mode 100644 index 00000000..63b56d24 --- /dev/null +++ b/src/test/java/org/aksw/iguana/cc/query/QueryDataTest.java @@ -0,0 +1,72 @@ +package org.aksw.iguana.cc.query; + +import org.aksw.iguana.cc.query.source.QuerySource; +import org.aksw.iguana.cc.query.source.impl.FileSeparatorQuerySource; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.List; + +import static org.junit.jupiter.api.Assertions.*; + +class QueryDataTest { + + private static Path tempFile = null; + + @BeforeAll + public static void setup() throws IOException { + tempFile = Files.createTempFile("test", "txt"); + Files.writeString(tempFile, """ + SELECT ?s ?p ?o WHERE { + ?s ?p ?o + } + + INSERT DATA { + + } + + DELETE DATA { + + } + + SELECT ?s ?p ?o WHERE { + ?s ?p ?o + } + + INSERT DATA { + + }; INSERT DATA { + + } + """); + } + + @AfterAll + public static void teardown() throws IOException { + Files.deleteIfExists(tempFile); + } + + @Test + void testGeneration() throws IOException { + final QuerySource querySource = new FileSeparatorQuerySource(tempFile, ""); + final var testStrings = querySource.getAllQueries(); + + List> generations = List.of( + QueryData.generate(testStrings.stream().map(s -> (InputStream) new ByteArrayInputStream(s.getBytes())).toList()) + ); + for (List generation : generations) { + assertEquals(5, generation.size()); + assertFalse(generation.get(0).update()); + assertTrue(generation.get(1).update()); + assertTrue(generation.get(2).update()); + assertFalse(generation.get(3).update()); + assertTrue(generation.get(4).update()); + } + } +} \ No newline at end of file diff --git a/src/test/java/org/aksw/iguana/cc/worker/impl/SPARQLProtocolWorkerTest.java b/src/test/java/org/aksw/iguana/cc/worker/impl/SPARQLProtocolWorkerTest.java index b7d4daf7..6d9842fa 100644 --- a/src/test/java/org/aksw/iguana/cc/worker/impl/SPARQLProtocolWorkerTest.java +++ b/src/test/java/org/aksw/iguana/cc/worker/impl/SPARQLProtocolWorkerTest.java @@ -57,15 +57,19 @@ public class SPARQLProtocolWorkerTest { .build(); private final static String QUERY = "SELECT * WHERE { ?s ?p ?o }"; + private final static String UPDATE_QUERY = "INSERT DATA { }"; private final static int QUERY_MIXES = 1; private static Path queryFile; + private static Path updateFile; private static final Logger LOGGER = LoggerFactory.getLogger(SPARQLProtocolWorker.class); @BeforeAll public static void setup() throws IOException { queryFile = Files.createTempFile("iguana-test-queries", ".tmp"); + updateFile = Files.createTempFile("iguana-test-updates", ".tmp"); Files.writeString(queryFile, QUERY, StandardCharsets.UTF_8); + Files.writeString(updateFile, QUERY + "\n\n" + UPDATE_QUERY, StandardCharsets.UTF_8); } @BeforeEach @@ -77,6 +81,7 @@ public void reset() { @AfterAll public static void cleanup() throws IOException { Files.deleteIfExists(queryFile); + Files.deleteIfExists(updateFile); SPARQLProtocolWorker.closeHttpClient(); } @@ -120,6 +125,31 @@ public static Stream requestFactoryData() throws URISyntaxException { return workers.stream(); } + public static Stream updateWorkerData() throws IOException { + final var normalEndpoint = URI.create("http://localhost:" + wm.getPort() + "/ds/query"); + final var updateEndpoint = URI.create("http://localhost:" + wm.getPort() + "/ds/update"); + final var processor = new ResponseBodyProcessor("application/sparql-results+json"); + final var format = QueryHandler.Config.Format.SEPARATOR; + final var queryHandler = new QueryHandler(new QueryHandler.Config(updateFile.toAbsolutePath().toString(), format, null, true, QueryHandler.Config.Order.LINEAR, 0L, QueryHandler.Config.Language.SPARQL)); + final var datasetConfig = new DatasetConfig("TestDS", null); + final var connection = new ConnectionConfig("TestConn", "1", datasetConfig, normalEndpoint, new ConnectionConfig.Authentication("testUser", "password"), updateEndpoint, new ConnectionConfig.Authentication("updateUser", "password")); + final var workers = new ArrayDeque(); + for (var requestType : List.of(RequestFactory.RequestType.GET_QUERY, RequestFactory.RequestType.POST_URL_ENC_QUERY, RequestFactory.RequestType.POST_QUERY)) { + final var config = new SPARQLProtocolWorker.Config( + 1, + queryHandler, + new HttpWorker.QueryMixes(QUERY_MIXES), + connection, + Duration.parse("PT6S"), + "application/sparql-results+json", + requestType, + true + ); + workers.add(Arguments.of(Named.of(requestType.name(), new SPARQLProtocolWorker(0, processor, config)))); + } + return workers.stream(); + } + public static List completionTargets() { final var out = new ArrayList(); final var queryMixesAmount = List.of(1, 2, 5, 10, 100, 200); @@ -204,10 +234,63 @@ public void testRequestFactory(SPARQLProtocolWorker worker, boolean cached) { assertNotEquals(Duration.ZERO, result.executionStats().get(0).duration(), "Worker returned zero duration"); } + @ParameterizedTest + @MethodSource("updateWorkerData") + public void testSeparateUpdateEndpoint(SPARQLProtocolWorker worker) { + final var workerConfig = worker.config(); + switch (workerConfig.requestType()) { + case GET_QUERY -> { + wm.stubFor(get(urlPathEqualTo("/ds/query")) + .withQueryParam("query", equalTo(QUERY)) + .withBasicAuth("testUser", "password") + .willReturn(aResponse().withStatus(200).withBody("Non-Empty-Body"))); + wm.stubFor(post(urlPathEqualTo("/ds/update")) + .withHeader("Content-Type", equalTo("application/sparql-update")) + .withBasicAuth("updateUser", "password") + .withRequestBody(equalTo(UPDATE_QUERY)) + .willReturn(aResponse().withStatus(200).withBody("Non-Empty-Body"))); + } + case POST_URL_ENC_QUERY -> { + wm.stubFor(post(urlPathEqualTo("/ds/query")) + .withHeader("Content-Type", equalTo("application/x-www-form-urlencoded")) + .withBasicAuth("testUser", "password") + .withRequestBody(equalTo("query=" + URLEncoder.encode(QUERY, StandardCharsets.UTF_8))) + .willReturn(aResponse().withStatus(200).withBody("Non-Empty-Body"))); + wm.stubFor(post(urlPathEqualTo("/ds/update")) + .withHeader("Content-Type", equalTo("application/x-www-form-urlencoded")) + .withBasicAuth("updateUser", "password") + .withRequestBody(equalTo("update=" + URLEncoder.encode(UPDATE_QUERY, StandardCharsets.UTF_8))) + .willReturn(aResponse().withStatus(200).withBody("Non-Empty-Body"))); + } + case POST_QUERY -> { + wm.stubFor(post(urlPathEqualTo("/ds/query")) + .withHeader("Content-Type", equalTo("application/sparql-query")) + .withBasicAuth("testUser", "password") + .withRequestBody(equalTo(QUERY)) + .willReturn(aResponse().withStatus(200).withBody("Non-Empty-Body"))); + wm.stubFor(post(urlPathEqualTo("/ds/update")) + .withHeader("Content-Type", equalTo("application/sparql-update")) + .withBasicAuth("updateUser", "password") + .withRequestBody(equalTo(UPDATE_QUERY)) + .willReturn(aResponse().withStatus(200).withBody("Non-Empty-Body"))); + } + } + final HttpWorker.Result result = worker.start().join(); + assertEquals(result.executionStats().size(), QUERY_MIXES * 2, "Worker should have executed only 1 query"); + for (var res : result.executionStats()) { + assertNull(res.error().orElse(null), "Worker threw an exception, during execution"); + assertEquals(200, res.httpStatusCode().get(), "Worker returned wrong status code"); + assertNotEquals(0, res.responseBodyHash().getAsLong(), "Worker didn't return a response body hash"); + assertEquals("Non-Empty-Body".getBytes(StandardCharsets.UTF_8).length, res.contentLength().getAsLong(), "Worker returned wrong content length"); + assertNotEquals(Duration.ZERO, res.duration(), "Worker returned zero duration"); + } + + } + @DisplayName("Test Malformed Response Processing") @ParameterizedTest(name = "[{index}] fault = {0}") @EnumSource(Fault.class) - public void testMalformedResponseProcessing(Fault fault) throws IOException, URISyntaxException { + public void testMalformedResponseProcessing(Fault fault) throws URISyntaxException { SPARQLProtocolWorker worker = (SPARQLProtocolWorker) ((Named)requestFactoryData().toList().get(0).get()[0]).getPayload(); wm.stubFor(get(urlPathEqualTo("/ds/query")) .willReturn(aResponse().withFault(fault))); @@ -217,7 +300,7 @@ public void testMalformedResponseProcessing(Fault fault) throws IOException, URI } @Test - public void testBadHttpCodeResponse() throws IOException, URISyntaxException { + public void testBadHttpCodeResponse() throws URISyntaxException { SPARQLProtocolWorker worker = (SPARQLProtocolWorker) ((Named)requestFactoryData().toList().get(0).get()[0]).getPayload(); wm.stubFor(get(urlPathEqualTo("/ds/query")) .willReturn(aResponse().withStatus(404)));