首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何设计自定义的窗口函数,以便在pyspark dataframe的时间窗口中选择列值

如何设计自定义的窗口函数,以便在pyspark dataframe的时间窗口中选择列值
EN

Stack Overflow用户
提问于 2020-02-11 01:16:08
回答 1查看 42关注 0票数 0

这个问题与我以前的问题有关。pyspark dataframe aggregate a column by sliding time window

不过,我想设立一个职位,以澄清我在上一个问题中遗漏的一些要点。

原始数据文件:

代码语言:javascript
复制
client_id    value1    name1    a_date
 dhd         561       ecdu     2019-10-8
 dhd         561       tygp     2019-10-8  
 dhd         561       rdsr     2019-10-8
 dhd         561       rgvd     2019-8-12
 dhd         561       bhnd     2019-8-12
 dhd         561       prti     2019-8-12
 dhd         561       teuq     2019-5-7
 dhd         561       wnva     2019-5-7
 dhd         561       pqhn     2019-5-7

我需要为每个"name1“、每个"value1”和一些给定的滑动时间窗口找到"client_id“的值。

我定义了一个窗口函数:

代码语言:javascript
复制
 w = window().partitionBy("client_id", "value1").orderBy("a_date")

但我不知道如何为窗口大小为1、2、6、9、12选择"name1“的值。

在这里,窗口大小意味着从当前的"a_date“月份开始的月份长度。

例如:

代码语言:javascript
复制
 client_id     value1    names1_within_window_size_1  names1_within_window_size_2
  dhd           561       [ecdu,tygp,rdsr]             [ecdu,tygp,rdsr]   

  names1_within_window_size_6
  [ecdu,tygp,rdsr, rgvd,bhnd,prti, teuq, wnva,pqhn ]  

 names1_within_window_size_1   : the month window 2019-10
 names1_within_window_size_2    : the month window 2019-10 and 2019-9 (no data in 2019-9 so just keep the data from 2019-10)
 names1_within_window_size_6    : the month window 2019-10 and 2019-9 (no data in 2019-9 so just keep the data from 2019-10) but there are data in 2019-8

谢谢

============================================ 更新

代码语言:javascript
复制
from pyspark.sql import functions as F
from pyspark.sql.window import Window

data=  [['dhd',589,'ecdu','2020-1-5'],
    ['dhd',575,'tygp','2020-1-5'],  
    ['dhd',821,'rdsr','2020-1-5'],
    ['dhd',872,'rgvd','2019-12-1'],
    ['dhd',619,'bhnd','2019-12-15'],
    ['dhd',781,'prti','2019-12-18'],
    ['dhd',781,'prti1','2019-12-18'],
    ['dhd',781,'prti2','2019-11-18'],
    ['dhd',781,'prti3','2019-10-31'],
    ['dhd',781,'prti4','2019-09-30'],
    ['dhd',781,'prt1','2019-07-31'],
    ['dhd',781,'pr4','2019-06-30'],
    ['dhd',781,'pr2','2019-08-31'],
    ['dhd',781,'prt4','2019-01-31'],
    ['dhd',781,'prti6','2019-02-28'],
    ['dhd',781,'prti7','2019-02-02'],
    ['dhd',781,'prti8','2019-03-29'],
    ['dhd',781,'prti9','2019-04-29'],
    ['dhd',781,'prti10','2019-05-04'],
    ['dhd',781,'prti11','2019-03-01'],
    ['dhd',781,'prti12','2018-12-17'],
    ['dhd',781,'prti15','2018-11-21'],
    ['dhd',781,'prti17','2018-10-31'],
    ['dhd',781,'prti19','2018-09-5']

   ]
columns= ['client_id','value1','name1','a_date']

df= spark.createDataFrame(data,columns)

