首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >PySpark余弦相似变压器

PySpark余弦相似变压器
EN

Stack Overflow用户
提问于 2019-08-07 08:40:05
回答 1查看 155关注 0票数 0

我有一个包含两个列的DataFrame,每个列都包含向量。

代码语言:javascript
复制
+-------------+------------+
|     v1      |     v2     |
+-------------+------------+
| [1,1.2,0.4] | [2,0.4,5]  |
| [1,.2,0.6]  | [2,.2,5]   |
| .           | .          |
| .           | .          |
| .           | .          |
| [0,1.2,.6]  | [2,.2,0.4] |
+-------------+------------+

我想在这个DataFrame中添加另一列,它包含每个行中两个向量之间的宇宙相似性。

  • 这个有变压器吗?
  • 变压器是完成这项任务的正确方法吗?
  • 如果这是正确的方法,而且没有这样的变压器,你能给我一个如何写自己的指针吗?
EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2019-08-08 21:13:44

不知道有任何变换可以直接计算这里的consine相似性。您可以为这样的功能编写自己的udf

代码语言:javascript
复制
from pyspark.ml.linalg import Vectors, DenseVector
from pyspark.sql import functions as F
from pyspark.ml.feature import VectorAssembler
from pyspark.sql.types import *

v = [(DenseVector([1,1.2,0.4]), DenseVector([2,0.4,5])),
    (DenseVector([1,2,0.6]), DenseVector([2,0.2,5])),
    (DenseVector([0,1.2,0.6]), DenseVector([2,0.2,0.4]))]

dfv1 = spark.createDataFrame(v, ['v1', 'v2'])
dfv1 = dfv1.withColumn('v1v2', F.struct([F.col('v1'), F.col('v2')]))
dfv1.show(truncate=False)

下面是带有组合向量的DataFrame:

代码语言:javascript
复制
+-------------+-------------+------------------------------+
|v1           |v2           |v1v2                          |
+-------------+-------------+------------------------------+
|[1.0,1.2,0.4]|[2.0,0.4,5.0]|[[1.0,1.2,0.4], [2.0,0.4,5.0]]|
|[1.0,2.0,0.6]|[2.0,0.2,5.0]|[[1.0,2.0,0.6], [2.0,0.2,5.0]]|
|[0.0,1.2,0.6]|[2.0,0.2,0.4]|[[0.0,1.2,0.6], [2.0,0.2,0.4]]|
+-------------+-------------+------------------------------+

现在,我们可以定义余弦相似性的udf

代码语言:javascript
复制
dot_prod_udf = F.udf(lambda v: float(v[0].dot(v[1])/v[0].norm(None)/v[1].norm(None)), FloatType())
dfv1 = dfv1.withColumn('cosine_similarity', dot_prod_udf(dfv1['v1v2']))
dfv1.show(truncate=False)

最后一列显示余弦相似性:

代码语言:javascript
复制
+-------------+-------------+------------------------------+-----------------+
|v1           |v2           |v1v2                          |cosine_similarity|
+-------------+-------------+------------------------------+-----------------+
|[1.0,1.2,0.4]|[2.0,0.4,5.0]|[[1.0,1.2,0.4], [2.0,0.4,5.0]]|0.51451445       |
|[1.0,2.0,0.6]|[2.0,0.2,5.0]|[[1.0,2.0,0.6], [2.0,0.2,5.0]]|0.4328257        |
|[0.0,1.2,0.6]|[2.0,0.2,0.4]|[[0.0,1.2,0.6], [2.0,0.2,0.4]]|0.17457432       |
+-------------+-------------+------------------------------+-----------------+
票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/57390273

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档