我正在使用Spark 3.1.1和JAVA 8,我试图根据其中一个数值列的值(大于或小于阈值)拆分dataset<Row>,只有当行的某些字符串列的值相同时,拆分才是可能的:我尝试这样做:
Iterator<Row> iter2 = partition.toLocalIterator();
while (iter2.hasNext()) {
Row item = iter2.next();
//getColVal is a function that gets the value given a column
String numValue = getColVal(item, dim);
if (Integer.parseInt(numValue) < threshold)
pl.add(item);
else
pr.add(item);但是,如何在拆分之前检查相关行的其他列值(字符串)是否相同,以便执行拆分?
PS :我试着在拆分前对列进行groupBy,如下所示:
Dataset<Row> newDataset=oldDataset.groupBy("col1","col4").agg(col("col1"));但它不起作用
谢谢你的帮助
编辑:
我想拆分的样本数据集是:
abc,9,40,A
abc,7,50,A
cde,4,20,B
cde,3,25,B如果阈值为30,则第一行和最后一行将形成两个数据集,因为它们的第一列和第四列是相同的;否则拆分是不可能的。
编辑:结果输出将为
abc,9,40,A
abc,7,50,A
cde,4,20,B
cde,3,25,B发布于 2021-04-14 02:33:53
我主要使用pyspark,但您可以适应您的环境
## could add some conditional logic or just always output 2 data frames where
## one would be empty
print("pdf - two dataframe")
## create pandas dataframe
pdf = pd.DataFrame({'col1':['abc','abc','cde','cde'],'col2':[9,7,4,3],'col3':[40,50,20,25],'col4':['A','A','B','B']})
print( pdf )
## move it to spark
print("sdf")
sdf = spark.createDataFrame(pdf)
sdf.show()
# +----+----+----+----+
# |col1|col2|col3|col4|
# +----+----+----+----+
# | abc| 9| 40| A|
# | abc| 7| 50| A|
# | cde| 4| 20| B|
# | cde| 3| 25| B|
# +----+----+----+----+
## filter
pl = sdf.filter('col3 <= 30')\
.groupBy("col1","col4").agg(F.sum('col2').alias('sumC2'))
pr = sdf.filter('col3 > 30')\
.groupBy("col1","col4").agg(F.sum('col2').alias('sumC2'))
print("pl")
pl.show()
# +----+----+-----+
# |col1|col4|sumC2|
# +----+----+-----+
# | cde| B| 7|
# +----+----+-----+
print("pr")
pr.show()
# +----+----+-----+
# |col1|col4|sumC2|
# +----+----+-----+
# | abc| A| 16|
# +----+----+-----+
print("pdf - one dataframe")
## create pandas dataframe
pdf = pd.DataFrame({'col1':['abc','abc','cde','cde'],'col2':[9,7,4,3],'col3':[11,29,20,25],'col4':['A','A','B','B']})
print( pdf )
## move it to spark
print("sdf")
sdf = spark.createDataFrame(pdf)
sdf.show()
# +----+----+----+----+
# |col1|col2|col3|col4|
# +----+----+----+----+
# | abc| 9| 11| A|
# | abc| 7| 29| A|
# | cde| 4| 20| B|
# | cde| 3| 25| B|
# +----+----+----+----+
pl = sdf.filter('col3 <= 30')\
.groupBy("col1","col4").agg( F.sum('col2').alias('sumC2') )
pr = sdf.filter('col3 > 30')\
.groupBy("col1","col4").agg(F.sum('col2').alias('sumC2'))
print("pl")
pl.show()
# +----+----+-----+
# |col1|col4|sumC2|
# +----+----+-----+
# | abc| A| 16|
# | cde| B| 7|
# +----+----+-----+
print("pr")
pr.show()
# +----+----+-----+
# |col1|col4|sumC2|
# +----+----+-----+
# +----+----+-----+通过动态均值过滤
print("pdf - filter by mean")
## create pandas dataframe
pdf = pd.DataFrame({'col1':['abc','abc','cde','cde'],'col2':[9,7,4,3],'col3':[40,50,20,25],'col4':['A','A','B','B']})
print( pdf )
## move it to spark
print("sdf")
sdf = spark.createDataFrame(pdf)
sdf.show()
# +----+----+----+----+
# |col1|col2|col3|col4|
# +----+----+----+----+
# | abc| 9| 40| A|
# | abc| 7| 50| A|
# | cde| 4| 20| B|
# | cde| 3| 25| B|
# +----+----+----+----+
w = Window.partitionBy("col1").orderBy("col2")
## add another column, the mean of col2 partitioned by col1
sdf = sdf.withColumn('mean_c2', F.mean('col2').over(w))
## filter by the dynamic mean
pr = sdf.filter('col2 > mean_c2')
pr.show()
# +----+----+----+----+-------+
# |col1|col2|col3|col4|mean_c2|
# +----+----+----+----+-------+
# | cde| 4| 20| B| 3.5|
# | abc| 9| 40| A| 8.0|
# +----+----+----+----+-------+https://stackoverflow.com/questions/67075707
复制相似问题