Talk about the AggregateFunction of flink Table.

  flink

Order

This article mainly studies the AggregateFunction of flink Table.

Example

/**
 * Accumulator for WeightedAvg.
 */
public static class WeightedAvgAccum {
    public long sum = 0;
    public int count = 0;
}

/**
 * Weighted Average user-defined aggregate function.
 */
public static class WeightedAvg extends AggregateFunction<Long, WeightedAvgAccum> {

    @Override
    public WeightedAvgAccum createAccumulator() {
        return new WeightedAvgAccum();
    }

    @Override
    public Long getValue(WeightedAvgAccum acc) {
        if (acc.count == 0) {
            return 0L;
        } else {
            return acc.sum / acc.count;
        }
    }

    public void accumulate(WeightedAvgAccum acc, long iValue, int iWeight) {
        acc.sum += iValue * iWeight;
        acc.count += iWeight;
    }

    public void retract(WeightedAvgAccum acc, long iValue, int iWeight) {
        acc.sum -= iValue * iWeight;
        acc.count -= iWeight;
    }
    
    public void merge(WeightedAvgAccum acc, Iterable<WeightedAvgAccum> it) {
        Iterator<WeightedAvgAccum> iter = it.iterator();
        while (iter.hasNext()) {
            WeightedAvgAccum a = iter.next();
            acc.count += a.count;
            acc.sum += a.sum;
        }
    }
    
    public void resetAccumulator(WeightedAvgAccum acc) {
        acc.count = 0;
        acc.sum = 0L;
    }
}

// register function
BatchTableEnvironment tEnv = ...
tEnv.registerFunction("wAvg", new WeightedAvg());

// use function
tEnv.sqlQuery("SELECT user, wAvg(points, level) AS avgPoints FROM userScores GROUP BY user");
  • WeightedAvg inherits AggregateFunction and implements getValue, accumulate, retract, merge, resetAccumulator methods.

AggregateFunction

flink-table_2.11-1.7.1-sources.jar! /org/apache/flink/table/functions/AggregateFunction.scala

abstract class AggregateFunction[T, ACC] extends UserDefinedFunction {
  /**
    * Creates and init the Accumulator for this [[AggregateFunction]].
    *
    * @return the accumulator with the initial value
    */
  def createAccumulator(): ACC

  /**
    * Called every time when an aggregation result should be materialized.
    * The returned value could be either an early and incomplete result
    * (periodically emitted as data arrive) or the final result of the
    * aggregation.
    *
    * @param accumulator the accumulator which contains the current
    *                    aggregated results
    * @return the aggregation result
    */
  def getValue(accumulator: ACC): T

    /**
    * Returns true if this AggregateFunction can only be applied in an OVER window.
    *
    * @return true if the AggregateFunction requires an OVER window, false otherwise.
    */
  def requiresOver: Boolean = false

  /**
    * Returns the TypeInformation of the AggregateFunction's result.
    *
    * @return The TypeInformation of the AggregateFunction's result or null if the result type
    *         should be automatically inferred.
    */
  def getResultType: TypeInformation[T] = null

  /**
    * Returns the TypeInformation of the AggregateFunction's accumulator.
    *
    * @return The TypeInformation of the AggregateFunction's accumulator or null if the
    *         accumulator type should be automatically inferred.
    */
  def getAccumulatorType: TypeInformation[ACC] = null
}
  • AggregateFunction inherits UserDefinedFunction;; It has two generics, one T for value and one ACC for Accumulator. It defines the createAccumulator, getValue, getResultType, getAccumulatorType methods (Subclasses of these methods must implement the createAccumulator and getValue methods)
  • For AggregateFunction, there is an accumulate method not defined here, but subclass definition and implementation are required. This method receives ACC,T and other parameters and returns void; ; In addition, there are three optional methods: retract, merge, and resetAccumulator, which need subclasses to be defined and implemented according to the situation.
  • For datastream bounded over aggregate operation, restract method is required to be implemented, which receives ACC,T and other parameters and returns void; . For datastreamsessionwindowgrouping aggregate and dataset grouping aggregate operations, the merge method is required to be implemented. the method receives ACC, Java.lang.iterate < t > two parameters and returns void; . For dataset grouping aggregate operation, it is required to implement the resetAccumulator method, which receives ACC parameters and returns void.