df2 = df.withColumn("year_val", F.year("a_date"))\
    .withColumn("month_val", F.month("a_date"))\
    .withColumn("year_month", F.year(F.col("a_date")) * 100 + 
    F.month(F.col("a_date")))\
    .groupBy("client_id", "value1", "year_month")\
    .agg(F.concat_ws(", ", F.collect_list("name1")).alias("init_list"))

 df2.sort(F.col("value1").desc(), F.col("year_month").desc()).show()

 w = Window().partitionBy("client_id", "value1")\
    .orderBy("year_month")
df4 = df2.withColumn("a_rank", F.dense_rank().over(w))
df4.sort(F.col("value1"), F.col("year_month")).show()


month_range = 3
w = Window().partitionBy("client_id", "value1")\
    .orderBy("a_rank")\
    .rangeBetween(-(month_range),0)

 df5 = df4.withColumn("last_" + str(month_range) + "_month", F.collect_list(F.col("init_list")).over(w))\
    .orderBy("value1", "a_rank")

 df6 = df5.sort(F.col("value1").desc(), F.col("year_month").desc())
 df6.show(100,False)
EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2020-02-11 07:37:21

我从你之前的问题中盗取了数据,因为我太懒了,我自己也没有这么做,而且有个好家伙已经在那里为输入数据编制了列表。

当窗口滑过记录的数量,而不是月份数时,我将给定月份的所有记录(当然是按client_idvalue1分组)合并到.groupBy("client_id", "value1", "year_val", "month_val")中的单个记录中,该记录存在于df2的计算中。

代码语言:javascript
复制
from pyspark.sql import functions as F
from pyspark.sql.window import Window

data=  [['dhd',589,'ecdu','2020-1-5'],
        ['dhd',575,'tygp','2020-1-5'],  
        ['dhd',821,'rdsr','2020-1-5'],
        ['dhd',872,'rgvd','2019-12-1'],
        ['dhd',619,'bhnd','2019-12-15'],
        ['dhd',781,'prti','2019-12-18'],
        ['dhd',781,'prti1','2019-12-18'],
        ['dhd',781,'prti2','2019-11-18'],
        ['dhd',781,'prti3','2019-10-31'],
        ['dhd',781,'prti4','2019-09-30'],
        ['dhd',781,'prt1','2019-07-31'],
        ['dhd',781,'pr4','2019-06-30'],
        ['dhd',781,'pr2','2019-08-31'],
        ['dhd',781,'prt4','2019-01-31'],
        ['dhd',781,'prti6','2019-02-28'],
        ['dhd',781,'prti7','2019-02-02'],
        ['dhd',781,'prti8','2019-03-29'],
        ['dhd',781,'prti9','2019-04-29'],
        ['dhd',781,'prti10','2019-05-04'],
        ['dhd',781,'prti11','2019-03-01']]
columns= ['client_id','value1','name1','a_date']

df= spark.createDataFrame(data,columns)

df2 = df.withColumn("year_val", F.year("a_date"))\
        .withColumn("month_val", F.month("a_date"))\
        .groupBy("client_id", "value1", "year_val", "month_val")\
        .agg(F.concat_ws(", ", F.collect_list("name1")).alias("init_list"))

df2.show()

在这里,我们将init_list作为:

代码语言:javascript
复制
+---------+------+--------+---------+-------------+
|client_id|value1|year_val|month_val|    init_list|
+---------+------+--------+---------+-------------+
|      dhd|   781|    2019|       12|  prti, prti1|
|      dhd|   589|    2020|        1|         ecdu|
|      dhd|   781|    2019|        8|          pr2|
|      dhd|   781|    2019|        3|prti8, prti11|
|      dhd|   575|    2020|        1|         tygp|
|      dhd|   781|    2019|        5|       prti10|
|      dhd|   781|    2019|        9|        prti4|
|      dhd|   781|    2019|       11|        prti2|
|      dhd|   781|    2019|       10|        prti3|
|      dhd|   821|    2020|        1|         rdsr|
|      dhd|   781|    2019|        6|          pr4|
|      dhd|   619|    2019|       12|         bhnd|
|      dhd|   781|    2019|        7|         prt1|
|      dhd|   781|    2019|        4|        prti9|
|      dhd|   781|    2019|        1|         prt4|
|      dhd|   781|    2019|        2| prti6, prti7|
|      dhd|   872|    2019|       12|         rgvd|
+---------+------+--------+---------+-------------+

