Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Send Update Queries to seperate endpoint #281

Merged
merged 10 commits into from
Oct 11, 2024
Merged
32 changes: 32 additions & 0 deletions src/main/java/org/aksw/iguana/cc/query/QueryData.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package org.aksw.iguana.cc.query;

import org.apache.jena.update.UpdateFactory;

import java.io.InputStream;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;

/**
* This class stores extra information about a query.
* At the moment, it only stores if the query is an update query or not.
*
* @param queryId The id of the query
* @param update If the query is an update query
*/
public record QueryData(int queryId, boolean update) {
public static List<QueryData> generate(Collection<InputStream> queries) {
final var queryData = new ArrayList<QueryData>();
int i = 0;
for (InputStream query : queries) {
boolean update = true;
try {
UpdateFactory.read(query); // Throws an exception if the query is not an update query
} catch (Exception e) {
update = false;
}
queryData.add(new QueryData(i++, update));
}
return queryData;
}
}
19 changes: 15 additions & 4 deletions src/main/java/org/aksw/iguana/cc/query/handler/QueryHandler.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import com.fasterxml.jackson.databind.DeserializationContext;
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import com.fasterxml.jackson.databind.deser.std.StdDeserializer;
import org.aksw.iguana.cc.query.QueryData;
import org.aksw.iguana.cc.query.list.impl.StringListQueryList;
import org.aksw.iguana.cc.query.selector.QuerySelector;
import org.aksw.iguana.cc.query.selector.impl.LinearQuerySelector;
Expand Down Expand Up @@ -145,8 +146,9 @@ public Template(URI endpoint, Long limit, Boolean save) {
}
}

public record QueryStringWrapper(int index, String query) {}
public record QueryStreamWrapper(int index, boolean cached, Supplier<InputStream> queryInputStreamSupplier) {}
public record QueryStringWrapper(int index, String query, boolean update) {}

public record QueryStreamWrapper(int index, boolean cached, Supplier<InputStream> queryInputStreamSupplier, boolean update) {}


protected static final Logger LOGGER = LoggerFactory.getLogger(QueryHandler.class);
Expand All @@ -155,6 +157,7 @@ public record QueryStreamWrapper(int index, boolean cached, Supplier<InputStream
final protected Config config;

final protected QueryList queryList;
final protected List<QueryData> queryData;

private int workerCount = 0; // give every worker inside the same worker config an offset seed

Expand All @@ -168,6 +171,7 @@ protected QueryHandler() {
config = null;
queryList = null;
hashCode = 0;
queryData = null;
}

@JsonCreator
Expand All @@ -184,6 +188,13 @@ public QueryHandler(Config config) throws IOException {
new FileReadingQueryList(querySource);
}
this.hashCode = queryList.hashCode();
this.queryData = QueryData.generate(IntStream.range(0, queryList.size()).mapToObj(i -> {
try {
return queryList.getQueryStream(i);
} catch (IOException e) {
throw new RuntimeException("Couldn't read query stream", e);
}
}).collect(Collectors.toList()));
}

private QueryList initializeTemplateQueryHandler(QuerySource templateSource) throws IOException {
Expand Down Expand Up @@ -247,7 +258,7 @@ public QuerySelector getQuerySelectorInstance() {

public QueryStringWrapper getNextQuery(QuerySelector querySelector) throws IOException {
final var queryIndex = querySelector.getNextIndex();
return new QueryStringWrapper(queryIndex, queryList.getQuery(queryIndex));
return new QueryStringWrapper(queryIndex, queryList.getQuery(queryIndex), queryData.get(queryIndex).update());
}

public QueryStreamWrapper getNextQueryStream(QuerySelector querySelector) {
Expand All @@ -258,7 +269,7 @@ public QueryStreamWrapper getNextQueryStream(QuerySelector querySelector) {
} catch (IOException e) {
throw new RuntimeException(e);
}
});
}, queryData.get(queryIndex).update());
}

@Override
Expand Down
2 changes: 0 additions & 2 deletions src/main/java/org/aksw/iguana/cc/query/list/QueryList.java
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
package org.aksw.iguana.cc.query.list;

import org.aksw.iguana.cc.query.source.QuerySource;

import java.io.IOException;
import java.io.InputStream;

