UDF & UDAF

UDF

UDF(User Defined Function),自定义函数。函数的输入、输出都是一条数据记录,类似于Spark SQL中普通的数学或字符串函数。实现上看就是普通的Scala函数;

用Scala编写的UDF与普通的Scala函数几乎没有任何区别,唯一需要多执行的一个步骤是要在SQLContext注册它。

1
2
3
def len(bookTitle: String):Int = bookTitle.length spark.udf.register("len", len _)

val booksWithLongTitle = spark.sql("select title, author from books where len(title) > 10")

编写的UDF可以放到SQL语句的fields部分,也可以作为where、groupBy或者having子句的一部分。

也可以在使用UDF时,传入常量而非表的列名。稍稍修改一下前面的函数,让长度10作为函数的参数传入:

1
2
3
def lengthLongerThan(bookTitle: String, length: Int): Boolean = bookTitle.length > length spark.udf.register("longLength", lengthLongerThan _)

val booksWithLongTitle = spark.sql("select title, author from books where longLength(title, 10)")

若使用DataFrame的API,则以字符串的形式将UDF传入:

1
val booksWithLongTitle = dataFrame.filter("longLength(title, 10)")

DataFrame的API也可以接收Column对象,可以用符号来包裹一个字符串表示一个Column符号来包裹一个字符串表示一个Column。是定义在 SQLImplicits 对象中的一个隐式转换。此时,UDF的定义也不相同,不能直接定义Scala函数,而是要用定义在org.apache.spark.sql.functions中 的 udf 方法来接收一个函数。这种方式无需register:

1
2
3
4
5
6
7
import org.apache.spark.sql.functions._

val longLength = udf((bookTitle: String, length: Int) => bookTitle.length > length)

import spark.implicits._

val booksWithLongTitle = dataFrame.filter(longLength($"title", lit(10)))

完整示例:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
package cn.lagou.sparksql

import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.{Row, SparkSession}

object UDF {
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder()
.appName(this.getClass.getCanonicalName)
.master("local[*]")
.getOrCreate()
spark.sparkContext.setLogLevel("WARN")

import spark.implicits._
import spark.sql
val data = List(("scala", "author1"), ("spark", "author2"),
("hadoop", "author3"), ("hive", "author4"),
("strom", "author5"), ("kafka", "author6"))
val df = data.toDF("title", "author")
df.createTempView("books")

// 定义scala函数并注册
def len1(str: String): Int = str.length
spark.udf.register("len1", len1 _)

// 使用udf,select、where子句
// sql("select title, author, len1(title) as titleLength from books").show
// sql("select title, author from books where len1(title)>5").show

// DSL
// df.filter("len1(title)>5").show

// 如果要在DSL语法中使用$符号包裹字符串表示一个Column,需要用udf方法来接收函数。这种函数无需注册
import org.apache.spark.sql.functions._
val len2: UserDefinedFunction = udf(len1 _)
df.select($"title", $"author", len2($"title")).show
df.filter(len2($"title")>5).show

// 不使用UDF
df.map{case Row (title: String, author: String) => (title, author, title.length)}.show

spark.stop()
}
}

UDAF

UDAF(User Defined Aggregation Funcation),用户自定义聚合函数。函数本身作用于数据集合,能够在聚合操作的基础上进行自定义操作(多条数据输入,一条数据输出);类似于在group by之后使用的sum、avg等函数;

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
数据如下:

id, name, sales, discount, state, saleDate
1, "Widget Co", 1000.00, 0.00, "AZ", "2019-01- 01"
2, "Acme Widgets", 2000.00, 500.00, "CA", "2019-02- 01"
3, "Widgetry", 1000.00, 200.00, "CA", "2020-01- 11"
4, "Widgets R Us", 2000.00, 0.0, "CA", "2020-02- 19"
5, "Ye Olde Widgete", 3000.00, 0.0, "MA", "2020-02- 28"

最后要得到的结果为:
(2020年的合计值 – 2019年的合计值) / 2019年的合计值 (6000 - 3000) / 3000 = 1