DataSetPreAggFunction

flink-table_2.11-1.7.1-sources.jar! /org/apache/flink/table/runtime/aggregate/DataSetPreAggFunction.scala

class DataSetPreAggFunction(genAggregations: GeneratedAggregationsFunction)
  extends AbstractRichFunction
  with GroupCombineFunction[Row, Row]
  with MapPartitionFunction[Row, Row]
  with Compiler[GeneratedAggregations]
  with Logging {

  private var output: Row = _
  private var accumulators: Row = _

  private var function: GeneratedAggregations = _

  override def open(config: Configuration) {
    LOG.debug(s"Compiling AggregateHelper: $genAggregations.name \n\n " +
                s"Code:\n$genAggregations.code")
    val clazz = compile(
      getRuntimeContext.getUserCodeClassLoader,
      genAggregations.name,
      genAggregations.code)
    LOG.debug("Instantiating AggregateHelper.")
    function = clazz.newInstance()

    output = function.createOutputRow()
    accumulators = function.createAccumulators()
  }

  override def combine(values: Iterable[Row], out: Collector[Row]): Unit = {
    // reset accumulators
    function.resetAccumulator(accumulators)

    val iterator = values.iterator()

    var record: Row = null
    while (iterator.hasNext) {
      record = iterator.next()
      // accumulate
      function.accumulate(accumulators, record)
    }

    // set group keys and accumulators to output
    function.setAggregationResults(accumulators, output)
    function.setForwardedFields(record, output)

    out.collect(output)
  }

  override def mapPartition(values: Iterable[Row], out: Collector[Row]): Unit = {
    combine(values, out)
  }
}
  • The combination method of DataSetPreAggFunction calls function.accumulators (records), where accumulators are of type Row [WeightedDavgaccum] and record is of type ROW; Function is a generated class. it inherits GeneratedAggregations, whose code is in genAggregations, while genAggregations are generated by the aggregateutil.createdatasetaggregatedfunctions method, which calls the aggregate method of WeightedAvg

GeneratedAggregations

flink-table_2.11-1.7.1-sources.jar! /org/apache/flink/table/runtime/aggregate/GeneratedAggregations.scala

abstract class GeneratedAggregations extends Function {

  /**
    * Setup method for [[org.apache.flink.table.functions.AggregateFunction]].
    * It can be used for initialization work. By default, this method does nothing.
    *
    * @param ctx The runtime context.
    */
  def open(ctx: RuntimeContext)

  /**
    * Sets the results of the aggregations (partial or final) to the output row.
    * Final results are computed with the aggregation function.
    * Partial results are the accumulators themselves.
    *
    * @param accumulators the accumulators (saved in a row) which contains the current
    *                     aggregated results
    * @param output       output results collected in a row
    */
  def setAggregationResults(accumulators: Row, output: Row)

  /**
    * Copies forwarded fields, such as grouping keys, from input row to output row.
    *
    * @param input        input values bundled in a row
    * @param output       output results collected in a row
    */
  def setForwardedFields(input: Row, output: Row)

  /**
    * Accumulates the input values to the accumulators.
    *
    * @param accumulators the accumulators (saved in a row) which contains the current
    *                     aggregated results
    * @param input        input values bundled in a row
    */
  def accumulate(accumulators: Row, input: Row)

  /**
    * Retracts the input values from the accumulators.
    *
    * @param accumulators the accumulators (saved in a row) which contains the current
    *                     aggregated results
    * @param input        input values bundled in a row
    */
  def retract(accumulators: Row, input: Row)

