我需要在PySpark中生成一些数据,目前我正在使用PySpark熊猫来制作。我发现,当我想要使用.repeat()来扩展我的数据生成过程时,它非常非常慢(几十分钟)。
是否还有其他可供选择的方法,用于生成如下所示的数据文件?
import pyspark.pandas as ps
# params
start_time = '2022-04-01'
end_time = '2022-07-01'
IDs = [1, 2, 3, 4, 5, 6, 7, 8, ...]
dStates = ['A', 'B', 'C', 'D', ....]
# delta time
delta_time = (ps.to_datetime(end_time).month - ps.to_datetime(start_time).month)
# create DF
timeSet = ps.date_range(start=start_time, end=end_time, freq='MS').repeat( len(dStates) * len(IDs) )
stateSet = ps.Series( dStates * ( delta_time + 1 ) * len(IDs) )
nodeSet = ps.Series(IDs).repeat( len(dStates) * ( delta_time + 1 ) ).reset_index(drop=True)
# combine
tseries = ps.DataFrame({'monthlyTrend': timeSet.astype(str),
'FromState': stateSet,
'ID': nodeSet})发布于 2022-07-27 09:11:12
通常,numpy函数是更优化的,所以您可以尝试使用numpy.repeat()。我调整了下面的代码,每天在一个范围内生成日期,并根据IDs和dStates的长度调整timeList:
import pandas as pd
import numpy as np
from datetime import datetime, timedelta
# params
start_time = '2022-04-01'
end_time = '2022-07-01'
IDs = [1, 2, 3, 4, 5, 6, 7, 8]
dStates = ['A', 'B', 'C', 'D']
# Generate data based on params
timeList = np.arange(datetime(2022, 4, 1), datetime(2022, 7, 1), timedelta(days=1)).astype(datetime)
stateList = np.repeat(dStates, len(timeList)//len(dStates))
stateList = np.append(stateList, dStates[:len(timeList)%len(dStates)]) # this ensures the lengths remain the same
nodeList = np.repeat(IDs, len(timeList)//len(IDs))
nodeList = np.append(nodeList, IDs[:len(timeList)%len(IDs)])
# combine
tseries = pd.DataFrame({
'monthlyTrend': timeList.astype(str),
'FromState': stateList,
'ID': nodeList
})
df = spark.createDataFrame(tseries)更新
下面是另一种使用explode()和array_repeat来实现上述目标的方法,只使用pyspark函数。我们首先创建一个与最长的params列表一样长的dataframe (在示例中是IDs)。然后使用pyspark函数对其进行扩展。
from pyspark.sql import functions as F
import pyspark.pandas as ps
# params
start_time = '2022-04-01'
end_time = '2022-07-01'
delta_time = (ps.to_datetime(end_time).month - ps.to_datetime(start_time).month)
timeSet = ps.date_range(start=start_time, end=end_time, freq='MS').tolist()
IDs = [1, 2, 3, 4, 5, 6, 7, 8]
dStates = ['A', 'B', 'C', 'D']
# create a minimum length DF aligned to the longest list of params
longest_list = IDs
timeSet = ps.concat([ps.Series(timeSet * (len(longest_list)//len(timeSet))), ps.Series(timeSet[:len(longest_list)%len(timeSet)])], ignore_index=True)
stateSet = ps.concat([ps.Series(dStates * (len(longest_list)//len(dStates))), ps.Series(dStates[:len(longest_list)%len(dStates)])], ignore_index=True)
nodeSet = ps.Series(IDs)
# combine
df_tseries = ps.DataFrame({
'monthlyTrend': timeSet,
'FromState': stateSet,
'ID': nodeSet}).to_spark()
# expand the df with explode and array_repeat
no_of_repeats = 10
df_tseries = df_tseries.withColumn("ID", F.explode(F.array_repeat("ID", no_of_repeats)))https://stackoverflow.com/questions/73130979
复制相似问题