在Spark SQL中,如何注册和使用通用UDF?

时间:2016-04-28 12:44:13

标签: scala apache-spark udf

在我的项目中,我想实现ADD(+)功能,但我的参数可能是LongTypeDoubleTypeIntType。我使用sqlContext.udf.register("add",XXX),但我不知道如何编写XXX,这是为了制作泛型函数。

2 个答案:

答案 0 :(得分:5)

您可以通过创建UDF StructType来创建通用struct($"col1", $"col2")UDF保存您的值并让UDF解决此问题。它会作为Row对象传递到您的val multiAdd = udf[Double,Row](r => { var n = 0.0 r.toSeq.foreach(n1 => n = n + (n1 match { case l: Long => l.toDouble case i: Int => i.toDouble case d: Double => d case f: Float => f.toDouble })) n }) val df = Seq((1.0,2),(3.0,4)).toDF("c1","c2") df.withColumn("add", multiAdd(struct($"c1", $"c2"))).show +---+---+---+ | c1| c2|add| +---+---+---+ |1.0| 2|3.0| |3.0| 4|7.0| +---+---+---+ ,因此您可以执行以下操作:

UDF

你甚至可以做一些有趣的事情,例如将可变数量的列作为输入。实际上,我们上面定义的val df = Seq((1, 2L, 3.0f,4.0),(5, 6L, 7.0f,8.0)).toDF("int","long","float","double") df.printSchema root |-- int: integer (nullable = false) |-- long: long (nullable = false) |-- float: float (nullable = false) |-- double: double (nullable = false) df.withColumn("add", multiAdd(struct($"int", $"long", $"float", $"double"))).show +---+----+-----+------+----+ |int|long|float|double| add| +---+----+-----+------+----+ | 1| 2| 3.0| 4.0|10.0| | 5| 6| 7.0| 8.0|26.0| +---+----+-----+------+----+ 已经这样做了:

df.withColumn("add", multiAdd(struct(lit(100), $"int", $"long"))).show
+---+----+-----+------+-----+
|int|long|float|double|  add|
+---+----+-----+------+-----+
|  1|   2|  3.0|   4.0|103.0|
|  5|   6|  7.0|   8.0|111.0|
+---+----+-----+------+-----+

您甚至可以在混音中添加一个硬编码的数字:

UDF

如果要在SQL语法中使用sqlContext.udf.register("multiAdd", (r: Row) => { var n = 0.0 r.toSeq.foreach(n1 => n = n + (n1 match { case l: Long => l.toDouble case i: Int => i.toDouble case d: Double => d case f: Float => f.toDouble })) n }) df.registerTempTable("df") // Note that 'int' and 'long' are column names sqlContext.sql("SELECT *, multiAdd(struct(int, long)) as add from df").show +---+----+-----+------+----+ |int|long|float|double| add| +---+----+-----+------+----+ | 1| 2| 3.0| 4.0| 3.0| | 5| 6| 7.0| 8.0|11.0| +---+----+-----+------+----+ ,可以执行以下操作:

sqlContext.sql("SELECT *, multiAdd(struct(*)) as add from df").show
+---+----+-----+------+----+
|int|long|float|double| add|
+---+----+-----+------+----+
|  1|   2|  3.0|   4.0|10.0|
|  5|   6|  7.0|   8.0|26.0|
+---+----+-----+------+----+

这也有效:

SELECT

答案 1 :(得分:2)

我认为您无法注册通用UDF。

如果我们查看register方法的signature (实际上,它只是22个register重载中的一个,用于具有一个参数的UDF,其他的是等效的):

def register[RT: TypeTag, A1: TypeTag](name: String, func: Function1[A1, RT]): UserDefinedFunction

我们可以看到它使用A1: TypeTag类型进行了参数化 - TypeTag意味着在注册时,我们必须拥有UDF实际类型的证据 #39;的论点。所以 - 在没有明确输入的情况下传递泛型函数func无法编译。

对于您的情况,您可以利用Spark自动投射数字类型的能力 - 仅为Double编写UDF,您也可以将其应用于{{1} } s(输出结果为Int):

Double
相关问题