Hive UDAF和UDTF实现group by后获取top值

先自定义一个UDAF,由于udaf是多输入一条输出的聚合,所以结果拼成字符串输出,代码如下:

public class Top4GroupBy extends UDAF {

//定义一个对象用于存储数据
    public static class State {
        private Map<Text, IntWritable> counts;
        private int limit;

}

/**
     * 累加数据,判断map的key中是否存在该字符串,如果存在累加,不存在放入map中
     * @param s
     * @param o
     * @param i
     */
    private static void increment(State s, Text o, int i) {
        if (s.counts == null) {
            s.counts = new HashMap<Text, IntWritable>();
        }
        IntWritable count = s.counts.get(o);
        if (count == null) {
            Text key = new Text();
            key.set(o);
            s.counts.put(key, new IntWritable(i));
        } else {
            count.set(count.get() + i);
        }

}

public static class Top4GroupByEvaluator implements UDAFEvaluator {

private final State state;

public Top4GroupByEvaluator() {
            state = new State();
        }

@Override
        public void init() {
            if (state.counts != null) {
                state.counts.clear();
            }
            if (state.limit == 0) {
                state.limit = 100;
            }
        }

public boolean iterate(Text value, IntWritable limits) {
            if (value == null || limits == null) {
                return false;
            } else {
                state.limit = limits.get();
                increment(state, value, 1);
            }
            return true;
        }

public State terminatePartial() {
            return state;
        }

public boolean merge(State other) {
            if (state == null || other == null) {
                return false;
            }
            state.limit = other.limit;
            for (Map.Entry<Text, IntWritable> e : other.counts.entrySet()) {
                increment(state, e.getKey(), e.getValue().get());
            }
            return true;
        }

public Text terminate() {
            if (state == null || state.counts.size() == 0) {
                return null;
            }
            Map<Text, IntWritable> it = sortByValue(state.counts, true);
            StringBuffer str = new StringBuffer();
            int i = 0;
            for (Map.Entry<Text, IntWritable> e : it.entrySet()) {
                ++i;
                if (i > state.limit) {//只输出传入条数的结果,并拼成字符串
                    break;
                }
                str.append(e.getKey().toString()).append("$@").append(e.getValue().get()).append("$*");
            }
            return new Text(str.toString());
        }

内容版权声明:除非注明,否则皆为本站原创文章。

转载注明出处:http://www.heiqu.com/832ac4b595ddda530092bbf816c87ebf.html