首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >IF语句Pyspark

IF语句Pyspark
EN

Stack Overflow用户
提问于 2017-12-01 05:36:54
回答 2查看 30.9K关注 0票数 6

我的数据如下所示:

代码语言:javascript
复制
+----------+-------------+-------+--------------------+--------------+---+
|purch_date|  purch_class|tot_amt|       serv-provider|purch_location| id|
+----------+-------------+-------+--------------------+--------------+---+
|03/11/2017|Uncategorized| -17.53|             HOVER  |              |  0|
|02/11/2017|    Groceries| -70.05|1774 MAC'S CONVEN...|     BRAMPTON |  1|
|31/10/2017|Gasoline/Fuel|    -20|              ESSO  |              |  2|
|31/10/2017|       Travel|     -9|TORONTO PARKING A...|      TORONTO |  3|
|30/10/2017|    Groceries|  -1.84|         LONGO'S # 2|              |  4|

我正在尝试创建一个binary列,它将由tot_amt列的值定义。我想将这一列添加到上面的数据中。如果tot_amt <(-50),我希望它返回0,如果tot_amt > (-50),我希望它在新列中返回1。

我到目前为止的尝试是:

代码语言:javascript
复制
from pyspark.sql.types import IntegerType
from pyspark.sql.functions import udf

def y(row):
    if row['tot_amt'] < (-50):
        val = 1
    else:
        val = 0
        return val

y_udf = udf(y, IntegerType())
df_7 = df_4.withColumn('Y',y_udf(df_4['tot_amt'], (df_4['purch_class'], 
(df_4['purch_date'], (df_4['serv-provider'], (df_4['purch_location'])))
display(df_7)

我收到错误消息:

代码语言:javascript
复制
SparkException: Job aborted due to stage failure: Task 0 in stage 67.0 failed 
1 times, most recent failure: Lost task 0.0 in stage 67.0 (TID 85, localhost, 
executor driver): org.apache.spark.api.python.PythonException: Traceback (most 
recent call last):
File "/databricks/spark/python/pyspark/worker.py", line 177, in main
process()
File "/databricks/spark/python/pyspark/worker.py", line 172, in process
serializer.dump_stream(func(split_index, iterator), outfile)
File "/databricks/spark/python/pyspark/worker.py", line 104, in <lambda>
func = lambda _, it: map(mapper, it)
File "<string>", line 1, in <lambda>
File "/databricks/spark/python/pyspark/worker.py", line 71, in <lambda>
return lambda *a: f(*a)
TypeError: y() takes exactly 1 argument (2 given)
EN

回答 2

Stack Overflow用户

回答已采纳

发布于 2017-12-01 07:11:00

如何让它工作(通过struct)

代码语言:javascript
复制
from pyspark.sql.functions import struct

df_4.withColumn("y", y_udf(
    # Include columns you want
    struct(df_4['tot_amt'], df_4['purch_class'])
))

什么会更有意义?

代码语言:javascript
复制
y_udf = udf(lambda y: 1 if y < -50 else 0, IntegerType())

df_4.withColumn("y", y_udf('tot_amt'))

它应该如何做到:

代码语言:javascript
复制
from pyspark.sql.functions import when

df_4.withColumn("y", when(df_4['tot_amt'] < -50, 1).otherwise(0))
票数 7
EN

Stack Overflow用户

发布于 2017-12-01 07:16:58

为此,您不应该需要UDF -请使用内置函数when。下面是一个包含类似于tot_amt列的玩具数据的示例:

代码语言:javascript
复制
spark.version
# u'2.2.0'

from pyspark.sql import Row
from pyspark.sql.functions import col, when

df = spark.createDataFrame([Row(-17.53),
                              Row(-70.05),
                              Row(-20.),
                              Row(-9.),
                              Row(-1.84)
                             ],
                              ["tot_amt"])

df.show()
# +-------+
# |tot_amt|
# +-------+
# | -17.53| 
# | -70.05|
# |  -20.0|
# |   -9.0|
# |  -1.84|
# +-------+

df.withColumn('Y', when(col('tot_amt') < -50., 1).otherwise(0)).show()
# +-------+---+ 
# |tot_amt|  Y|
# +-------+---+
# | -17.53|  0|
# | -70.05|  1|
# |  -20.0|  0|
# |   -9.0|  0| 
# |  -1.84|  0|
# +-------+---+
票数 8
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/47583007

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档