Làm thế nào để chọn hàng đầu tiên của mỗi nhóm?


143

Tôi có một DataFrame được tạo như sau:

df.groupBy($"Hour", $"Category")
  .agg(sum($"value") as "TotalValue")
  .sort($"Hour".asc, $"TotalValue".desc))

Kết quả trông như sau:

+----+--------+----------+
|Hour|Category|TotalValue|
+----+--------+----------+
|   0|   cat26|      30.9|
|   0|   cat13|      22.1|
|   0|   cat95|      19.6|
|   0|  cat105|       1.3|
|   1|   cat67|      28.5|
|   1|    cat4|      26.8|
|   1|   cat13|      12.6|
|   1|   cat23|       5.3|
|   2|   cat56|      39.6|
|   2|   cat40|      29.7|
|   2|  cat187|      27.9|
|   2|   cat68|       9.8|
|   3|    cat8|      35.6|
| ...|    ....|      ....|
+----+--------+----------+

Như bạn có thể thấy, DataFrame được sắp xếp theo Hourthứ tự tăng dần, sau đó theo TotalValuethứ tự giảm dần.

Tôi muốn chọn hàng trên cùng của mỗi nhóm, nghĩa là

  • từ nhóm Giờ == 0 chọn (0, cat26,30.9)
  • từ nhóm Giờ == 1 chọn (1, cat67,28,5)
  • từ nhóm Giờ == 2 chọn (2, cat56,39.6)
  • và như thế

Vì vậy, đầu ra mong muốn sẽ là:

+----+--------+----------+
|Hour|Category|TotalValue|
+----+--------+----------+
|   0|   cat26|      30.9|
|   1|   cat67|      28.5|
|   2|   cat56|      39.6|
|   3|    cat8|      35.6|
| ...|     ...|       ...|
+----+--------+----------+

Cũng có thể có ích khi có thể chọn N hàng trên cùng của mỗi nhóm.

Bất kỳ sự trợ giúp nào cũng được đánh giá cao.

Câu trả lời:


231

Các chức năng của cửa sổ :

Một cái gì đó như thế này nên thực hiện các mẹo:

import org.apache.spark.sql.functions.{row_number, max, broadcast}
import org.apache.spark.sql.expressions.Window

val df = sc.parallelize(Seq(
  (0,"cat26",30.9), (0,"cat13",22.1), (0,"cat95",19.6), (0,"cat105",1.3),
  (1,"cat67",28.5), (1,"cat4",26.8), (1,"cat13",12.6), (1,"cat23",5.3),
  (2,"cat56",39.6), (2,"cat40",29.7), (2,"cat187",27.9), (2,"cat68",9.8),
  (3,"cat8",35.6))).toDF("Hour", "Category", "TotalValue")

val w = Window.partitionBy($"hour").orderBy($"TotalValue".desc)

val dfTop = df.withColumn("rn", row_number.over(w)).where($"rn" === 1).drop("rn")

dfTop.show
// +----+--------+----------+
// |Hour|Category|TotalValue|
// +----+--------+----------+
// |   0|   cat26|      30.9|
// |   1|   cat67|      28.5|
// |   2|   cat56|      39.6|
// |   3|    cat8|      35.6|
// +----+--------+----------+

Phương pháp này sẽ không hiệu quả trong trường hợp sai lệch dữ liệu quan trọng.

Tập hợp SQL đơn giản theo saujoin :

Ngoài ra, bạn có thể tham gia với khung dữ liệu tổng hợp:

val dfMax = df.groupBy($"hour".as("max_hour")).agg(max($"TotalValue").as("max_value"))

val dfTopByJoin = df.join(broadcast(dfMax),
    ($"hour" === $"max_hour") && ($"TotalValue" === $"max_value"))
  .drop("max_hour")
  .drop("max_value")

dfTopByJoin.show

// +----+--------+----------+
// |Hour|Category|TotalValue|
// +----+--------+----------+
// |   0|   cat26|      30.9|
// |   1|   cat67|      28.5|
// |   2|   cat56|      39.6|
// |   3|    cat8|      35.6|
// +----+--------+----------+

Nó sẽ giữ các giá trị trùng lặp (nếu có nhiều hơn một danh mục mỗi giờ với cùng một giá trị). Bạn có thể loại bỏ chúng như sau:

dfTopByJoin
  .groupBy($"hour")
  .agg(
    first("category").alias("category"),
    first("TotalValue").alias("TotalValue"))

Sử dụng đặt hàng quastructs :

Mặc dù gọn gàng, nhưng không được kiểm tra kỹ, thủ thuật không yêu cầu tham gia hoặc chức năng cửa sổ:

val dfTop = df.select($"Hour", struct($"TotalValue", $"Category").alias("vs"))
  .groupBy($"hour")
  .agg(max("vs").alias("vs"))
  .select($"Hour", $"vs.Category", $"vs.TotalValue")

dfTop.show
// +----+--------+----------+
// |Hour|Category|TotalValue|
// +----+--------+----------+
// |   0|   cat26|      30.9|
// |   1|   cat67|      28.5|
// |   2|   cat56|      39.6|
// |   3|    cat8|      35.6|
// +----+--------+----------+

