Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
@JsonTypeName("variance")
public class VarianceAggregatorFactory extends AggregatorFactory
{
private static final String VARIANCE_TYPE_NAME = "variance";
public static final String VARIANCE_TYPE_NAME = "variance";
public static final ColumnType TYPE = ColumnType.ofComplex(VARIANCE_TYPE_NAME);

protected final String fieldName;
Expand Down
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think a lot of boilerplate code here can be removed if you add a class called DruidSqlAvgAggFunction and in that class, the constructor passes the Any operand type. Then all the variance and std dev aggregation functions can extend that class and simply pass the kind value

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the approach of #14249 which is adding the SqlAggFunction equivalent of OperatorConversions.OperatorBuilder is probably the better way to do this

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's replacing the constructor with fluent API but you will still need to reuse stuff.

Copy link
Copy Markdown
Member

@clintropolis clintropolis Jun 23, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, I guess i was thinking with that pattern can just have a function that makes SqlAggFunction based on the parts that change instead of a bunch of separate classes. The other PR is removing many of the custom SqlAggFunction to replace with the builder so making a more complicated setup here with base class seems like it would conflict even more with the goals there

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed this have a function that makes SqlAggFunction and uses the aggregator builder

Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.SqlFunctionCategory;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.type.OperandTypes;
import org.apache.calcite.sql.type.ReturnTypes;
import org.apache.druid.java.util.common.IAE;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.query.aggregation.AggregatorFactory;
Expand All @@ -42,15 +45,33 @@
import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
import org.apache.druid.sql.calcite.expression.DruidExpression;
import org.apache.druid.sql.calcite.expression.Expressions;
import org.apache.druid.sql.calcite.expression.OperatorConversions;
import org.apache.druid.sql.calcite.planner.Calcites;
import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import org.apache.druid.sql.calcite.table.RowSignatures;

import javax.annotation.Nullable;
import java.util.List;

public abstract class BaseVarianceSqlAggregator implements SqlAggregator
{
private static final String VARIANCE_NAME = "VARIANCE";
private static final String STDDEV_NAME = "STDDEV";

private static final SqlAggFunction VARIANCE_SQL_AGG_FUNC_INSTANCE =
buildSqlAvgAggFunction(VARIANCE_NAME);
private static final SqlAggFunction VARIANCE_POP_SQL_AGG_FUNC_INSTANCE =
buildSqlAvgAggFunction(SqlKind.VAR_POP.name());
private static final SqlAggFunction VARIANCE_SAMP_SQL_AGG_FUNC_INSTANCE =
buildSqlAvgAggFunction(SqlKind.VAR_SAMP.name());
private static final SqlAggFunction STDDEV_SQL_AGG_FUNC_INSTANCE =
buildSqlAvgAggFunction(STDDEV_NAME);
private static final SqlAggFunction STDDEV_POP_SQL_AGG_FUNC_INSTANCE =
buildSqlAvgAggFunction(SqlKind.STDDEV_POP.name());
private static final SqlAggFunction STDDEV_SAMP_SQL_AGG_FUNC_INSTANCE =
buildSqlAvgAggFunction(SqlKind.STDDEV_SAMP.name());

@Nullable
@Override
public Aggregation toDruidAggregation(
Expand Down Expand Up @@ -104,12 +125,13 @@ public Aggregation toDruidAggregation(

if (inputType.isNumeric()) {
inputTypeName = StringUtils.toLowerCase(inputType.getType().name());
} else if (inputType.equals(VarianceAggregatorFactory.TYPE)) {
inputTypeName = VarianceAggregatorFactory.VARIANCE_TYPE_NAME;
} else {
throw new IAE("VarianceSqlAggregator[%s] has invalid inputType[%s]", func, inputType.asTypeString());
}


if (func == SqlStdOperatorTable.VAR_POP || func == SqlStdOperatorTable.STDDEV_POP) {
if (func.getName().equals(SqlKind.VAR_POP.name()) || func.getName().equals(SqlKind.STDDEV_POP.name())) {
estimator = "population";
} else {
estimator = "sample";
Expand All @@ -122,9 +144,9 @@ public Aggregation toDruidAggregation(
inputTypeName
);

if (func == SqlStdOperatorTable.STDDEV_POP
|| func == SqlStdOperatorTable.STDDEV_SAMP
|| func == SqlStdOperatorTable.STDDEV) {
if (func.getName().equals(STDDEV_NAME)
|| func.getName().equals(SqlKind.STDDEV_POP.name())
|| func.getName().equals(SqlKind.STDDEV_SAMP.name())) {
postAggregator = new StandardDeviationPostAggregator(
name,
aggregatorFactory.getName(),
Expand All @@ -137,21 +159,40 @@ public Aggregation toDruidAggregation(
);
}

/**
* Creates a {@link SqlAggFunction} that is the same as {@link org.apache.calcite.sql.fun.SqlAvgAggFunction}
* but with an operand type that accepts variance aggregator objects in addition to numeric inputs.
*/
private static SqlAggFunction buildSqlAvgAggFunction(String name)
{
return OperatorConversions
.aggregatorBuilder(name)
.returnTypeInference(ReturnTypes.AVG_AGG_FUNCTION)
.operandTypeChecker(
OperandTypes.or(
OperandTypes.NUMERIC,
RowSignatures.complexTypeChecker(VarianceAggregatorFactory.TYPE)
)
)
.functionCategory(SqlFunctionCategory.NUMERIC)
.build();
}

public static class VarPopSqlAggregator extends BaseVarianceSqlAggregator
{
@Override
public SqlAggFunction calciteFunction()
{
return SqlStdOperatorTable.VAR_POP;
return VARIANCE_POP_SQL_AGG_FUNC_INSTANCE;
}
}

public static class VarSampSqlAggregator extends BaseVarianceSqlAggregator
{
@Override
public SqlAggFunction calciteFunction()
{
return SqlStdOperatorTable.VAR_SAMP;
return VARIANCE_SAMP_SQL_AGG_FUNC_INSTANCE;
}
}

Expand All @@ -160,7 +201,7 @@ public static class VarianceSqlAggregator extends BaseVarianceSqlAggregator
@Override
public SqlAggFunction calciteFunction()
{
return SqlStdOperatorTable.VARIANCE;
return VARIANCE_SQL_AGG_FUNC_INSTANCE;
}
}

Expand All @@ -169,7 +210,7 @@ public static class StdDevPopSqlAggregator extends BaseVarianceSqlAggregator
@Override
public SqlAggFunction calciteFunction()
{
return SqlStdOperatorTable.STDDEV_POP;
return STDDEV_POP_SQL_AGG_FUNC_INSTANCE;
}
}

Expand All @@ -178,7 +219,7 @@ public static class StdDevSampSqlAggregator extends BaseVarianceSqlAggregator
@Override
public SqlAggFunction calciteFunction()
{
return SqlStdOperatorTable.STDDEV_SAMP;
return STDDEV_SAMP_SQL_AGG_FUNC_INSTANCE;
}
}

Expand All @@ -187,7 +228,7 @@ public static class StdDevSqlAggregator extends BaseVarianceSqlAggregator
@Override
public SqlAggFunction calciteFunction()
{
return SqlStdOperatorTable.STDDEV;
return STDDEV_SQL_AGG_FUNC_INSTANCE;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import org.apache.druid.query.aggregation.variance.StandardDeviationPostAggregator;
import org.apache.druid.query.aggregation.variance.VarianceAggregatorCollector;
import org.apache.druid.query.aggregation.variance.VarianceAggregatorFactory;
import org.apache.druid.query.aggregation.variance.VarianceSerde;
import org.apache.druid.query.dimension.DefaultDimensionSpec;
import org.apache.druid.query.groupby.GroupByQuery;
import org.apache.druid.query.groupby.orderby.DefaultLimitSpec;
Expand All @@ -51,6 +52,7 @@
import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.incremental.IncrementalIndexSchema;
import org.apache.druid.segment.join.JoinableFactoryWrapper;
import org.apache.druid.segment.serde.ComplexMetrics;
import org.apache.druid.segment.virtual.ExpressionVirtualColumn;
import org.apache.druid.segment.writeout.OffHeapMemorySegmentWriteOutMediumFactory;
import org.apache.druid.sql.calcite.BaseCalciteQueryTest;
Expand Down Expand Up @@ -82,8 +84,10 @@ public SpecificSegmentsQuerySegmentWalker createQuerySegmentWalker(
final Injector injector
) throws IOException
{
ComplexMetrics.registerSerde(VarianceSerde.TYPE_NAME, new VarianceSerde());

final QueryableIndex index =
IndexBuilder.create()
IndexBuilder.create(CalciteTests.getJsonMapper().registerModules(new DruidStatsModule().getJacksonModules()))
.tmpDir(temporaryFolder.newFolder())
.segmentWriteOutMediumFactory(OffHeapMemorySegmentWriteOutMediumFactory.instance())
.schema(
Expand All @@ -100,7 +104,8 @@ public SpecificSegmentsQuerySegmentWalker createQuerySegmentWalker(
)
.withMetrics(
new CountAggregatorFactory("cnt"),
new DoubleSumAggregatorFactory("m1", "m1")
new DoubleSumAggregatorFactory("m1", "m1"),
new VarianceAggregatorFactory("var1", "m1", null, null)
)
.withRollup(false)
.build()
Expand Down Expand Up @@ -624,6 +629,55 @@ public void testGroupByAggregatorDefaultValues()
);
}

@Test
public void testVarianceAggAsInput()
{
final List<Object[]> expectedResults = ImmutableList.of(
new Object[]{
"3.5",
"2.9166666666666665",
"3.5",
"1.8708286933869707",
"1.707825127659933",
"1.8708286933869707"
}
);
testQuery(
"SELECT\n"
+ "VARIANCE(var1),\n"
+ "VAR_POP(var1),\n"
+ "VAR_SAMP(var1),\n"
+ "STDDEV(var1),\n"
+ "STDDEV_POP(var1),\n"
+ "STDDEV_SAMP(var1)\n"
+ "FROM numfoo",
ImmutableList.of(
Druids.newTimeseriesQueryBuilder()
.dataSource(CalciteTests.DATASOURCE3)
.intervals(new MultipleIntervalSegmentSpec(ImmutableList.of(Filtration.eternity())))
.granularity(Granularities.ALL)
.aggregators(
ImmutableList.of(
new VarianceAggregatorFactory("a0:agg", "var1", "sample", "variance"),
new VarianceAggregatorFactory("a1:agg", "var1", "population", "variance"),
new VarianceAggregatorFactory("a2:agg", "var1", "sample", "variance"),
new VarianceAggregatorFactory("a3:agg", "var1", "sample", "variance"),
new VarianceAggregatorFactory("a4:agg", "var1", "population", "variance"),
new VarianceAggregatorFactory("a5:agg", "var1", "sample", "variance")
)
)
.postAggregators(
new StandardDeviationPostAggregator("a3", "a3:agg", "sample"),
new StandardDeviationPostAggregator("a4", "a4:agg", "population"),
new StandardDeviationPostAggregator("a5", "a5:agg", "sample")
)
.context(BaseCalciteQueryTest.QUERY_CONTEXT_DEFAULT)
.build()
),
expectedResults
);
}

@Override
public void assertResultsEquals(String sql, List<Object[]> expectedResults, List<Object[]> results)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,18 @@
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeComparability;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.sql.SqlCallBinding;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.SqlOperandCountRange;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.type.AbstractSqlType;
import org.apache.calcite.sql.type.SqlOperandCountRanges;
import org.apache.calcite.sql.type.SqlSingleOperandTypeChecker;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.java.util.common.IAE;
import org.apache.druid.java.util.common.ISE;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.query.ordering.StringComparator;
import org.apache.druid.query.ordering.StringComparators;
import org.apache.druid.segment.column.ColumnHolder;
Expand Down Expand Up @@ -79,7 +86,9 @@ public static StringComparator getNaturalStringComparator(
{
Preconditions.checkNotNull(simpleExtraction, "simpleExtraction");
if (simpleExtraction.getExtractionFn() != null
|| rowSignature.getColumnType(simpleExtraction.getColumn()).map(type -> type.is(ValueType.STRING)).orElse(false)) {
|| rowSignature.getColumnType(simpleExtraction.getColumn())
.map(type -> type.is(ValueType.STRING))
.orElse(false)) {
return StringComparators.LEXICOGRAPHIC;
} else {
return StringComparators.NUMERIC;
Expand Down Expand Up @@ -164,7 +173,7 @@ public static RelDataType toRelDataType(
* Creates a {@link ComplexSqlType} using the supplied {@link RelDataTypeFactory} to ensure that the
* {@link ComplexSqlType} is interned. This is important because Calcite checks that the references are equal
* instead of the objects being equivalent.
*
* <p>
* This method uses {@link RelDataTypeFactory#createTypeWithNullability(RelDataType, boolean) ensures that if the
* type factory is a {@link org.apache.calcite.rel.type.RelDataTypeFactoryImpl} that the type is passed through
* {@link org.apache.calcite.rel.type.RelDataTypeFactoryImpl#canonize(RelDataType)} which interns the type.
Expand All @@ -179,15 +188,15 @@ public static RelDataType makeComplexType(RelDataTypeFactory typeFactory, Column

/**
* Calcite {@link RelDataType} for Druid complex columns, to preserve complex type information.
*
* <p>
* If using with other operations of a {@link RelDataTypeFactory}, consider wrapping the creation of this type in
* {@link RelDataTypeFactory#createTypeWithNullability(RelDataType, boolean) to ensure that if the type factory is a
* {@link org.apache.calcite.rel.type.RelDataTypeFactoryImpl} that the type is passed through
* {@link org.apache.calcite.rel.type.RelDataTypeFactoryImpl#canonize(RelDataType)} which interns the type.
*
* <p>
* If {@link SqlTypeName} is going to be {@link SqlTypeName#OTHER} and a {@link RelDataTypeFactory} is available,
* consider using {@link #makeComplexType(RelDataTypeFactory, ColumnType, boolean)}.
*
* <p>
* This type does not work well with {@link org.apache.calcite.sql.type.ReturnTypes#explicit(RelDataType)}, which
* will create new {@link RelDataType} using {@link SqlTypeName} during return type inference, so implementors of
* {@link org.apache.druid.sql.calcite.expression.SqlOperatorConversion} should implement the
Expand Down Expand Up @@ -235,4 +244,67 @@ public String asTypeString()
return columnType.asTypeString();
}
}

public static ComplexSqlSingleOperandTypeChecker complexTypeChecker(ColumnType complexType)
{
return new ComplexSqlSingleOperandTypeChecker(
new ComplexSqlType(SqlTypeName.OTHER, complexType, true)
);
}

public static final class ComplexSqlSingleOperandTypeChecker implements SqlSingleOperandTypeChecker
{
private final ComplexSqlType type;

public ComplexSqlSingleOperandTypeChecker(
ComplexSqlType type
)
{
this.type = type;
}

@Override
public boolean checkSingleOperandType(
SqlCallBinding callBinding,
SqlNode operand,
int iFormalOperand,
boolean throwOnFailure
)
{
return type.equals(callBinding.getValidator().deriveType(callBinding.getScope(), operand));
}

@Override
public boolean checkOperandTypes(SqlCallBinding callBinding, boolean throwOnFailure)
{
if (callBinding.getOperandCount() != 1) {
return false;
}
return checkSingleOperandType(callBinding, callBinding.operand(0), 0, throwOnFailure);
}

@Override
public SqlOperandCountRange getOperandCountRange()
{
return SqlOperandCountRanges.of(1);
}

@Override
public String getAllowedSignatures(SqlOperator op, String opName)
{
return StringUtils.format("'%s'(%s)", opName, type);
}

@Override
public Consistency getConsistency()
{
return Consistency.NONE;
}

@Override
public boolean isOptional(int i)
{
return false;
}
}
}