我在管道上工作,并在将列值传递给CountVectorizer之前尝试拆分它。
为此,我制作了一个自定义转换器。
class FlatMapTransformer(override val uid: String)
extends Transformer {
/**
* Param for input column name.
* @group param
*/
final val inputCol = new Param[String](this, "inputCol", "The input column")
final def getInputCol: String = $(inputCol)
/**
* Param for output column name.
* @group param
*/
final val outputCol = new Param[String](this, "outputCol", "The output column")
final def getOutputCol: String = $(outputCol)
def setInputCol(value: String): this.type = set(inputCol, value)
def setOutputCol(value: String): this.type = set(outputCol, value)
def this() = this(Identifiable.randomUID("FlatMapTransformer"))
private val flatMap: String => Seq[String] = { input: String =>
input.split(",")
}
override def copy(extra: ParamMap): SplitString = defaultCopy(extra)
override def transform(dataset: Dataset[_]): DataFrame = {
val flatMapUdf = udf(flatMap)
dataset.withColumn($(outputCol), explode(flatMapUdf(col($(inputCol)))))
}
override def transformSchema(schema: StructType): StructType = {
val dataType = schema($(inputCol)).dataType
require(
dataType.isInstanceOf[StringType],
s"Input column must be of type StringType but got ${dataType}")
val inputFields = schema.fields
require(
!inputFields.exists(_.name == $(outputCol)),
s"Output column ${$(outputCol)} already exists.")
DataTypes.createStructType(
Array(
DataTypes.createStructField($(outputCol), DataTypes.StringType, false)))
}
}代码似乎是合法的,但是当我试图用其他操作链接它时,就会出现问题。这是我的管道:
val train = reader.readTrainingData()
val cat_features = getFeaturesByType(taskConfig, "categorical")
val num_features = getFeaturesByType(taskConfig, "numeric")
val cat_ohe_features = getFeaturesByType(taskConfig, "categorical", Some("ohe"))
val cat_features_string_index = cat_features.
filter { feature: String => !cat_ohe_features.contains(feature) }
val catIndexer = cat_features_string_index.map {
feature =>
new StringIndexer()
.setInputCol(feature)
.setOutputCol(feature + "_index")
.setHandleInvalid("keep")
}
val flatMapper = cat_ohe_features.map {
feature =>
new FlatMapTransformer()
.setInputCol(feature)
.setOutputCol(feature + "_transformed")
}
val countVectorizer = cat_ohe_features.map {
feature =>
new CountVectorizer()
.setInputCol(feature + "_transformed")
.setOutputCol(feature + "_vectorized")
.setVocabSize(10)
}
// val countVectorizer = cat_ohe_features.map {
// feature =>
//
// val flatMapper = new FlatMapTransformer()
// .setInputCol(feature)
// .setOutputCol(feature + "_transformed")
//
// new CountVectorizer()
// .setInputCol(flatMapper.getOutputCol)
// .setOutputCol(feature + "_vectorized")
// .setVocabSize(10)
// }
val cat_features_index = cat_features_string_index.map {
(feature: String) => feature + "_index"
}
val count_vectorized_index = cat_ohe_features.map {
(feature: String) => feature + "_vectorized"
}
val catFeatureAssembler = new VectorAssembler()
.setInputCols(cat_features_index)
.setOutputCol("cat_features")
val oheFeatureAssembler = new VectorAssembler()
.setInputCols(count_vectorized_index)
.setOutputCol("cat_ohe_features")
val numFeatureAssembler = new VectorAssembler()
.setInputCols(num_features)
.setOutputCol("num_features")
val featureAssembler = new VectorAssembler()
.setInputCols(Array("cat_features", "num_features", "cat_ohe_features_vectorized"))
.setOutputCol("features")
val pipelineStages = catIndexer ++ flatMapper ++ countVectorizer ++
Array(
catFeatureAssembler,
oheFeatureAssembler,
numFeatureAssembler,
featureAssembler)
val pipeline = new Pipeline().setStages(pipelineStages)
pipeline.fit(dataset = train)运行此代码,我将收到一个错误:java.lang.IllegalArgumentException: Field "my_ohe_field_trasformed" does not exist.
[info] java.lang.IllegalArgumentException: Field "from_expdelv_areas_transformed" does not exist.
[info] at org.apache.spark.sql.types.StructType$$anonfun$apply$1.apply(StructType.scala:266)
[info] at org.apache.spark.sql.types.StructType$$anonfun$apply$1.apply(StructType.scala:266)
[info] at scala.collection.MapLike$class.getOrElse(MapLike.scala:128)
[info] at scala.collection.AbstractMap.getOrElse(Map.scala:59)
[info] at org.apache.spark.sql.types.StructType.apply(StructType.scala:265)
[info] at org.apache.spark.ml.util.SchemaUtils$.checkColumnTypes(SchemaUtils.scala:56)
[info] at org.apache.spark.ml.feature.CountVectorizerParams$class.validateAndTransformSchema(CountVectorizer.scala:75)
[info] at org.apache.spark.ml.feature.CountVectorizer.validateAndTransformSchema(CountVectorizer.scala:123)
[info] at org.apache.spark.ml.feature.CountVectorizer.transformSchema(CountVectorizer.scala:188)当我取消对stringSplitter和countVectorizer的注释时,会在我的转换器中引发错误
java.lang.IllegalArgumentException: Field "my_ohe_field" does not exist. at val dataType = schema($(inputCol)).dataType
调用pipeline.getStages的结果
strIdx_3c2630a738f0
strIdx_0d76d55d4200
FlatMapTransformer_fd8595c2969c
FlatMapTransformer_2e9a7af0b0fa
cntVec_c2ef31f00181
cntVec_68a78eca06c9
vecAssembler_a81dd9f43d56
vecAssembler_b647d348f0a0
vecAssembler_b5065a22d5c8
vecAssembler_d9176b8bb593我可能走错路了。如有任何意见,敬请见谅。
发布于 2020-05-26 09:54:10
您的FlatMapTransformer #transform是不正确的,当您只在outputCol上选择时,您会删除/忽略所有其他列。
请将您的方法修改为-
override def transform(dataset: Dataset[_]): DataFrame = {
val flatMapUdf = udf(flatMap)
dataset.withColumn($(outputCol), explode(flatMapUdf(col($(inputCol)))))
}另外,修改您的transformSchema,在检查其数据类型之前先检查输入列-
override def transformSchema(schema: StructType): StructType = {
require(schema.names.contains($(inputCol)), "inputCOl is not there in the input dataframe")
//... rest as it is
}根据评论更新-1
copy方法(尽管它不是您面临异常的原因)-override def copy(extra: ParamMap): FlatMapTransformer = defaultCopy(extra)CountVectorizer接受具有ArrayType(StringType, true/false)类型列的列,并且由于FlatMapTransformer输出列成为CountVectorizer的输入,因此需要确保FlatMapTransformer的输出列必须为ArrayType(StringType, true/false)。我想,不是这样的,你今天的代码如下- override def transform(dataset: Dataset[_]): DataFrame = {
val flatMapUdf = udf(flatMap)
dataset.withColumn($(outputCol), explode(flatMapUdf(col($(inputCol)))))
}explode函数将array<string>转换为string,因此转换器的输出变成StringType。你可能想把这个代码修改成-
override def transform(dataset: Dataset[_]): DataFrame = {
val flatMapUdf = udf(flatMap)
dataset.withColumn($(outputCol), flatMapUdf(col($(inputCol))))
}transformSchema方法以输出ArrayType(StringType) override def transformSchema(schema: StructType): StructType = {
val dataType = schema($(inputCol)).dataType
require(
dataType.isInstanceOf[StringType],
s"Input column must be of type StringType but got ${dataType}")
val inputFields = schema.fields
require(
!inputFields.exists(_.name == $(outputCol)),
s"Output column ${$(outputCol)} already exists.")
schema.add($(outputCol), ArrayType(StringType))
}val featureAssembler = new VectorAssembler()
.setInputCols(Array("cat_features", "num_features", "cat_ohe_features"))
.setOutputCol("features")我试着用虚拟数据文件执行你的管道,效果很好。请参考这个要旨获得完整代码。
https://stackoverflow.com/questions/62018875
复制相似问题