有人能告诉我火花1.5中collect_set的等效函数吗?
是否有任何工作来获得类似的结果,如collect_set(col(name))?
这种做法是否正确:
class CollectSetFunction[T](val colType: DataType) extends UserDefinedAggregateFunction {
def inputSchema: StructType =
new StructType().add("inputCol", colType)
def bufferSchema: StructType =
new StructType().add("outputCol", ArrayType(colType))
def dataType: DataType = ArrayType(colType)
def deterministic: Boolean = true
def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer.update(0, new scala.collection.mutable.ArrayBuffer[T])
}
def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
val list = buffer.getSeq[T](0)
if (!input.isNullAt(0)) {
val sales = input.getAs[T](0)
buffer.update(0, list:+sales)
}
}
def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1.update(0, buffer1.getSeq[T](0).toSet ++ buffer2.getSeq[T](0).toSet)
}
def evaluate(buffer: Row): Any = {
buffer.getSeq[T](0)
}
}发布于 2016-10-12 06:22:34
它的代码看起来是正确的。此外,我在1.6.2中进行了本地模式的测试,并得到了相同的结果(见下文)。我不知道有什么更简单的选择使用DataFrame API。使用RDD是非常简单的,也许最好在1.5中绕道到RDD,因为数据帧还没有完全实现。
scala> val rdd = sc.parallelize((1 to 10)).map(x => (x%5,x))
scala> rdd.groupByKey.mapValues(_.toSet.toList)).toDF("k","set").show
+---+-------+
| k| set|
+---+-------+
| 0|[5, 10]|
| 1| [1, 6]|
| 2| [2, 7]|
| 3| [3, 8]|
| 4| [4, 9]|
+---+-------+如果您想将其分解,初始版本(可以嵌入)可以如下所示
def collectSet(df: DataFrame, k: Column, v: Column) = df
.select(k.as("k"),v.as("v"))
.map( r => (r.getInt(0),r.getInt(1)))
.groupByKey()
.mapValues(_.toSet.toList)
.toDF("k","v")但是,如果要进行其他聚合,则无法避免联接。
scala> val df = sc.parallelize((1 to 10)).toDF("v").withColumn("k", pmod('v,lit(5)))
df: org.apache.spark.sql.DataFrame = [v: int, k: int]
scala> val csudaf = new CollectSetFunction[Int](IntegerType)
scala> df.groupBy('k).agg(collect_set('v),csudaf('v)).show
+---+--------------+---------------------+
| k|collect_set(v)|CollectSetFunction(v)|
+---+--------------+---------------------+
| 0| [5, 10]| [5, 10]|
| 1| [1, 6]| [1, 6]|
| 2| [2, 7]| [2, 7]|
| 3| [3, 8]| [3, 8]|
| 4| [4, 9]| [4, 9]|
+---+--------------+---------------------+试验2:
scala> val df = sc.parallelize((1 to 100000)).toDF("v").withColumn("k", floor(rand*10))
df: org.apache.spark.sql.DataFrame = [v: int, k: bigint]
scala> df.groupBy('k).agg(collect_set('v).as("a"),csudaf('v).as("b"))
.groupBy('a==='b).count.show
+-------+-----+
|(a = b)|count|
+-------+-----+
| true| 10|
+-------+-----+https://stackoverflow.com/questions/39990867
复制相似问题