执行以下SQL得到最终的结果:
select userFunc(sales, saleDate) from table1;
即计算逻辑在userFunc中实现

普通的UDF不支持数据的聚合运算。如当要对销售数据执行年度同比计算,就需要对当年和上一年的销量分别求和,然后再利用公式进行计算。此时需要使用UDAF,Spark为所有的UDAF定义了一个父类 UserDefinedAggregateFunction 。要继承这个类,需要实现父类的几个抽象方法:


  • inputSchema用于定义与DataFrame列有关的输入样式

  • bufferSchema用于定义存储聚合运算时产生的中间数据结果的Schema

  • dataType标明了UDAF函数的返回值类型

  • deterministic是一个布尔值,用以标记针对给定的一组输入,UDAF是否总是生成相同的结果

  • initialize对聚合运算中间结果的初始化

  • update函数的第一个参数为bufferSchema中两个Field的索引,默认以0开始

    UDAF的核心计算都发生在update函数中;update函数的第二个参数input

    Row对应的并非DataFrame的行,而是被inputSchema投影了的行

  • merge函数负责合并两个聚合运算的buffer,再将其存储到MutableAggregationBuffer中

  • evaluate函数完成对聚合Buffer值的运算,得到最终的结果

UDAF-类型不安全

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
package cn.lagou.sparksql

import org.apache.log4j.{Level, Logger}
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, DoubleType, StringType, StructType}
import org.apache.spark.sql.{Row, SparkSession}

class TypeUnsafeUDAF extends UserDefinedAggregateFunction {
// 定义输入数据的类型
override def inputSchema: StructType = new StructType().add("sales", DoubleType).add("saleDate", StringType)

// 定义数据缓存的类型
override def bufferSchema: StructType = new StructType().add("year2019", DoubleType).add("year2020", DoubleType)

// 定义最终返回结果的类型
override def dataType: DataType = DoubleType

// 对于相同的结果是否有相同的输出
override def deterministic: Boolean = true

// 数据缓存的初始化
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer.update(0, 0.0)
buffer.update(1, 0.0)
}

// 分区内数据合并
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
// 销售量、销售日期(year)
val sales = input.getAs[Double](0)
val saleYear = input.getAs[String](1).take(4)

saleYear match{
case "2019" => buffer(0) = buffer.getAs[Double](0) + sales
case "2020" => buffer(1) = buffer.getAs[Double](1) + sales
case _ => println("Error!")
}
}

// 分区间数据合并
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(0) = buffer1.getAs[Double](0) + buffer2.getAs[Double](0)
buffer1(1) = buffer1.getAs[Double](1) + buffer2.getAs[Double](1)
}

// 计算最终的结果
override def evaluate(buffer: Row): Double = {
// (2020年的合计值 – 2019年的合计值) / 2019年的合计值
if (math.abs(buffer.getAs[Double](0)) < 0.000000001) 0.0
else (buffer.getAs[Double](1) - buffer.getAs[Double](0)) / buffer.getAs[Double](0)
}
}

object TypeUnsafeUDAFTest{
def main(args: Array[String]): Unit = {
Logger.getLogger("org").setLevel(Level.WARN)
val spark = SparkSession.builder()
.appName(s"${this.getClass.getCanonicalName}")
.master("local[*]")
.getOrCreate()

val sales = Seq(
(1, "Widget Co", 1000.00, 0.00, "AZ", "2019-01-02"),
(2, "Acme Widgets", 1000.00, 500.00, "CA", "2019-02-01"),
(3, "Widgetry", 1000.00, 200.00, "CA", "2020-01-11"),
(4, "Widgets R Us", 2000.00, 0.0, "CA", "2020-02-19"),
(5, "Ye Olde Widgete", 3000.00, 0.0, "MA", "2020-02-28"))
val salesDF = spark.createDataFrame(sales).toDF("id", "name", "sales", "discount", "state", "saleDate")
salesDF.createTempView("sales")
val userFunc = new TypeUnsafeUDAF
spark.udf.register("userFunc", userFunc)
spark.sql("select userFunc(sales, saleDate) as rate from sales").show()

spark.stop()
}
}

