火花火车测试分裂

时间:2016-10-12 09:02:38

标签: apache-spark apache-spark-mllib train-test-split

我很好奇,如果在最新的2.0.1版本中有类似于sklearn' s http://scikit-learn.org/stable/modules/generated/sklearn.model_selection.StratifiedShuffleSplit.html的apache-spark。

到目前为止,我只能找到https://spark.apache.org/docs/latest/mllib-statistics.html#stratified-sampling,它似乎不适合将严重不平衡的数据集拆分为火车/测试样本。

4 个答案:

答案 0 :(得分:4)

虽然这个答案并不是针对Spark的,但是在Apache光束中我这样做是为了将66%的火车拆分并测试33%(只是一个说明性的例子,你可以自定义下面的partition_fn更复杂并接受这样的参数来指定对某些事物的桶数或偏差选择的数量或确保随机化是跨维度的公平等等):

raw_data = p | 'Read Data' >> Read(...)

clean_data = (raw_data
              | "Clean Data" >> beam.ParDo(CleanFieldsFn())


def partition_fn(element):
    return random.randint(0, 2)

random_buckets = (clean_data | beam.Partition(partition_fn, 3))

clean_train_data = ((random_buckets[0], random_buckets[1])
                    | beam.Flatten())

clean_eval_data = random_buckets[2]

答案 1 :(得分:2)

假设我们有这样的数据集:

+---+-----+
| id|label|
+---+-----+
|  0|  0.0|
|  1|  1.0|
|  2|  0.0|
|  3|  1.0|
|  4|  0.0|
|  5|  1.0|
|  6|  0.0|
|  7|  1.0|
|  8|  0.0|
|  9|  1.0|
+---+-----+

此数据集完美平衡,但此方法也适用于不平衡数据。

现在,让我们使用其他信息来扩充此DataFrame,这些信息对于决定哪些行应该用于训练集非常有用。步骤如下:

  • 在某些ratio的情况下,确定每个标签的多少个示例应该是列车集的一部分。
  • 随机播放DataFrame的行。
  • 使用窗口功能按label对DataFrame进行分区和排序,然后使用row_number()对每个标签的观察进行排名。

我们最终得到以下数据框:

+---+-----+----------+
| id|label|row_number|
+---+-----+----------+
|  6|  0.0|         1|
|  2|  0.0|         2|
|  0|  0.0|         3|
|  4|  0.0|         4|
|  8|  0.0|         5|
|  9|  1.0|         1|
|  5|  1.0|         2|
|  3|  1.0|         3|
|  1|  1.0|         4|
|  7|  1.0|         5|
+---+-----+----------+

注意:行被洗牌(请参阅:id列中的随机顺序),按标签分区(请参阅:label列)并排名。

让我们假设我们想要分成80%。在这种情况下,我们希望四个1.0标签和四个0.0标签转到训练数据集,一个1.0标签和一个0.0标签转到测试数据集。我们在row_number列中有此信息,因此现在我们可以在用户定义的函数中使用它(如果row_number小于或等于4,示例将转到训练集)。

应用UDF后,生成的数据框如下:

+---+-----+----------+----------+
| id|label|row_number|isTrainSet|
+---+-----+----------+----------+
|  6|  0.0|         1|      true|
|  2|  0.0|         2|      true|
|  0|  0.0|         3|      true|
|  4|  0.0|         4|      true|
|  8|  0.0|         5|     false|
|  9|  1.0|         1|      true|
|  5|  1.0|         2|      true|
|  3|  1.0|         3|      true|
|  1|  1.0|         4|      true|
|  7|  1.0|         5|     false|
+---+-----+----------+----------+

现在,要获得列车/测试数据,必须要做的事情:

val train = df.where(col("isTrainSet") === true)
val test = df.where(col("isTrainSet") === false)

对于某些非常大的数据集,这些排序和分区步骤可能会过高,所以我建议首先尽可能地过滤数据集。实际计划如下:

== Physical Plan ==
*(3) Project [id#4, label#5, row_number#11, if (isnull(row_number#11)) null else UDF(label#5, row_number#11) AS isTrainSet#48]
+- Window [row_number() windowspecdefinition(label#5, label#5 ASC NULLS FIRST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS row_number#11], [label#5], [label#5 ASC NULLS FIRST]
   +- *(2) Sort [label#5 ASC NULLS FIRST, label#5 ASC NULLS FIRST], false, 0
      +- Exchange hashpartitioning(label#5, 200)
         +- *(1) Project [id#4, label#5]
            +- *(1) Sort [_nondeterministic#9 ASC NULLS FIRST], true, 0
               +- Exchange rangepartitioning(_nondeterministic#9 ASC NULLS FIRST, 200)
                  +- LocalTableScan [id#4, label#5, _nondeterministic#9

这里有完整的工作示例(使用Spark 2.3.0和Scala 2.11.12测试):

import org.apache.spark.SparkConf
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import org.apache.spark.sql.functions.{col, row_number, udf, rand}

class StratifiedTrainTestSplitter {

  def getNumExamplesPerClass(ss: SparkSession, label: String, trainRatio: Double)(df: DataFrame): Map[Double, Long] = {
    df.groupBy(label).count().createOrReplaceTempView("labelCounts")
    val query = f"SELECT $label AS ratioLabel, count, cast(count * $trainRatio as long) AS trainExamples FROM labelCounts"
    import ss.implicits._
    ss.sql(query)
      .select("ratioLabel", "trainExamples")
      .map((r: Row) => r.getDouble(0) -> r.getLong(1))
      .collect()
      .toMap
  }

  def split(df: DataFrame, label: String, trainRatio: Double): DataFrame = {
    val w = Window.partitionBy(col(label)).orderBy(col(label))

    val rowNumPartitioner = row_number().over(w)

    val dfRowNum = df.sort(rand).select(col("*"), rowNumPartitioner as "row_number")

    dfRowNum.show()

    val observationsPerLabel: Map[Double, Long] = getNumExamplesPerClass(df.sparkSession, label, trainRatio)(df)

    val addIsTrainColumn = udf((label: Double, rowNumber: Int) => rowNumber <= observationsPerLabel(label))

    dfRowNum.withColumn("isTrainSet", addIsTrainColumn(col(label), col("row_number")))
  }


}

object StratifiedTrainTestSplitter {

  def getDf(ss: SparkSession): DataFrame = {
    val data = Seq(
      (0, 0.0), (1, 1.0), (2, 0.0), (3, 1.0), (4, 0.0), (5, 1.0), (6, 0.0), (7, 1.0), (8, 0.0), (9, 1.0)
    )
    ss.createDataFrame(data).toDF("id", "label")
  }

  def main(args: Array[String]): Unit = {
    val spark: SparkSession = SparkSession
      .builder()
      .config(new SparkConf().setMaster("local[1]"))
      .getOrCreate()

    val df = new StratifiedTrainTestSplitter().split(getDf(spark), "label", 0.8)

    df.cache()

    df.where(col("isTrainSet") === true).show()
    df.where(col("isTrainSet") === false).show()
  }
}

注意:在这种情况下,标签为Double。如果您的标签为String,那么您必须在这里和那里切换类型。

答案 2 :(得分:2)

Spark支持https://s3.amazonaws.com/sparksummit-share/ml-ams-1.0.1/6-sampling/scala/6-sampling_student.html

中概述的分层样本
df.stat.sampleBy("label", Map(0 -> .10, 1 -> .20, 2 -> .3), 0)

答案 3 :(得分:2)

也许当OP发布此问题时此方法不可用,但我将其留在此处以供将来参考:

# splitting dataset into train and test set
(train test) = df.randomSplit([0.7, 0.3], seed=42)
相关问题