首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >使用聚合器的Spark3.0中的通用联合发展新议程

使用聚合器的Spark3.0中的通用联合发展新议程
EN

Stack Overflow用户
提问于 2020-08-10 12:53:46
回答 2查看 554关注 0票数 1

Spark3.0已经不推荐UserDefinedAggregateFunction了,我正试图用Aggregator重写我的新议程。Aggregator的基本用法很简单,但是,我很难使用更通用的函数版本。

我将尝试用这个例子来解释我的问题,一个collect_set的实现。这不是我的实际案例,但更容易解释这个问题:

代码语言:javascript
复制
class CollectSetDemoAgg(name: String) extends Aggregator[Row, Set[Int], Set[Int]] {
  override def zero = Set.empty
  override def reduce(b: Set[Int], a: Row) = b + a.getInt(a.fieldIndex(name))
  override def merge(b1: Set[Int], b2: Set[Int]) = b1 ++ b2
  override def finish(reduction: Set[Int]) = reduction
  override def bufferEncoder = Encoders.kryo[Set[Int]]
  override def outputEncoder = ExpressionEncoder()
}

// using it:
df.agg(new CollectSetDemoAgg("rank").toColumn as "result").show()

我更喜欢.toColumn.udf.register,但这不是重点。

问题:我不能制作这个聚合器的通用版本,它只能与整数一起工作。

我试过:

代码语言:javascript
复制
class CollectSetDemo(name: String) extends Aggregator[Row, Set[Any], Set[Any]] 

它因错误而崩溃:

代码语言:javascript
复制
No Encoder found for Any
- array element class: "java.lang.Object"
- root class: "scala.collection.immutable.Set"
java.lang.UnsupportedOperationException: No Encoder found for Any
- array element class: "java.lang.Object"
- root class: "scala.collection.immutable.Set"
    at org.apache.spark.sql.catalyst.ScalaReflection$.$anonfun$serializerFor$1(ScalaReflection.scala:567)

我不能和CollectSetDemo[T]一起去,如果我不能正确的outputEncoder。另外,当我使用联非新议程时,我只能使用星火数据类型、列等。

EN

回答 2

Stack Overflow用户

发布于 2020-09-16 14:04:46

还没有找到一个很好的方法来解决这个问题,但我能够在某种程度上解决它。代码部分借用自RowEncoder

代码语言:javascript
复制
class CollectSetDemoAgg(name: String, fieldType: DataType) extends Aggregator[Row, Set[Any], Any] {
  override def zero = Set.empty
  override def reduce(b: Set[Any], a: Row) = b + a.get(a.fieldIndex(name))
  override def merge(b1: Set[Any], b2: Set[Any]) = b1 ++ b2
  override def finish(reduction: Set[Any]) = reduction.toSeq
  override def bufferEncoder = Encoders.kryo[Set[Any]]

  // now
  override def outputEncoder = {
    val mirror = ScalaReflection.mirror
    val tt = fieldType match {
      case ArrayType(LongType, _) => typeTag[Seq[Long]]
      case ArrayType(IntegerType, _) => typeTag[Seq[Int]]
      case ArrayType(StringType, _) => typeTag[Seq[String]]
      // .. etc etc
      case _ => throw new RuntimeException(s"Could not create encoder for ${name} column (${fieldType})")
    }
    val tpe = tt.in(mirror).tpe

    val cls = mirror.runtimeClass(tpe)
    val serializer = ScalaReflection.serializerForType(tpe)
    val deserializer = ScalaReflection.deserializerForType(tpe)

    new ExpressionEncoder[Any](serializer, deserializer, ClassTag[Any](cls))
  }
}

我必须添加的一件事是聚合器中的结果数据类型参数。然后将用法更改为:

代码语言:javascript
复制
df.agg(new CollectSetDemoAgg("rank", new ArrayType(IntegerType, true)).toColumn as "result").show()

我真的不喜欢结果如何,但效果很好。我也欢迎任何关于如何改进它的建议。

票数 1
EN

Stack Overflow用户

发布于 2020-11-03 20:56:50

用泛型修改@Ramunas答案:

代码语言:javascript
复制
class CollectSetDemoAgg[T: TypeTag](name: String) extends Aggregator[Row, Set[T], Seq[T]] {
  override def zero = Set.empty
  override def reduce(b: Set[T], a: Row) = b + a.getAs[T](a.fieldIndex(name))
  override def merge(b1: Set[T], b2: Set[T]) = b1 ++ b2
  override def finish(reduction: Set[T]) = reduction.toSeq
  override def bufferEncoder = Encoders.kryo[Set[T]]
  
  override def outputEncoder = {
    val tt = typeTag[Seq[T]]
    val tpe = tt.in(mirror).tpe

    val cls = mirror.runtimeClass(tpe)
    val serializer = serializerForType(tpe)
    val deserializer = deserializerForType(tpe)

    new ExpressionEncoder[Seq[T]](serializer, deserializer, ClassTag[Seq[T]](cls))
  }
}
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/63340626

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档