基本上,我希望在每一行数据中应用函数countSimilarColumns,并将结果放到一个新列中。
我的代码如下
def main(args: Array[String]) = {
val customerID = "customer-1" //args(0)
val rawData = readFromResource("json", "/spark-test-data-copy.json")
val flattenData = rawData.select(flattenSchema(rawData.schema): _*)
val referenceCustomerRow = flattenData.transform(getCustomer(customerID)).first
}
def getCustomer(customerID: String)(dataFrame: DataFrame) = {
dataFrame.filter($"customer" === customerID)
}
def countSimilarColumns(first: Row, second: Row): Int = {
if (!(first.getAs[String]("customer").equals(second.getAs[String]("customer"))))
first.toSeq.zip(second.toSeq).count { case (x, y) => x == y }
else
-1
}我想做的事情如下。但我不知道该怎么做。
flattenData
.withColumn(
"similarity_score",
flattenData.map(row => countSimilarColumns(row, referenceCustomerRow))
)
.show()示例数据扁平:
{"customer":"customer-1","att-a":"7","att-b":"3","att-c":"10","att-d":"10"}
{"customer":"customer-2","att-a":"9","att-b":"7","att-c":"12","att-d":"4"}
{"customer":"customer-3","att-a":"7","att-b":"3","att-c":"1","att-d":"10"}
{"customer":"customer-4","att-a":"9","att-b":"14","att-c":"10","att-d":"4"}期望输出:
+--------------------+-----------+
| customer | similarity_score |
+--------------------+-----------+
|customer-1 | -1 |
|customer-2 | 0 |
|customer-3 | 3 |
|customer-4 | 1 |UDF是唯一的方法吗?如果是,那么我希望函数countSimilarColumns保持不变,因此它是可测试的。怎么可能?我是斯派克/斯卡拉的新手。
发布于 2020-05-16 12:39:21
flattenData是DataFrame类型&在flattenData上应用映射函数将得到Dataset的结果。
您正在将flattenData.map(row => countSimilarColumns(row, referenceCustomerRow))的结果传递给withColumn,但是withColumn只能接受org.apache.spark.sql.Column类型的数据。
因此,如果您想在没有UDF的情况下将上述结果添加到列中,则必须使用collect函数&然后将其传递给lit
请检查下面的代码。
flattenData
.withColumn("similarity_score",lit(
flattenData
.map(row => countSimilarColumns(row, referenceCustomerRow))
.collect
.map(_.toInt)
)
) 基于示例数据,添加到下面的逻辑.
scala> df.show(false)
+-----+-----+-----+-----+----------+
|att-a|att-b|att-c|att-d|customer |
+-----+-----+-----+-----+----------+
|7 |3 |10 |10 |customer-1|
|9 |7 |12 |4 |customer-2|
|7 |3 |1 |10 |customer-3|
|9 |14 |10 |4 |customer-4|
+-----+-----+-----+-----+----------+
scala> val conditions = df.columns.filterNot(_ == "customer").map(c => (when(count(col(c)).over(Window.partitionBy(col(c)).orderBy(col(c).asc)) =!= 1,lit(1)).otherwise(0))).reduce(_ + _) // if row_number is 1 then adding 1 else 0 ..
conditions: org.apache.spark.sql.Column = (((CASE WHEN (NOT (count(att-a) OVER (PARTITION BY att-a ORDER BY att-a ASC NULLS FIRST unspecifiedframe$()) = 1)) THEN 1 ELSE 0 END + CASE WHEN (NOT (count(att-b) OVER (PARTITION BY att-b ORDER BY att-b ASC NULLS FIRST unspecifiedframe$()) = 1)) THEN 1 ELSE 0 END) + CASE WHEN (NOT (count(att-c) OVER (PARTITION BY att-c ORDER BY att-c ASC NULLS FIRST unspecifiedframe$()) = 1)) THEN 1 ELSE 0 END) + CASE WHEN (NOT (count(att-d) OVER (PARTITION BY att-d ORDER BY att-d ASC NULLS FIRST unspecifiedframe$()) = 1)) THEN 1 ELSE 0 END)最终结果
scala> df.withColumn("similarity_score",conditions).show(false)
+-----+-----+-----+-----+----------+----------------+
|att-a|att-b|att-c|att-d|customer |similarity_score|
+-----+-----+-----+-----+----------+----------------+
|9 |7 |12 |4 |customer-2|2 |
|7 |3 |1 |10 |customer-3|3 |
|7 |3 |10 |10 |customer-1|4 |
|9 |14 |10 |4 |customer-4|3 |
+-----+-----+-----+-----+----------+----------------+https://stackoverflow.com/questions/61830946
复制相似问题