Spark从一行中提取值

时间:2015-10-08 06:08:24

标签: scala apache-spark apache-spark-sql

我有以下数据框

val transactions_with_counts = sqlContext.sql(
  """SELECT user_id AS user_id, category_id AS category_id,
  COUNT(category_id) FROM transactions GROUP BY user_id, category_id""")

我试图将行转换为Rating对象,但由于x(0)返回一个数组,因此失败

val ratings = transactions_with_counts
  .map(x => Rating(x(0).toInt, x(1).toInt, x(2).toInt))
  

错误:值toInt不是Any

的成员

3 个答案:

答案 0 :(得分:57)

让我们从一些虚拟数据开始:

val transactions = Seq((1, 2), (1, 4), (2, 3)).toDF("user_id", "category_id")

val transactions_with_counts = transactions
  .groupBy($"user_id", $"category_id")
  .count

transactions_with_counts.printSchema

// root
// |-- user_id: integer (nullable = false)
// |-- category_id: integer (nullable = false)
// |-- count: long (nullable = false)

有几种方法可以访问Row值并保留预期类型:

  1. 模式匹配

    import org.apache.spark.sql.Row
    
    transactions_with_counts.map{
      case Row(user_id: Int, category_id: Int, rating: Long) =>
        Rating(user_id, category_id, rating)
    } 
    
  2. get*getInt等类型getLong方法:

    transactions_with_counts.map(
      r => Rating(r.getInt(0), r.getInt(1), r.getLong(2))
    )
    
  3. getAs方法,可以同时使用名称和索引:

    transactions_with_counts.map(r => Rating(
      r.getAs[Int]("user_id"), r.getAs[Int]("category_id"), r.getAs[Long](2)
    ))
    

    它可用于正确提取用户定义的类型,包括mllib.linalg.Vector。显然,按名称访问需要一个架构。

  4. 转换为静态类型Dataset(Spark 1.6+ / 2.0 +):

    transactions_with_counts.as[(Int, Int, Long)]
    

答案 1 :(得分:7)

使用数据集,您可以按如下方式定义评级:

case class Rating(user_id: Int, category_id:Int, count:Long)

这里的评级类有一个列名'count'而不是'rating',建议为零。因此,评级变量分配如下:

val transactions_with_counts = transactions.groupBy($"user_id", $"category_id").count

val rating = transactions_with_counts.as[Rating]

这样你就不会遇到Spark中的运行时错误,因为你的错误 评级类列名称与Spark在运行时生成的“计数”列名称相同。

答案 2 :(得分:0)

要访问 Dataframe 行的值,您需要将{strong> Dataframe 的rdd.collect与for循环一起使用。

考虑您的数据框如下所示。

val df = Seq(
      (1,"James"),    
      (2,"Albert"),
      (3,"Pete")).toDF("user_id","name")

在您的数据框顶部使用rdd.collectrow变量将包含rdd行类型的 Dataframe 的每一行。要从一行中获取每个元素,请使用row.mkString(","),它将以逗号分隔的值包含每一行的值。使用split函数(内置函数),您可以访问带有索引的rdd行的每个列值。

for (row <- df.rdd.collect)
{   
    var user_id = row.mkString(",").split(",")(0)
    var category_id = row.mkString(",").split(",")(1)       
}

dataframe.foreach循环相比,上面的代码看起来更大一些,但是通过使用上面的代码,您将对逻辑有更多的控制。

相关问题