首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >collect_set等效火花1.5UDAF方法验证

collect_set等效火花1.5UDAF方法验证
EN

Stack Overflow用户
提问于 2016-10-12 05:05:33
回答 1查看 825关注 0票数 0

有人能告诉我火花1.5中collect_set的等效函数吗?

是否有任何工作来获得类似的结果,如collect_set(col(name))?

这种做法是否正确:

代码语言:javascript
复制
class CollectSetFunction[T](val colType: DataType) extends UserDefinedAggregateFunction {

  def inputSchema: StructType =
    new StructType().add("inputCol", colType)

  def bufferSchema: StructType =
    new StructType().add("outputCol", ArrayType(colType))

  def dataType: DataType = ArrayType(colType)

  def deterministic: Boolean = true

  def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer.update(0, new scala.collection.mutable.ArrayBuffer[T])
  }

  def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    val list = buffer.getSeq[T](0)
    if (!input.isNullAt(0)) {
      val sales = input.getAs[T](0)
      buffer.update(0, list:+sales)
    }
  }

  def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1.update(0, buffer1.getSeq[T](0).toSet ++ buffer2.getSeq[T](0).toSet)
  }

  def evaluate(buffer: Row): Any = {
    buffer.getSeq[T](0)
  }
}
EN

回答 1

Stack Overflow用户

发布于 2016-10-12 06:22:34

它的代码看起来是正确的。此外,我在1.6.2中进行了本地模式的测试,并得到了相同的结果(见下文)。我不知道有什么更简单的选择使用DataFrame API。使用RDD是非常简单的,也许最好在1.5中绕道到RDD,因为数据帧还没有完全实现。

代码语言:javascript
复制
scala> val rdd = sc.parallelize((1 to 10)).map(x => (x%5,x))
scala> rdd.groupByKey.mapValues(_.toSet.toList)).toDF("k","set").show
+---+-------+
|  k|    set|
+---+-------+
|  0|[5, 10]|
|  1| [1, 6]|
|  2| [2, 7]|
|  3| [3, 8]|
|  4| [4, 9]|
+---+-------+

如果您想将其分解,初始版本(可以嵌入)可以如下所示

代码语言:javascript
复制
def collectSet(df: DataFrame, k: Column, v: Column) = df
    .select(k.as("k"),v.as("v"))
    .map( r => (r.getInt(0),r.getInt(1)))
    .groupByKey()
    .mapValues(_.toSet.toList)
    .toDF("k","v")

但是,如果要进行其他聚合,则无法避免联接。

代码语言:javascript
复制
scala> val df = sc.parallelize((1 to 10)).toDF("v").withColumn("k", pmod('v,lit(5)))
df: org.apache.spark.sql.DataFrame = [v: int, k: int]

scala> val csudaf = new CollectSetFunction[Int](IntegerType)

scala> df.groupBy('k).agg(collect_set('v),csudaf('v)).show
+---+--------------+---------------------+
|  k|collect_set(v)|CollectSetFunction(v)|
+---+--------------+---------------------+
|  0|       [5, 10]|              [5, 10]|
|  1|        [1, 6]|               [1, 6]|
|  2|        [2, 7]|               [2, 7]|
|  3|        [3, 8]|               [3, 8]|
|  4|        [4, 9]|               [4, 9]|
+---+--------------+---------------------+

试验2:

代码语言:javascript
复制
scala> val df = sc.parallelize((1 to 100000)).toDF("v").withColumn("k", floor(rand*10))
df: org.apache.spark.sql.DataFrame = [v: int, k: bigint]

scala> df.groupBy('k).agg(collect_set('v).as("a"),csudaf('v).as("b"))
         .groupBy('a==='b).count.show
+-------+-----+                                                                 
|(a = b)|count|
+-------+-----+
|   true|   10|
+-------+-----+
票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/39990867

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档