Spark编写UDAF自定义函数(JAVA)


maven:

<!-- spark -->
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-core_2.10</artifactId>
<version>1.6.0</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-sql_2.10</artifactId>
<version>1.6.0</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-hive_2.10</artifactId>
<version>1.6.0</version>
</dependency>
<!-- google工具类 -->
<dependency>
<groupId>com.google.guava</groupId>
<artifactId>guava</artifactId>
<version>18.0</version>
</dependency>

public class StringCount extends UserDefinedAggregateFunction {
/**
* inputSchema指的是输入的数据类型
* @return
*/
@Override
public StructType inputSchema() {
List<StructField> fields = Lists.newArrayList();
fields.add(DataTypes.createStructField("str", DataTypes.StringType,true));
return DataTypes.createStructType(fields);
}

/**
* bufferSchema指的是 中间进行聚合时 所处理的数据类型
* @return
*/
@Override
public StructType bufferSchema() {
List<StructField> fields = Lists.newArrayList();
fields.add(DataTypes.createStructField("count", DataTypes.IntegerType,true));
return DataTypes.createStructType(fields);
}

/**
* dataType指的是函数返回值的类型
* @return
*/
@Override
public DataType dataType() {
return DataTypes.IntegerType;
}

/**
* 一致性检验,如果为true,那么输入不变的情况下计算的结果也是不变的。
* @return
*/
@Override
public boolean deterministic() {
return true;
}

/**
* 设置聚合中间buffer的初始值,但需要保证这个语义:两个初始buffer调用下面实现的merge方法后也应该为初始buffer
* 即如果你初始值是1,然后你merge是执行一个相加的动作,两个初始buffer合并之后等于2
* 不会等于初始buffer了。这样的初始值就是有问题的,所以初始值也叫"zero value"
* @param buffer
*/
@Override
public void initialize(MutableAggregationBuffer buffer) {
buffer.update(0,0);
}

/**
* 用输入数据input更新buffer,类似于combineByKey
* @param buffer
* @param input
*/
@Override
public void update(MutableAggregationBuffer buffer, Row input) {
buffer.update(0,Integer.valueOf(buffer.getAs(0).toString())+1);
}

/**
* 合并两个buffer,buffer2合并到buffer1.在合并两个分区聚合结果的时候会被用到,类似于reduceByKey
* 这里要注意该方法没有返回值,在实现的时候是把buffer2合并到buffer1中去,你需要实现这个合并细节
* @param buffer1
* @param buffer2
*/
@Override
public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
buffer1.update(0,Integer.valueOf(buffer1.getAs(0).toString())+Integer.valueOf(buffer2.getAs(0).toString()));
}

/**
* 计算并返回最终的聚合结果
* @param buffer
* @return
*/
@Override
public Object evaluate(Row buffer) {
return buffer.getInt(0);
}
}
public class UDAF {
public static void main(String[] args) {
SparkConf conf = new SparkConf().setAppName("UDAF").setMaster("local");
JavaSparkContext sc = new JavaSparkContext(conf);
SQLContext sqlContext = new SQLContext(sc);
List<String> nameList = Arrays.asList("xiaoming","xiaoming", "feifei","feifei","feifei", "katong");
//转换为javaRDD
JavaRDD<String> nameRDD = sc.parallelize(nameList, 3);
//转换为JavaRDD<Row>
JavaRDD<Row> nameRowRDD = nameRDD.map(new Function<String, Row>() {
public Row call(String name) throws Exception {
return RowFactory.create(name);
}
});
List<StructField> fields = Lists.newArrayList();
fields.add(DataTypes.createStructField("name", DataTypes.StringType,true));
StructType structType = DataTypes.createStructType(fields);
DataFrame namesDF = sqlContext.createDataFrame(nameRowRDD, structType);
//注册names
namesDF.registerTempTable("names");
sqlContext.udf().register("countString",new StringCount());
List<Row> rows = sqlContext.sql("select name,countString(name) from names group by name").javaRDD().collect();
for (Row row : rows) {
System.out.println(row);
}
sc.close();
}
}
执行结果:

[feifei,3]
[xiaoming,2]
[katong,1]

智能推荐

注意!

本站转载的文章为个人学习借鉴使用,本站对版权不负任何法律责任。如果侵犯了您的隐私权益,请联系我们删除。



 
© 2014-2019 ITdaan.com 粤ICP备14056181号  

赞助商广告