我希望执行一些条件分支,以避免计算不必要的节点,但我注意到,如果条件语句中的源列是UDF,那么否则将解析,不管如何:
@pandas_udf("double", PandasUDFType.SCALAR)
def udf_that_throws_exception(*cols):
raise Exception('Error')
@pandas_udf("int", PandasUDFType.SCALAR)
def simple_mul_udf(*cols):
result = cols[0]
for c in cols[1:]:
result *= c
return result
df = spark.range(0,5)
df = df.withColumn('A', lit(1))
df = df.withColumn('B', lit(2))
df = df.withColumn('udf', simple_mul('A','B'))
df = df.withColumn('sql', expr('A*B'))
df = df.withColumn('res', when(df.sql < 100, lit(1)).otherwise(udf_that_throws(lit(0))))上面的代码如预期的那样工作,本例中的语句始终为true,因此我抛出异常的UDF永远不会被调用。
但是,如果我将条件改为使用df.udf,则会突然调用否则的UDF,即使条件结果没有改变,也会得到异常。
我想我可以通过从条件中删除UDF来混淆它,但是无论发生什么,都会发生相同的结果:
df = df.withColumn('cond', when(df.udf < 100, lit(1)).otherwise(lit(0)))
df = df.withColumn('res', when(df.cond == lit(1), lit(1)).otherwise(udf_that_throws_exception(lit(0))))我认为这与星火优化的方式有关,这很好,但我正在寻找任何方法来做到这一点,而不承担成本。有什么想法吗?
编辑每一个请求获得更多信息。我们正在编写一个可以接受任意模型并由代码生成图形的处理引擎。在此过程中,我们在运行时根据值的状态进行决策。我们大量使用熊猫。因此,假设图中有多条路径,根据运行时的某些条件,我们希望遵循其中的一条路径,而所有其他路径都不受影响。
我想将这个逻辑编码到图中,这样就没有必要在代码中收集和分支了。
我提供的示例代码仅用于演示。我面临的问题是,如果if语句中使用的列是UDF,或者,如果它是从UDF派生出来的,那么即使它从未实际使用,否则的条件也总是被执行。如果If / such是廉价的操作,比如文字,我不会介意,但是如果列UDF (可能是两边的)导致了一个大的聚合或其他一些实际上被丢弃的长度进程,该怎么办?
发布于 2019-10-19 08:57:05
在PySpark中,UDF是预先计算出来的,因此您得到了这个次优的bahaviour。您还可以从查询计划中看到它:
== Physical Plan ==
*(2) Project [id#753L, 1 AS A#755, 2 AS B#758, pythonUDF1#776 AS udf#763, CASE WHEN (pythonUDF1#776 < 100) THEN 1.0 ELSE pythonUDF2#777 END AS res#769]
+- ArrowEvalPython [simple_mul_udf(1, 2), simple_mul_udf(1, 2), udf_that_throws_exception(0)], [id#753L, pythonUDF0#775, pythonUDF1#776, pythonUDF2#777]
+- *(1) Range (0, 5, step=1, splits=8)ArrowEvalPython操作符负责计算UDF,然后在Project操作符中计算条件。
在您的条件下调用df.sql (最佳行为)时,您会得到不同的行为,原因是这是一种特殊情况,在这种情况下,此列中的值是常数( A和B列都是常数),而火花优化器可以事先对其进行评估(在查询计划处理期间,在执行集群上的实际作业之前,在驱动程序中),因此它知道永远不必计算条件的otherwise分支。如果此sql列中的值是动态的(例如,在id列中),则行为将再次处于次优状态,因为火花事先不知道不应该发生otherwise部件。
如果您想避免这种次优行为(即使不需要在otherwise中调用udf ),一个可能的解决方案是在您的udf中评估这个条件,例如:
@pandas_udf("int", PandasUDFType.SCALAR)
def udf_with_cond(*cols):
result = cols[0]
for c in cols[1:]:
result *= c
if((result < 100).any()):
return result
else:
raise Exception('Error')
df = df.withColumn('res', udf_with_cond('A', 'B'))https://stackoverflow.com/questions/58447674
复制相似问题