Skip to content

Commit

Permalink
Send Update Queries to seperate endpoint (#281)
Browse files Browse the repository at this point in the history
* Add QueryData class

* Check for update queries

* Move responsibility of QueryData to QueryHandler

* Remove unused methods

* Add tests

* Fix authentication

* Cleanup

* Remove unused import statements

* Add chained update request as test case for Update queries distinction

Closes #229
  • Loading branch information
nck-mlcnv authored Oct 11, 2024
1 parent c01273b commit 38cae67
Show file tree
Hide file tree
Showing 6 changed files with 225 additions and 14 deletions.
32 changes: 32 additions & 0 deletions src/main/java/org/aksw/iguana/cc/query/QueryData.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package org.aksw.iguana.cc.query;

import org.apache.jena.update.UpdateFactory;

import java.io.InputStream;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;

/**
* This class stores extra information about a query.
* At the moment, it only stores if the query is an update query or not.
*
* @param queryId The id of the query
* @param update If the query is an update query
*/
public record QueryData(int queryId, boolean update) {
public static List<QueryData> generate(Collection<InputStream> queries) {
final var queryData = new ArrayList<QueryData>();
int i = 0;
for (InputStream query : queries) {
boolean update = true;
try {
UpdateFactory.read(query); // Throws an exception if the query is not an update query
} catch (Exception e) {
update = false;
}
queryData.add(new QueryData(i++, update));
}
return queryData;
}
}
19 changes: 15 additions & 4 deletions src/main/java/org/aksw/iguana/cc/query/handler/QueryHandler.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import com.fasterxml.jackson.databind.DeserializationContext;
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import com.fasterxml.jackson.databind.deser.std.StdDeserializer;
import org.aksw.iguana.cc.query.QueryData;
import org.aksw.iguana.cc.query.list.impl.StringListQueryList;
import org.aksw.iguana.cc.query.selector.QuerySelector;
import org.aksw.iguana.cc.query.selector.impl.LinearQuerySelector;
Expand Down Expand Up @@ -149,8 +150,9 @@ public Template(URI endpoint, Long limit, Boolean save) {
}
}

public record QueryStringWrapper(int index, String query) {}
public record QueryStreamWrapper(int index, boolean cached, Supplier<InputStream> queryInputStreamSupplier) {}
public record QueryStringWrapper(int index, String query, boolean update) {}

public record QueryStreamWrapper(int index, boolean cached, Supplier<InputStream> queryInputStreamSupplier, boolean update) {}


protected static final Logger LOGGER = LoggerFactory.getLogger(QueryHandler.class);
Expand All @@ -159,6 +161,7 @@ public record QueryStreamWrapper(int index, boolean cached, Supplier<InputStream
final protected Config config;

final protected QueryList queryList;
final protected List<QueryData> queryData;

private int workerCount = 0; // give every worker inside the same worker config an offset seed

Expand All @@ -172,6 +175,7 @@ protected QueryHandler() {
config = null;
queryList = null;
hashCode = 0;
queryData = null;
}

@JsonCreator
Expand All @@ -188,6 +192,13 @@ public QueryHandler(Config config) throws IOException {
new FileReadingQueryList(querySource);
}
this.hashCode = queryList.hashCode();
this.queryData = QueryData.generate(IntStream.range(0, queryList.size()).mapToObj(i -> {
try {
return queryList.getQueryStream(i);
} catch (IOException e) {
throw new RuntimeException("Couldn't read query stream", e);
}
}).collect(Collectors.toList()));
}

private QueryList initializeTemplateQueryHandler(QuerySource templateSource) throws IOException {
Expand Down Expand Up @@ -251,7 +262,7 @@ public QuerySelector getQuerySelectorInstance() {

public QueryStringWrapper getNextQuery(QuerySelector querySelector) throws IOException {
final var queryIndex = querySelector.getNextIndex();
return new QueryStringWrapper(queryIndex, queryList.getQuery(queryIndex));
return new QueryStringWrapper(queryIndex, queryList.getQuery(queryIndex), queryData.get(queryIndex).update());
}

public QueryStreamWrapper getNextQueryStream(QuerySelector querySelector) {
Expand All @@ -262,7 +273,7 @@ public QueryStreamWrapper getNextQueryStream(QuerySelector querySelector) {
} catch (IOException e) {
throw new RuntimeException(e);
}
});
}, queryData.get(queryIndex).update());
}

@Override
Expand Down
2 changes: 0 additions & 2 deletions src/main/java/org/aksw/iguana/cc/query/list/QueryList.java
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
package org.aksw.iguana.cc.query.list;

import org.aksw.iguana.cc.query.source.QuerySource;

