Skip to content

Commit

Permalink
ML: changing JobResultsProvider.getForecastRequestStats to support > …
Browse files Browse the repository at this point in the history
…1 index (#37157)

* ML: changing JobResultsProvider.getForecastRequestStats to support more than one index

* moving to use idsQuery()
  • Loading branch information
benwtrent authored Jan 7, 2019
1 parent 9602d79 commit 1780ced
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 48 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
package org.elasticsearch.xpack.ml.integration;

import org.elasticsearch.action.admin.cluster.node.tasks.list.ListTasksRequest;
import org.elasticsearch.action.get.GetResponse;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.support.master.AcknowledgedResponse;
import org.elasticsearch.client.Client;
Expand Down Expand Up @@ -348,17 +347,19 @@ protected void waitForecastToFinish(String jobId, String forecastId) throws Exce
}

protected ForecastRequestStats getForecastStats(String jobId, String forecastId) {
GetResponse getResponse = client().prepareGet()
.setIndex(AnomalyDetectorsIndex.jobResultsAliasedName(jobId))
.setId(ForecastRequestStats.documentId(jobId, forecastId))
.execute().actionGet();
SearchResponse searchResponse = client().prepareSearch(AnomalyDetectorsIndex.jobResultsAliasedName(jobId))
.setQuery(QueryBuilders.idsQuery().addIds(ForecastRequestStats.documentId(jobId, forecastId)))
.get();

if (getResponse.isExists() == false) {
if (searchResponse.getHits().getHits().length == 0) {
return null;
}

assertThat(searchResponse.getHits().getHits().length, equalTo(1));

try (XContentParser parser = XContentFactory.xContent(XContentType.JSON).createParser(
NamedXContentRegistry.EMPTY, DeprecationHandler.THROW_UNSUPPORTED_OPERATION,
getResponse.getSourceAsBytesRef().streamInput())) {
searchResponse.getHits().getHits()[0].getSourceRef().streamInput())) {
return ForecastRequestStats.STRICT_PARSER.apply(parser, null);
} catch (IOException e) {
throw new IllegalStateException(e);
Expand Down Expand Up @@ -398,7 +399,6 @@ protected long countForecastDocs(String jobId, String forecastId) {

protected List<Forecast> getForecasts(String jobId, ForecastRequestStats forecastRequestStats) {
List<Forecast> forecasts = new ArrayList<>();

SearchResponse searchResponse = client().prepareSearch(AnomalyDetectorsIndex.jobResultsIndexPrefix() + "*")
.setSize((int) forecastRequestStats.getRecordCount())
.setQuery(QueryBuilders.boolQuery()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -490,20 +490,6 @@ private <T, U> T parseSearchHit(SearchHit hit, BiFunction<XContentParser, U, T>
}
}

private <T, U> T parseGetHit(GetResponse getResponse, BiFunction<XContentParser, U, T> objectParser,
Consumer<Exception> errorHandler) {
BytesReference source = getResponse.getSourceAsBytesRef();

try (InputStream stream = source.streamInput();
XContentParser parser = XContentFactory.xContent(XContentType.JSON)
.createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, stream)) {
return objectParser.apply(parser, null);
} catch (IOException e) {
errorHandler.accept(new ElasticsearchParseException("failed to parse " + getResponse.getType(), e));
return null;
}
}

/**
* Search for buckets with the parameters in the {@link BucketsQueryBuilder}
* Uses the internal client, so runs as the _xpack user
Expand Down Expand Up @@ -957,19 +943,6 @@ private <U, T> void searchSingleResult(String jobId, String resultDescription, S
), client::search);
}

private <U, T> void getResult(String jobId, String resultDescription, GetRequest get, BiFunction<XContentParser, U, T> objectParser,
Consumer<Result<T>> handler, Consumer<Exception> errorHandler, Supplier<T> notFoundSupplier) {

executeAsyncWithOrigin(client.threadPool().getThreadContext(), ML_ORIGIN, get, ActionListener.<GetResponse>wrap(getDocResponse -> {
if (getDocResponse.isExists()) {
handler.accept(new Result<>(getDocResponse.getIndex(), parseGetHit(getDocResponse, objectParser, errorHandler)));
} else {
LOGGER.trace("No {} for job with id {}", resultDescription, jobId);
handler.accept(new Result<>(null, notFoundSupplier.get()));
}
}, errorHandler), client::get);
}

private SearchRequestBuilder createLatestModelSizeStatsSearch(String indexName) {
return client.prepareSearch(indexName)
.setSize(1)
Expand Down Expand Up @@ -1115,11 +1088,14 @@ public void scheduledEvents(ScheduledEventsQueryBuilder query, ActionListener<Qu
public void getForecastRequestStats(String jobId, String forecastId, Consumer<ForecastRequestStats> handler,
Consumer<Exception> errorHandler) {
String indexName = AnomalyDetectorsIndex.jobResultsAliasedName(jobId);
GetRequest getRequest = new GetRequest(indexName, ElasticsearchMappings.DOC_TYPE,
ForecastRequestStats.documentId(jobId, forecastId));

getResult(jobId, ForecastRequestStats.RESULTS_FIELD.getPreferredName(), getRequest, ForecastRequestStats.LENIENT_PARSER,
result -> handler.accept(result.result), errorHandler, () -> null);
SearchRequestBuilder forecastSearch = client.prepareSearch(indexName)
.setQuery(QueryBuilders.idsQuery().addIds(ForecastRequestStats.documentId(jobId, forecastId)));

searchSingleResult(jobId,
ForecastRequestStats.RESULTS_FIELD.getPreferredName(),
forecastSearch,
ForecastRequestStats.LENIENT_PARSER,result -> handler.accept(result.result),
errorHandler, () -> null);
}

public void getForecastStats(String jobId, Consumer<ForecastStats> handler, Consumer<Exception> errorHandler) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.admin.indices.create.CreateIndexRequest;
import org.elasticsearch.action.admin.indices.mapping.get.GetMappingsResponse;
import org.elasticsearch.action.get.GetResponse;
import org.elasticsearch.action.search.MultiSearchRequest;
import org.elasticsearch.action.search.MultiSearchResponse;
import org.elasticsearch.action.search.SearchRequest;
Expand Down Expand Up @@ -834,13 +833,6 @@ private JobResultsProvider createProvider(Client client) {
return new JobResultsProvider(client, Settings.EMPTY);
}

private static GetResponse createGetResponse(boolean exists, Map<String, Object> source) throws IOException {
GetResponse getResponse = mock(GetResponse.class);
when(getResponse.isExists()).thenReturn(exists);
when(getResponse.getSourceAsBytesRef()).thenReturn(BytesReference.bytes(XContentFactory.jsonBuilder().map(source)));
return getResponse;
}

private static SearchResponse createSearchResponse(List<Map<String, Object>> source) throws IOException {
SearchResponse response = mock(SearchResponse.class);
List<SearchHit> list = new ArrayList<>();
Expand Down

0 comments on commit 1780ced

Please sign in to comment.