UDAF-类型安全


1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
package cn.lagou.sparksql

import org.apache.log4j.{Level, Logger}
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.{Encoder, Encoders, SparkSession, TypedColumn}

case class Sales(id: Int, name1: String, sales: Double, discount: Double, name2: String, stime: String)
case class SalesBuffer(var sales2019: Double, var sales2020: Double)

class TypeSafeUDAF extends Aggregator[Sales, SalesBuffer, Double]{
// 定义初值
override def zero: SalesBuffer = SalesBuffer(0.0, 0.0)

// 分区内的数据合并
override def reduce(buffer: SalesBuffer, input: Sales): SalesBuffer = {
val sales: Double = input.sales
val year = input.stime.take(4)
year match {
case "2019" => buffer.sales2019 += sales
case "2020" => buffer.sales2020 += sales
case _ => println("ERROR")
}
buffer
}

// 分区间的数据合并
override def merge(b1: SalesBuffer, b2: SalesBuffer): SalesBuffer = {
SalesBuffer(b1.sales2019 + b2.sales2019, b1.sales2020 + b2.sales2020)
}

// 计算最终结果
override def finish(reduction: SalesBuffer): Double = {
if (math.abs(reduction.sales2019) < 0.0000000001) 0.0
else (reduction.sales2020 - reduction.sales2019) / reduction.sales2019
}

// 定义编码器
override def bufferEncoder: Encoder[SalesBuffer] = Encoders.product
override def outputEncoder: Encoder[Double] = Encoders.scalaDouble
}

object TypeSafeUDAFTest{
def main(args: Array[String]): Unit = {
Logger.getLogger("org").setLevel(Level.WARN)
val spark = SparkSession.builder()
.appName(s"${this.getClass.getCanonicalName}")
.master("local[*]")
.getOrCreate()

val sales = Seq(
Sales(1, "Widget Co", 1000.00, 0.00, "AZ", "2019-01-02"),
Sales(2, "Acme Widgets", 2000.00, 500.00, "CA", "2019-02-01"),
Sales(3, "Widgetry", 1000.00, 200.00, "CA", "2020-01-11"),
Sales(4, "Widgets R Us", 2000.00, 0.0, "CA", "2020-02-19"),
Sales(5, "Ye Olde Widgete", 3000.00, 0.0, "MA", "2020-02-28"))

import spark.implicits._
val ds = spark.createDataset(sales)
ds.show

val rate: TypedColumn[Sales, Double] = new TypeSafeUDAF().toColumn.name("rate")
ds.select(rate).show

spark.stop()
}
}

访问Hive

在 pom 文件中增加依赖:

1
2
3
4
5
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-hive_2.12</artifactId>
<version>${spark.version}</version>
</dependency>

在 resources中增加hive-site.xml文件,在文件中增加内容:

1
2
3
4
5
6
<configuration>
<property>
<name>hive.metastore.uris</name>
<value>thrift://linux123:9083</value>
</property>
</configuration>

备注:最好使用 metastore service 连接Hive;使用直连 metastore 的方式时,SparkSQL程序会修改 Hive 的版本信息;

默认Spark使用 Hive 1.2.1进行编译,包含对应的serde, udf, udaf等。


1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
package cn.lagou.sparksql

import org.apache.spark.sql.{DataFrame, SaveMode, SparkSession}

object AccessHive {
def main(args: Array[String]): Unit = {
val spark = SparkSession
.builder()
.appName("Demo1")
.master("local[*]")
.enableHiveSupport()
// Spark使用与Hive相同的约定写parquet数据
.config("spark.sql.parquet.writeLegacyFormat", "true")
.getOrCreate()
val sc = spark.sparkContext
sc.setLogLevel("warn")

// spark.sql("show databases").show
// spark.sql("select * from ods.ods_trade_product_info").show
// val df: DataFrame = spark.table("ods.ods_trade_product_info")
// df.show()
// df.write.mode(SaveMode.Append).saveAsTable("ods.ods_trade_product_info_backup")
spark.table("ods.ods_trade_product_info_backup").show

spark.close()
}
}