Với API dữ liệu (Spark 1.6+, 2.0+):

Tia lửa 1.6 :

case class Record(Hour: Integer, Category: String, TotalValue: Double)

df.as[Record]
  .groupBy($"hour")
  .reduce((x, y) => if (x.TotalValue > y.TotalValue) x else y)
  .show

// +---+--------------+
// | _1|            _2|
// +---+--------------+
// |[0]|[0,cat26,30.9]|
// |[1]|[1,cat67,28.5]|
// |[2]|[2,cat56,39.6]|
// |[3]| [3,cat8,35.6]|
// +---+--------------+

Spark 2.0 trở lên :

df.as[Record]
  .groupByKey(_.Hour)
  .reduceGroups((x, y) => if (x.TotalValue > y.TotalValue) x else y)

Hai phương pháp cuối cùng có thể tận dụng kết hợp phía bản đồ và không yêu cầu xáo trộn hoàn toàn, do đó, phần lớn thời gian sẽ thể hiện hiệu suất tốt hơn so với các chức năng cửa sổ và tham gia. Những gậy này cũng được sử dụng với Structured Streaming ở completedchế độ đầu ra.

Đừng sử dụng :

df.orderBy(...).groupBy(...).agg(first(...), ...)

Nó có vẻ hoạt động (đặc biệt là trong localchế độ) nhưng không đáng tin cậy (xem SPARK-16207 , tín dụng cho Tzach Zohar để liên kết vấn đề JIRA có liên quanSPARK-30335 ).

Lưu ý tương tự áp dụng cho

df.orderBy(...).dropDuplicates(...)

trong đó sử dụng kế hoạch thực hiện tương đương.


3
Có vẻ như từ spark 1.6, nó là row_number () thay vì rowNumber
Adam Szałucha

Về việc Đừng sử dụng df.orderBy (...). GropBy (...). Trong hoàn cảnh nào chúng ta có thể dựa vào orderBy (...)? hoặc nếu chúng ta không thể chắc chắn nếu orderBy () sẽ cho kết quả chính xác, chúng ta có những lựa chọn thay thế nào?
Ignacio Alorre

Tôi có thể đang xem xét một cái gì đó, nhưng nói chung nên tránh dùng GroupByKey , thay vào đó nên sử dụng lessByKey. Ngoài ra, bạn sẽ tiết kiệm được một dòng.
Thomas

3
@Thomas tránh nhómBy / groupByKey chỉ khi giao dịch với RDD, bạn sẽ nhận thấy rằng api Dataset thậm chí không có chức năng lessByKey.
soote


16

Đối với Spark 2.0.2 với việc nhóm theo nhiều cột:

import org.apache.spark.sql.functions.row_number
import org.apache.spark.sql.expressions.Window

val w = Window.partitionBy($"col1", $"col2", $"col3").orderBy($"timestamp".desc)

val refined_df = df.withColumn("rn", row_number.over(w)).where($"rn" === 1).drop("rn")

8

Đây là một chính xác cùng của zero323 's câu trả lời nhưng trong SQL truy vấn cách.

Giả sử rằng khung dữ liệu được tạo và đăng ký là

df.createOrReplaceTempView("table")
//+----+--------+----------+
//|Hour|Category|TotalValue|
//+----+--------+----------+
//|0   |cat26   |30.9      |
//|0   |cat13   |22.1      |
//|0   |cat95   |19.6      |
//|0   |cat105  |1.3       |
//|1   |cat67   |28.5      |
//|1   |cat4    |26.8      |
//|1   |cat13   |12.6      |
//|1   |cat23   |5.3       |
//|2   |cat56   |39.6      |
//|2   |cat40   |29.7      |
//|2   |cat187  |27.9      |
//|2   |cat68   |9.8       |
//|3   |cat8    |35.6      |
//+----+--------+----------+

Chức năng cửa sổ:

sqlContext.sql("select Hour, Category, TotalValue from (select *, row_number() OVER (PARTITION BY Hour ORDER BY TotalValue DESC) as rn  FROM table) tmp where rn = 1").show(false)
//+----+--------+----------+
//|Hour|Category|TotalValue|
//+----+--------+----------+
//|1   |cat67   |28.5      |
//|3   |cat8    |35.6      |
//|2   |cat56   |39.6      |
//|0   |cat26   |30.9      |
//+----+--------+----------+

Tập hợp SQL đơn giản theo sau là tham gia:

sqlContext.sql("select Hour, first(Category) as Category, first(TotalValue) as TotalValue from " +
  "(select Hour, Category, TotalValue from table tmp1 " +
  "join " +
  "(select Hour as max_hour, max(TotalValue) as max_value from table group by Hour) tmp2 " +
  "on " +
  "tmp1.Hour = tmp2.max_hour and tmp1.TotalValue = tmp2.max_value) tmp3 " +
  "group by tmp3.Hour")
  .show(false)
