web-dev-qa-db-ja.com

Sparkトレーニングテスト分割

最新の2.0.1のApache-sparkで、sklearnの http://scikit-learn.org/stable/modules/generated/sklearn.model_selection.StratifiedShuffleSplit.html に似たものがあるかどうか知りたいです。リリース。

これまでのところ、私は https://spark.Apache.org/docs/latest/mllib-statistics.html#stratified-sampling しか見つけることができませんでした。 train/testサンプルに。

15
Georg Heiler

Sparkは https://s3.amazonaws.com/sparksummit-share/ml-ams-1.0.1/6-sampling/scala/6-sampling_student.html で概説されているように層別化されたサンプルをサポートします

df.stat.sampleBy("label", Map(0 -> .10, 1 -> .20, 2 -> .3), 0)
3
Georg Heiler

この答えはSparkに固有のものではありませんが、Apacheビームでこれを行って、列車を66%分割して33%をテストします(単なる例ですが、以下のpartition_fnをより洗練されたものにカスタマイズし、バケットまたは何かにバイアスをかけるか、ランダム化がディメンション全体で公平であることを保証します):

raw_data = p | 'Read Data' >> Read(...)

clean_data = (raw_data
              | "Clean Data" >> beam.ParDo(CleanFieldsFn())


def partition_fn(element):
    return random.randint(0, 2)

random_buckets = (clean_data | beam.Partition(partition_fn, 3))

clean_train_data = ((random_buckets[0], random_buckets[1])
                    | beam.Flatten())

clean_eval_data = random_buckets[2]
4
Reinaldo Aguiar

次のようなデータセットがあるとします。

_+---+-----+
| id|label|
+---+-----+
|  0|  0.0|
|  1|  1.0|
|  2|  0.0|
|  3|  1.0|
|  4|  0.0|
|  5|  1.0|
|  6|  0.0|
|  7|  1.0|
|  8|  0.0|
|  9|  1.0|
+---+-----+
_

このデータセットは完全にバランスが取れていますが、このアプローチはバランスが取れていないデータでも機能します。

次に、このDataFrameを、どの行をトレーニングセットに入れるかを決定するのに役立つ追加情報で拡張します。手順は次のとおりです。

  • ratioを指定して、すべてのラベルの例がトレインセットの一部となる数を決定します。
  • DataFrameの行をシャッフルします。
  • ウィンドウ関数を使用してlabelでデータフレームを分割および順序付けし、次にrow_number()を使用して各ラベルの観測をランク付けします。

最終的に、次のデータフレームになります。

_+---+-----+----------+
| id|label|row_number|
+---+-----+----------+
|  6|  0.0|         1|
|  2|  0.0|         2|
|  0|  0.0|         3|
|  4|  0.0|         4|
|  8|  0.0|         5|
|  9|  1.0|         1|
|  5|  1.0|         2|
|  3|  1.0|         3|
|  1|  1.0|         4|
|  7|  1.0|         5|
+---+-----+----------+
_

注:行はシャッフルされ(参照:id列のランダムな順序)、ラベルで分割され(参照:label列)、ランク付けされます。

80%に分割したいとします。この場合、4つの_1.0_ラベルと4つの_0.0_ラベルをトレーニングデータセットに移動し、1つの_1.0_ラベルと1つの_0.0_ラベルをテストデータセットに移動します。この情報は_row_number_列にあるので、ユーザー定義関数で簡単に使用できます(_row_number_が4以下の場合、例はトレーニングセットになります)。

UDFを適用すると、結果のデータフレームは次のようになります。

_+---+-----+----------+----------+
| id|label|row_number|isTrainSet|
+---+-----+----------+----------+
|  6|  0.0|         1|      true|
|  2|  0.0|         2|      true|
|  0|  0.0|         3|      true|
|  4|  0.0|         4|      true|
|  8|  0.0|         5|     false|
|  9|  1.0|         1|      true|
|  5|  1.0|         2|      true|
|  3|  1.0|         3|      true|
|  1|  1.0|         4|      true|
|  7|  1.0|         5|     false|
+---+-----+----------+----------+
_

ここで、トレーニング/テストデータを取得するには、次のことを行う必要があります。

_val train = df.where(col("isTrainSet") === true)
val test = df.where(col("isTrainSet") === false)
_

これらの並べ替えと分割の手順は、一部の非常に大きなデータセットでは禁止される場合があるため、最初にデータセットをできるだけフィルタリングすることをお勧めします。物理的な計画は次のとおりです。

_== Physical Plan ==
*(3) Project [id#4, label#5, row_number#11, if (isnull(row_number#11)) null else UDF(label#5, row_number#11) AS isTrainSet#48]
+- Window [row_number() windowspecdefinition(label#5, label#5 ASC NULLS FIRST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS row_number#11], [label#5], [label#5 ASC NULLS FIRST]
   +- *(2) Sort [label#5 ASC NULLS FIRST, label#5 ASC NULLS FIRST], false, 0
      +- Exchange hashpartitioning(label#5, 200)
         +- *(1) Project [id#4, label#5]
            +- *(1) Sort [_nondeterministic#9 ASC NULLS FIRST], true, 0
               +- Exchange rangepartitioning(_nondeterministic#9 ASC NULLS FIRST, 200)
                  +- LocalTableScan [id#4, label#5, _nondeterministic#9
_

以下は完全に機能する例です(Spark 2.3.0およびScala 2.11.12)でテスト):

_import org.Apache.spark.SparkConf
import org.Apache.spark.sql.expressions.Window
import org.Apache.spark.sql.{DataFrame, Row, SparkSession}
import org.Apache.spark.sql.functions.{col, row_number, udf, Rand}

class StratifiedTrainTestSplitter {

  def getNumExamplesPerClass(ss: SparkSession, label: String, trainRatio: Double)(df: DataFrame): Map[Double, Long] = {
    df.groupBy(label).count().createOrReplaceTempView("labelCounts")
    val query = f"SELECT $label AS ratioLabel, count, cast(count * $trainRatio as long) AS trainExamples FROM labelCounts"
    import ss.implicits._
    ss.sql(query)
      .select("ratioLabel", "trainExamples")
      .map((r: Row) => r.getDouble(0) -> r.getLong(1))
      .collect()
      .toMap
  }

  def split(df: DataFrame, label: String, trainRatio: Double): DataFrame = {
    val w = Window.partitionBy(col(label)).orderBy(col(label))

    val rowNumPartitioner = row_number().over(w)

    val dfRowNum = df.sort(Rand).select(col("*"), rowNumPartitioner as "row_number")

    dfRowNum.show()

    val observationsPerLabel: Map[Double, Long] = getNumExamplesPerClass(df.sparkSession, label, trainRatio)(df)

    val addIsTrainColumn = udf((label: Double, rowNumber: Int) => rowNumber <= observationsPerLabel(label))

    dfRowNum.withColumn("isTrainSet", addIsTrainColumn(col(label), col("row_number")))
  }


}

object StratifiedTrainTestSplitter {

  def getDf(ss: SparkSession): DataFrame = {
    val data = Seq(
      (0, 0.0), (1, 1.0), (2, 0.0), (3, 1.0), (4, 0.0), (5, 1.0), (6, 0.0), (7, 1.0), (8, 0.0), (9, 1.0)
    )
    ss.createDataFrame(data).toDF("id", "label")
  }

  def main(args: Array[String]): Unit = {
    val spark: SparkSession = SparkSession
      .builder()
      .config(new SparkConf().setMaster("local[1]"))
      .getOrCreate()

    val df = new StratifiedTrainTestSplitter().split(getDf(spark), "label", 0.8)

    df.cache()

    df.where(col("isTrainSet") === true).show()
    df.where(col("isTrainSet") === false).show()
  }
}
_

注:この場合、ラベルはDoublesです。ラベルがStringsの場合は、あちこちでタイプを切り替える必要があります。

4