使用此方法,我们可以通过简单地在记录上运行窗口来获得最终结果:

代码语言:javascript
复制
month_range = 6
w = Window().partitionBy("client_id", "value1")\
        .orderBy("month_val")\
        .rangeBetween(-(month_range+1),0)

df3 = df2.withColumn("last_0_month", F.collect_list(F.col("init_list")).over(w))\
        .orderBy("value1", "year_val", "month_val")

df3.show(100,False)

这给了我们:

代码语言:javascript
复制
+---------+------+--------+---------+-------------+-------------------------------------------------------------------+
|client_id|value1|year_val|month_val|init_list    |last_0_month                                                       |
+---------+------+--------+---------+-------------+-------------------------------------------------------------------+
|dhd      |575   |2020    |1        |tygp         |[tygp]                                                             |
|dhd      |589   |2020    |1        |ecdu         |[ecdu]                                                             |
|dhd      |619   |2019    |12       |bhnd         |[bhnd]                                                             |
|dhd      |781   |2019    |1        |prt4         |[prt4]                                                             |
|dhd      |781   |2019    |2        |prti6, prti7 |[prt4, prti6, prti7]                                               |
|dhd      |781   |2019    |3        |prti8, prti11|[prt4, prti6, prti7, prti8, prti11]                                |
|dhd      |781   |2019    |4        |prti9        |[prt4, prti6, prti7, prti8, prti11, prti9]                         |
|dhd      |781   |2019    |5        |prti10       |[prt4, prti6, prti7, prti8, prti11, prti9, prti10]                 |
|dhd      |781   |2019    |6        |pr4          |[prt4, prti6, prti7, prti8, prti11, prti9, prti10, pr4]            |
|dhd      |781   |2019    |7        |prt1         |[prt4, prti6, prti7, prti8, prti11, prti9, prti10, pr4, prt1]      |
|dhd      |781   |2019    |8        |pr2          |[prt4, prti6, prti7, prti8, prti11, prti9, prti10, pr4, prt1, pr2] |
|dhd      |781   |2019    |9        |prti4        |[prti6, prti7, prti8, prti11, prti9, prti10, pr4, prt1, pr2, prti4]|
|dhd      |781   |2019    |10       |prti3        |[prti8, prti11, prti9, prti10, pr4, prt1, pr2, prti4, prti3]       |
|dhd      |781   |2019    |11       |prti2        |[prti9, prti10, pr4, prt1, pr2, prti4, prti3, prti2]               |
|dhd      |781   |2019    |12       |prti, prti1  |[prti10, pr4, prt1, pr2, prti4, prti3, prti2, prti, prti1]         |
|dhd      |821   |2020    |1        |rdsr         |[rdsr]                                                             |
|dhd      |872   |2019    |12       |rgvd         |[rgvd]                                                             |
+---------+------+--------+---------+-------------+-------------------------------------------------------------------+

限制:

遗憾的是,在第二部分,a_date字段丢失了,对于在其上定义范围的滑动窗口操作,orderBy不能指定多列(请注意,窗口定义中的orderBy仅在month_val上)。因此,这个精确的解决方案不会适用于跨越多年的数据。但是,可以很容易地将类似于month_id的内容作为合并年份和月份值的单个列,然后在orderBy子句中使用。

如果您希望有多个窗口,可以将month_range转换为一个列表,并在最后一个代码片段中循环它,以覆盖所有范围。

虽然最后一列(last_0_month)看起来像一个数组,但它包含与以前的agg操作分离的逗号字符串。你可能也想把它清理干净。

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/60160677

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档