如何解决Pyspark:如何编码复杂数据帧算法问题根据条件求和
我有一个如下所示的数据框:
pass
TEST_schema = StructType([StructField("date",StringType(),True),\
StructField("Trigger",\
StructField("value",FloatType(),\
StructField("col1",IntegerType(),StructField("col2",StructField("want",True)])
TEST_data = [('2020-08-01','T',0.0,3,5,0.5),('2020-08-02',-1,4,0.0),('2020-08-03',('2020-08-04','F',0.2,0.7),('2020-08-05',0.3,1,0.9),\
('2020-08-06',('2020-08-07',('2020-08-08',0.5,('2020-08-09',0.0)]
rdd3 = sc.parallelize(TEST_data)
TEST_df = sqlContext.createDataFrame(TEST_data,TEST_schema)
TEST_df = TEST_df.withColumn("date",to_date("date",'yyyy-MM-dd'))
TEST_df.show()
+----------+-------+-----+----+----+
| date|Trigger|value|col1|col2|
+----------+-------+-----+----+----+
|2020-08-01| T| 0.0| 3| 5|
|2020-08-02| T| 0.0| -1| 4|
|2020-08-03| T| 0.0| -1| 3|
|2020-08-04| F| 0.2| 3| 3|
|2020-08-05| T| 0.3| 1| 4|
|2020-08-06| F| 0.2| -1| 3|
|2020-08-07| T| 0.2| -1| 4|
|2020-08-08| T| 0.5| -1| 5|
|2020-08-09| T| 0.0| -1| 5|
+----------+-------+-----+----+----+
:排序很好
date
:仅 T 或 F
Trigger
:任意随机的十进制(浮点)值
value
:代表天数,不能小于-1。** -1
col1
:代表天数,不能为负数。 col2> = 0
**计算逻辑**
如果为col2
,否则为col1 == -1,then return 0
,下图将有助于理解逻辑。
如果我们查看“红色”,则+3来自col1,即2020-08-01的Trigger == T
,这意味着我们跳了3行,同时也取了差col1==3
(在2020-08-01时) 1 表示对下一个值(col2 - col1) -1 = ( 5-3) -1 = 1.
求和。相同的逻辑适用于“蓝色”
“绿色”是指0.2 + 0.3 = 0.5
然后仅取trigger == "F"
(2020-08-04)时, 2 表示下两个值的和。这是(col2 -1)=3-1 =2
编辑:
如果我什么都不想要怎么办,假设我们有这个 df
0.2+0.3+0.2 = 0.7
当我们具有Trigger ==“ F”条件时,同样的逻辑适用,所以TEST_schema = StructType([StructField("date",\
StructField("col2",5),4),3),5)]
rdd3 = sc.parallelize(TEST_data)
TEST_df = sqlContext.createDataFrame(TEST_data,'yyyy-MM-dd'))
TEST_df.show()
+----------+-----+----+
| date|value|col2|
+----------+-----+----+
|2020-08-01| 0.0| 5|
|2020-08-02| 0.0| 4|
|2020-08-03| 0.0| 3|
|2020-08-04| 0.2| 3|
|2020-08-05| 0.3| 4|
|2020-08-06| 0.2| 3|
|2020-08-07| 0.2| 4|
|2020-08-08| 0.5| 5|
|2020-08-09| 0.0| 5|
+----------+-----+----+
但在这种情况下没有条件。
解决方法
IIUC,我们可以使用Windows函数collect_list
获取所有相关行,按date
对结构数组进行排序,然后基于该数组的slice进行聚合。每个切片的 start_idx 和 span 可以基于以下定义:
- 如果 col1 = -1 ,则 start_idx = 1 和 span = 0 ,因此没有任何汇总
- 否则,如果 Trigger = 'F',则 start_idx = 1 和 span = col2
- else start_idx = col1 + 1 和 span = col2-col1
请注意,功能片的索引是基于 1的。
代码:
from pyspark.sql.functions import to_date,sort_array,collect_list,struct,expr
from pyspark.sql import Window
w1 = Window.orderBy('date').rowsBetween(0,Window.unboundedFollowing)
# columns used to do calculations,date must be the first field for sorting purpose
cols = ["date","value","start_idx","span"]
df_new = (TEST_df
.withColumn('start_idx',expr("IF(col1 = -1 OR Trigger = 'F',1,col1+1)"))
.withColumn('span',expr("IF(col1 = -1,IF(Trigger = 'F',col2,col2-col1))"))
.withColumn('dta',sort_array(collect_list(struct(*cols)).over(w1)))
.withColumn("want1",expr("aggregate(slice(dta,start_idx,span),0D,(acc,x) -> acc+x.value)"))
)
结果:
df_new.show()
+----------+-------+-----+----+----+----+---------+----+--------------------+------------------+
| date|Trigger|value|col1|col2|want|start_idx|span| dta| want1|
+----------+-------+-----+----+----+----+---------+----+--------------------+------------------+
|2020-08-01| T| 0.0| 3| 5| 0.5| 4| 2|[[2020-08-01,T,...|0.5000000149011612|
|2020-08-02| T| 0.0| -1| 4| 0.0| 1| 0|[[2020-08-02,...| 0.0|
|2020-08-03| T| 0.0| -1| 3| 0.0| 1| 0|[[2020-08-03,...| 0.0|
|2020-08-04| F| 0.2| 3| 3| 0.7| 1| 3|[[2020-08-04,F,...|0.7000000178813934|
|2020-08-05| T| 0.3| 1| 4| 0.9| 2| 3|[[2020-08-05,...|0.9000000059604645|
|2020-08-06| F| 0.2| -1| 3| 0.0| 1| 0|[[2020-08-06,...| 0.0|
|2020-08-07| T| 0.2| -1| 4| 0.0| 1| 0|[[2020-08-07,...| 0.0|
|2020-08-08| T| 0.5| -1| 5| 0.0| 1| 0|[[2020-08-08,...| 0.0|
|2020-08-09| T| 0.0| -1| 5| 0.0| 1| 0|[[2020-08-09,...| 0.0|
+----------+-------+-----+----+----+----+---------+----+--------------------+------------------+
一些说明:
-
切片函数除定位数组外还需要两个参数。在我们的代码中,
start_idx
是起始索引,span
是切片的长度。在代码中,我使用 IF 语句根据原始帖子中的图表规格计算 start_idx 和 span 。 -
collect_list + sort_array 在窗口
w1
上生成的数组覆盖了从当前行到窗口末尾的行(请参见{ {1}}分配)。然后,我们使用 aggregate 函数内的 slice 函数仅检索必要的数组项。 -
SparkSQL内置函数aggregate采用以下形式:
w1
可以跳过第四个参数
aggregate(expr,start,merge,finish)
。在我们的情况下,可以将其重新格式化为(您可以复制以下内容以替换 exprfinish
中的代码):.withColumn('want1',expr(""" .... """)
aggregate 函数的工作方式类似于Python中的 reduce 函数,第二个参数是零值(
aggregate( /* targeting array,use slice function to take only part of the array `dta` */ slice(dta,/* start,zero_value used for reduce */ 0D,/* merge,similar to reduce function */ (acc,x) -> acc+x.value,/* finish,skipped in the post,but you can do some post-processing here,for example,round-up the result from merge */ acc -> round(acc,2) )
是0D
的快捷方式,是用于强制转换聚合变量double(0)
的数据类型。
如评论中所述 -
,如果 col2
其中 Trigger = 'T'和 col1 != -1 存在,它将在当前代码中产生负的 span 。在这种情况下,我们应该使用全尺寸的Window规范: acc
并使用 array_position 查找当前行的位置(refer to one of my recent posts),然后根据该位置计算 start_idx 。
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。