Expand Down
27 changes: 21 additions & 6 deletions src/main/java/org/aksw/iguana/cc/utils/http/RequestFactory.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import java.io.IOException;
import java.io.InputStream;
import java.net.URI;
import java.net.URISyntaxException;
import java.net.URLEncoder;
import java.nio.charset.StandardCharsets;
Expand Down Expand Up @@ -113,9 +114,18 @@ public AsyncRequestProducer buildHttpRequest(QueryHandler.QueryStreamWrapper que
throw new IOException(e);
}

// check if the query is an update query, if yes, change the request type to similar update request type
RequestType actualRequestType = requestType;
if (requestType == RequestType.GET_QUERY || requestType == RequestType.POST_QUERY)
actualRequestType = queryHandle.update() ? RequestType.POST_UPDATE : requestType;
if (requestType == RequestType.POST_URL_ENC_QUERY)
actualRequestType = queryHandle.update() ? RequestType.POST_URL_ENC_UPDATE : requestType;
// if only one endpoint is set, use it for both queries and updates
URI updateEndpoint = connectionConfig.updateEndpoint() != null ? connectionConfig.updateEndpoint() : connectionConfig.endpoint();

// If the query is bigger than 2^31 bytes (2GB) and the request type is set to GET_QUERY, POST_URL_ENC_QUERY or
// POST_URL_ENC_UPDATE, the following code will throw an exception.
switch (requestType) {
switch (actualRequestType) {
case GET_QUERY -> asyncRequestBuilder = AsyncRequestBuilder.get(new URIBuilder(connectionConfig.endpoint())
.addParameter("query", new String(queryStream.readAllBytes(), StandardCharsets.UTF_8))
.build()
Expand All @@ -127,21 +137,26 @@ public AsyncRequestProducer buildHttpRequest(QueryHandler.QueryStreamWrapper que
.setEntity(new BasicAsyncEntityProducer(urlEncode("query", new String(queryStream.readAllBytes(), StandardCharsets.UTF_8)), null, false));
case POST_QUERY -> asyncRequestBuilder = AsyncRequestBuilder.post(connectionConfig.endpoint())
.setEntity(new StreamEntityProducer(queryStreamSupplier, !caching, "application/sparql-query"));
case POST_URL_ENC_UPDATE -> asyncRequestBuilder = AsyncRequestBuilder.post(connectionConfig.endpoint())
case POST_URL_ENC_UPDATE -> asyncRequestBuilder = AsyncRequestBuilder.post(updateEndpoint)
.setHeader(HttpHeaders.CONTENT_TYPE, "application/x-www-form-urlencoded")
.setEntity(new BasicAsyncEntityProducer(urlEncode("update", new String(queryStream.readAllBytes(), StandardCharsets.UTF_8)), null, false));
case POST_UPDATE -> asyncRequestBuilder = AsyncRequestBuilder.post(connectionConfig.endpoint())
case POST_UPDATE -> asyncRequestBuilder = AsyncRequestBuilder.post(updateEndpoint)
.setEntity(new StreamEntityProducer(queryStreamSupplier, !caching, "application/sparql-update"));
default -> throw new IllegalStateException("Unexpected value: " + requestType);
}

// set additional headers
if (acceptHeader != null)
asyncRequestBuilder.addHeader("Accept", acceptHeader);
if (connectionConfig.authentication() != null && connectionConfig.authentication().user() != null)
if (queryHandle.update() && connectionConfig.updateAuthentication() != null && connectionConfig.updateAuthentication().user() != null) {
asyncRequestBuilder.addHeader("Authorization",
HttpWorker.basicAuth(connectionConfig.authentication().user(),
Optional.ofNullable(connectionConfig.authentication().password()).orElse("")));
HttpWorker.basicAuth(connectionConfig.updateAuthentication().user(),
Optional.ofNullable(connectionConfig.updateAuthentication().password()).orElse("")));
} else if (connectionConfig.authentication() != null && connectionConfig.authentication().user() != null) {
asyncRequestBuilder.addHeader("Authorization",
HttpWorker.basicAuth(connectionConfig.authentication().user(),
Optional.ofNullable(connectionConfig.authentication().password()).orElse("")));
}

// cache request
if (caching)
Expand Down
65 changes: 65 additions & 0 deletions src/test/java/org/aksw/iguana/cc/query/QueryDataTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
package org.aksw.iguana.cc.query;

import org.aksw.iguana.cc.query.source.QuerySource;
import org.aksw.iguana.cc.query.source.impl.FileSeparatorQuerySource;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.List;

import static org.junit.jupiter.api.Assertions.*;

class QueryDataTest {

private static Path tempFile = null;

@BeforeAll
public static void setup() throws IOException {
tempFile = Files.createTempFile("test", "txt");
Files.writeString(tempFile, """
SELECT ?s ?p ?o WHERE {
?s ?p ?o
}

INSERT DATA {
bigerl marked this conversation as resolved.
Show resolved Hide resolved
<http://example.org/s> <http://example.org/p> <http://example.org/o>
}

DELETE DATA {
<http://example.org/s> <http://example.org/p> <http://example.org/o>
}

SELECT ?s ?p ?o WHERE {
?s ?p ?o
}
""");
}

@AfterAll
public static void teardown() throws IOException {
Files.deleteIfExists(tempFile);
}

@Test
void testGeneration() throws IOException {
final QuerySource querySource = new FileSeparatorQuerySource(tempFile, "");
final var testStrings = querySource.getAllQueries();

List<List<QueryData>> generations = List.of(
QueryData.generate(testStrings.stream().map(s -> (InputStream) new ByteArrayInputStream(s.getBytes())).toList())
);
for (List<QueryData> generation : generations) {
assertEquals(4, generation.size());
assertFalse(generation.get(0).update());
assertTrue(generation.get(1).update());
assertTrue(generation.get(2).update());
assertFalse(generation.get(3).update());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -57,15 +57,19 @@ public class SPARQLProtocolWorkerTest {
.build();

private final static String QUERY = "SELECT * WHERE { ?s ?p ?o }";
private final static String UPDATE_QUERY = "INSERT DATA { <http://example.org/s> <http://example.org/p> <http://example.org/o> }";
private final static int QUERY_MIXES = 1;
private static Path queryFile;
private static Path updateFile;

private static final Logger LOGGER = LoggerFactory.getLogger(SPARQLProtocolWorker.class);

@BeforeAll
public static void setup() throws IOException {
queryFile = Files.createTempFile("iguana-test-queries", ".tmp");
updateFile = Files.createTempFile("iguana-test-updates", ".tmp");
Files.writeString(queryFile, QUERY, StandardCharsets.UTF_8);
Files.writeString(updateFile, QUERY + "\n\n" + UPDATE_QUERY, StandardCharsets.UTF_8);
}

@BeforeEach
Expand All @@ -77,6 +81,7 @@ public void reset() {
@AfterAll
public static void cleanup() throws IOException {
Files.deleteIfExists(queryFile);
Files.deleteIfExists(updateFile);
SPARQLProtocolWorker.closeHttpClient();
}

Expand Down Expand Up @@ -120,6 +125,31 @@ public static Stream<Arguments> requestFactoryData() throws URISyntaxException {
return workers.stream();
}

public static Stream<Arguments> updateWorkerData() throws IOException {
final var normalEndpoint = URI.create("http://localhost:" + wm.getPort() + "/ds/query");
final var updateEndpoint = URI.create("http://localhost:" + wm.getPort() + "/ds/update");
final var processor = new ResponseBodyProcessor("application/sparql-results+json");
final var format = QueryHandler.Config.Format.SEPARATOR;
final var queryHandler = new QueryHandler(new QueryHandler.Config(updateFile.toAbsolutePath().toString(), format, null, true, QueryHandler.Config.Order.LINEAR, 0L, QueryHandler.Config.Language.SPARQL));
final var datasetConfig = new DatasetConfig("TestDS", null);
final var connection = new ConnectionConfig("TestConn", "1", datasetConfig, normalEndpoint, new ConnectionConfig.Authentication("testUser", "password"), updateEndpoint, new ConnectionConfig.Authentication("updateUser", "password"));
final var workers = new ArrayDeque<Arguments>();
for (var requestType : List.of(RequestFactory.RequestType.GET_QUERY, RequestFactory.RequestType.POST_URL_ENC_QUERY, RequestFactory.RequestType.POST_QUERY)) {
final var config = new SPARQLProtocolWorker.Config(
1,
queryHandler,
new HttpWorker.QueryMixes(QUERY_MIXES),
connection,
Duration.parse("PT6S"),
"application/sparql-results+json",
requestType,
true
);
workers.add(Arguments.of(Named.of(requestType.name(), new SPARQLProtocolWorker(0, processor, config))));
}
return workers.stream();
}

public static List<Arguments> completionTargets() {
final var out = new ArrayList<Arguments>();
final var queryMixesAmount = List.of(1, 2, 5, 10, 100, 200);
Expand Down Expand Up @@ -204,10 +234,63 @@ public void testRequestFactory(SPARQLProtocolWorker worker, boolean cached) {
assertNotEquals(Duration.ZERO, result.executionStats().get(0).duration(), "Worker returned zero duration");
}

@ParameterizedTest
@MethodSource("updateWorkerData")
public void testSeparateUpdateEndpoint(SPARQLProtocolWorker worker) {
final var workerConfig = worker.config();
switch (workerConfig.requestType()) {
case GET_QUERY -> {
wm.stubFor(get(urlPathEqualTo("/ds/query"))
.withQueryParam("query", equalTo(QUERY))
.withBasicAuth("testUser", "password")
.willReturn(aResponse().withStatus(200).withBody("Non-Empty-Body")));
wm.stubFor(post(urlPathEqualTo("/ds/update"))
.withHeader("Content-Type", equalTo("application/sparql-update"))
.withBasicAuth("updateUser", "password")
.withRequestBody(equalTo(UPDATE_QUERY))
.willReturn(aResponse().withStatus(200).withBody("Non-Empty-Body")));
}
case POST_URL_ENC_QUERY -> {
wm.stubFor(post(urlPathEqualTo("/ds/query"))
.withHeader("Content-Type", equalTo("application/x-www-form-urlencoded"))
.withBasicAuth("testUser", "password")
.withRequestBody(equalTo("query=" + URLEncoder.encode(QUERY, StandardCharsets.UTF_8)))
.willReturn(aResponse().withStatus(200).withBody("Non-Empty-Body")));
wm.stubFor(post(urlPathEqualTo("/ds/update"))
.withHeader("Content-Type", equalTo("application/x-www-form-urlencoded"))
.withBasicAuth("updateUser", "password")
.withRequestBody(equalTo("update=" + URLEncoder.encode(UPDATE_QUERY, StandardCharsets.UTF_8)))
.willReturn(aResponse().withStatus(200).withBody("Non-Empty-Body")));
}
case POST_QUERY -> {
wm.stubFor(post(urlPathEqualTo("/ds/query"))
.withHeader("Content-Type", equalTo("application/sparql-query"))
.withBasicAuth("testUser", "password")
.withRequestBody(equalTo(QUERY))
.willReturn(aResponse().withStatus(200).withBody("Non-Empty-Body")));
wm.stubFor(post(urlPathEqualTo("/ds/update"))
.withHeader("Content-Type", equalTo("application/sparql-update"))
.withBasicAuth("updateUser", "password")
.withRequestBody(equalTo(UPDATE_QUERY))
.willReturn(aResponse().withStatus(200).withBody("Non-Empty-Body")));
}
}
final HttpWorker.Result result = worker.start().join();
assertEquals(result.executionStats().size(), QUERY_MIXES * 2, "Worker should have executed only 1 query");
for (var res : result.executionStats()) {
assertNull(res.error().orElse(null), "Worker threw an exception, during execution");
assertEquals(200, res.httpStatusCode().get(), "Worker returned wrong status code");
assertNotEquals(0, res.responseBodyHash().getAsLong(), "Worker didn't return a response body hash");
assertEquals("Non-Empty-Body".getBytes(StandardCharsets.UTF_8).length, res.contentLength().getAsLong(), "Worker returned wrong content length");
assertNotEquals(Duration.ZERO, res.duration(), "Worker returned zero duration");
}

}

@DisplayName("Test Malformed Response Processing")
@ParameterizedTest(name = "[{index}] fault = {0}")
@EnumSource(Fault.class)
public void testMalformedResponseProcessing(Fault fault) throws IOException, URISyntaxException {
public void testMalformedResponseProcessing(Fault fault) throws URISyntaxException {
SPARQLProtocolWorker worker = (SPARQLProtocolWorker) ((Named<?>)requestFactoryData().toList().get(0).get()[0]).getPayload();
wm.stubFor(get(urlPathEqualTo("/ds/query"))
.willReturn(aResponse().withFault(fault)));
Expand All @@ -217,7 +300,7 @@ public void testMalformedResponseProcessing(Fault fault) throws IOException, URI
}

@Test
public void testBadHttpCodeResponse() throws IOException, URISyntaxException {
public void testBadHttpCodeResponse() throws URISyntaxException {
SPARQLProtocolWorker worker = (SPARQLProtocolWorker) ((Named<?>)requestFactoryData().toList().get(0).get()[0]).getPayload();
wm.stubFor(get(urlPathEqualTo("/ds/query"))
.willReturn(aResponse().withStatus(404)));
Expand Down
Loading