为了计算加权平均值,累加器需要存储已累积的所有数据的加权和及计数。在栗子中定义一个WeightedAvgAccum类作为accumulator。尽管,retract(), merge(), 和resetAccumulator()方法在很多聚合类型是不需要的,这里也给出了栗子。
/** * Accumulator for WeightedAvg. */ public static class WeightedAvgAccum { public long sum = 0; public int count = 0; } /** * Weighted Average user-defined aggregate function. */ public static class WeightedAvg extends AggregateFunction<Long, WeightedAvgAccum> { @Override public WeightedAvgAccum createAccumulator() { return new WeightedAvgAccum(); } @Override public Long getValue(WeightedAvgAccum acc) { if (acc.count == 0) { return null; } else { return acc.sum / acc.count; } } public void accumulate(WeightedAvgAccum acc, long iValue, int iWeight) { acc.sum += iValue * iWeight; acc.count += iWeight; } public void retract(WeightedAvgAccum acc, long iValue, int iWeight) { acc.sum -= iValue * iWeight; acc.count -= iWeight; } public void merge(WeightedAvgAccum acc, Iterable<WeightedAvgAccum> it) { Iterator<WeightedAvgAccum> iter = it.iterator(); while (iter.hasNext()) { WeightedAvgAccum a = iter.next(); acc.count += a.count; acc.sum += a.sum; } } public void resetAccumulator(WeightedAvgAccum acc) { acc.count = 0; acc.sum = 0L; } } // register function StreamTableEnvironment tEnv = ... tEnv.registerFunction("wAvg", new WeightedAvg()); // use function tEnv.sqlQuery("SELECT user, wAvg(points, level) AS avgPoints FROM userScores GROUP BY user"); 4.udf的最佳实践经验 4.1 Table API和SQL代码生成器内部会尽可能多的尝试使用原生值。用户定义的函数可能通过对象创建、强制转换(casting)和拆装箱((un)boxing)引入大量开销。因此,强烈推荐参数和返回值的类型定义为原生类型而不是他们包装类型(boxing class)。Types.DATE 和Types.TIME可以用int代替。Types.TIMESTAMP可以用long代替。
建议用户自定义函数使用java编写而不是scala编写,因为scala的类型可能会有不被flink类型抽取器兼容。
4.2 用Runtime集成UDFs有时候udf需要获取全局runtime信息或者在进行实际工作之前做一些设置和清除工作,比如,打开数据库链接和关闭数据库链接。Udf提供了open()和close()方法,可以被复写,功能类似Dataset和DataStream API的RichFunction方法。
Open()方法是在evaluation方法调用前调用一次。Close()是在evaluation方法最后一次调用后调用。Open()方法提共一个FunctionContext,FunctionContext包含了udf执行环境的上下文,比如,metric group,分布式缓存文件,全局的job参数。
通过调用FunctionContext的相关方法,可以获取到相关的信息:
getMetricGroup()并行子任务的指标组;
getCachedFile(name)分布式缓存文件的本地副本;
getJobParameter(name, defaultValue)给定key全局job参数;
给出的例子就是通过FunctionContext在一个标量函数中获取全局job的参数。主要是实现获取redis的配置,然后简历redis链接,实现redis的交互的过程。
import org.apache.flink.table.functions.FunctionContext; import org.apache.flink.table.functions.ScalarFunction; import redis.clients.jedis.Jedis; public class HashCode extends ScalarFunction { private int factor = 12; Jedis jedis = null; public HashCode() { super(); } @Override public void open(FunctionContext context) throws Exception { super.open(context); String redisHost = context.getJobParameter("redis.host","localhost"); int redisPort = Integer.valueOf(context.getJobParameter("redis.port","6379")); jedis = new Jedis(redisHost,redisPort); } @Override public void close() throws Exception { super.close(); jedis.close(); } public HashCode(int factor) { this.factor = factor; } public int eval(int s) { s = s % 3; if(s == 2) return Integer.valueOf(jedis.get(String.valueOf(s))); else return 0; } } ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); BatchTableEnvironment tableEnv = TableEnvironment.getTableEnvironment(env); // set job parameter Map<String,String> hashmap = new HashMap<>(); hashmap.put("redis.host","localhost"); hashmap.put("redis.port","6379"); ParameterTool parameter = ParameterTool.fromMap(hashmap); exeEnv.getConfig().setGlobalJobParameters(parameter); // register the function tableEnv.registerFunction("hashCode", new HashCode()); // use the function in Java Table API myTable.select("string, string.hashCode(), hashCode(string)"); // use the function in SQL tableEnv.sqlQuery("SELECT string, HASHCODE(string) FROM MyTable");