Skip to content

Commit

Permalink
insilico
Browse files Browse the repository at this point in the history
  • Loading branch information
talevy committed Mar 6, 2020
1 parent 9e87913 commit bbddbcd
Show file tree
Hide file tree
Showing 7 changed files with 650 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
package org.elasticsearch.search.aggregations.support;

import org.apache.lucene.index.LeafReaderContext;
import org.elasticsearch.index.fielddata.MultiGeoPointValues;
import org.elasticsearch.index.fielddata.SortedNumericDoubleValues;
import org.elasticsearch.index.query.QueryShardContext;

Expand Down Expand Up @@ -51,6 +52,41 @@ public SortedNumericDoubleValues getField(String fieldName, LeafReaderContext ct
}
}

public static class AnyMultiValuesSource extends MultiValuesSource<ValuesSource> {
public AnyMultiValuesSource(Map<String, ValuesSourceConfig<ValuesSource>> valuesSourceConfigs,
QueryShardContext context) {
values = new HashMap<>(valuesSourceConfigs.size());
for (Map.Entry<String, ValuesSourceConfig<ValuesSource>> entry : valuesSourceConfigs.entrySet()) {
values.put(entry.getKey(), entry.getValue().toValuesSource(context));
}
}

private ValuesSource getField(String fieldName) {
ValuesSource valuesSource = values.get(fieldName);
if (valuesSource == null) {
throw new IllegalArgumentException("Could not find field name [" + fieldName + "] in multiValuesSource");
}
return valuesSource;
}

public SortedNumericDoubleValues getNumericField(String fieldName, LeafReaderContext ctx) throws IOException {
ValuesSource valuesSource = getField(fieldName);
if (valuesSource instanceof ValuesSource.Numeric) {
return ((ValuesSource.Numeric) valuesSource).doubleValues(ctx);
}
throw new IllegalArgumentException("field [" + fieldName + "] is not a numeric type");
}

public MultiGeoPointValues getGeoPointField(String fieldName, LeafReaderContext ctx) {
ValuesSource valuesSource = getField(fieldName);
if (valuesSource instanceof ValuesSource.GeoPoint) {
return ((ValuesSource.GeoPoint) valuesSource).geoPointValues(ctx);
}
throw new IllegalArgumentException("field [" + fieldName + "] is not a geo_point type");
}

}

