将具有结构数组的列展开到新列中

时间:2018-01-29 22:43:22

标签: apache-spark pyspark

我有一个带有单个列的DataFrame,它是一个结构数组

df.printSchema()
root
 |-- dataCells: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- label: string (nullable = true)
 |    |    |-- value: string (nullable = true)

某些示例数据可能如下所示:

df.first()
Row(dataCells=[Row(label="firstName", value="John"), Row(label="lastName", value="Doe"), Row(label="Date", value="1/29/2018")])

我正在尝试通过将每个结构转换为命名列来弄清楚如何重新格式化此DataFrame。我希望有一个像这样的DataFrame:

------------------------------------
| firstName | lastName | Date      |
------------------------------------
| John      | Doe      | 1/29/2018 |
| ....      | ...      | ...       |

我已经尝试了我能想到的一切但却没有想到这一点。

2 个答案:

答案 0 :(得分:3)

爆炸并选择*

from pyspark.sql.functions import explode, first, col, monotonically_increasing_id

df = spark.createDataFrame([
  Row(dataCells=[Row(label="firstName", value="John"), Row(label="lastName", value="Doe"), Row(label="Date", value="1/29/2018")])
])

long = (df
   .withColumn("id", monotonically_increasing_id())
   .select("id", explode("dataCells").alias("col"))
   .select("id", "col.*"))

pivot

long.groupBy("id").pivot("label").agg(first("value")).show()
# +-----------+---------+---------+--------+                                      
# |         id|     Date|firstName|lastName|
# +-----------+---------+---------+--------+
# |25769803776|1/29/2018|     John|     Doe|
# +-----------+---------+---------+--------+

你也可以:

from pyspark.sql.functions import udf

@udf("map<string,string>")
def as_map(x):
    return dict(x)

cols = [col("dataCells")[c].alias(c) for c in ["Date", "firstName", "lastName"]]
df.select(as_map("dataCells").alias("dataCells")).select(cols).show()

# +---------+---------+--------+
# |     Date|firstName|lastName|
# +---------+---------+--------+
# |1/29/2018|     John|     Doe|
# +---------+---------+--------+

参考文献:

答案 1 :(得分:1)

我在没有UDF的情况下尝试的另一种方法,

>>> df.show()
+--------------------+
|           dataCells|
+--------------------+
|[[firstName,John]...|
+--------------------+

>>> from pyspark.sql import functions as F

## size of array with maximum length in column 
>>> arr_len = df.select(F.max(F.size('dataCells')).alias('len')).first().len

## get values from struct 
>>> df1 = df.select([df.dataCells[i].value for i in range(arr_len)])
>>> df1.show()
+------------------+------------------+------------------+
|dataCells[0].value|dataCells[1].value|dataCells[2].value|
+------------------+------------------+------------------+
|              John|               Doe|         1/29/2018|
+------------------+------------------+------------------+

>>> oldcols = df1.columns

## get the labels from struct
>>> cols = df.select([df.dataCells[i].label.alias('col_%s'%i) for i in range(arr_len)]).dropna().first()
>>> cols
Row(dataCells[0].label=u'firstName', dataCells[1].label=u'lastName', dataCells[2].label=u'Date')
>>> newcols = [cols[i] for i in range(arr_len)]
>>> newcols
[u'firstName', u'lastName', u'Date']

## use the labels to rename the columns
>>> df2 = reduce(lambda data, idx: data.withColumnRenamed(oldcols[idx], newcols[idx]), range(len(oldcols)), df1)
>>> df2.show()
+---------+--------+---------+
|firstName|lastName|     Date|
+---------+--------+---------+
|     John|     Doe|1/29/2018|
+---------+--------+---------+