グループ化されたSparkデータフレームの列でtrueのレコードの数を数えたいのですが、Pythonでそれを行う方法がわかりません。たとえば、region
のデータがあります。 、salary
およびIsUnemployed
列(ブール値としてIsUnemployed
)。各地域の失業者の数を確認したい。filter
、次にgroupby
を実行できることはわかっているが、以下のように2つの集計を同時に生成したい。
from pyspark.sql import functions as F
data.groupby("Region").agg(F.avg("Salary"), F.count("IsUnemployed"))
おそらく最も単純な解決策は、CAST
(TRUE
-> 1、FALSE
-> 0のCスタイル)とSUM
です。
(data
.groupby("Region")
.agg(F.avg("Salary"), F.sum(F.col("IsUnemployed").cast("long"))))
もう少し普遍的で慣用的な解決策はCASE WHEN
とCOUNT
:
(data
.groupby("Region")
.agg(
F.avg("Salary"),
F.count(F.when(F.col("IsUnemployed"), F.col("IsUnemployed")))))
しかし、ここでは明らかにやり過ぎです。