首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >利用过滤和UDF优化星图码

利用过滤和UDF优化星图码
EN

Stack Overflow用户
提问于 2022-03-16 20:04:54
回答 1查看 63关注 0票数 0

我正在使用Spark处理2000万个XML文档的数据集。我最初是在处理所有这些问题,但实际上我只需要其中的三分之一。在不同的星星之火工作流中,我创建了一个dataframe keyfilter,其中一个列是每个XML的键,第二个列是布尔值,如果应该处理对应于键的xml,则为True,否则为False

XML本身是使用Pandas处理的,我无法共享这个UDF。

我在DataBricks上的笔记本基本上是这样工作的:

代码语言:javascript
复制
import pyspark
import time
from pyspark.sql.types import StringType
from pyspark.sql.functions import pandas_udf, col
from pyspark.sql.utils import AnalysisException
from multiprocessing import Pool
from multiprocessing.pool import ThreadPool
import pandas as pd
DATE = '20200314'
<define UDF pandas_xml_convert_string()>

keyfilter = spark.read.parquet('/path/to/keyfilter/os/s3.parquet')
keyfilter.cache()

def process_part(part, fraction=1, filter=True, return_df=False):
  try:
    df = spark.read.parquet('/path/to/parquets/on/s3/%s/part-%05d*' % (DATE, part))
  # Sometimes, the file part-xxxxx doesn't exist
  except AnalysisException:
    return None
  if fraction < 1:
    df = df.sample(fraction=fraction, withReplacement=False)
  if filter:
    df_with_filter = df.join(keyfilter, on='key', how='left').fillna(False)
    filtered_df = df_with_filter.filter(col('filter')).drop('filter')
    mod_df = filtered_df.select(col('key'), pandas_xml_convert_string(col('xml')).alias('xmlplain'), col('xml'))
  else:
    mod_df = df.select(col('key'), pandas_xml_convert_string(col('xml')).alias('xmlplain'), col('xml'))
  mod_df.write.parquet('/output/path/on/s3/part-%05d_%s_%d' % (part, DATE, time.time()))
  if return_df:
    return mod_df


n_cores = 6
i=0
while n_cores*i < 1024:
    with ThreadPool(n_cores) as p:
        p.map(process_part, range(n_cores*i, min(1024, n_cores*i+n_cores)))
    i += 1

我发布这个问题的原因是,尽管Pandas应该是发生的最昂贵的操作,但是添加过滤实际上会使我的代码运行速度比我根本没有过滤时慢得多。我对Spark非常陌生,我想知道这里是否做了一些愚蠢的事情,导致与keyfilter的连接非常慢,如果是的话,是否有一种方法可以使它们快速(例如,是否有一种方法可以使keyfilter充当从键到布尔的哈希表,比如在SQL中创建索引)。我设想keyfilter的大大小在这里扮演着某种角色;它有2000万行,而process_part中的df只有这些行的一小部分(但是,df的大小要大得多,因为它包含XML )。我是不是应该把所有的部分组合成一个巨大的数据文件,而不是一次处理它们呢?

或者是否有一种方法可以告诉火花,在这两个数据文件中,密钥是唯一的?

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2022-03-17 13:33:05

在合理的时间框架内实现连接的关键是使用broadcast on keyfilter来执行广播哈希连接,而不是标准的连接。我还合并了部分部件并减少了并行性(由于某种原因,过多的线程有时会导致引擎崩溃)。我的新代码如下所示:

代码语言:javascript
复制
import pyspark
import time
from pyspark.sql.types import StringType
from pyspark.sql.functions import pandas_udf, col, braodcast
from pyspark.sql.utils import AnalysisException
from multiprocessing import Pool
from multiprocessing.pool import ThreadPool
import pandas as pd
DATE = '20200314'
<define UDF pandas_xml_convert_string()>

keyfilter = spark.read.parquet('/path/to/keyfilter/on/s3.parquet')
keyfilter.cache()

def process_parts(part_pair, fraction=1, return_df=False, filter=True):
  dfs = []
  parts_start, parts_end = part_pair
  parts = range(parts_start, parts_end)
  for part in parts:
    try:
      df = spark.read.parquet('/input/path/on/s3/%s/part-%05d*' % (DATE, part))
      dfs.append(df)
    except AnalysisException:
      print("There is no part %05d!" % part)
      continue
  if len(dfs) >= 2:
    df = reduce(lambda x, y: x.union(y), dfs)
  elif len(dfs) == 1:
    df = dfs[0]
  else:
    return None
  if fraction < 1:
    df = df.sample(fraction=fraction, withReplacement=False)
  if filter:
    df_with_filter = df.join(broadcast(keyfilter), on='key', how='left').fillna(False)
    filtered_df = df_with_filter.filter(col('filter')).drop('filter')
    mod_df = filtered_df.select(col('key'), pandas_xml_convert_string(col('xml')).alias('xmlplain'), col('xml'))
  else:
    mod_df = df.select(col('key'), pandas_xml_convert_string(col('xml')).alias('xmlplain'), col('xml'))
  mod_df.write.parquet('/output/path/on/s3/parts-%05d-%05d_%s_%d' % (parts_start, parts_end-1, DATE, time.time()))
  if return_df:
    return mod_df


start_time = time.time()
pairs = [(i*4, i*4+4) for i in range(256)]
with ThreadPool(3) as p:
  batch_start_time = time.time()
  for i, _ in enumerate(p.imap_unordered(process_parts, pairs, chunksize=1)):
    batch_end_time = time.time()
    batch_len = batch_end_time - batch_start_time
    cum_len = batch_end_time - start_time
    print('Processed group %d/256 %d minutes and %d seconds after previous group.' % (i+1, batch_len // 60, batch_len % 60))
    print('%d hours, %d minutes, %d seconds since start.' % (cum_len // 3600, (cum_len % 3600) // 60, cum_len % 60))
    batch_start_time = time.time()
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/71503597

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档