Reputation: 10153
I need to covert a column of the Spark dataframe to list to use later for matplotlib
df.toPandas()[col_name].values.tolist()
it looks like there is high performance overhead this operation takes around 18sec is there other way to do that or improve the perfomance?
Upvotes: 5
Views: 28501
Reputation: 31
You can use an iterator to save memory toLocalIterator
. The iterator will consume as much memory as the largest partition in this. And if you need to use the result only once, then the iterator is perfect is this case.
d = [['Bender', 12], ['Flex', 123],['Fry', 1234]]
df = spark.createDataFrame(d, ['name', 'value'])
df.show()
+------+-----+
| name|value|
+------+-----+
|Bender| 12|
| Flex| 123|
| Fry| 1234|
+------+-----+`
values = [row.value for row in df.toLocalIterator()]
print(values)
>>> [12, 123, 1234]
Also toPandas() method should only be used if the resulting Pandas's DataFrame is expected to be small, as all the data is loaded into the driver's memory.
Upvotes: 2
Reputation: 573
You can do it this way:
>>> [list(row) for row in df.collect()]
Example:
>>> d = [['Alice', 1], ['Bob', 2]]
>>> df = spark.createDataFrame(d, ['name', 'age'])
>>> df.show()
+-----+---+
| name|age|
+-----+---+
|Alice| 1|
| Bob| 2|
+-----+---+
>>> to_list = [list(row) for row in df.collect()]
print list
Result: [[u'Alice', 1], [u'Bob', 2]]
Upvotes: 15
Reputation: 330073
If you really need a local list there is not much you can do here but one improvement is to collect only a single column not a whole DataFrame
:
df.select(col_name).flatMap(lambda x: x).collect()
Upvotes: 8