public boolean needsScores() {
return values.values().stream().anyMatch(ValuesSource::needsScores);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.spatial.search.aggregations;

import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.index.query.QueryShardContext;
import org.elasticsearch.search.DocValueFormat;
import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.aggregations.AggregatorFactories;
import org.elasticsearch.search.aggregations.AggregatorFactory;
import org.elasticsearch.search.aggregations.support.MultiValuesSourceAggregationBuilder;
import org.elasticsearch.search.aggregations.support.MultiValuesSourceAggregatorFactory;
import org.elasticsearch.search.aggregations.support.MultiValuesSourceFieldConfig;
import org.elasticsearch.search.aggregations.support.MultiValuesSourceParseHelper;
import org.elasticsearch.search.aggregations.support.ValueType;
import org.elasticsearch.search.aggregations.support.ValuesSource;
import org.elasticsearch.search.aggregations.support.ValuesSourceConfig;

import java.io.IOException;
import java.util.Map;
import java.util.Objects;

public class GeoLineAggregationBuilder
extends MultiValuesSourceAggregationBuilder<ValuesSource, GeoLineAggregationBuilder> {

static final ParseField GEO_POINT_FIELD = new ParseField("geo_point");
static final ParseField SORT_FIELD = new ParseField("sort");

static final String NAME = "geo_line";

private static final ObjectParser<GeoLineAggregationBuilder, Void> PARSER;
static {
PARSER = new ObjectParser<>(NAME);
MultiValuesSourceParseHelper.declareCommon(PARSER, true, ValueType.NUMERIC);
MultiValuesSourceParseHelper.declareField(GEO_POINT_FIELD.getPreferredName(), PARSER, true, false);
MultiValuesSourceParseHelper.declareField(SORT_FIELD.getPreferredName(), PARSER, true, false);
}

GeoLineAggregationBuilder(String name) {
super(name, null);
}

private GeoLineAggregationBuilder(GeoLineAggregationBuilder clone,
AggregatorFactories.Builder factoriesBuilder, Map<String, Object> metaData) {
super(clone, factoriesBuilder, metaData);
}

/**
* Read from a stream.
*/
GeoLineAggregationBuilder(StreamInput in) throws IOException {
super(in, null);
}

static AggregationBuilder parse(String aggregationName, XContentParser parser) throws IOException {
return PARSER.parse(parser, new GeoLineAggregationBuilder(aggregationName), null);
}

@Override
protected AggregationBuilder shallowCopy(AggregatorFactories.Builder factoriesBuilder, Map<String, Object> metaData) {
return new GeoLineAggregationBuilder(this, factoriesBuilder, metaData);
}

@Override
protected void innerWriteTo(StreamOutput out) {
// Do nothing, no extra state to write to stream
}

@Override
protected MultiValuesSourceAggregatorFactory<ValuesSource> innerBuild(QueryShardContext queryShardContext, Map<String,
ValuesSourceConfig<ValuesSource>> configs, DocValueFormat format, AggregatorFactory parent,
AggregatorFactories.Builder subFactoriesBuilder) throws IOException {
return new GeoLineAggregatorFactory(name, configs, format, queryShardContext, parent, subFactoriesBuilder, metaData);
}

public GeoLineAggregationBuilder value(MultiValuesSourceFieldConfig valueConfig) {
valueConfig = Objects.requireNonNull(valueConfig, "Configuration for field [" + GEO_POINT_FIELD + "] cannot be null");
field(GEO_POINT_FIELD.getPreferredName(), valueConfig);
return this;
}

public GeoLineAggregationBuilder sort(MultiValuesSourceFieldConfig sortConfig) {
sortConfig = Objects.requireNonNull(sortConfig, "Configuration for field [" + SORT_FIELD + "] cannot be null");
field(SORT_FIELD.getPreferredName(), sortConfig);
return this;
}

@Override
public XContentBuilder doXContentBody(XContentBuilder builder, ToXContent.Params params) {
return builder;
}

@Override
public String getType() {
return NAME;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.spatial.search.aggregations;

import org.apache.lucene.geo.GeoEncodingUtils;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.NumericUtils;
import org.elasticsearch.common.geo.GeoPoint;
import org.elasticsearch.common.lease.Releasables;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.IntArray;
import org.elasticsearch.common.util.ObjectArray;
import org.elasticsearch.index.fielddata.MultiGeoPointValues;
import org.elasticsearch.index.fielddata.SortedNumericDoubleValues;
import org.elasticsearch.search.aggregations.AggregationExecutionException;
import org.elasticsearch.search.aggregations.Aggregator;
import org.elasticsearch.search.aggregations.InternalAggregation;
import org.elasticsearch.search.aggregations.LeafBucketCollector;
import org.elasticsearch.search.aggregations.LeafBucketCollectorBase;
import org.elasticsearch.search.aggregations.metrics.MetricsAggregator;
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
import org.elasticsearch.search.aggregations.support.MultiValuesSource;
import org.elasticsearch.search.internal.SearchContext;

import java.io.IOException;
import java.util.List;
import java.util.Map;

import static org.elasticsearch.xpack.spatial.search.aggregations.GeoLineAggregationBuilder.GEO_POINT_FIELD;
import static org.elasticsearch.xpack.spatial.search.aggregations.GeoLineAggregationBuilder.SORT_FIELD;

/**
* Metric Aggregation for computing the pearson product correlation coefficient between multiple fields
**/
final class GeoLineAggregator extends MetricsAggregator {
/** Multiple ValuesSource with field names */
private final MultiValuesSource.AnyMultiValuesSource valuesSources;

private ObjectArray<long[]> paths;
private ObjectArray<double[]> sortValues;
private IntArray idxs;

GeoLineAggregator(String name, MultiValuesSource.AnyMultiValuesSource valuesSources, SearchContext context,
Aggregator parent, List<PipelineAggregator> pipelineAggregators,
Map<String,Object> metaData) throws IOException {
super(name, context, parent, pipelineAggregators, metaData);
this.valuesSources = valuesSources;
if (valuesSources != null) {
paths = context.bigArrays().newObjectArray(1);
sortValues = context.bigArrays().newObjectArray(1);
idxs = context.bigArrays().newIntArray(1);
}
}

@Override
public ScoreMode scoreMode() {
if (valuesSources != null && valuesSources.needsScores()) {
return ScoreMode.COMPLETE;
}
return super.scoreMode();
}

@Override
public LeafBucketCollector getLeafCollector(LeafReaderContext ctx,
final LeafBucketCollector sub) throws IOException {
if (valuesSources == null) {
return LeafBucketCollector.NO_OP_COLLECTOR;
}
final BigArrays bigArrays = context.bigArrays();
MultiGeoPointValues docGeoPointValues = valuesSources.getGeoPointField(GEO_POINT_FIELD.getPreferredName(), ctx);
SortedNumericDoubleValues docSortValues = valuesSources.getNumericField(SORT_FIELD.getPreferredName(), ctx);

return new LeafBucketCollectorBase(sub, docGeoPointValues) {
@Override
public void collect(int doc, long bucket) throws IOException {
paths = bigArrays.grow(paths, bucket + 1);
if (docGeoPointValues.advanceExact(doc) && docSortValues.advanceExact(doc)) {
if (docSortValues.docValueCount() > 1) {
throw new AggregationExecutionException("Encountered more than one sort value for a " +
"single document. Use a script to combine multiple sort-values-per-doc into a single value.");
}
if (docGeoPointValues.docValueCount() > 1) {
throw new AggregationExecutionException("Encountered more than one geo_point value for a " +
"single document. Use a script to combine multiple geo_point-values-per-doc into a single value.");
}

// There should always be one weight if advanceExact lands us here, either
// a real weight or a `missing` weight
assert docSortValues.docValueCount() == 1;
assert docGeoPointValues.docValueCount() == 1;
final double sort = docSortValues.nextValue();
final GeoPoint point = docGeoPointValues.nextValue();

int idx = idxs.get(bucket);
long[] bucketLine = paths.get(bucket);
double[] sortVals = sortValues.get(bucket);
if (bucketLine == null) {
bucketLine = new long[10];
} else {
bucketLine = ArrayUtil.grow(bucketLine, idx + 1);
}


if (sortVals == null) {
sortVals = new double[10];
} else {
sortVals = ArrayUtil.grow(sortVals, idx + 1);
}

int encodedLat = GeoEncodingUtils.encodeLatitude(point.lat());
int encodedLon = GeoEncodingUtils.encodeLongitude(point.lon());
long lonLat = (((long) encodedLon) << 32) | (encodedLat & 0xffffffffL);

sortVals[idx] = sort;
bucketLine[idx] = lonLat;

paths.set(bucket, bucketLine);
sortValues.set(bucket, sortVals);
idxs.set(bucket, idx + 1);
}
}
};
}

@Override
public InternalAggregation buildAggregation(long bucket) {
if (valuesSources == null) {
return buildEmptyAggregation();
}
long[] bucketLine = paths.get(bucket);
double[] sortVals = sortValues.get(bucket);
int length = idxs.get(bucket);
new PathArraySorter(bucketLine, sortVals, length).sort();
return new InternalGeoLine(name, bucketLine, sortVals, length, pipelineAggregators(), metaData());
}

@Override
public InternalAggregation buildEmptyAggregation() {
return new InternalGeoLine(name, null, null, 0, pipelineAggregators(), metaData());
}

@Override
public void doClose() {
Releasables.close(paths, idxs, sortValues);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.spatial.search.aggregations;

import org.elasticsearch.index.query.QueryShardContext;
import org.elasticsearch.search.DocValueFormat;
import org.elasticsearch.search.aggregations.Aggregator;
import org.elasticsearch.search.aggregations.AggregatorFactories;
import org.elasticsearch.search.aggregations.AggregatorFactory;
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
import org.elasticsearch.search.aggregations.support.MultiValuesSource;
import org.elasticsearch.search.aggregations.support.MultiValuesSourceAggregatorFactory;
import org.elasticsearch.search.aggregations.support.ValuesSource;
import org.elasticsearch.search.aggregations.support.ValuesSourceConfig;
import org.elasticsearch.search.internal.SearchContext;

import java.io.IOException;
import java.util.List;
import java.util.Map;

final class GeoLineAggregatorFactory extends MultiValuesSourceAggregatorFactory<ValuesSource> {

GeoLineAggregatorFactory(String name,
Map<String, ValuesSourceConfig<ValuesSource>> configs,
DocValueFormat format, QueryShardContext queryShardContext, AggregatorFactory parent,
AggregatorFactories.Builder subFactoriesBuilder,
Map<String, Object> metaData) throws IOException {
super(name, configs, format, queryShardContext, parent, subFactoriesBuilder, metaData);
}

@Override
protected Aggregator createUnmapped(SearchContext searchContext, Aggregator parent, List<PipelineAggregator> pipelineAggregators,
Map<String, Object> metaData) throws IOException {
return new GeoLineAggregator(name, null, searchContext, parent, pipelineAggregators, metaData);
}

@Override
protected Aggregator doCreateInternal(SearchContext searchContext, Map<String, ValuesSourceConfig<ValuesSource>> configs,
DocValueFormat format, Aggregator parent, boolean collectsFromSingleBucket,
List<PipelineAggregator> pipelineAggregators, Map<String, Object> metaData) throws IOException {
MultiValuesSource.AnyMultiValuesSource valuesSources = new MultiValuesSource
.AnyMultiValuesSource(configs, searchContext.getQueryShardContext());
return new GeoLineAggregator(name, valuesSources, searchContext, parent, pipelineAggregators, metaData);
}
}
Loading

0 comments on commit bbddbcd

Please sign in to comment.