Lerysson
Lerysson

Reputation: 31

How preprocess image using PySpark?

I got a project where I need to set up a proof-of-concept of a big data architecture (AWS S3 + SageMaker) for 1) pre-treat images using PySpark, 2) perform a PCA and 3) train some machine or deep learning models. My issue is to understand how to manipulate image data using PySpark and could not provide satisfactory answers online.

So I think that any answer/hint can interest a broad audience of beginners like me. A similar thread remains unanswered here.

As follows you find find what I tried so far (using Python 3.8 on Jupyter Notebook):

from pyspark.sql import SparkSession
import sagemaker_pyspark
import botocore.session

session = botocore.session.get_session()
credentials = session.get_credentials()

conf = (SparkConf()**
        .set("spark.driver.extraClassPath", ":".join(sagemaker_pyspark.classpath_jars())))

spark = (
    SparkSession
    .builder
    .config(conf=conf) \
    .config('fs.s3a.access.key',  credentials.access_key)
    .config('fs.s3a.secret.key', credentials.secret_key)
    .appName("test")
    .getOrCreate()
s3_url = "s3a://<MY_BUCKET>/dataset/*"
df = spark.read.format("image").load(s3_url)
print((df.count(), len(df.columns)))
print(df.printSchema())
df.select('image.nChannels', "image.width", "image.height", "image.data").show(truncate=True)

Output:

(60, 1)
root
 |-- image: struct (nullable = true)
 |    |-- origin: string (nullable = true)
 |    |-- height: integer (nullable = true)
 |    |-- width: integer (nullable = true)
 |    |-- nChannels: integer (nullable = true)
 |    |-- mode: integer (nullable = true)
 |    |-- data: binary (nullable = true)

None
+---------+-----+------+--------------------+
|nChannels|width|height|                data|
+---------+-----+------+--------------------+
|        3|  100|   100|[FF FF FF FF FF F...|
|        3|  100|   100|[FF FF FF FF FF F...|
|        3|  100|   100|[FF FF FF FF FF F...|
|        3|  100|   100|[FF FF FF FF FF F...|
|        3|  100|   100|[FF FF FF FF FF F...|
|        3|  100|   100|[FF FF FF FF FF F...|
|        3|  100|   100|[FF FF FF FF FF F...|
|        3|  100|   100|[FF FF FF FF FF F...|
|        3|  100|   100|[FF FF FF FF FF F...|
|        3|  100|   100|[FF FF FF FF FF F...|
|        3|  100|   100|[FF FF FF FF FF F...|
|        3|  100|   100|[FF FF FF FF FF F...|
|        3|  100|   100|[FF FF FF FF FF F...|
|        3|  100|   100|[FF FF FF FF FF F...|
|        3|  100|   100|[FF FF FF FF FF F...|
|        3|  100|   100|[FF FF FF FF FF F...|
|        3|  100|   100|[FF FF FF FF FF F...|
|        3|  100|   100|[FF FF FF FF FF F...|
|        3|  100|   100|[FF FF FF FF FF F...|
|        3|  100|   100|[FF FF FF FF FF F...|
+---------+-----+------+--------------------+
only showing top 20 rows

So I got the images as bytes in my df.data.

import numpy as np
import io
from PIL import Image
from pyspark.sql.functions import pandas_udf, PandasUDFType
from tensorflow.keras.preprocessing.image import img_to_array
from tensorflow.keras.applications.resnet50 import preprocess_input


@pandas_udf('array<float>', 'pyspark.sql.dataframe.DataFrame')
def preprocess(content):
    """
    Preprocesses raw image bytes for prediction.
    """
    img = Image.open(io.BytesIO(content))
    arr = img_to_array(img)
    return arr.flatten()


df_transformed = df.select(preprocess("image.data"))
type(df_transformed)
df_transformed.printSchema()
df_transformed.show()

Output:

root
 |-- preprocess(image.data): array (nullable = true)
 |    |-- element: float (containsNull = true)

---------------------------------------------------------------------------
Py4JJavaError                             Traceback (most recent call last)
<ipython-input-24-e3999c55086a> in <module>
     20 type(df_transformed)
     21 df_transformed.printSchema()
---> 22 df_transformed.show()

~/anaconda3/envs/tensorflow2_p36/lib/python3.6/site-packages/pyspark/sql/dataframe.py in show(self, n, truncate, vertical)
    376         """
    377         if isinstance(truncate, bool) and truncate:
--> 378             print(self._jdf.showString(n, 20, vertical))
    379         else:
    380             print(self._jdf.showString(n, int(truncate), vertical))

~/anaconda3/envs/tensorflow2_p36/lib/python3.6/site-packages/py4j/java_gateway.py in __call__(self, *args)
   1255         answer = self.gateway_client.send_command(command)
   1256         return_value = get_return_value(
-> 1257             answer, self.gateway_client, self.target_id, self.name)
   1258 
   1259         for temp_arg in temp_args:

~/anaconda3/envs/tensorflow2_p36/lib/python3.6/site-packages/pyspark/sql/utils.py in deco(*a, **kw)
     61     def deco(*a, **kw):
     62         try:
---> 63             return f(*a, **kw)
     64         except py4j.protocol.Py4JJavaError as e:
     65             s = e.java_exception.toString()

~/anaconda3/envs/tensorflow2_p36/lib/python3.6/site-packages/py4j/protocol.py in get_return_value(answer, gateway_client, target_id, name)
    326                 raise Py4JJavaError(
    327                     "An error occurred while calling {0}{1}{2}.\n".
--> 328                     format(target_id, ".", name), value)
    329             else:
    330                 raise Py4JError(

Py4JJavaError: An error occurred while calling o433.showString.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 12.0 failed 1 times, most recent failure: Lost task 0.0 in stage 12.0 (TID 16, localhost, executor driver): org.apache.spark.api.python.PythonException: Traceback (most recent call last):
  File "/home/ec2-user/anaconda3/envs/tensorflow2_p36/lib/python3.6/site-packages/pyspark/python/lib/pyspark.zip/pyspark/worker.py", line 372, in main
    process()
  File "/home/ec2-user/anaconda3/envs/tensorflow2_p36/lib/python3.6/site-packages/pyspark/python/lib/pyspark.zip/pyspark/worker.py", line 367, in process
    serializer.dump_stream(func(split_index, iterator), outfile)
  File "/home/ec2-user/anaconda3/envs/tensorflow2_p36/lib/python3.6/site-packages/pyspark/python/lib/pyspark.zip/pyspark/serializers.py", line 283, in dump_stream
    for series in iterator:
  File "<string>", line 1, in <lambda>
  File "/home/ec2-user/anaconda3/envs/tensorflow2_p36/lib/python3.6/site-packages/pyspark/python/lib/pyspark.zip/pyspark/worker.py", line 96, in <lambda>
    return lambda *a: (verify_result_length(*a), arrow_return_type)
  File "/home/ec2-user/anaconda3/envs/tensorflow2_p36/lib/python3.6/site-packages/pyspark/python/lib/pyspark.zip/pyspark/worker.py", line 87, in verify_result_length
    result = f(*a)
  File "/home/ec2-user/anaconda3/envs/tensorflow2_p36/lib/python3.6/site-packages/pyspark/python/lib/pyspark.zip/pyspark/util.py", line 99, in wrapper
    return f(*args, **kwargs)
  File "<ipython-input-24-e3999c55086a>", line 14, in preprocess
TypeError: a bytes-like object is required, not 'Series'

    at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.handlePythonException(PythonRunner.scala:452)
    at org.apache.spark.sql.execution.python.ArrowPythonRunner$$anon$1.read(ArrowPythonRunner.scala:172)
    at org.apache.spark.sql.execution.python.ArrowPythonRunner$$anon$1.read(ArrowPythonRunner.scala:122)
    at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:406)
    at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
    at org.apache.spark.sql.execution.python.ArrowEvalPythonExec$$anon$2.<init>(ArrowEvalPythonExec.scala:98)
    at org.apache.spark.sql.execution.python.ArrowEvalPythonExec.evaluate(ArrowEvalPythonExec.scala:96)
    at org.apache.spark.sql.execution.python.EvalPythonExec$$anonfun$doExecute$1.apply(EvalPythonExec.scala:127)
    at org.apache.spark.sql.execution.python.EvalPythonExec$$anonfun$doExecute$1.apply(EvalPythonExec.scala:89)
    at org.apache.spark.rdd.RDD$$anonfun$mapPartitions$1$$anonfun$apply$23.apply(RDD.scala:801)
    at org.apache.spark.rdd.RDD$$anonfun$mapPartitions$1$$anonfun$apply$23.apply(RDD.scala:801)
    at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
    at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
    at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
    at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
    at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
    at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
    at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
    at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
    at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
    at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
    at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
    at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
    at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
    at org.apache.spark.scheduler.Task.run(Task.scala:121)
    at org.apache.spark.executor.Executor$TaskRunner$$anonfun$10.apply(Executor.scala:402)
    at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1360)
    at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:408)
    at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
    at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
    at java.lang.Thread.run(Thread.java:748)

Driver stacktrace:
    at org.apache.spark.scheduler.DAGScheduler.org$apache$spark$scheduler$DAGScheduler$$failJobAndIndependentStages(DAGScheduler.scala:1887)
    at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1875)
    at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1874)
    at scala.collection.mutable.ResizableArray$class.foreach(ResizableArray.scala:59)
    at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:48)
    at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:1874)
    at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:926)
    at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:926)
    at scala.Option.foreach(Option.scala:257)
    at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:926)
    at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:2108)
    at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2057)
    at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2046)
    at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
    at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:737)
    at org.apache.spark.SparkContext.runJob(SparkContext.scala:2061)
    at org.apache.spark.SparkContext.runJob(SparkContext.scala:2082)
    at org.apache.spark.SparkContext.runJob(SparkContext.scala:2101)
    at org.apache.spark.sql.execution.SparkPlan.executeTake(SparkPlan.scala:365)
    at org.apache.spark.sql.execution.CollectLimitExec.executeCollect(limit.scala:38)
    at org.apache.spark.sql.Dataset.org$apache$spark$sql$Dataset$$collectFromPlan(Dataset.scala:3384)
    at org.apache.spark.sql.Dataset$$anonfun$head$1.apply(Dataset.scala:2545)
    at org.apache.spark.sql.Dataset$$anonfun$head$1.apply(Dataset.scala:2545)
    at org.apache.spark.sql.Dataset$$anonfun$53.apply(Dataset.scala:3365)
    at org.apache.spark.sql.execution.SQLExecution$$anonfun$withNewExecutionId$1.apply(SQLExecution.scala:78)
    at org.apache.spark.sql.execution.SQLExecution$.withSQLConfPropagated(SQLExecution.scala:125)
    at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:73)
    at org.apache.spark.sql.Dataset.withAction(Dataset.scala:3364)
    at org.apache.spark.sql.Dataset.head(Dataset.scala:2545)
    at org.apache.spark.sql.Dataset.take(Dataset.scala:2759)
    at org.apache.spark.sql.Dataset.getRows(Dataset.scala:255)
    at org.apache.spark.sql.Dataset.showString(Dataset.scala:292)
    at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
    at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
    at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
    at java.lang.reflect.Method.invoke(Method.java:498)
    at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
    at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
    at py4j.Gateway.invoke(Gateway.java:282)
    at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
    at py4j.commands.CallCommand.execute(CallCommand.java:79)
    at py4j.GatewayConnection.run(GatewayConnection.java:238)
    at java.lang.Thread.run(Thread.java:748)
Caused by: org.apache.spark.api.python.PythonException: Traceback (most recent call last):
  File "/home/ec2-user/anaconda3/envs/tensorflow2_p36/lib/python3.6/site-packages/pyspark/python/lib/pyspark.zip/pyspark/worker.py", line 372, in main
    process()
  File "/home/ec2-user/anaconda3/envs/tensorflow2_p36/lib/python3.6/site-packages/pyspark/python/lib/pyspark.zip/pyspark/worker.py", line 367, in process
    serializer.dump_stream(func(split_index, iterator), outfile)
  File "/home/ec2-user/anaconda3/envs/tensorflow2_p36/lib/python3.6/site-packages/pyspark/python/lib/pyspark.zip/pyspark/serializers.py", line 283, in dump_stream
    for series in iterator:
  File "<string>", line 1, in <lambda>
  File "/home/ec2-user/anaconda3/envs/tensorflow2_p36/lib/python3.6/site-packages/pyspark/python/lib/pyspark.zip/pyspark/worker.py", line 96, in <lambda>
    return lambda *a: (verify_result_length(*a), arrow_return_type)
  File "/home/ec2-user/anaconda3/envs/tensorflow2_p36/lib/python3.6/site-packages/pyspark/python/lib/pyspark.zip/pyspark/worker.py", line 87, in verify_result_length
    result = f(*a)
  File "/home/ec2-user/anaconda3/envs/tensorflow2_p36/lib/python3.6/site-packages/pyspark/python/lib/pyspark.zip/pyspark/util.py", line 99, in wrapper
    return f(*args, **kwargs)
  File "<ipython-input-24-e3999c55086a>", line 14, in preprocess
TypeError: a bytes-like object is required, not 'Series'

    at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.handlePythonException(PythonRunner.scala:452)
    at org.apache.spark.sql.execution.python.ArrowPythonRunner$$anon$1.read(ArrowPythonRunner.scala:172)
    at org.apache.spark.sql.execution.python.ArrowPythonRunner$$anon$1.read(ArrowPythonRunner.scala:122)
    at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:406)
    at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
    at org.apache.spark.sql.execution.python.ArrowEvalPythonExec$$anon$2.<init>(ArrowEvalPythonExec.scala:98)
    at org.apache.spark.sql.execution.python.ArrowEvalPythonExec.evaluate(ArrowEvalPythonExec.scala:96)
    at org.apache.spark.sql.execution.python.EvalPythonExec$$anonfun$doExecute$1.apply(EvalPythonExec.scala:127)
    at org.apache.spark.sql.execution.python.EvalPythonExec$$anonfun$doExecute$1.apply(EvalPythonExec.scala:89)
    at org.apache.spark.rdd.RDD$$anonfun$mapPartitions$1$$anonfun$apply$23.apply(RDD.scala:801)
    at org.apache.spark.rdd.RDD$$anonfun$mapPartitions$1$$anonfun$apply$23.apply(RDD.scala:801)
    at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
    at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
    at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
    at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
    at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
    at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
    at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
    at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
    at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
    at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
    at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
    at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
    at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
    at org.apache.spark.scheduler.Task.run(Task.scala:121)
    at org.apache.spark.executor.Executor$TaskRunner$$anonfun$10.apply(Executor.scala:402)
    at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1360)
    at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:408)
    at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
    at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
    ... 1 more

from pyspark.ml.image import ImageSchema
#https://stackoverflow.com/questions/67705881/unable-to-read-images-simultaneously-in-parallels-using-pyspark
df = df.select('image.*')

# Pre-caching the required schema. If you remove this line an error will be raised.
ImageSchema.imageFields

# Transforming images to np.array
arrays = df.rdd.map(ImageSchema.toNDArray).collect()

img = np.array(arrays)
print(img.shape)

Output: (60, 100, 100, 3)

On top of that I need to perform PCA to reduce image dims.

Upvotes: 3

Views: 4690

Answers (1)

AdibP
AdibP

Reputation: 2939

Try using ImageSchema and DenseVector inside an UDF and apply the function to the unpacked image column (struct format). The result would be in dense vector format of the images.

df = spark.read.format("image").load(url)
df.show()

# +--------------------+
# |               image|
# +--------------------+
# |[file:///content/...|
# |[file:///content/...|
# +--------------------+

import pyspark.sql.functions as F
from pyspark.ml.image import ImageSchema
from pyspark.ml.linalg import DenseVector, VectorUDT

ImageSchema.imageFields

img2vec = F.udf(lambda x: DenseVector(ImageSchema.toNDArray(x).flatten()), VectorUDT())

df = df.withColumn('vecs', img2vec("image"))
df.show()

# +--------------------+--------------------+
# |               image|                vecs|
# +--------------------+--------------------+
# |[file:///content/...|[255.0,255.0,255....|
# |[file:///content/...|[248.0,248.0,248....|
# +--------------------+--------------------+

Upvotes: 2

Related Questions