我使用CountVectorizerModel创建文本外的要素,以便在LDA中进行训练
+-----+--------------------------------------+-------------------------------------------------+-------------------------------------------------------------------+
|label|sentence |words |features |
+-----+--------------------------------------+-------------------------------------------------+-------------------------------------------------------------------+
|0.0 |Hi I heard about Spark |[hi, i, heard, about, spark] |(30,[1,5,6,7,16],[1.0,1.0,1.0,1.0,1.0]) |
|0.0 |I wish Java could use case classes |[i, wish, java, could, use, case, classes] |(30,[5,9,11,13,24,26,29],[1.0,1.0,1.0,1.0,1.0,1.0,1.0]) |
|1.0 |Logistic regression models are neat |[logistic, regression, models, are, neat] |(30,[4,14,18,21,22],[1.0,1.0,1.0,1.0,1.0]) |
|1.0 |They are cats |[they, are, cats] |(30,[3,4,17],[1.0,1.0,1.0]) |
|0.0 |cat is only one cat in a group of cats|[cat, is, only, one, cat, in, a, group, of, cats]|(30,[0,2,3,8,10,20,23,27,28],[2.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0])|
|1.0 |cat is meowingful all day long. |[cat, is, meowingful, all, day, long.] |(30,[0,2,12,15,19,25],[1.0,1.0,1.0,1.0,1.0,1.0]) |
|1.0 |cat |[cat] |(30,[0],[1.0]) |
|1.0 |spark |[spark] |(30,[1],[1.0]) |
|1.0 |spark cat |[spark, cat] |(30,[0,1],[1.0,1.0]) |
+-----+--------------------------------------+-------------------------------------------------+-------------------------------------------------------------------+这是主题
val topics = model.describeTopics(3)
println("The topics described by their top-weighted terms:")
topics.show(false)
+-----+------------+-----------------------------------------------------------------+
|topic|termIndices |termWeights |
+-----+------------+-----------------------------------------------------------------+
|0 |[2, 5, 7] |[0.03954771670945735, 0.03941180947330347, 0.03888945410782809] |
|1 |[3, 23, 20] |[0.038638315281474093, 0.037879704408459995, 0.03774139169021561]|
|2 |[9, 28, 21] |[0.04232988497943897, 0.04007287769364308, 0.039937267948921336] |
|3 |[18, 5, 15] |[0.03705824484750299, 0.036890803795663674, 0.036716976690456406]|
|4 |[15, 2, 19] |[0.051298533195568756, 0.049034272085125466, 0.04766027890074748]|
|5 |[8, 15, 28] |[0.039784800740184825, 0.03919450578763458, 0.03747537818514296] |
|6 |[26, 7, 10] |[0.03914211167490289, 0.038519959566040284, 0.03777486155909476] |
|7 |[3, 2, 25] |[0.03824521540169412, 0.03809586773398763, 0.03744203244313033] |
|8 |[8, 28, 1] |[0.04141091418342947, 0.040997706216988956, 0.03925572055141317] |
|9 |[16, 24, 23]|[0.04106798576100414, 0.03947867647938766, 0.036999875515655097] |
+-----+------------+-----------------------------------------------------------------+和他们的模式
root
|-- topic: integer (nullable = false)
|-- termIndices: array (nullable = true)
| |-- element: integer (containsNull = false)
|-- termWeights: array (nullable = true)
| |-- element: double (containsNull = false)我想创建另一个列(名为term)来显示字符串而不是索引。
因此,我创建了一个函数
val lookup2 = ((a:Array[Int]) => {
a.map(x => cvModel.vocabulary(x))
})当我使用一个案例进行测试时,函数lookup2工作得很好
lookup2(Array(1,2,3))
res194: Array[String] = Array(spark, is, cats)我尝试将函数转换为UDF并应用于整个列
val lookupudf = udf(lookup2)
topics.withColumn("term", lookupudf($"termIndices")).show()它不能工作
org.apache.spark.SparkException: Failed to execute user defined function($anonfun$1: (array<int>) => array<string>)
at org.apache.spark.sql.catalyst.expressions.ScalaUDF.eval(ScalaUDF.scala:1058)
at org.apache.spark.sql.catalyst.expressions.UnaryExpression.eval(Expression.scala:359)
at org.apache.spark.sql.catalyst.expressions.Alias.eval(namedExpressions.scala:139)
at org.apache.spark.sql.catalyst.expressions.InterpretedProjection.apply(Projection.scala:48)
at org.apache.spark.sql.catalyst.expressions.InterpretedProjection.apply(Projection.scala:30)
at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234)
at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234)
at scala.collection.mutable.ResizableArray$class.foreach(ResizableArray.scala:59)
at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:48)
at scala.collection.TraversableLike$class.map(TraversableLike.scala:234)
at scala.collection.AbstractTraversable.map(Traversable.scala:104)
at org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation$$anonfun$apply$23.applyOrElse(Optimizer.scala:1191)
at org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation$$anonfun$apply$23.applyOrElse(Optimizer.scala:1186)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$2.apply(TreeNode.scala:267)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$2.apply(TreeNode.scala:267)
at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:70)
at org.apache.spark.sql.catalyst.trees.TreeNode.transformDown(TreeNode.scala:266)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformDown$1.apply(TreeNode.scala:272)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformDown$1.apply(TreeNode.scala:272)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$4.apply(TreeNode.scala:306)
at org.apache.spark.sql.catalyst.trees.TreeNode.mapProductIterator(TreeNode.scala:187)
at org.apache.spark.sql.catalyst.trees.TreeNode.mapChildren(TreeNode.scala:304)
at org.apache.spark.sql.catalyst.trees.TreeNode.transformDown(TreeNode.scala:272)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformDown$1.apply(TreeNode.scala:272)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformDown$1.apply(TreeNode.scala:272)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$4.apply(TreeNode.scala:306)
at org.apache.spark.sql.catalyst.trees.TreeNode.mapProductIterator(TreeNode.scala:187)
at org.apache.spark.sql.catalyst.trees.TreeNode.mapChildren(TreeNode.scala:304)
at org.apache.spark.sql.catalyst.trees.TreeNode.transformDown(TreeNode.scala:272)
at org.apache.spark.sql.catalyst.trees.TreeNode.transform(TreeNode.scala:256)
at org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation$.apply(Optimizer.scala:1186)
at org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation$.apply(Optimizer.scala:1185)
at org.apache.spark.sql.catalyst.rules.RuleExecutor$$anonfun$execute$1$$anonfun$apply$1.apply(RuleExecutor.scala:87)
at org.apache.spark.sql.catalyst.rules.RuleExecutor$$anonfun$execute$1$$anonfun$apply$1.apply(RuleExecutor.scala:84)
at scala.collection.IndexedSeqOptimized$class.foldl(IndexedSeqOptimized.scala:57)
at scala.collection.IndexedSeqOptimized$class.foldLeft(IndexedSeqOptimized.scala:66)
at scala.collection.mutable.WrappedArray.foldLeft(WrappedArray.scala:35)
at org.apache.spark.sql.catalyst.rules.RuleExecutor$$anonfun$execute$1.apply(RuleExecutor.scala:84)
at org.apache.spark.sql.catalyst.rules.RuleExecutor$$anonfun$execute$1.apply(RuleExecutor.scala:76)
at scala.collection.immutable.List.foreach(List.scala:381)
at org.apache.spark.sql.catalyst.rules.RuleExecutor.execute(RuleExecutor.scala:76)
at org.apache.spark.sql.execution.QueryExecution.optimizedPlan$lzycompute(QueryExecution.scala:66)
at org.apache.spark.sql.execution.QueryExecution.optimizedPlan(QueryExecution.scala:66)
at org.apache.spark.sql.execution.QueryExecution.sparkPlan$lzycompute(QueryExecution.scala:72)
at org.apache.spark.sql.execution.QueryExecution.sparkPlan(QueryExecution.scala:68)
at org.apache.spark.sql.execution.QueryExecution.executedPlan$lzycompute(QueryExecution.scala:77)
at org.apache.spark.sql.execution.QueryExecution.executedPlan(QueryExecution.scala:77)
at org.apache.spark.sql.Dataset.withAction(Dataset.scala:3248)
at org.apache.spark.sql.Dataset.head(Dataset.scala:2484)
at org.apache.spark.sql.Dataset.take(Dataset.scala:2698)
at org.apache.spark.sql.Dataset.showString(Dataset.scala:254)
at org.apache.spark.sql.Dataset.show(Dataset.scala:723)
at org.apache.spark.sql.Dataset.show(Dataset.scala:682)
at org.apache.spark.sql.Dataset.show(Dataset.scala:691)
... 52 elided
Caused by: java.lang.ClassCastException: scala.collection.mutable.WrappedArray$ofRef cannot be cast to [I
at $anonfun$1.apply(<console>:58)
at org.apache.spark.sql.catalyst.expressions.ScalaUDF$$anonfun$2.apply(ScalaUDF.scala:102)
at org.apache.spark.sql.catalyst.expressions.ScalaUDF$$anonfun$2.apply(ScalaUDF.scala:101)
at org.apache.spark.sql.catalyst.expressions.ScalaUDF.eval(ScalaUDF.scala:1055)
... 105 more我需要修复什么?
发布于 2018-12-17 15:45:07
DataFrame中的数组是WrappedArray。所以我应该这样定义我的udf
import org.apache.spark.sql.functions.udf
val lookup3 = ((a:WrappedArray[Int]) => {
a.toArray.map(x => cvModel.vocabulary(x))
})
val lookupudf3 = udf(lookup3)然后使用udf创建新列
topics.withColumn("term", lookupudf3($"termIndices")).show()它应该是可行的
+-----+------------+--------------------+--------------------+
|topic| termIndices| termWeights| term|
+-----+------------+--------------------+--------------------+
| 0| [2, 5, 7]|[0.03954762152543...| [cats, are, hi]|
| 1| [3, 23, 20]|[0.03863839536342...| [is, long., use]|
| 2| [9, 28, 21]|[0.04232988718372...|[could, they, cla...|
| 3| [18, 5, 15]|[0.03705824666867...| [of, are, one]|
| 4| [18, 3, 15]|[0.04114420013742...| [of, is, one]|
| 5| [8, 15, 28]|[0.03978480361117...| [a, one, they]|
| 6| [26, 7, 10]|[0.03914211373502...| [logistic, hi, in]|
| 7| [3, 25, 23]|[0.05067447986285...| [is, day, long.]|
| 8| [8, 28, 1]|[0.04141091392312...| [a, they, spark]|
| 9|[16, 24, 23]|[0.04106809235206...|[meowingful, java...|
+-----+------------+--------------------+--------------------+https://stackoverflow.com/questions/53809332
复制相似问题