diff --git a/docs/java-api/aggregations/metrics/scripted-metric-aggregation.asciidoc b/docs/java-api/aggregations/metrics/scripted-metric-aggregation.asciidoc index b23a683b05610..5b68fa7be451f 100644 --- a/docs/java-api/aggregations/metrics/scripted-metric-aggregation.asciidoc +++ b/docs/java-api/aggregations/metrics/scripted-metric-aggregation.asciidoc @@ -13,8 +13,8 @@ Here is an example on how to create the aggregation request: -------------------------------------------------- ScriptedMetricAggregationBuilder aggregation = AggregationBuilders .scriptedMetric("agg") - .initScript(new Script("params._agg.heights = []")) - .mapScript(new Script("params._agg.heights.add(doc.gender.value == 'male' ? doc.height.value : -1.0 * doc.height.value)")); + .initScript(new Script("state.heights = []")) + .mapScript(new Script("state.heights.add(doc.gender.value == 'male' ? doc.height.value : -1.0 * doc.height.value)")); -------------------------------------------------- You can also specify a `combine` script which will be executed on each shard: @@ -23,9 +23,9 @@ You can also specify a `combine` script which will be executed on each shard: -------------------------------------------------- ScriptedMetricAggregationBuilder aggregation = AggregationBuilders .scriptedMetric("agg") - .initScript(new Script("params._agg.heights = []")) - .mapScript(new Script("params._agg.heights.add(doc.gender.value == 'male' ? doc.height.value : -1.0 * doc.height.value)")) - .combineScript(new Script("double heights_sum = 0.0; for (t in params._agg.heights) { heights_sum += t } return heights_sum")); + .initScript(new Script("state.heights = []")) + .mapScript(new Script("state.heights.add(doc.gender.value == 'male' ? doc.height.value : -1.0 * doc.height.value)")) + .combineScript(new Script("double heights_sum = 0.0; for (t in state.heights) { heights_sum += t } return heights_sum")); -------------------------------------------------- You can also specify a `reduce` script which will be executed on the node which gets the request: @@ -34,10 +34,10 @@ You can also specify a `reduce` script which will be executed on the node which -------------------------------------------------- ScriptedMetricAggregationBuilder aggregation = AggregationBuilders .scriptedMetric("agg") - .initScript(new Script("params._agg.heights = []")) - .mapScript(new Script("params._agg.heights.add(doc.gender.value == 'male' ? doc.height.value : -1.0 * doc.height.value)")) - .combineScript(new Script("double heights_sum = 0.0; for (t in params._agg.heights) { heights_sum += t } return heights_sum")) - .reduceScript(new Script("double heights_sum = 0.0; for (a in params._aggs) { heights_sum += a } return heights_sum")); + .initScript(new Script("state.heights = []")) + .mapScript(new Script("state.heights.add(doc.gender.value == 'male' ? doc.height.value : -1.0 * doc.height.value)")) + .combineScript(new Script("double heights_sum = 0.0; for (t in state.heights) { heights_sum += t } return heights_sum")) + .reduceScript(new Script("double heights_sum = 0.0; for (a in states) { heights_sum += a } return heights_sum")); -------------------------------------------------- diff --git a/docs/reference/aggregations/metrics/scripted-metric-aggregation.asciidoc b/docs/reference/aggregations/metrics/scripted-metric-aggregation.asciidoc index 1a4d6d4774c49..c4857699f9805 100644 --- a/docs/reference/aggregations/metrics/scripted-metric-aggregation.asciidoc +++ b/docs/reference/aggregations/metrics/scripted-metric-aggregation.asciidoc @@ -15,10 +15,10 @@ POST ledger/_search?size=0 "aggs": { "profit": { "scripted_metric": { - "init_script" : "params._agg.transactions = []", - "map_script" : "params._agg.transactions.add(doc.type.value == 'sale' ? doc.amount.value : -1 * doc.amount.value)", <1> - "combine_script" : "double profit = 0; for (t in params._agg.transactions) { profit += t } return profit", - "reduce_script" : "double profit = 0; for (a in params._aggs) { profit += a } return profit" + "init_script" : "state.transactions = []", + "map_script" : "state.transactions.add(doc.type.value == 'sale' ? doc.amount.value : -1 * doc.amount.value)", <1> + "combine_script" : "double profit = 0; for (t in state.transactions) { profit += t } return profit", + "reduce_script" : "double profit = 0; for (a in states) { profit += a } return profit" } } } @@ -67,8 +67,7 @@ POST ledger/_search?size=0 "id": "my_combine_script" }, "params": { - "field": "amount", <1> - "_agg": {} <2> + "field": "amount" <1> }, "reduce_script" : { "id": "my_reduce_script" @@ -82,8 +81,7 @@ POST ledger/_search?size=0 // TEST[setup:ledger,stored_scripted_metric_script] <1> script parameters for `init`, `map` and `combine` scripts must be specified -in a global `params` object so that it can be share between the scripts. -<2> if you specify script parameters then you must specify `"_agg": {}`. +in a global `params` object so that it can be shared between the scripts. //// Verify this response as well but in a hidden block. @@ -108,7 +106,7 @@ For more details on specifying scripts see <, List> scriptContexts() { + Map, List> contexts = new HashMap<>(); + contexts.put(ScriptedMetricAggContexts.InitScript.CONTEXT, Whitelist.BASE_WHITELISTS); + contexts.put(ScriptedMetricAggContexts.MapScript.CONTEXT, Whitelist.BASE_WHITELISTS); + contexts.put(ScriptedMetricAggContexts.CombineScript.CONTEXT, Whitelist.BASE_WHITELISTS); + contexts.put(ScriptedMetricAggContexts.ReduceScript.CONTEXT, Whitelist.BASE_WHITELISTS); + return contexts; + } + + public void testInitBasic() { + ScriptedMetricAggContexts.InitScript.Factory factory = scriptEngine.compile("test", + "state.testField = params.initialVal", ScriptedMetricAggContexts.InitScript.CONTEXT, Collections.emptyMap()); + + Map params = new HashMap<>(); + Map state = new HashMap<>(); + + params.put("initialVal", 10); + + ScriptedMetricAggContexts.InitScript script = factory.newInstance(params, state); + script.execute(); + + assert(state.containsKey("testField")); + assertEquals(10, state.get("testField")); + } + + public void testMapBasic() { + ScriptedMetricAggContexts.MapScript.Factory factory = scriptEngine.compile("test", + "state.testField = 2*_score", ScriptedMetricAggContexts.MapScript.CONTEXT, Collections.emptyMap()); + + Map params = new HashMap<>(); + Map state = new HashMap<>(); + + Scorer scorer = new Scorer(null) { + @Override + public int docID() { return 0; } + + @Override + public float score() { return 0.5f; } + + @Override + public DocIdSetIterator iterator() { return null; } + }; + + ScriptedMetricAggContexts.MapScript.LeafFactory leafFactory = factory.newFactory(params, state, null); + ScriptedMetricAggContexts.MapScript script = leafFactory.newInstance(null); + + script.setScorer(scorer); + script.execute(); + + assert(state.containsKey("testField")); + assertEquals(1.0, state.get("testField")); + } + + public void testCombineBasic() { + ScriptedMetricAggContexts.CombineScript.Factory factory = scriptEngine.compile("test", + "state.testField = params.initialVal; return state.testField + params.inc", ScriptedMetricAggContexts.CombineScript.CONTEXT, + Collections.emptyMap()); + + Map params = new HashMap<>(); + Map state = new HashMap<>(); + + params.put("initialVal", 10); + params.put("inc", 2); + + ScriptedMetricAggContexts.CombineScript script = factory.newInstance(params, state); + Object res = script.execute(); + + assert(state.containsKey("testField")); + assertEquals(10, state.get("testField")); + assertEquals(12, res); + } + + public void testReduceBasic() { + ScriptedMetricAggContexts.ReduceScript.Factory factory = scriptEngine.compile("test", + "states[0].testField + states[1].testField", ScriptedMetricAggContexts.ReduceScript.CONTEXT, Collections.emptyMap()); + + Map params = new HashMap<>(); + List states = new ArrayList<>(); + + Map state1 = new HashMap<>(), state2 = new HashMap<>(); + state1.put("testField", 1); + state2.put("testField", 2); + + states.add(state1); + states.add(state2); + + ScriptedMetricAggContexts.ReduceScript script = factory.newInstance(params, states); + Object res = script.execute(); + assertEquals(3, res); + } +} diff --git a/server/src/main/java/org/elasticsearch/script/ScriptModule.java b/server/src/main/java/org/elasticsearch/script/ScriptModule.java index 7074d3ad9fe44..f0e075eac7d93 100644 --- a/server/src/main/java/org/elasticsearch/script/ScriptModule.java +++ b/server/src/main/java/org/elasticsearch/script/ScriptModule.java @@ -53,7 +53,11 @@ public class ScriptModule { SimilarityScript.CONTEXT, SimilarityWeightScript.CONTEXT, TemplateScript.CONTEXT, - MovingFunctionScript.CONTEXT + MovingFunctionScript.CONTEXT, + ScriptedMetricAggContexts.InitScript.CONTEXT, + ScriptedMetricAggContexts.MapScript.CONTEXT, + ScriptedMetricAggContexts.CombineScript.CONTEXT, + ScriptedMetricAggContexts.ReduceScript.CONTEXT ).collect(Collectors.toMap(c -> c.name, Function.identity())); } diff --git a/server/src/main/java/org/elasticsearch/script/ScriptedMetricAggContexts.java b/server/src/main/java/org/elasticsearch/script/ScriptedMetricAggContexts.java new file mode 100644 index 0000000000000..774dc95d39977 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/script/ScriptedMetricAggContexts.java @@ -0,0 +1,161 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.script; + +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.Scorer; +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.index.fielddata.ScriptDocValues; +import org.elasticsearch.search.lookup.LeafSearchLookup; +import org.elasticsearch.search.lookup.SearchLookup; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +public class ScriptedMetricAggContexts { + private abstract static class ParamsAndStateBase { + private final Map params; + private final Object state; + + ParamsAndStateBase(Map params, Object state) { + this.params = params; + this.state = state; + } + + public Map getParams() { + return params; + } + + public Object getState() { + return state; + } + } + + public abstract static class InitScript extends ParamsAndStateBase { + public InitScript(Map params, Object state) { + super(params, state); + } + + public abstract void execute(); + + public interface Factory { + InitScript newInstance(Map params, Object state); + } + + public static String[] PARAMETERS = {}; + public static ScriptContext CONTEXT = new ScriptContext<>("aggs_init", Factory.class); + } + + public abstract static class MapScript extends ParamsAndStateBase { + private final LeafSearchLookup leafLookup; + private Scorer scorer; + + public MapScript(Map params, Object state, SearchLookup lookup, LeafReaderContext leafContext) { + super(params, state); + + this.leafLookup = leafContext == null ? null : lookup.getLeafSearchLookup(leafContext); + } + + // Return the doc as a map (instead of LeafDocLookup) in order to abide by type whitelisting rules for + // Painless scripts. + public Map> getDoc() { + return leafLookup == null ? null : leafLookup.doc(); + } + + public void setDocument(int docId) { + if (leafLookup != null) { + leafLookup.setDocument(docId); + } + } + + public void setScorer(Scorer scorer) { + this.scorer = scorer; + } + + // get_score() is named this way so that it's picked up by Painless as '_score' + public double get_score() { + if (scorer == null) { + return 0.0; + } + + try { + return scorer.score(); + } catch (IOException e) { + throw new ElasticsearchException("Couldn't look up score", e); + } + } + + public abstract void execute(); + + public interface LeafFactory { + MapScript newInstance(LeafReaderContext ctx); + } + + public interface Factory { + LeafFactory newFactory(Map params, Object state, SearchLookup lookup); + } + + public static String[] PARAMETERS = new String[] {}; + public static ScriptContext CONTEXT = new ScriptContext<>("aggs_map", Factory.class); + } + + public abstract static class CombineScript extends ParamsAndStateBase { + public CombineScript(Map params, Object state) { + super(params, state); + } + + public abstract Object execute(); + + public interface Factory { + CombineScript newInstance(Map params, Object state); + } + + public static String[] PARAMETERS = {}; + public static ScriptContext CONTEXT = new ScriptContext<>("aggs_combine", Factory.class); + } + + public abstract static class ReduceScript { + private final Map params; + private final List states; + + public ReduceScript(Map params, List states) { + this.params = params; + this.states = states; + } + + public Map getParams() { + return params; + } + + public List getStates() { + return states; + } + + public abstract Object execute(); + + public interface Factory { + ReduceScript newInstance(Map params, List states); + } + + public static String[] PARAMETERS = {}; + public static ScriptContext CONTEXT = new ScriptContext<>("aggs_reduce", Factory.class); + } +} diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/scripted/InternalScriptedMetric.java b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/scripted/InternalScriptedMetric.java index e350ecbed5814..f4281c063ff2c 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/scripted/InternalScriptedMetric.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/scripted/InternalScriptedMetric.java @@ -23,7 +23,7 @@ import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.util.CollectionUtils; import org.elasticsearch.common.xcontent.XContentBuilder; -import org.elasticsearch.script.ExecutableScript; +import org.elasticsearch.script.ScriptedMetricAggContexts; import org.elasticsearch.script.Script; import org.elasticsearch.search.aggregations.InternalAggregation; import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator; @@ -90,16 +90,19 @@ public InternalAggregation doReduce(List aggregations, Redu InternalScriptedMetric firstAggregation = ((InternalScriptedMetric) aggregations.get(0)); List aggregation; if (firstAggregation.reduceScript != null && reduceContext.isFinalReduce()) { - Map vars = new HashMap<>(); - vars.put("_aggs", aggregationObjects); + Map params = new HashMap<>(); if (firstAggregation.reduceScript.getParams() != null) { - vars.putAll(firstAggregation.reduceScript.getParams()); + params.putAll(firstAggregation.reduceScript.getParams()); } - ExecutableScript.Factory factory = reduceContext.scriptService().compile( - firstAggregation.reduceScript, ExecutableScript.AGGS_CONTEXT); - ExecutableScript script = factory.newInstance(vars); - Object scriptResult = script.run(); + // Add _aggs to params map for backwards compatibility (redundant with a context variable on the ReduceScript created below). + params.put("_aggs", aggregationObjects); + + ScriptedMetricAggContexts.ReduceScript.Factory factory = reduceContext.scriptService().compile( + firstAggregation.reduceScript, ScriptedMetricAggContexts.ReduceScript.CONTEXT); + ScriptedMetricAggContexts.ReduceScript script = factory.newInstance(params, aggregationObjects); + + Object scriptResult = script.execute(); CollectionUtils.ensureNoSelfReferences(scriptResult, "reduce script"); aggregation = Collections.singletonList(scriptResult); diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregationBuilder.java b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregationBuilder.java index 225398e51b7c0..8b6d834184d73 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregationBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregationBuilder.java @@ -26,9 +26,8 @@ import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.index.query.QueryShardContext; -import org.elasticsearch.script.ExecutableScript; +import org.elasticsearch.script.ScriptedMetricAggContexts; import org.elasticsearch.script.Script; -import org.elasticsearch.script.SearchScript; import org.elasticsearch.search.aggregations.AbstractAggregationBuilder; import org.elasticsearch.search.aggregations.AggregationBuilder; import org.elasticsearch.search.aggregations.AggregatorFactories.Builder; @@ -202,30 +201,32 @@ protected ScriptedMetricAggregatorFactory doBuild(SearchContext context, Aggrega // Extract params from scripts and pass them along to ScriptedMetricAggregatorFactory, since it won't have // access to them for the scripts it's given precompiled. - ExecutableScript.Factory executableInitScript; + ScriptedMetricAggContexts.InitScript.Factory compiledInitScript; Map initScriptParams; if (initScript != null) { - executableInitScript = queryShardContext.getScriptService().compile(initScript, ExecutableScript.AGGS_CONTEXT); + compiledInitScript = queryShardContext.getScriptService().compile(initScript, ScriptedMetricAggContexts.InitScript.CONTEXT); initScriptParams = initScript.getParams(); } else { - executableInitScript = p -> null; + compiledInitScript = (p, a) -> null; initScriptParams = Collections.emptyMap(); } - SearchScript.Factory searchMapScript = queryShardContext.getScriptService().compile(mapScript, SearchScript.AGGS_CONTEXT); + ScriptedMetricAggContexts.MapScript.Factory compiledMapScript = queryShardContext.getScriptService().compile(mapScript, + ScriptedMetricAggContexts.MapScript.CONTEXT); Map mapScriptParams = mapScript.getParams(); - ExecutableScript.Factory executableCombineScript; + ScriptedMetricAggContexts.CombineScript.Factory compiledCombineScript; Map combineScriptParams; if (combineScript != null) { - executableCombineScript = queryShardContext.getScriptService().compile(combineScript, ExecutableScript.AGGS_CONTEXT); + compiledCombineScript = queryShardContext.getScriptService().compile(combineScript, + ScriptedMetricAggContexts.CombineScript.CONTEXT); combineScriptParams = combineScript.getParams(); } else { - executableCombineScript = p -> null; + compiledCombineScript = (p, a) -> null; combineScriptParams = Collections.emptyMap(); } - return new ScriptedMetricAggregatorFactory(name, searchMapScript, mapScriptParams, executableInitScript, initScriptParams, - executableCombineScript, combineScriptParams, reduceScript, + return new ScriptedMetricAggregatorFactory(name, compiledMapScript, mapScriptParams, compiledInitScript, + initScriptParams, compiledCombineScript, combineScriptParams, reduceScript, params, queryShardContext.lookup(), context, parent, subfactoriesBuilder, metaData); } diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregator.java b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregator.java index d6e861a9a6792..ffdff44b783b6 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregator.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregator.java @@ -20,10 +20,10 @@ package org.elasticsearch.search.aggregations.metrics.scripted; import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.Scorer; import org.elasticsearch.common.util.CollectionUtils; -import org.elasticsearch.script.ExecutableScript; +import org.elasticsearch.script.ScriptedMetricAggContexts; import org.elasticsearch.script.Script; -import org.elasticsearch.script.SearchScript; import org.elasticsearch.search.aggregations.Aggregator; import org.elasticsearch.search.aggregations.InternalAggregation; import org.elasticsearch.search.aggregations.LeafBucketCollector; @@ -38,17 +38,17 @@ public class ScriptedMetricAggregator extends MetricsAggregator { - private final SearchScript.LeafFactory mapScript; - private final ExecutableScript combineScript; + private final ScriptedMetricAggContexts.MapScript.LeafFactory mapScript; + private final ScriptedMetricAggContexts.CombineScript combineScript; private final Script reduceScript; - private Map params; + private Object aggState; - protected ScriptedMetricAggregator(String name, SearchScript.LeafFactory mapScript, ExecutableScript combineScript, - Script reduceScript, - Map params, SearchContext context, Aggregator parent, List pipelineAggregators, Map metaData) - throws IOException { + protected ScriptedMetricAggregator(String name, ScriptedMetricAggContexts.MapScript.LeafFactory mapScript, ScriptedMetricAggContexts.CombineScript combineScript, + Script reduceScript, Object aggState, SearchContext context, Aggregator parent, + List pipelineAggregators, Map metaData) + throws IOException { super(name, context, parent, pipelineAggregators, metaData); - this.params = params; + this.aggState = aggState; this.mapScript = mapScript; this.combineScript = combineScript; this.reduceScript = reduceScript; @@ -62,14 +62,20 @@ public boolean needsScores() { @Override public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, final LeafBucketCollector sub) throws IOException { - final SearchScript leafMapScript = mapScript.newInstance(ctx); + final ScriptedMetricAggContexts.MapScript leafMapScript = mapScript.newInstance(ctx); return new LeafBucketCollectorBase(sub, leafMapScript) { + @Override + public void setScorer(Scorer scorer) throws IOException { + leafMapScript.setScorer(scorer); + } + @Override public void collect(int doc, long bucket) throws IOException { assert bucket == 0 : bucket; + leafMapScript.setDocument(doc); - leafMapScript.run(); - CollectionUtils.ensureNoSelfReferences(params, "Scripted metric aggs map script"); + leafMapScript.execute(); + CollectionUtils.ensureNoSelfReferences(aggState, "Scripted metric aggs map script"); } }; } @@ -78,10 +84,10 @@ public void collect(int doc, long bucket) throws IOException { public InternalAggregation buildAggregation(long owningBucketOrdinal) { Object aggregation; if (combineScript != null) { - aggregation = combineScript.run(); + aggregation = combineScript.execute(); CollectionUtils.ensureNoSelfReferences(aggregation, "Scripted metric aggs combine script"); } else { - aggregation = params.get("_agg"); + aggregation = aggState; } return new InternalScriptedMetric(name, aggregation, reduceScript, pipelineAggregators(), metaData()); diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregatorFactory.java b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregatorFactory.java index 0deda32e79d77..9bd904a07013d 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregatorFactory.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregatorFactory.java @@ -19,10 +19,9 @@ package org.elasticsearch.search.aggregations.metrics.scripted; +import org.elasticsearch.script.ScriptedMetricAggContexts; import org.elasticsearch.common.util.CollectionUtils; -import org.elasticsearch.script.ExecutableScript; import org.elasticsearch.script.Script; -import org.elasticsearch.script.SearchScript; import org.elasticsearch.search.SearchParseException; import org.elasticsearch.search.aggregations.Aggregator; import org.elasticsearch.search.aggregations.AggregatorFactories; @@ -39,20 +38,21 @@ public class ScriptedMetricAggregatorFactory extends AggregatorFactory { - private final SearchScript.Factory mapScript; + private final ScriptedMetricAggContexts.MapScript.Factory mapScript; private final Map mapScriptParams; - private final ExecutableScript.Factory combineScript; + private final ScriptedMetricAggContexts.CombineScript.Factory combineScript; private final Map combineScriptParams; private final Script reduceScript; private final Map aggParams; private final SearchLookup lookup; - private final ExecutableScript.Factory initScript; + private final ScriptedMetricAggContexts.InitScript.Factory initScript; private final Map initScriptParams; - public ScriptedMetricAggregatorFactory(String name, SearchScript.Factory mapScript, Map mapScriptParams, - ExecutableScript.Factory initScript, Map initScriptParams, - ExecutableScript.Factory combineScript, Map combineScriptParams, - Script reduceScript, Map aggParams, + public ScriptedMetricAggregatorFactory(String name, + ScriptedMetricAggContexts.MapScript.Factory mapScript, Map mapScriptParams, + ScriptedMetricAggContexts.InitScript.Factory initScript, Map initScriptParams, + ScriptedMetricAggContexts.CombineScript.Factory combineScript, + Map combineScriptParams, Script reduceScript, Map aggParams, SearchLookup lookup, SearchContext context, AggregatorFactory parent, AggregatorFactories.Builder subFactories, Map metaData) throws IOException { super(name, context, parent, subFactories, metaData); @@ -79,21 +79,29 @@ public Aggregator createInternal(Aggregator parent, boolean collectsFromSingleBu } else { aggParams = new HashMap<>(); } + + // Add _agg to params map for backwards compatibility (redundant with context variables on the scripts created below). + // When this is removed, aggState (as passed to ScriptedMetricAggregator) can be changed to Map, since + // it won't be possible to completely replace it with another type as is possible when it's an entry in params. if (aggParams.containsKey("_agg") == false) { aggParams.put("_agg", new HashMap()); } + Object aggState = aggParams.get("_agg"); - final ExecutableScript initScript = this.initScript.newInstance(mergeParams(aggParams, initScriptParams)); - final SearchScript.LeafFactory mapScript = this.mapScript.newFactory(mergeParams(aggParams, mapScriptParams), lookup); - final ExecutableScript combineScript = this.combineScript.newInstance(mergeParams(aggParams, combineScriptParams)); + final ScriptedMetricAggContexts.InitScript initScript = this.initScript.newInstance( + mergeParams(aggParams, initScriptParams), aggState); + final ScriptedMetricAggContexts.MapScript.LeafFactory mapScript = this.mapScript.newFactory( + mergeParams(aggParams, mapScriptParams), aggState, lookup); + final ScriptedMetricAggContexts.CombineScript combineScript = this.combineScript.newInstance( + mergeParams(aggParams, combineScriptParams), aggState); final Script reduceScript = deepCopyScript(this.reduceScript, context); if (initScript != null) { - initScript.run(); - CollectionUtils.ensureNoSelfReferences(aggParams.get("_agg"), "Scripted metric aggs init script"); + initScript.execute(); + CollectionUtils.ensureNoSelfReferences(aggState, "Scripted metric aggs init script"); } return new ScriptedMetricAggregator(name, mapScript, - combineScript, reduceScript, aggParams, context, parent, + combineScript, reduceScript, aggState, context, parent, pipelineAggregators, metaData); } diff --git a/server/src/test/java/org/elasticsearch/search/aggregations/metrics/ScriptedMetricIT.java b/server/src/test/java/org/elasticsearch/search/aggregations/metrics/ScriptedMetricIT.java index 816c0464d95d9..13e1489795996 100644 --- a/server/src/test/java/org/elasticsearch/search/aggregations/metrics/ScriptedMetricIT.java +++ b/server/src/test/java/org/elasticsearch/search/aggregations/metrics/ScriptedMetricIT.java @@ -193,14 +193,55 @@ protected Map, Object>> pluginScripts() { return newAggregation; }); + scripts.put("state.items = new ArrayList()", vars -> + aggContextScript(vars, state -> ((HashMap) state).put("items", new ArrayList()))); + + scripts.put("state.items.add(1)", vars -> + aggContextScript(vars, state -> { + HashMap stateMap = (HashMap) state; + List items = (List) stateMap.get("items"); + items.add(1); + })); + + scripts.put("sum context state values", vars -> { + int sum = 0; + HashMap state = (HashMap) vars.get("state"); + List items = (List) state.get("items"); + + for (Object x : items) { + sum += (Integer)x; + } + + return sum; + }); + + scripts.put("sum context states", vars -> { + Integer sum = 0; + + List states = (List) vars.get("states"); + for (Object state : states) { + sum += ((Number) state).intValue(); + } + + return sum; + }); + return scripts; } - @SuppressWarnings("unchecked") static Object aggScript(Map vars, Consumer fn) { - T agg = (T) vars.get("_agg"); - fn.accept(agg); - return agg; + return aggScript(vars, fn, "_agg"); + } + + static Object aggContextScript(Map vars, Consumer fn) { + return aggScript(vars, fn, "state"); + } + + @SuppressWarnings("unchecked") + private static Object aggScript(Map vars, Consumer fn, String stateVarName) { + T aggState = (T) vars.get(stateVarName); + fn.accept(aggState); + return aggState; } } @@ -1015,4 +1056,37 @@ public void testConflictingAggAndScriptParams() { SearchPhaseExecutionException ex = expectThrows(SearchPhaseExecutionException.class, builder::get); assertThat(ex.getCause().getMessage(), containsString("Parameter name \"param1\" used in both aggregation and script parameters")); } + + public void testAggFromContext() { + Script initScript = new Script(ScriptType.INLINE, CustomScriptPlugin.NAME, "state.items = new ArrayList()", Collections.emptyMap()); + Script mapScript = new Script(ScriptType.INLINE, CustomScriptPlugin.NAME, "state.items.add(1)", Collections.emptyMap()); + Script combineScript = new Script(ScriptType.INLINE, CustomScriptPlugin.NAME, "sum context state values", Collections.emptyMap()); + Script reduceScript = + new Script(ScriptType.INLINE, CustomScriptPlugin.NAME, "sum context states", + Collections.emptyMap()); + + SearchResponse response = client() + .prepareSearch("idx") + .setQuery(matchAllQuery()) + .addAggregation( + scriptedMetric("scripted") + .initScript(initScript) + .mapScript(mapScript) + .combineScript(combineScript) + .reduceScript(reduceScript)) + .get(); + + Aggregation aggregation = response.getAggregations().get("scripted"); + assertThat(aggregation, notNullValue()); + assertThat(aggregation, instanceOf(ScriptedMetric.class)); + + ScriptedMetric scriptedMetricAggregation = (ScriptedMetric) aggregation; + assertThat(scriptedMetricAggregation.getName(), equalTo("scripted")); + assertThat(scriptedMetricAggregation.aggregation(), notNullValue()); + + assertThat(scriptedMetricAggregation.aggregation(), instanceOf(Integer.class)); + Integer aggResult = (Integer) scriptedMetricAggregation.aggregation(); + long totalAgg = aggResult.longValue(); + assertThat(totalAgg, equalTo(numDocs)); + } } diff --git a/server/src/test/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregatorTests.java b/server/src/test/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregatorTests.java index 7a7c66d21aada..b2a949ceeee1a 100644 --- a/server/src/test/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregatorTests.java +++ b/server/src/test/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregatorTests.java @@ -31,7 +31,6 @@ import org.elasticsearch.index.query.QueryShardContext; import org.elasticsearch.indices.breaker.CircuitBreakerService; import org.elasticsearch.script.MockScriptEngine; -import org.elasticsearch.script.ScoreAccessor; import org.elasticsearch.script.Script; import org.elasticsearch.script.ScriptEngine; import org.elasticsearch.script.ScriptModule; @@ -107,7 +106,7 @@ public static void initMockScripts() { }); SCRIPTS.put("mapScriptScore", params -> { Map agg = (Map) params.get("_agg"); - ((List) agg.get("collector")).add(((ScoreAccessor) params.get("_score")).doubleValue()); + ((List) agg.get("collector")).add(((Number) params.get("_score")).doubleValue()); return agg; }); SCRIPTS.put("combineScriptScore", params -> { diff --git a/test/framework/src/main/java/org/elasticsearch/script/MockScriptEngine.java b/test/framework/src/main/java/org/elasticsearch/script/MockScriptEngine.java index b86cb9ff29352..e608bd13d2559 100644 --- a/test/framework/src/main/java/org/elasticsearch/script/MockScriptEngine.java +++ b/test/framework/src/main/java/org/elasticsearch/script/MockScriptEngine.java @@ -33,6 +33,7 @@ import java.io.IOException; import java.util.Collections; import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.function.Function; @@ -115,6 +116,18 @@ public String execute() { } else if (context.instanceClazz.equals(ScoreScript.class)) { ScoreScript.Factory factory = new MockScoreScript(script); return context.factoryClazz.cast(factory); + } else if (context.instanceClazz.equals(ScriptedMetricAggContexts.InitScript.class)) { + ScriptedMetricAggContexts.InitScript.Factory factory = mockCompiled::createMetricAggInitScript; + return context.factoryClazz.cast(factory); + } else if (context.instanceClazz.equals(ScriptedMetricAggContexts.MapScript.class)) { + ScriptedMetricAggContexts.MapScript.Factory factory = mockCompiled::createMetricAggMapScript; + return context.factoryClazz.cast(factory); + } else if (context.instanceClazz.equals(ScriptedMetricAggContexts.CombineScript.class)) { + ScriptedMetricAggContexts.CombineScript.Factory factory = mockCompiled::createMetricAggCombineScript; + return context.factoryClazz.cast(factory); + } else if (context.instanceClazz.equals(ScriptedMetricAggContexts.ReduceScript.class)) { + ScriptedMetricAggContexts.ReduceScript.Factory factory = mockCompiled::createMetricAggReduceScript; + return context.factoryClazz.cast(factory); } throw new IllegalArgumentException("mock script engine does not know how to handle context [" + context.name + "]"); } @@ -179,6 +192,23 @@ public SimilarityWeightScript createSimilarityWeightScript() { public MovingFunctionScript createMovingFunctionScript() { return new MockMovingFunctionScript(); } + + public ScriptedMetricAggContexts.InitScript createMetricAggInitScript(Map params, Object state) { + return new MockMetricAggInitScript(params, state, script != null ? script : ctx -> 42d); + } + + public ScriptedMetricAggContexts.MapScript.LeafFactory createMetricAggMapScript(Map params, Object state, + SearchLookup lookup) { + return new MockMetricAggMapScript(params, state, lookup, script != null ? script : ctx -> 42d); + } + + public ScriptedMetricAggContexts.CombineScript createMetricAggCombineScript(Map params, Object state) { + return new MockMetricAggCombineScript(params, state, script != null ? script : ctx -> 42d); + } + + public ScriptedMetricAggContexts.ReduceScript createMetricAggReduceScript(Map params, List states) { + return new MockMetricAggReduceScript(params, states, script != null ? script : ctx -> 42d); + } } public class MockExecutableScript implements ExecutableScript { @@ -333,6 +363,108 @@ public double execute(Query query, Field field, Term term) throws IOException { } } + public static class MockMetricAggInitScript extends ScriptedMetricAggContexts.InitScript { + private final Function, Object> script; + + MockMetricAggInitScript(Map params, Object state, + Function, Object> script) { + super(params, state); + this.script = script; + } + + public void execute() { + Map map = new HashMap<>(); + + if (getParams() != null) { + map.putAll(getParams()); // TODO: remove this once scripts know to look for params under params key + map.put("params", getParams()); + } + + map.put("state", getState()); + script.apply(map); + } + } + + public static class MockMetricAggMapScript implements ScriptedMetricAggContexts.MapScript.LeafFactory { + private final Map params; + private final Object state; + private final SearchLookup lookup; + private final Function, Object> script; + + MockMetricAggMapScript(Map params, Object state, SearchLookup lookup, + Function, Object> script) { + this.params = params; + this.state = state; + this.lookup = lookup; + this.script = script; + } + + @Override + public ScriptedMetricAggContexts.MapScript newInstance(LeafReaderContext context) { + return new ScriptedMetricAggContexts.MapScript(params, state, lookup, context) { + @Override + public void execute() { + Map map = new HashMap<>(); + + if (getParams() != null) { + map.putAll(getParams()); // TODO: remove this once scripts know to look for params under params key + map.put("params", getParams()); + } + + map.put("state", getState()); + map.put("doc", getDoc()); + map.put("_score", get_score()); + + script.apply(map); + } + }; + } + } + + public static class MockMetricAggCombineScript extends ScriptedMetricAggContexts.CombineScript { + private final Function, Object> script; + + MockMetricAggCombineScript(Map params, Object state, + Function, Object> script) { + super(params, state); + this.script = script; + } + + public Object execute() { + Map map = new HashMap<>(); + + if (getParams() != null) { + map.putAll(getParams()); // TODO: remove this once scripts know to look for params under params key + map.put("params", getParams()); + } + + map.put("state", getState()); + return script.apply(map); + } + } + + public static class MockMetricAggReduceScript extends ScriptedMetricAggContexts.ReduceScript { + private final Function, Object> script; + + MockMetricAggReduceScript(Map params, List states, + Function, Object> script) { + super(params, states); + this.script = script; + } + + public Object execute() { + Map map = new HashMap<>(); + + if (getParams() != null) { + map.putAll(getParams()); // TODO: remove this once scripts know to look for params under params key + map.put("params", getParams()); + } + + map.put("states", getStates()); + return script.apply(map); + } + } + public static Script mockInlineScript(final String script) { return new Script(ScriptType.INLINE, "mock", script, emptyMap()); } @@ -343,15 +475,15 @@ public double execute(Map params, double[] values) { return MovingFunctions.unweightedAvg(values); } } - + public class MockScoreScript implements ScoreScript.Factory { - + private final Function, Object> scripts; - + MockScoreScript(Function, Object> scripts) { this.scripts = scripts; } - + @Override public ScoreScript.LeafFactory newFactory(Map params, SearchLookup lookup) { return new ScoreScript.LeafFactory() { @@ -359,7 +491,7 @@ public ScoreScript.LeafFactory newFactory(Map params, SearchLook public boolean needs_score() { return true; } - + @Override public ScoreScript newInstance(LeafReaderContext ctx) throws IOException { Scorer[] scorerHolder = new Scorer[1]; @@ -373,7 +505,7 @@ public double execute() { } return ((Number) scripts.apply(vars)).doubleValue(); } - + @Override public void setScorer(Scorer scorer) { scorerHolder[0] = scorer;