//+----+--------+----------+
//|Hour|Category|TotalValue|
//+----+--------+----------+
//|1   |cat67   |28.5      |
//|3   |cat8    |35.6      |
//|2   |cat56   |39.6      |
//|0   |cat26   |30.9      |
//+----+--------+----------+

Sử dụng thứ tự trên các cấu trúc:

sqlContext.sql("select Hour, vs.Category, vs.TotalValue from (select Hour, max(struct(TotalValue, Category)) as vs from table group by Hour)").show(false)
//+----+--------+----------+
//|Hour|Category|TotalValue|
//+----+--------+----------+
//|1   |cat67   |28.5      |
//|3   |cat8    |35.6      |
//|2   |cat56   |39.6      |
//|0   |cat26   |30.9      |
//+----+--------+----------+

DataSets cáchkhông làm giống như trong câu trả lời ban đầu


2

Mẫu được nhóm theo các khóa => làm một cái gì đó cho mỗi nhóm, ví dụ: giảm => quay lại khung dữ liệu

Tôi nghĩ rằng trừu tượng hóa Dataframe là một chút rườm rà trong trường hợp này vì vậy tôi đã sử dụng chức năng RDD

 val rdd: RDD[Row] = originalDf
  .rdd
  .groupBy(row => row.getAs[String]("grouping_row"))
  .map(iterableTuple => {
    iterableTuple._2.reduce(reduceFunction)
  })

val productDf = sqlContext.createDataFrame(rdd, originalDf.schema)

1

Giải pháp bên dưới chỉ thực hiện một nhómBy và trích xuất các hàng của khung dữ liệu của bạn có chứa maxValue trong một lần chụp. Không cần tham gia thêm, hoặc Windows.

import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.DataFrame

//df is the dataframe with Day, Category, TotalValue

implicit val dfEnc = RowEncoder(df.schema)

val res: DataFrame = df.groupByKey{(r) => r.getInt(0)}.mapGroups[Row]{(day: Int, rows: Iterator[Row]) => i.maxBy{(r) => r.getDouble(2)}}

Nhưng nó xáo trộn mọi thứ đầu tiên. Nó hầu như không phải là một cải tiến (có thể không tệ hơn các chức năng của cửa sổ, tùy thuộc vào dữ liệu).
Alper t. Turker

bạn có một nhóm đầu tiên, điều đó sẽ kích hoạt xáo trộn. Nó không tệ hơn chức năng cửa sổ bởi vì trong chức năng cửa sổ, nó sẽ đánh giá cửa sổ cho từng hàng đơn trong khung dữ liệu.
elghoto

1

Một cách hay để làm điều này với api khung dữ liệu là sử dụng logic argmax như vậy

  val df = Seq(
    (0,"cat26",30.9), (0,"cat13",22.1), (0,"cat95",19.6), (0,"cat105",1.3),
    (1,"cat67",28.5), (1,"cat4",26.8), (1,"cat13",12.6), (1,"cat23",5.3),
    (2,"cat56",39.6), (2,"cat40",29.7), (2,"cat187",27.9), (2,"cat68",9.8),
    (3,"cat8",35.6)).toDF("Hour", "Category", "TotalValue")

  df.groupBy($"Hour")
    .agg(max(struct($"TotalValue", $"Category")).as("argmax"))
    .select($"Hour", $"argmax.*").show

 +----+----------+--------+
 |Hour|TotalValue|Category|
 +----+----------+--------+
 |   1|      28.5|   cat67|
 |   3|      35.6|    cat8|
 |   2|      39.6|   cat56|
 |   0|      30.9|   cat26|
 +----+----------+--------+

0

Ở đây bạn có thể làm như thế này -

   val data = df.groupBy("Hour").agg(first("Hour").as("_1"),first("Category").as("Category"),first("TotalValue").as("TotalValue")).drop("Hour")

data.withColumnRenamed("_1","Hour").show

-2

Chúng ta có thể sử dụng hàm cửa sổ xếp hạng () (trong đó bạn sẽ chọn thứ hạng = 1) thứ hạng chỉ cần thêm một số cho mỗi hàng của một nhóm (trong trường hợp này sẽ là giờ)

đây là một ví dụ (từ https://github.com/jaceklaskowski/mastering-apache-spark-book/blob/master/spark-sql-fifts.adoc#rank )

val dataset = spark.range(9).withColumn("bucket", 'id % 3)

import org.apache.spark.sql.expressions.Window
val byBucket = Window.partitionBy('bucket).orderBy('id)

scala> dataset.withColumn("rank", rank over byBucket).show
+---+------+----+
| id|bucket|rank|
+---+------+----+
|  0|     0|   1|
|  3|     0|   2|
|  6|     0|   3|
|  1|     1|   1|
|  4|     1|   2|
|  7|     1|   3|
|  2|     2|   1|
|  5|     2|   2|
|  8|     2|   3|
+---+------+----+
Khi sử dụng trang web của chúng tôi, bạn xác nhận rằng bạn đã đọc và hiểu Chính sách cookieChính sách bảo mật của chúng tôi.
Licensed under cc by-sa 3.0 with attribution required.