次のようにしてDataFrameを生成しました。
df.groupBy($"Hour", $"Category")
.agg(sum($"value") as "TotalValue")
.sort($"Hour".asc, $"TotalValue".desc))
結果は次のようになります。
+----+--------+----------+
|Hour|Category|TotalValue|
+----+--------+----------+
| 0| cat26| 30.9|
| 0| cat13| 22.1|
| 0| cat95| 19.6|
| 0| cat105| 1.3|
| 1| cat67| 28.5|
| 1| cat4| 26.8|
| 1| cat13| 12.6|
| 1| cat23| 5.3|
| 2| cat56| 39.6|
| 2| cat40| 29.7|
| 2| cat187| 27.9|
| 2| cat68| 9.8|
| 3| cat8| 35.6|
| ...| ....| ....|
+----+--------+----------+
ご覧のとおり、DataFrameはHour
の昇順で並べられ、次にTotalValue
の降順で並べられています。
各グループの一番上の行を選択します。
したがって、望ましい出力は次のようになります。
+----+--------+----------+
|Hour|Category|TotalValue|
+----+--------+----------+
| 0| cat26| 30.9|
| 1| cat67| 28.5|
| 2| cat56| 39.6|
| 3| cat8| 35.6|
| ...| ...| ...|
+----+--------+----------+
各グループの上位N行も選択できると便利です。
任意の助けは大歓迎です。
ウィンドウ関数:
このような何かがトリックをするべきです:
import org.Apache.spark.sql.functions.{row_number, max, broadcast}
import org.Apache.spark.sql.expressions.Window
val df = sc.parallelize(Seq(
(0,"cat26",30.9), (0,"cat13",22.1), (0,"cat95",19.6), (0,"cat105",1.3),
(1,"cat67",28.5), (1,"cat4",26.8), (1,"cat13",12.6), (1,"cat23",5.3),
(2,"cat56",39.6), (2,"cat40",29.7), (2,"cat187",27.9), (2,"cat68",9.8),
(3,"cat8",35.6))).toDF("Hour", "Category", "TotalValue")
val w = Window.partitionBy($"hour").orderBy($"TotalValue".desc)
val dfTop = df.withColumn("rn", row_number.over(w)).where($"rn" === 1).drop("rn")
dfTop.show
// +----+--------+----------+
// |Hour|Category|TotalValue|
// +----+--------+----------+
// | 0| cat26| 30.9|
// | 1| cat67| 28.5|
// | 2| cat56| 39.6|
// | 3| cat8| 35.6|
// +----+--------+----------+
この方法は、データが大きく歪んでいる場合は非効率的です。
プレーンなSQL集計とそれに続くjoin
:
あるいは、集約データフレームを使って結合することもできます。
val dfMax = df.groupBy($"hour".as("max_hour")).agg(max($"TotalValue").as("max_value"))
val dfTopByJoin = df.join(broadcast(dfMax),
($"hour" === $"max_hour") && ($"TotalValue" === $"max_value"))
.drop("max_hour")
.drop("max_value")
dfTopByJoin.show
// +----+--------+----------+
// |Hour|Category|TotalValue|
// +----+--------+----------+
// | 0| cat26| 30.9|
// | 1| cat67| 28.5|
// | 2| cat56| 39.6|
// | 3| cat8| 35.6|
// +----+--------+----------+
重複する値が保持されます(同じ合計値で1時間に複数のカテゴリがある場合)。次のようにしてこれらを削除することができます。
dfTopByJoin
.groupBy($"hour")
.agg(
first("category").alias("category"),
first("TotalValue").alias("TotalValue"))
structs
に対する順序付けを使用します。
うまくテストされていませんが、結合やウィンドウ関数を必要としないトリックです。
val dfTop = df.select($"Hour", struct($"TotalValue", $"Category").alias("vs"))
.groupBy($"hour")
.agg(max("vs").alias("vs"))
.select($"Hour", $"vs.Category", $"vs.TotalValue")
dfTop.show
// +----+--------+----------+
// |Hour|Category|TotalValue|
// +----+--------+----------+
// | 0| cat26| 30.9|
// | 1| cat67| 28.5|
// | 2| cat56| 39.6|
// | 3| cat8| 35.6|
// +----+--------+----------+
DataSet APIの場合(Spark 1.6以降、2.0以降):
スパーク1.6:
case class Record(Hour: Integer, Category: String, TotalValue: Double)
df.as[Record]
.groupBy($"hour")
.reduce((x, y) => if (x.TotalValue > y.TotalValue) x else y)
.show
// +---+--------------+
// | _1| _2|
// +---+--------------+
// |[0]|[0,cat26,30.9]|
// |[1]|[1,cat67,28.5]|
// |[2]|[2,cat56,39.6]|
// |[3]| [3,cat8,35.6]|
// +---+--------------+
Spark 2.0以降:
df.as[Record]
.groupByKey(_.Hour)
.reduceGroups((x, y) => if (x.TotalValue > y.TotalValue) x else y)
最後の2つの方法は、マップサイドコンバインを活用でき、フルシャッフルを必要としないため、ほとんどの場合、ウィンドウ関数や結合と比較してより優れたパフォーマンスを示すはずです。これらはcompleted
出力モードの構造化ストリーミングでも使用できます。
使わないでください:
df.orderBy(...).groupBy(...).agg(first(...), ...)
動作するように見えるかもしれませんが(特にlocal
モードで)、信頼できません( SPARK-16207 )。 Tzach Zohar関連するJIRA関連の問題をリンクする のクレジット。
同じことが当てはまります。
df.orderBy(...).dropDuplicates(...)
これは内部的に同等の実行計画を使用します。
複数の列でグループ化したSpark 2.0.2の場合:
import org.Apache.spark.sql.functions.row_number
import org.Apache.spark.sql.expressions.Window
val w = Window.partitionBy($"col1", $"col2", $"col3").orderBy($"timestamp".desc)
val refined_df = df.withColumn("rn", row_number.over(w)).where($"rn" === 1).drop("rn")
これは、 zero32 's answer とまったく同じですが、SQLクエリの方法です。
データフレームが以下のように作成および登録されたと仮定します。
df.createOrReplaceTempView("table")
//+----+--------+----------+
//|Hour|Category|TotalValue|
//+----+--------+----------+
//|0 |cat26 |30.9 |
//|0 |cat13 |22.1 |
//|0 |cat95 |19.6 |
//|0 |cat105 |1.3 |
//|1 |cat67 |28.5 |
//|1 |cat4 |26.8 |
//|1 |cat13 |12.6 |
//|1 |cat23 |5.3 |
//|2 |cat56 |39.6 |
//|2 |cat40 |29.7 |
//|2 |cat187 |27.9 |
//|2 |cat68 |9.8 |
//|3 |cat8 |35.6 |
//+----+--------+----------+
ウィンドウ関数:
sqlContext.sql("select Hour, Category, TotalValue from (select *, row_number() OVER (PARTITION BY Hour ORDER BY TotalValue DESC) as rn FROM table) tmp where rn = 1").show(false)
//+----+--------+----------+
//|Hour|Category|TotalValue|
//+----+--------+----------+
//|1 |cat67 |28.5 |
//|3 |cat8 |35.6 |
//|2 |cat56 |39.6 |
//|0 |cat26 |30.9 |
//+----+--------+----------+
単純なSQL集計とそれに続く結合:
sqlContext.sql("select Hour, first(Category) as Category, first(TotalValue) as TotalValue from " +
"(select Hour, Category, TotalValue from table tmp1 " +
"join " +
"(select Hour as max_hour, max(TotalValue) as max_value from table group by Hour) tmp2 " +
"on " +
"tmp1.Hour = tmp2.max_hour and tmp1.TotalValue = tmp2.max_value) tmp3 " +
"group by tmp3.Hour")
.show(false)
//+----+--------+----------+
//|Hour|Category|TotalValue|
//+----+--------+----------+
//|1 |cat67 |28.5 |
//|3 |cat8 |35.6 |
//|2 |cat56 |39.6 |
//|0 |cat26 |30.9 |
//+----+--------+----------+
構造体の順序付けを使う:
sqlContext.sql("select Hour, vs.Category, vs.TotalValue from (select Hour, max(struct(TotalValue, Category)) as vs from table group by Hour)").show(false)
//+----+--------+----------+
//|Hour|Category|TotalValue|
//+----+--------+----------+
//|1 |cat67 |28.5 |
//|3 |cat8 |35.6 |
//|2 |cat56 |39.6 |
//|0 |cat26 |30.9 |
//+----+--------+----------+
データセットの方法およびしないは、元の答えと同じです。
以下の解決策は、groupByを1回だけ実行し、maxValueを含むデータフレームの行をワンショットで抽出します。それ以上の結合、つまりWindowsは必要ありません。
import org.Apache.spark.sql.Row
import org.Apache.spark.sql.catalyst.encoders.RowEncoder
import org.Apache.spark.sql.DataFrame
//df is the dataframe with Day, Category, TotalValue
implicit val dfEnc = RowEncoder(df.schema)
val res: DataFrame = df.groupByKey{(r) => r.getInt(0)}.mapGroups[Row]{(day: Int, rows: Iterator[Row]) => i.maxBy{(r) => r.getDouble(2)}}
データフレームを複数の列でグループ化する必要がある場合は、これが役立ちます。
val keys = List("Hour", "Category");
val selectFirstValueOfNoneGroupedColumns =
df.columns
.filterNot(keys.toSet)
.map(_ -> "first").toMap
val grouped =
df.groupBy(keys.head, keys.tail: _*)
.agg(selectFirstValueOfNoneGroupedColumns)
これが同様の問題を抱えている人に役立つことを願っています
パターンはキーでグループ化されます=>各グループに何かをします。 reduce =>データフレームに戻る
この場合、データフレームの抽象化は少し面倒だと思いましたので、RDD機能を使用しました。
val rdd: RDD[Row] = originalDf
.rdd
.groupBy(row => row.getAs[String]("grouping_row"))
.map(iterableTuple => {
iterableTuple._2.reduce(reduceFunction)
})
val productDf = sqlContext.createDataFrame(rdd, originalDf.schema)
データフレームapiを使ってこれを行う良い方法は、argmaxロジックを次のように使用することです。
val df = Seq(
(0,"cat26",30.9), (0,"cat13",22.1), (0,"cat95",19.6), (0,"cat105",1.3),
(1,"cat67",28.5), (1,"cat4",26.8), (1,"cat13",12.6), (1,"cat23",5.3),
(2,"cat56",39.6), (2,"cat40",29.7), (2,"cat187",27.9), (2,"cat68",9.8),
(3,"cat8",35.6)).toDF("Hour", "Category", "TotalValue")
df.groupBy($"Hour")
.agg(max(struct($"TotalValue", $"Category")).as("argmax"))
.select($"Hour", $"argmax.*").show
+----+----------+--------+
|Hour|TotalValue|Category|
+----+----------+--------+
| 1| 28.5| cat67|
| 3| 35.6| cat8|
| 2| 39.6| cat56|
| 0| 30.9| cat26|
+----+----------+--------+
ここであなたはこのようにすることができます -
val data = df.groupBy("Hour").agg(first("Hour").as("_1"),first("Category").as("Category"),first("TotalValue").as("TotalValue")).drop("Hour")
data.withColumnRenamed("_1","Hour").show