Java spark
SparkSQL 之 基于Scala实现UDF和UDAF详解
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.api.java.function.PairFunction;
import scala.Tuple2;
import java.text.SimpleDateFormat;
import java.util.*;
public class test02 {
public static void main(String[] args) {
SparkConf conf = new SparkConf().setAppName("RDDtest").setMaster("local[1]");
JavaSparkContext sc = new JavaSparkContext(conf);
// 1 创建RDD
// 1.1 parallelize 调用SparkContext()的parallelize(),将一个存在的集合,变成一个RDD,这种方式适用于学习spark和做一些spark测试
List<String> strs = Arrays.asList("hello word", "hello spark", "hello java");
JavaRDD<String> parallelize1 = sc.parallelize(strs);
List<String> collect = parallelize1.collect();
for (String str : collect) {
System.out.println(str);
}
//1.2 textFile 调用SparkContext.textFile()方法,从外部存储中读取数据来创建 RDD
JavaRDD<String> rdd = sc.textFile("D:\\02Code\\0901\\sd_demo\\src\\data\\testdata.txt");
JavaRDD<Integer> rdd2 = rdd.map(s -> s.length());
Integer total = rdd2.reduce((a, b) -> a + b);
System.out.println(total);
JavaRDD<Integer> rdd3 = rdd.map(s -> len(s));
Integer total3 = rdd3.reduce((a, b) -> a + b);
System.out.println(total3);
// 2 基本操作
// 2.1 filter filter(func):对原 RDD 中每个元素使用func 函数进行过滤,并生成新的 RDD
JavaRDD<String> filterrdd = rdd.filter(new Function<String, Boolean>() {
@Override
public Boolean call(String s) throws Exception {
return s.contains("java");
}
});
System.out.println("filterrdd: " + filterrdd.collect());
JavaRDD<String> filterrdd02 = rdd.filter(item -> item.contains("java"));
System.out.println("filterrdd02: " + filterrdd02.collect());
// map map(func) :对原RDD中每个元素运用func函数,并生成新的RDD
// map算子输入分区与输出分区一一对应。
JavaRDD<String> maprdd = rdd.map(new Function<String, String>() {
@Override
public String call(String s) throws Exception {
String tmp = "";
if (s.contains("hello")) {
tmp = s.replace("hello", "bye");
} else {
tmp = s;
}
return tmp;
}
});
JavaRDD<String> maprdd02 = rdd.map(item -> {
String tmp = "";
if (item.contains("hello")) {
tmp = item.replace("hello", "bye");
} else {
tmp = item;
}
return tmp;
});
System.out.println("maprdd02: " + maprdd02.collect());
System.out.println("maprdd: " + maprdd.collect());
// flapmap flatMap(func):将每个元素进行扁平化处理,也就是将某个元素按照规则生成多个元素
// flatMap的函数应用于每一个元素,对于每一个元素返回的是多个元素组成的迭代器。
JavaRDD<String> flatmaprdd = rdd.flatMap(new FlatMapFunction<String, String>() {
@Override
public Iterator<String> call(String s) throws Exception {
String[] items = s.split(" ");
return Arrays.asList(items).iterator();
}
});
JavaRDD<String> flatmaprdd2 = rdd.flatMap(line -> {
String[] items = line.split(" ");
return Arrays.asList(items).iterator();
});
System.out.println("flatmaprdd: " + flatmaprdd.collect());
System.out.println("flatmaprdd2: " + flatmaprdd2.collect());
// distinct 去重,此操作涉及到混洗,操作开销很大。
JavaRDD<String> distinctrdd = flatmaprdd.distinct();
System.out.println("distinctrdd:" + distinctrdd.collect());
// 3 操作集合
// 3.1 union 两个RDD进行合并
JavaRDD<String> unionrdd = distinctrdd.union(flatmaprdd);
System.out.println("unionrdd: " + unionrdd.collect());
// intersection 求两个RDD的交集 strRdd.intersection(strRdd2);
// subtract RDD1.subtract(RDD2),但会在RDD1中出现,但是不在RDD2中出现的元素,不去重。
// cartesian RDD1.cartesian(RDD2) 返回RDD1和RDD2的笛卡儿积,这个开销非常大,慎用。
// 4 创建键值对
// mapToPair 创建键值对 将每行的第一个单词作为键,1作为value创建pairRDD
JavaPairRDD<String, Integer> mappairrdd = flatmaprdd.mapToPair(new PairFunction<String, String, Integer>() {
@Override
public Tuple2<String, Integer> call(String s) throws Exception {
return new Tuple2<>(s, 1);
}
});
System.out.println("mappairrdd: " + mappairrdd.collect());
// 4.2 flatMapToPair mapToPair是一对一,一个元素返回一个元素,而flatMapToPair可以一个元素返回多个,相当于先flatMap,再mapToPair。
// 5 键值对聚合操作
// 5.1 combineByKey
// 5.2 reduceByKey 接收一个函数,按照相同的key进行reduce操作,类似于scala的reduce操作
JavaPairRDD<String, Integer> stringIntegerJavaPairRDD = mappairrdd.reduceByKey(new Function2<Integer, Integer, Integer>() {
@Override
public Integer call(Integer integer, Integer integer2) throws Exception {
return integer + integer2;
}
});
JavaPairRDD<String, Integer> stringIntegerJavaPairRDD02 = mappairrdd.reduceByKey((a, b) -> (a + b));
Map<String, Integer> stringIntegerMap = stringIntegerJavaPairRDD.collectAsMap();
System.out.println("stringIntegerMap: " + stringIntegerMap);
System.out.println("stringIntegerJavaPairRDD02: " + stringIntegerJavaPairRDD02.collectAsMap());
// 5.3 foldByKey 与reduceByKey相似,但是有初始值
// 6 排序操作
// 6.1 SortByKey用于对pairRDD按照key进行排序,第一个参数可以设置true(正序)或者false(倒序),默认是true。
JavaPairRDD<String, Integer> sortbykeyrdd = stringIntegerJavaPairRDD.sortByKey();
System.out.println("sortbykeyrdd: " + sortbykeyrdd.collect());
// 7 键值对分组操作
// 7.1 groupByKey 会将RDD[key,value] 按照相同的key进行分组,形成RDD[key,Iterable[value]]的形式
JavaRDD<Tuple2<String, Integer>> scoreRDD = sc.parallelize(Arrays.asList(
new Tuple2<String, Integer>("zhangsan", 69),
new Tuple2<>("zhangsan", 89),
new Tuple2<>("lisi", 78),
new Tuple2<>("lisi", 90),
new Tuple2<>("wangwu", 66)
));
JavaPairRDD<String, Integer> scoreMapRDD = JavaPairRDD.fromJavaRDD(scoreRDD);
JavaPairRDD<String, Iterable<Integer>> scoreByKeyRdd = scoreMapRDD.groupByKey();
List<Tuple2<String, Iterable<Integer>>> collectList = scoreByKeyRdd.collect();
System.out.println("scoreByKeyRdd: ");
for (Tuple2<String, Iterable<Integer>> tp : collectList) {
System.out.println(tp);
}
// 8 键值对关联操作
// 8.1 subtractByKey 类似于subtract,删掉RDD中键与other RDD中的键相同的元素。
// 8.2 join RDD1.join(RDD2) 可以把RDD1、RDD2中的相同的key给连接起来,类似于sql中的join操作。
JavaPairRDD<String, Integer> scoreMapRDD02 = JavaPairRDD.fromJavaRDD(sc.parallelize(Arrays.asList(
new Tuple2<String, Integer>("zhangsan", 69),
new Tuple2<>("zhangsan", 89),
new Tuple2<>("lisi", 78),
new Tuple2<>("lisi", 90),
new Tuple2<>("wangwu", 66)
)));
JavaPairRDD<String, Tuple2<Integer, Integer>> joinRDD = scoreMapRDD02.join(scoreMapRDD);
System.out.println("joinRDD: " + joinRDD.collect());
//rightOuterJoin leftOuterJoin
List<String> list = Arrays.asList("a", "b", "c", "d", "e");
List<String> list2 = Arrays.asList("a", "b", "c", "f", "h");
JavaRDD<String> parallelize = sc.parallelize(list, 2);
JavaRDD<String> parallelize2 = sc.parallelize(list2, 2);
JavaPairRDD<String, Integer> javaPairRDD = parallelize.mapToPair(new PairFunction<String, String, Integer>() {
@Override
public Tuple2<String, Integer> call(String s) throws Exception {
return new Tuple2(s, 1);
}
});
JavaPairRDD<String, Integer> javaPairRDD1 = parallelize.mapToPair(new PairFunction<String, String, Integer>() {
@Override
public Tuple2<String, Integer> call(String s) throws Exception {
return new Tuple2(s, 2);
}
});
System.out.println("join: " + javaPairRDD.join(javaPairRDD1).collect());
System.out.println("leftOuterJoin: " + javaPairRDD.leftOuterJoin(javaPairRDD1).collect());
System.out.println("rightOuterJoin: " + javaPairRDD.rightOuterJoin(javaPairRDD1).collect());
System.out.println("fullOuterJoin: " + javaPairRDD.fullOuterJoin(javaPairRDD1).collect());
// 9 基本的Action操作
// 9.1 rdd.first() 返回第一个元素
System.out.println(rdd.first());
// 9.2 take rdd.take(n) 返回前n个元素
System.out.println(rdd.take(2));
// 9.3 rdd.collect() 返回 RDD 中的所有元素
// 9.4 rdd.count() 返回 RDD 中的元素个数
System.out.println("rdd.count: " + rdd.count());
// 9.5 countByValue 各元素在 RDD 中出现的次数 返回{(key1,次数),(key2,次数),…(keyn,次数)}
System.out.println("countByValue: " + flatmaprdd.countByValue());
// 9.6 reduce 并行整合RDD中所有数据, 类似于是scala中集合的reduce
JavaRDD<Integer> arrrdd = sc.parallelize(Arrays.asList(1, 2, 3, 4));
System.out.println("reduce: " + arrrdd.reduce((x, y) -> x + y));
// 9.7 top 按照降序的或者指定的排序规则,返回前n个元素
System.out.println(arrrdd.top(3));
// 9.8 takeOrdered 对RDD元素进行升序排序,取出前n个元素并返回,也可以自定义比较器(这里不介绍),类似于top的相反的方法
System.out.println(arrrdd.takeOrdered(3));
// 9.9 foeach 对 RDD 中的每个元素使用给定的函数
arrrdd.foreach(item -> System.out.println(item));
// 10 PairRDD的Action操作
// 10.1 countByKey
List<Tuple2<Integer, String>> listTuple = Arrays.asList(new Tuple2<>(1, "hello"), new Tuple2<>(2, "cn"), new Tuple2<>(1, "spark"), new Tuple2<>(2, "kgc"), new Tuple2<>(3, "Trump"));
JavaRDD<Tuple2<Integer, String>> listTuplerdd = sc.parallelize(listTuple);
JavaPairRDD<Integer, String> pairRDD = JavaPairRDD.fromJavaRDD(listTuplerdd);
Map<Integer, Long> result = pairRDD.countByKey();
System.out.println(result);
// 10.2 collectAsMap 将pair类型(键值对类型)的RDD转换成map
// 11 Action保存操作
// 11.1 savaAsTextFile saveAsTextFile用于将RDD以文本文件的格式存储到文件系统中
rdd.saveAsTextFile("D:\\02Code\\0901\\sd_demo\\src\\data\\save");
}
public static int len(String s) {
int str_length;
str_length = s.length();
return str_length;
}
}
import org.apache.spark.SparkContext
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, DoubleType, IntegerType, StructField, StructType}
object UDAFDemo02 {
/**
* 自定义聚合函数
* 功能:求平均薪水
*/
class MyUDAF extends UserDefinedAggregateFunction {
/**
* 定义输入的数据类型
*
* @return
*/
override def inputSchema: StructType = StructType(
StructField("salary", DoubleType, true) :: Nil
)
/**
* 定义辅助字段
* 辅助字段一:total 表示总金额
* 辅助字段二:count 表示总人数
*
* @return
*/
override def bufferSchema: StructType = StructType(
StructField("total", DoubleType, true) ::
StructField("count", IntegerType, true) :: Nil
)
/**
* 定义输出的数据类型
*
* @return
*/
// override def dataType: DataType = StructType(
// StructField("avg_salary", DoubleType, true) :: Nil
// )
override def dataType: DataType = {
DoubleType
}
override def deterministic: Boolean = true
/**
* 初始化辅助字段
*
* @param buffer
*/
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer.update(0, 0.0)
buffer.update(1, 0)
}
/**
* 更新辅助字段的值 -> 局部的,表示在一个分区中的
*
* @param buffer
* @param input
*/
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
// 获取上一次的值
val lastTotal = buffer.getDouble(0)
val lastCount = buffer.getInt(1)
// 获取当前的salary
// 更新
buffer.update(0, lastTotal + input.getDouble(0))
buffer.update(1, lastCount + 1)
}
/**
* 全局的->
* partition1 和 partition2 需要合并
*
* @param buffer1
* @param buffer2
*/
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
// 当前分区
val total1 = buffer1.getDouble(0);
val count1 = buffer1.getInt(1);
// 其他分区
val total2 = buffer2.getDouble(0)
val count2 = buffer2.getInt(1)
buffer1.update(0, total1 + total2)
buffer1.update(1, count1 + count2)
}
/**
* 最后的目标计算函数
*
* @param buffer
* @return
*/
override def evaluate(buffer: Row): Any = {
val total = buffer.getDouble(0)
val count = buffer.getInt(1)
total / count
}
}
def main(args: Array[String]): Unit = {
//1.获取sparkSession
val spark: SparkSession = SparkSession.builder().appName("SparkSQL").master("local[*]").getOrCreate()
val sc: SparkContext = spark.sparkContext
sc.setLogLevel("ERROR")
//2.读取文件
val employeeDF: DataFrame = spark.read.json("D:\\02Code\\0901\\sd_demo\\src\\data\\udaf.json")
//3.创建临时表
employeeDF.createOrReplaceTempView("t_employee")
//4.注册UDAF函数
spark.udf.register("myavg", new MyUDAF)
//5.使用自定义UDAF函数 -->查看薪水
spark.sql("select myavg(salary) from t_employee").show()
//6.使用内置的avg函数 -->查看平均工资
spark.sql("select avg(salary) from t_employee").show()
//关闭
spark.stop()
}
}