如何解决Pyspark滚动平均值从第一行开始
我正在尝试计算Pyspark的滚动平均值。我可以使用它,但是它的行为似乎与我预期的不同。滚动平均值从第一行开始。
例如:
columns = ['month','day','value']
data = [('JAN','01','20000'),('JAN','02','40000'),'03','30000'),'04','25000'),'05','5000'),'06','15000'),('FEB','10000'),'50000'),'100000'),'60000'),'1000'),]
df_test = sc.createDataFrame(data).toDF(*columns)
win = Window.partitionBy('month').orderBy('day').rowsBetween(-2,0)
df_test.withColumn('rolling_average',f.avg('value').over(win)).show()
+-----+---+------+------------------+
|month|day| value| rolling_average|
+-----+---+------+------------------+
| JAN| 01| 20000| 20000.0|
| JAN| 02| 40000| 30000.0|
| JAN| 03| 30000| 30000.0|
| JAN| 04| 25000|31666.666666666668|
| JAN| 05| 5000| 20000.0|
| JAN| 06| 15000| 15000.0|
| FEB| 01| 10000| 10000.0|
| FEB| 02| 50000| 30000.0|
| FEB| 03|100000|53333.333333333336|
| FEB| 04| 60000| 70000.0|
| FEB| 05| 1000|53666.666666666664|
| FEB| 06| 10000|23666.666666666668|
+-----+---+------+------------------+
这将更符合我的期望。有没有办法得到这种行为?
+-----+---+------+------------------+
|month|day| value| rolling_average|
+-----+---+------+------------------+
| JAN| 01| 20000| null|
| JAN| 02| 40000| null|
| JAN| 03| 30000| 30000.0|
| JAN| 04| 25000|31666.666666666668|
| JAN| 05| 5000| 20000.0|
| JAN| 06| 15000| 15000.0|
| FEB| 01| 10000| null|
| FEB| 02| 50000| null|
| FEB| 03|100000|53333.333333333336|
| FEB| 04| 60000| 70000.0|
| FEB| 05| 1000|53666.666666666664|
| FEB| 06| 10000|23666.666666666668|
+-----+---+------+------------------+
默认行为的问题是,我需要另一列来跟踪滞后应从何处开始。
解决方法
尝试使用 row_number()
窗口功能,然后使用 when + otherwise 语句替换null。
- 要更改
lag start
,然后更改when
语句col("rn") <= <value>
的值。
Example:
columns = ['month','day','value']
data = [('JAN','01','20000'),('JAN','02','40000'),'03','30000'),'04','25000'),'05','5000'),'06','15000'),('FEB','10000'),'50000'),'100000'),'60000'),'1000'),]
df_test = sc.createDataFrame(data).toDF(*columns)
win = Window.partitionBy('month').orderBy('day').rowsBetween(-2,0)
win1 = Window.partitionBy('month').orderBy('day')
df_test.withColumn('rolling_average',f.avg('value').over(win)).\
withColumn("rn",row_number().over(win1)).\
withColumn("rolling_average",when(col("rn") <= 2,lit(None)).\
otherwise(col("rolling_average"))).\
drop("rn").\
show()
#+-----+---+------+------------------+
#|month|day| value| rolling_average|
#+-----+---+------+------------------+
#| FEB| 01| 10000| null|
#| FEB| 02| 50000| null|
#| FEB| 03|100000|53333.333333333336|
#| FEB| 04| 60000| 70000.0|
#| FEB| 05| 1000|53666.666666666664|
#| FEB| 06| 10000|23666.666666666668|
#| JAN| 01| 20000| null|
#| JAN| 02| 40000| null|
#| JAN| 03| 30000| 30000.0|
#| JAN| 04| 25000|31666.666666666668|
#| JAN| 05| 5000| 20000.0|
#| JAN| 06| 15000| 15000.0|
#+-----+---+------+------------------+
,
@ 484的更多简化版本。
import pyspark.sql.functions as f
from pyspark.sql import Window
w1 = Window.partitionBy('month').orderBy('day')
w2 = Window.partitionBy('month').orderBy('day').rowsBetween(-2,0)
df.withColumn("rolling_average",f.when(f.row_number().over(w1) > f.lit(2),f.avg('value').over(w2))).show(10,False)
p.s。请不要将此标记为答案:)
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。