  /**
    * Initializes the accumulators and save them to a accumulators row.
    *
    * @return a row of accumulators which contains the aggregated results
    */
  def createAccumulators(): Row

  /**
    * Creates an output row object with the correct arity.
    *
    * @return an output row object with the correct arity.
    */
  def createOutputRow(): Row

  /**
    * Merges two rows of accumulators into one row.
    *
    * @param a First row of accumulators
    * @param b The other row of accumulators
    * @return A row with the merged accumulators of both input rows.
    */
  def mergeAccumulatorsPair(a: Row, b: Row): Row

  /**
    * Resets all the accumulators.
    *
    * @param accumulators the accumulators (saved in a row) which contains the current
    *                     aggregated results
    */
  def resetAccumulator(accumulators: Row)

  /**
    * Cleanup for the accumulators.
    */
  def cleanup()

  /**
    * Tear-down method for [[org.apache.flink.table.functions.AggregateFunction]].
    * It can be used for clean up work. By default, this method does nothing.
    */
  def close()
}
  • GeneratedAggregations define methods such as accumulators s: row, input: row, reset accumulators s: row, etc.

AggregateUtil

flink-table_2.11-1.7.1-sources.jar! /org/apache/flink/table/runtime/aggregate/AggregateUtil.scala

object AggregateUtil {

  type CalcitePair[T, R] = org.apache.calcite.util.Pair[T, R]
  type JavaList[T] = java.util.List[T]

  //......

  /**
    * Create functions to compute a [[org.apache.flink.table.plan.nodes.dataset.DataSetAggregate]].
    * If all aggregation functions support pre-aggregation, a pre-aggregation function and the
    * respective output type are generated as well.
    */
  private[flink] def createDataSetAggregateFunctions(
      generator: AggregationCodeGenerator,
      namedAggregates: Seq[CalcitePair[AggregateCall, String]],
      inputType: RelDataType,
      inputFieldTypeInfo: Seq[TypeInformation[_]],
      outputType: RelDataType,
      groupings: Array[Int],
      tableConfig: TableConfig): (
        Option[DataSetPreAggFunction],
        Option[TypeInformation[Row]],
        Either[DataSetAggFunction, DataSetFinalAggFunction]) = {

    val needRetract = false
    val (aggInFields, aggregates, isDistinctAggs, accTypes, _) = transformToAggregateFunctions(
      namedAggregates.map(_.getKey),
      inputType,
      needRetract,
      tableConfig)

    val (gkeyOutMapping, aggOutMapping) = getOutputMappings(
      namedAggregates,
      groupings,
      inputType,
      outputType
    )

    val aggOutFields = aggOutMapping.map(_._1)

    if (doAllSupportPartialMerge(aggregates)) {

      // compute preaggregation type
      val preAggFieldTypes = gkeyOutMapping.map(_._2)
        .map(inputType.getFieldList.get(_).getType)
        .map(FlinkTypeFactory.toTypeInfo) ++ accTypes
      val preAggRowType = new RowTypeInfo(preAggFieldTypes: _*)

      val genPreAggFunction = generator.generateAggregations(
        "DataSetAggregatePrepareMapHelper",
        inputFieldTypeInfo,
        aggregates,
        aggInFields,
        aggregates.indices.map(_ + groupings.length).toArray,
        isDistinctAggs,
        isStateBackedDataViews = false,
        partialResults = true,
        groupings,
        None,
        groupings.length + aggregates.length,
        needRetract,
        needMerge = false,
        needReset = true,
        None
      )

      // compute mapping of forwarded grouping keys
      val gkeyMapping: Array[Int] = if (gkeyOutMapping.nonEmpty) {
        val gkeyOutFields = gkeyOutMapping.map(_._1)
        val mapping = Array.fill[Int](gkeyOutFields.max + 1)(-1)
        gkeyOutFields.zipWithIndex.foreach(m => mapping(m._1) = m._2)
        mapping
      } else {
        new Array[Int](0)
      }

      val genFinalAggFunction = generator.generateAggregations(
        "DataSetAggregateFinalHelper",
        inputFieldTypeInfo,
        aggregates,
        aggInFields,
        aggOutFields,
        isDistinctAggs,
        isStateBackedDataViews = false,
        partialResults = false,
        gkeyMapping,
        Some(aggregates.indices.map(_ + groupings.length).toArray),
        outputType.getFieldCount,
        needRetract,
        needMerge = true,
        needReset = true,
        None
      )

      (
        Some(new DataSetPreAggFunction(genPreAggFunction)),
        Some(preAggRowType),
        Right(new DataSetFinalAggFunction(genFinalAggFunction))
      )
    }
    else {
      val genFunction = generator.generateAggregations(
        "DataSetAggregateHelper",
        inputFieldTypeInfo,
        aggregates,
        aggInFields,
        aggOutFields,
        isDistinctAggs,
        isStateBackedDataViews = false,
        partialResults = false,
        groupings,
        None,
        outputType.getFieldCount,
        needRetract,
        needMerge = false,
        needReset = true,
        None
      )

      (
        None,
        None,
        Left(new DataSetAggFunction(genFunction))
      )
    }

  }

