Skip to content

Remove generics and target value type from MultiVSAB #51647

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,19 +31,20 @@
import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.aggregations.AggregatorFactories.Builder;
import org.elasticsearch.search.aggregations.AggregatorFactory;
import org.elasticsearch.search.aggregations.support.CoreValuesSourceType;
import org.elasticsearch.search.aggregations.support.MultiValuesSourceAggregationBuilder;
import org.elasticsearch.search.aggregations.support.MultiValuesSourceAggregatorFactory;
import org.elasticsearch.search.aggregations.support.MultiValuesSourceFieldConfig;
import org.elasticsearch.search.aggregations.support.MultiValuesSourceParseHelper;
import org.elasticsearch.search.aggregations.support.ValueType;
import org.elasticsearch.search.aggregations.support.ValuesSource.Numeric;
import org.elasticsearch.search.aggregations.support.ValuesSourceConfig;
import org.elasticsearch.search.aggregations.support.ValuesSourceType;

import java.io.IOException;
import java.util.Map;
import java.util.Objects;

public class WeightedAvgAggregationBuilder extends MultiValuesSourceAggregationBuilder.LeafOnly<Numeric, WeightedAvgAggregationBuilder> {
public class WeightedAvgAggregationBuilder extends MultiValuesSourceAggregationBuilder.LeafOnly<WeightedAvgAggregationBuilder> {
public static final String NAME = "weighted_avg";
public static final ParseField VALUE_FIELD = new ParseField("value");
public static final ParseField WEIGHT_FIELD = new ParseField("weight");
Expand All @@ -61,7 +62,7 @@ public static AggregationBuilder parse(String aggregationName, XContentParser pa
}

public WeightedAvgAggregationBuilder(String name) {
super(name, ValueType.NUMERIC);
super(name);
}

public WeightedAvgAggregationBuilder(WeightedAvgAggregationBuilder clone, Builder factoriesBuilder, Map<String, Object> metaData) {
Expand All @@ -84,25 +85,30 @@ public WeightedAvgAggregationBuilder weight(MultiValuesSourceFieldConfig weightC
* Read from a stream.
*/
public WeightedAvgAggregationBuilder(StreamInput in) throws IOException {
super(in, ValueType.NUMERIC);
super(in);
}

@Override
protected AggregationBuilder shallowCopy(Builder factoriesBuilder, Map<String, Object> metaData) {
return new WeightedAvgAggregationBuilder(this, factoriesBuilder, metaData);
}

@Override
protected ValuesSourceType defaultValueSourceType() {
return CoreValuesSourceType.NUMERIC;
}

@Override
protected void innerWriteTo(StreamOutput out) {
// Do nothing, no extra state to write to stream
}

@Override
protected MultiValuesSourceAggregatorFactory<Numeric> innerBuild(QueryShardContext queryShardContext,
Map<String, ValuesSourceConfig> configs,
DocValueFormat format,
AggregatorFactory parent,
Builder subFactoriesBuilder) throws IOException {
protected MultiValuesSourceAggregatorFactory innerBuild(QueryShardContext queryShardContext,
Map<String, ValuesSourceConfig> configs,
DocValueFormat format,
AggregatorFactory parent,
Builder subFactoriesBuilder) throws IOException {
return new WeightedAvgAggregatorFactory(name, configs, format, queryShardContext, parent, subFactoriesBuilder, metaData);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,14 @@
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
import org.elasticsearch.search.aggregations.support.MultiValuesSource;
import org.elasticsearch.search.aggregations.support.MultiValuesSourceAggregatorFactory;
import org.elasticsearch.search.aggregations.support.ValuesSource.Numeric;
import org.elasticsearch.search.aggregations.support.ValuesSourceConfig;
import org.elasticsearch.search.internal.SearchContext;

import java.io.IOException;
import java.util.List;
import java.util.Map;

class WeightedAvgAggregatorFactory extends MultiValuesSourceAggregatorFactory<Numeric> {
class WeightedAvgAggregatorFactory extends MultiValuesSourceAggregatorFactory {

WeightedAvgAggregatorFactory(String name, Map<String, ValuesSourceConfig> configs,
DocValueFormat format, QueryShardContext queryShardContext, AggregatorFactory parent,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,18 +40,18 @@
*
* A limitation of this class is that all the ValuesSource's being refereenced must be of the same type.
*/
public abstract class MultiValuesSourceAggregationBuilder<VS extends ValuesSource, AB extends MultiValuesSourceAggregationBuilder<VS, AB>>
public abstract class MultiValuesSourceAggregationBuilder<AB extends MultiValuesSourceAggregationBuilder<AB>>
extends AbstractAggregationBuilder<AB> {


public abstract static class LeafOnly<VS extends ValuesSource, AB extends MultiValuesSourceAggregationBuilder<VS, AB>>
extends MultiValuesSourceAggregationBuilder<VS, AB> {
public abstract static class LeafOnly<AB extends MultiValuesSourceAggregationBuilder<AB>>
extends MultiValuesSourceAggregationBuilder<AB> {

protected LeafOnly(String name, ValueType targetValueType) {
super(name, targetValueType);
protected LeafOnly(String name) {
super(name);
}

protected LeafOnly(LeafOnly<VS, AB> clone, Builder factoriesBuilder, Map<String, Object> metaData) {
protected LeafOnly(LeafOnly<AB> clone, Builder factoriesBuilder, Map<String, Object> metaData) {
super(clone, factoriesBuilder, metaData);
if (factoriesBuilder.count() > 0) {
throw new AggregationInitializationException("Aggregator [" + name + "] of type ["
Expand All @@ -62,8 +62,8 @@ protected LeafOnly(LeafOnly<VS, AB> clone, Builder factoriesBuilder, Map<String,
/**
* Read from a stream that does not serialize its targetValueType. This should be used by most subclasses.
*/
protected LeafOnly(StreamInput in, ValueType targetValueType) throws IOException {
super(in, targetValueType);
protected LeafOnly(StreamInput in) throws IOException {
super(in);
}

@Override
Expand All @@ -76,30 +76,28 @@ public AB subAggregations(Builder subFactories) {


private Map<String, MultiValuesSourceFieldConfig> fields = new HashMap<>();
private final ValueType targetValueType;
private ValueType valueType = null;
private ValueType userValueTypeHint = null;
private String format = null;

protected MultiValuesSourceAggregationBuilder(String name, ValueType targetValueType) {
protected MultiValuesSourceAggregationBuilder(String name) {
super(name);
this.targetValueType = targetValueType;
}

protected MultiValuesSourceAggregationBuilder(MultiValuesSourceAggregationBuilder<VS, AB> clone,
protected MultiValuesSourceAggregationBuilder(MultiValuesSourceAggregationBuilder<AB> clone,
Builder factoriesBuilder, Map<String, Object> metaData) {
super(clone, factoriesBuilder, metaData);

this.fields = new HashMap<>(clone.fields);
this.targetValueType = clone.targetValueType;
this.valueType = clone.valueType;
this.userValueTypeHint = clone.userValueTypeHint;
this.format = clone.format;
}

protected MultiValuesSourceAggregationBuilder(StreamInput in, ValueType targetValueType)
/**
* Read from a stream.
*/
protected MultiValuesSourceAggregationBuilder(StreamInput in)
throws IOException {
super(in);
assert false == serializeTargetValueType() : "Wrong read constructor called for subclass that provides its targetValueType";
this.targetValueType = targetValueType;
read(in);
}

Expand All @@ -109,17 +107,14 @@ protected MultiValuesSourceAggregationBuilder(StreamInput in, ValueType targetVa
@SuppressWarnings("unchecked")
private void read(StreamInput in) throws IOException {
fields = in.readMap(StreamInput::readString, MultiValuesSourceFieldConfig::new);
valueType = in.readOptionalWriteable(ValueType::readFromStream);
userValueTypeHint = in.readOptionalWriteable(ValueType::readFromStream);
format = in.readOptionalString();
}

@Override
protected final void doWriteTo(StreamOutput out) throws IOException {
if (serializeTargetValueType()) {
out.writeOptionalWriteable(targetValueType);
}
out.writeMap(fields, StreamOutput::writeString, (o, value) -> value.writeTo(o));
out.writeOptionalWriteable(valueType);
out.writeOptionalWriteable(userValueTypeHint);
out.writeOptionalString(format);
innerWriteTo(out);
}
Expand All @@ -142,11 +137,11 @@ protected AB field(String propertyName, MultiValuesSourceFieldConfig config) {
* Sets the {@link ValueType} for the value produced by this aggregation
*/
@SuppressWarnings("unchecked")
public AB valueType(ValueType valueType) {
public AB userValueTypeHint(ValueType valueType) {
if (valueType == null) {
throw new IllegalArgumentException("[valueType] must not be null: [" + name + "]");
throw new IllegalArgumentException("[userValueTypeHint] must not be null: [" + name + "]");
}
this.valueType = valueType;
this.userValueTypeHint = valueType;
return (AB) this;
}

Expand All @@ -162,25 +157,34 @@ public AB format(String format) {
return (AB) this;
}

@Override
protected final MultiValuesSourceAggregatorFactory<VS> doBuild(QueryShardContext queryShardContext, AggregatorFactory parent,
Builder subFactoriesBuilder) throws IOException {
ValueType finalValueType = this.valueType != null ? this.valueType : targetValueType;
/**
* Aggregations should use this method to define a {@link ValuesSourceType} of last resort. This will only be used when the resolver
* can't find a field and the user hasn't provided a value type hint.
*
* @return The CoreValuesSourceType we expect this script to yield.
*/
protected abstract ValuesSourceType defaultValueSourceType();

@Override
protected final MultiValuesSourceAggregatorFactory doBuild(QueryShardContext queryShardContext, AggregatorFactory parent,
Builder subFactoriesBuilder) throws IOException {
Map<String, ValuesSourceConfig> configs = new HashMap<>(fields.size());
fields.forEach((key, value) -> {
ValuesSourceConfig config = ValuesSourceConfig.resolve(queryShardContext, finalValueType,
value.getFieldName(), value.getScript(), value.getMissing(), value.getTimeZone(), format, getType());
ValuesSourceConfig config = ValuesSourceConfig.resolve(queryShardContext, userValueTypeHint,
value.getFieldName(), value.getScript(), value.getMissing(), value.getTimeZone(), format, defaultValueSourceType(),
getType());
configs.put(key, config);
});
DocValueFormat docValueFormat = resolveFormat(format, finalValueType);
DocValueFormat docValueFormat = resolveFormat(format, userValueTypeHint, defaultValueSourceType());
return innerBuild(queryShardContext, configs, docValueFormat, parent, subFactoriesBuilder);
}


private static DocValueFormat resolveFormat(@Nullable String format, @Nullable ValueType valueType) {
private static DocValueFormat resolveFormat(@Nullable String format, @Nullable ValueType valueType,
ValuesSourceType defaultValuesSourceType) {
if (valueType == null) {
return DocValueFormat.RAW; // we can't figure it out
// If the user didn't send a hint, all we can do is fall back to the default
return defaultValuesSourceType.getFormatter(format, null);
}
DocValueFormat valueFormat = valueType.defaultFormat;
if (valueFormat instanceof DocValueFormat.Decimal && format != null) {
Expand All @@ -189,19 +193,11 @@ private static DocValueFormat resolveFormat(@Nullable String format, @Nullable V
return valueFormat;
}

protected abstract MultiValuesSourceAggregatorFactory<VS> innerBuild(QueryShardContext queryShardContext,
Map<String, ValuesSourceConfig> configs,
DocValueFormat format, AggregatorFactory parent,
Builder subFactoriesBuilder) throws IOException;

protected abstract MultiValuesSourceAggregatorFactory innerBuild(QueryShardContext queryShardContext,
Map<String, ValuesSourceConfig> configs,
DocValueFormat format, AggregatorFactory parent,
Builder subFactoriesBuilder) throws IOException;

/**
* Should this builder serialize its targetValueType? Defaults to false. All subclasses that override this to true
* should use the three argument read constructor rather than the four argument version.
*/
protected boolean serializeTargetValueType() {
return false;
}

@Override
public final XContentBuilder internalXContent(XContentBuilder builder, Params params) throws IOException {
Expand All @@ -214,8 +210,8 @@ public final XContentBuilder internalXContent(XContentBuilder builder, Params pa
if (format != null) {
builder.field(CommonFields.FORMAT.getPreferredName(), format);
}
if (valueType != null) {
builder.field(CommonFields.VALUE_TYPE.getPreferredName(), valueType.getPreferredName());
if (userValueTypeHint != null) {
builder.field(CommonFields.VALUE_TYPE.getPreferredName(), userValueTypeHint.getPreferredName());
}
doXContentBody(builder, params);
builder.endObject();
Expand All @@ -226,7 +222,7 @@ public final XContentBuilder internalXContent(XContentBuilder builder, Params pa

@Override
public int hashCode() {
return Objects.hash(super.hashCode(), fields, format, targetValueType, valueType);
return Objects.hash(super.hashCode(), fields, format, userValueTypeHint);
}


Expand All @@ -239,6 +235,6 @@ public boolean equals(Object obj) {
MultiValuesSourceAggregationBuilder other = (MultiValuesSourceAggregationBuilder) obj;
return Objects.equals(this.fields, other.fields)
&& Objects.equals(this.format, other.format)
&& Objects.equals(this.valueType, other.valueType);
&& Objects.equals(this.userValueTypeHint, other.userValueTypeHint);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@
import java.util.List;
import java.util.Map;

public abstract class MultiValuesSourceAggregatorFactory<VS extends ValuesSource>
extends AggregatorFactory {
public abstract class MultiValuesSourceAggregatorFactory extends AggregatorFactory {

protected final Map<String, ValuesSourceConfig> configs;
protected final DocValueFormat format;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@
public final class MultiValuesSourceParseHelper {

public static <VS extends ValuesSource, T> void declareCommon(
AbstractObjectParser<? extends MultiValuesSourceAggregationBuilder<VS, ?>, T> objectParser, boolean formattable,
ValueType expectedValueType) {
AbstractObjectParser<? extends MultiValuesSourceAggregationBuilder<?>, T> objectParser, boolean formattable,
ValueType expectedValueType) {

objectParser.declareField(MultiValuesSourceAggregationBuilder::valueType, p -> {
objectParser.declareField(MultiValuesSourceAggregationBuilder::userValueTypeHint, p -> {
ValueType valueType = ValueType.resolveForScript(p.text());
if (expectedValueType != null && valueType.isNotA(expectedValueType)) {
throw new ParsingException(p.getTokenLocation(),
Expand All @@ -49,7 +49,7 @@ public static <VS extends ValuesSource, T> void declareCommon(
}

public static <VS extends ValuesSource, T> void declareField(String fieldName,
AbstractObjectParser<? extends MultiValuesSourceAggregationBuilder<VS, ?>, T> objectParser,
AbstractObjectParser<? extends MultiValuesSourceAggregationBuilder<?>, T> objectParser,
boolean scriptable, boolean timezoneAware) {

objectParser.declareField((o, fieldConfig) -> o.field(fieldName, fieldConfig.build()),
Expand Down
Loading