import java.io.IOException;
import java.io.InputStream;

Expand Down
27 changes: 21 additions & 6 deletions src/main/java/org/aksw/iguana/cc/utils/http/RequestFactory.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import java.io.IOException;
import java.io.InputStream;
import java.net.URI;
import java.net.URISyntaxException;
import java.net.URLEncoder;
import java.nio.charset.StandardCharsets;
Expand Down Expand Up @@ -113,9 +114,18 @@ public AsyncRequestProducer buildHttpRequest(QueryHandler.QueryStreamWrapper que
throw new IOException(e);
}

// check if the query is an update query, if yes, change the request type to similar update request type
RequestType actualRequestType = requestType;
if (requestType == RequestType.GET_QUERY || requestType == RequestType.POST_QUERY)
actualRequestType = queryHandle.update() ? RequestType.POST_UPDATE : requestType;
if (requestType == RequestType.POST_URL_ENC_QUERY)
actualRequestType = queryHandle.update() ? RequestType.POST_URL_ENC_UPDATE : requestType;
// if only one endpoint is set, use it for both queries and updates
URI updateEndpoint = connectionConfig.updateEndpoint() != null ? connectionConfig.updateEndpoint() : connectionConfig.endpoint();

// If the query is bigger than 2^31 bytes (2GB) and the request type is set to GET_QUERY, POST_URL_ENC_QUERY or
// POST_URL_ENC_UPDATE, the following code will throw an exception.
switch (requestType) {
switch (actualRequestType) {
case GET_QUERY -> asyncRequestBuilder = AsyncRequestBuilder.get(new URIBuilder(connectionConfig.endpoint())
.addParameter("query", new String(queryStream.readAllBytes(), StandardCharsets.UTF_8))
.build()
Expand All @@ -127,21 +137,26 @@ public AsyncRequestProducer buildHttpRequest(QueryHandler.QueryStreamWrapper que
.setEntity(new BasicAsyncEntityProducer(urlEncode("query", new String(queryStream.readAllBytes(), StandardCharsets.UTF_8)), null, false));
case POST_QUERY -> asyncRequestBuilder = AsyncRequestBuilder.post(connectionConfig.endpoint())
.setEntity(new StreamEntityProducer(queryStreamSupplier, !caching, "application/sparql-query"));
case POST_URL_ENC_UPDATE -> asyncRequestBuilder = AsyncRequestBuilder.post(connectionConfig.endpoint())
case POST_URL_ENC_UPDATE -> asyncRequestBuilder = AsyncRequestBuilder.post(updateEndpoint)
.setHeader(HttpHeaders.CONTENT_TYPE, "application/x-www-form-urlencoded")
.setEntity(new BasicAsyncEntityProducer(urlEncode("update", new String(queryStream.readAllBytes(), StandardCharsets.UTF_8)), null, false));
case POST_UPDATE -> asyncRequestBuilder = AsyncRequestBuilder.post(connectionConfig.endpoint())
case POST_UPDATE -> asyncRequestBuilder = AsyncRequestBuilder.post(updateEndpoint)
.setEntity(new StreamEntityProducer(queryStreamSupplier, !caching, "application/sparql-update"));
default -> throw new IllegalStateException("Unexpected value: " + requestType);
}

// set additional headers
if (acceptHeader != null)
asyncRequestBuilder.addHeader("Accept", acceptHeader);
if (connectionConfig.authentication() != null && connectionConfig.authentication().user() != null)
if (queryHandle.update() && connectionConfig.updateAuthentication() != null && connectionConfig.updateAuthentication().user() != null) {
asyncRequestBuilder.addHeader("Authorization",
HttpWorker.basicAuth(connectionConfig.authentication().user(),
Optional.ofNullable(connectionConfig.authentication().password()).orElse("")));
HttpWorker.basicAuth(connectionConfig.updateAuthentication().user(),
Optional.ofNullable(connectionConfig.updateAuthentication().password()).orElse("")));
} else if (connectionConfig.authentication() != null && connectionConfig.authentication().user() != null) {
asyncRequestBuilder.addHeader("Authorization",
HttpWorker.basicAuth(connectionConfig.authentication().user(),
Optional.ofNullable(connectionConfig.authentication().password()).orElse("")));
}

// cache request
if (caching)
Expand Down
72 changes: 72 additions & 0 deletions src/test/java/org/aksw/iguana/cc/query/QueryDataTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
package org.aksw.iguana.cc.query;

import org.aksw.iguana.cc.query.source.QuerySource;
import org.aksw.iguana.cc.query.source.impl.FileSeparatorQuerySource;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.List;

import static org.junit.jupiter.api.Assertions.*;

class QueryDataTest {

private static Path tempFile = null;

@BeforeAll
public static void setup() throws IOException {
tempFile = Files.createTempFile("test", "txt");
Files.writeString(tempFile, """
SELECT ?s ?p ?o WHERE {
?s ?p ?o
}
INSERT DATA {
<http://example.org/s> <http://example.org/p> <http://example.org/o>
}
DELETE DATA {
<http://example.org/s> <http://example.org/p> <http://example.org/o>
}
SELECT ?s ?p ?o WHERE {
?s ?p ?o
}
INSERT DATA {
<http://example.org/s> <http://example.org/p> <http://example.org/o>
}; INSERT DATA {
<http://example.org/s> <http://example.org/p> <http://example.org/o>
}
""");
}

@AfterAll
public static void teardown() throws IOException {
Files.deleteIfExists(tempFile);
}

@Test
void testGeneration() throws IOException {
final QuerySource querySource = new FileSeparatorQuerySource(tempFile, "");
final var testStrings = querySource.getAllQueries();

List<List<QueryData>> generations = List.of(
QueryData.generate(testStrings.stream().map(s -> (InputStream) new ByteArrayInputStream(s.getBytes())).toList())
);
for (List<QueryData> generation : generations) {
assertEquals(5, generation.size());
assertFalse(generation.get(0).update());
assertTrue(generation.get(1).update());
assertTrue(generation.get(2).update());
assertFalse(generation.get(3).update());
assertTrue(generation.get(4).update());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -57,15 +57,19 @@ public class SPARQLProtocolWorkerTest {
.build();

private final static String QUERY = "SELECT * WHERE { ?s ?p ?o }";
private final static String UPDATE_QUERY = "INSERT DATA { <http://example.org/s> <http://example.org/p> <http://example.org/o> }";
private final static int QUERY_MIXES = 1;
private static Path queryFile;
private static Path updateFile;

private static final Logger LOGGER = LoggerFactory.getLogger(SPARQLProtocolWorker.class);

@BeforeAll
public static void setup() throws IOException {
queryFile = Files.createTempFile("iguana-test-queries", ".tmp");
updateFile = Files.createTempFile("iguana-test-updates", ".tmp");
Files.writeString(queryFile, QUERY, StandardCharsets.UTF_8);
Files.writeString(updateFile, QUERY + "\n\n" + UPDATE_QUERY, StandardCharsets.UTF_8);
}

@BeforeEach
Expand All @@ -77,6 +81,7 @@ public void reset() {
@AfterAll
public static void cleanup() throws IOException {
Files.deleteIfExists(queryFile);
Files.deleteIfExists(updateFile);
SPARQLProtocolWorker.closeHttpClient();
}

Expand Down Expand Up @@ -120,6 +125,31 @@ public static Stream<Arguments> requestFactoryData() throws URISyntaxException {
return workers.stream();
}

public static Stream<Arguments> updateWorkerData() throws IOException {
final var normalEndpoint = URI.create("http://localhost:" + wm.getPort() + "/ds/query");
final var updateEndpoint = URI.create("http://localhost:" + wm.getPort() + "/ds/update");
final var processor = new ResponseBodyProcessor("application/sparql-results+json");
final var format = QueryHandler.Config.Format.SEPARATOR;
final var queryHandler = new QueryHandler(new QueryHandler.Config(updateFile.toAbsolutePath().toString(), format, null, true, QueryHandler.Config.Order.LINEAR, 0L, QueryHandler.Config.Language.SPARQL));
final var datasetConfig = new DatasetConfig("TestDS", null);
final var connection = new ConnectionConfig("TestConn", "1", datasetConfig, normalEndpoint, new ConnectionConfig.Authentication("testUser", "password"), updateEndpoint, new ConnectionConfig.Authentication("updateUser", "password"));
final var workers = new ArrayDeque<Arguments>();
for (var requestType : List.of(RequestFactory.RequestType.GET_QUERY, RequestFactory.RequestType.POST_URL_ENC_QUERY, RequestFactory.RequestType.POST_QUERY)) {
final var config = new SPARQLProtocolWorker.Config(
1,
queryHandler,
new HttpWorker.QueryMixes(QUERY_MIXES),
connection,
Duration.parse("PT6S"),
"application/sparql-results+json",
requestType,
true
);
workers.add(Arguments.of(Named.of(requestType.name(), new SPARQLProtocolWorker(0, processor, config))));
}
return workers.stream();
}

public static List<Arguments> completionTargets() {
final var out = new ArrayList<Arguments>();
final var queryMixesAmount = List.of(1, 2, 5, 10, 100, 200);
Expand Down Expand Up @@ -204,10 +234,63 @@ public void testRequestFactory(SPARQLProtocolWorker worker, boolean cached) {
assertNotEquals(Duration.ZERO, result.executionStats().get(0).duration(), "Worker returned zero duration");
}

@ParameterizedTest
@MethodSource("updateWorkerData")
public void testSeparateUpdateEndpoint(SPARQLProtocolWorker worker) {
final var workerConfig = worker.config();
switch (workerConfig.requestType()) {
case GET_QUERY -> {
wm.stubFor(get(urlPathEqualTo("/ds/query"))
.withQueryParam("query", equalTo(QUERY))
.withBasicAuth("testUser", "password")
.willReturn(aResponse().withStatus(200).withBody("Non-Empty-Body")));
wm.stubFor(post(urlPathEqualTo("/ds/update"))
.withHeader("Content-Type", equalTo("application/sparql-update"))
.withBasicAuth("updateUser", "password")
.withRequestBody(equalTo(UPDATE_QUERY))
.willReturn(aResponse().withStatus(200).withBody("Non-Empty-Body")));
}
case POST_URL_ENC_QUERY -> {
wm.stubFor(post(urlPathEqualTo("/ds/query"))
.withHeader("Content-Type", equalTo("application/x-www-form-urlencoded"))
.withBasicAuth("testUser", "password")
.withRequestBody(equalTo("query=" + URLEncoder.encode(QUERY, StandardCharsets.UTF_8)))
.willReturn(aResponse().withStatus(200).withBody("Non-Empty-Body")));
wm.stubFor(post(urlPathEqualTo("/ds/update"))
.withHeader("Content-Type", equalTo("application/x-www-form-urlencoded"))
.withBasicAuth("updateUser", "password")
.withRequestBody(equalTo("update=" + URLEncoder.encode(UPDATE_QUERY, StandardCharsets.UTF_8)))
.willReturn(aResponse().withStatus(200).withBody("Non-Empty-Body")));
}
case POST_QUERY -> {
wm.stubFor(post(urlPathEqualTo("/ds/query"))
.withHeader("Content-Type", equalTo("application/sparql-query"))
.withBasicAuth("testUser", "password")
.withRequestBody(equalTo(QUERY))
.willReturn(aResponse().withStatus(200).withBody("Non-Empty-Body")));
wm.stubFor(post(urlPathEqualTo("/ds/update"))
.withHeader("Content-Type", equalTo("application/sparql-update"))
.withBasicAuth("updateUser", "password")
.withRequestBody(equalTo(UPDATE_QUERY))
.willReturn(aResponse().withStatus(200).withBody("Non-Empty-Body")));
}
}
final HttpWorker.Result result = worker.start().join();
assertEquals(result.executionStats().size(), QUERY_MIXES * 2, "Worker should have executed only 1 query");
for (var res : result.executionStats()) {
assertNull(res.error().orElse(null), "Worker threw an exception, during execution");
assertEquals(200, res.httpStatusCode().get(), "Worker returned wrong status code");
assertNotEquals(0, res.responseBodyHash().getAsLong(), "Worker didn't return a response body hash");
assertEquals("Non-Empty-Body".getBytes(StandardCharsets.UTF_8).length, res.contentLength().getAsLong(), "Worker returned wrong content length");
assertNotEquals(Duration.ZERO, res.duration(), "Worker returned zero duration");
}

}

@DisplayName("Test Malformed Response Processing")
@ParameterizedTest(name = "[{index}] fault = {0}")
@EnumSource(Fault.class)
public void testMalformedResponseProcessing(Fault fault) throws IOException, URISyntaxException {
public void testMalformedResponseProcessing(Fault fault) throws URISyntaxException {
SPARQLProtocolWorker worker = (SPARQLProtocolWorker) ((Named<?>)requestFactoryData().toList().get(0).get()[0]).getPayload();
wm.stubFor(get(urlPathEqualTo("/ds/query"))
.willReturn(aResponse().withFault(fault)));
Expand All @@ -217,7 +300,7 @@ public void testMalformedResponseProcessing(Fault fault) throws IOException, URI
}

@Test
public void testBadHttpCodeResponse() throws IOException, URISyntaxException {
public void testBadHttpCodeResponse() throws URISyntaxException {
SPARQLProtocolWorker worker = (SPARQLProtocolWorker) ((Named<?>)requestFactoryData().toList().get(0).get()[0]).getPayload();
wm.stubFor(get(urlPathEqualTo("/ds/query"))
.willReturn(aResponse().withStatus(404)));
Expand Down

0 comments on commit 38cae67

Please sign in to comment.