  //......
}
  • The CreatedDataSetAggregatedFunctions method of AggregateUtil is mainly to generate GeneratedAggregationsFunction, and then create DataSetPreAggFunction or DataSetAggFunction; ; The reason for dynamically generating code is that user-defined parameters such as the accumulate method are dynamic, while flink code is called based on the Accumulate (Accumulator S: Row, Input: Row) method defined by GeneratedAggregations, so the dynamically generated Code is used for adaptation. In the accumulator (accumulator s: Row, input: row) method, the row is converted into the parameters required to call the user-defined accumulator method, and then the user-defined accumulator method is called.

Summary

  • AggregateFunction inherits UserDefinedFunction;; It has two generics, one T for value and one ACC for Accumulator. It defines the createAccumulator, getValue, getResultType, getAccumulatorType methods (Subclasses of these methods must implement the createAccumulator and getValue methods); For AggregateFunction, there is an accumulate method not defined here, but subclass definition and implementation are required. This method receives ACC,T and other parameters and returns void;; In addition, there are three optional methods: retract, merge, resetAccumulator, which need subclasses to define and implement according to the situation (For datastream bounded over aggregate operation, restract method is required to be implemented, which receives ACC,T and other parameters and returns void;. For datastreamsessionwindowgrouping aggregate and dataset grouping aggregate operations, the merge method is required to be implemented. the method receives ACC, Java.lang.iterate < t > two parameters and returns void;. For dataset grouping aggregate operation, it is required to implement the resetAccumulator method, which receives ACC parameters and returns void.)
  • The combination method of DataSetPreAggFunction calls function.accumulators (records), where accumulators are of type Row [WeightedDavgaccum] and record is of type ROW; Function is a generated class. It inherits the GeneratedAggregations, whose code is in genAggregations, while genAggregations is generated by the Aggregation Util. CreateDataSetAggregation Functions method. It will call the accumulate method of WeightedA vg. GeneratedAggregations define methods such as accumulators s: row, input: row, reset accumulators s: row, etc.
  • The CreatedDataSetAggregatedFunctions method of AggregateUtil is mainly to generate GeneratedAggregationsFunction, and then create DataSetPreAggFunction or DataSetAggFunction; ; The reason for dynamically generating code is that user-defined parameters such as the accumulate method are dynamic, while flink code is called based on the Accumulate (Accumulator S: Row, Input: Row) method defined by GeneratedAggregations, so the dynamically generated Code is used for adaptation. In the accumulator (accumulator s: Row, input: row) method, the row is converted into the parameters required to call the user-defined accumulator method, and then the user-defined accumulator method is called.

doc