首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >星火自定义估计器,包括持久性

星火自定义估计器,包括持久性
EN

Stack Overflow用户
提问于 2016-11-26 10:38:11
回答 2查看 2.6K关注 0票数 6

我想要开发一个自定义的火花估计器,它也处理伟大的管道API的持久性。但正如如何在PySpark mllib中滚动自定义估计器所言,目前还没有太多的文档。

我有一些数据清理代码编写的火花,并希望包装在一个自定义估计器。包括一些na替换、列删除、过滤和基本特征生成(例如生日到年龄)。

  • transformSchema将使用数据集ScalaReflection.schemaFor[MyClass].dataType.asInstanceOf[StructType]的case类。
  • 适合只适合,例如,平均年龄为na。代用品

我还不清楚的是:

  • transform在定制管道模型中将用于转换新数据上的“拟合”估计值。这是正确的吗?如果是,我应该如何将拟合值(如平均年龄)从上面转换到模型中?
  • 如何处理持久力?我在私有spark组件中找到了一些通用的loadImpl方法,但不确定如何将自己的参数(例如,平均年龄)传递到用于序列化的MLReader / MLWriter中。

如果你能帮我做一个自定义评估器--尤其是持久化部分,那就太好了。

EN

回答 2

Stack Overflow用户

回答已采纳

发布于 2016-11-26 12:00:03

首先,我认为你混合了两种不同的东西:

  • Estimators --它代表的阶段可以是fit-ted。Estimator fit方法采用Dataset并返回Transformer (模型)。
  • Transformers --它表示可以transform数据的阶段。

当您fit Pipeline it fits all Estimators并返回PipelineModel时。PipelineModel可以在模型中的所有Transformers上依次调用transform

我应该如何转换合适的值

这个问题没有一个单一的答案。一般来说,您有两种选择:

  • 将拟合模型的参数作为Transformer的参数传递。
  • 制作适合的Params Transformer模型参数。

第一种方法通常由内置的Transformer使用,但是第二种方法在一些简单的情况下应该有效。

如何处理持久化

  • 如果Transformer仅由其Params定义,则可以扩展DefaultParamsReadable
  • 如果使用更复杂的参数,则应该扩展MLWritable并实现对数据有意义的MLWriter。Spark源代码中有多个示例,说明了如何实现数据和元数据的读写。

如果您正在寻找一个易于理解的示例,请查看CountVectorizer(Model),其中:

票数 4
EN

Stack Overflow用户

发布于 2017-05-31 03:22:46

下面使用了Scala ,但是如果您真的想要的话,可以轻松地将它重构为.

首先要做的是:

  • 估计器:实现返回转换器的.fit()
  • 转换器:实现.transform()并操作DataFrame
  • Serialization/Deserialization:尽力使用内置参数,并利用简单的DefaultParamsWritable特性+伙伴对象扩展DefaultParamsReadable[T]。A远离MLReader / MLWriter,保持代码简单。
  • 传递参数:使用扩展Params的共同特征,并在您的估计器和模型之间共享它(a.k.a )。变压器)

骨架代码:

代码语言:javascript
复制
// Common Parameters
trait MyCommonParams extends Params {
  final val inputCols: StringArrayParam = // usage: new MyMeanValueStuff().setInputCols(...)
    new StringArrayParam(this, "inputCols", "doc...")
    def setInputCols(value: Array[String]): this.type = set(inputCols, value)
    def getInputCols: Array[String] = $(inputCols)

  final val meanValues: DoubleArrayParam = 
    new DoubleArrayParam(this, "meanValues", "doc...")
    // more setters and getters
}

// Estimator
class MyMeanValueStuff(override val uid: String) extends Estimator[MyMeanValueStuffModel] 
  with DefaultParamsWritable // Enables Serialization of MyCommonParams
  with MyCommonParams {

  override def copy(extra: ParamMap): Estimator[MeanValueFillerModel] = defaultCopy(extra) // deafult
  override def transformSchema(schema: StructType): StructType = schema // no changes
  override def fit(dataset: Dataset[_]): MyMeanValueStuffModel = {
    // your logic here. I can't do all the work for you! ;)
   this.setMeanValues(meanValues)
   copyValues(new MyMeanValueStuffModel(uid + "_model").setParent(this))
  }
}
// Companion object enables deserialization of MyCommonParams
object MyMeanValueStuff extends DefaultParamsReadable[MyMeanValueStuff]

// Model (Transformer)
class MyMeanValueStuffModel(override val uid: String) extends Model[MyMeanValueStuffModel] 
  with DefaultParamsWritable // Enables Serialization of MyCommonParams
  with MyCommonParams {

  override def copy(extra: ParamMap): MyMeanValueStuffModel = defaultCopy(extra) // default
  override def transformSchema(schema: StructType): StructType = schema // no changes
  override def transform(dataset: Dataset[_]): DataFrame = {
      // your logic here: zip inputCols and meanValues, toMap, replace nulls with NA functions
      // you have access to both inputCols and meanValues here!
  }
}
// Companion object enables deserialization of MyCommonParams
object MyMeanValueStuffModel extends DefaultParamsReadable[MyMeanValueStuffModel]

使用上面的代码,您可以序列化/反序列化包含MyMeanValueStuff级的管道。

想看看估值器的一些真正的简单实现吗?MinMaxScaler!(我的例子其实更简单.)

票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/40817433

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档