Reputation: 18107
_pickle.PicklingError: Could not serialize object: Exception: It appears that you are attempting to reference SparkContext from a broadcast variable, action, or transformation. SparkContext can only be used on the driver, not in code that it run on workers. For more information, see SPARK-5063.
Super simple EXAMPLE app to try and run some calculations in parallel. Works (sometimes) but most times crashes with the above exception.
I don't think I have nested RDD, but the part about not being able to use the sparkContext in workers is worrisome since I think I need that to achieve some level of parallelism. If I can't use the sparkContext in the worker threads, how do I get the computational results back?
At this point I still expect it to be serialized, and was going to enable the parallel run after this. But can't even get the serialized multi-threaded version to run....
from pyspark import SparkContext
import threading
THREADED = True. # Set this to false and it always works but is sequential
content_file = "file:///usr/local/Cellar/apache-spark/3.0.0/README.md"
sc = SparkContext("local", "first app")
content = sc.textFile(content_file).cache() # For the non-threaded version
class Worker(threading.Thread):
def __init__(self, letter, *args, **kwargs):
super().__init__(*args, **kwargs)
self.letter = letter
def run(self):
print(f"Starting: {self.letter}")
nums[self.letter] = content.filter(lambda s: self.letter in s).count() # SPOILER self.letter turns out to be the problem
print(f"{self.letter}: {nums[self.letter]}")
nums = {}
if THREADED:
threads = []
for char in range(ord('a'), ord('z')+1):
letter = chr(char)
threads.append(Worker(letter, name=letter))
for thread in threads:
thread.start()
for thread in threads:
thread.join()
else:
for char in range(ord('a'), ord('z')+1):
letter = chr(char)
nums[letter] = content.filter(lambda s: letter in s).count()
print(f"{letter}: {nums[letter]}")
print(nums)
Even when I change the code to use one thread at a time
threads = []
for char in range(ord('a'), ord('z')+1):
letter = chr(char)
thread = Worker(letter, name=letter)
threads.append(thread)
thread.start()
thread.join()
It raises the same exception, I guess because it is trying to get the results back in a worker thread and not the main thread (where the SparkContext is declared).
I need to be able to wait on several values simultaneously if spark is going to provide any benefit here.
The real problem I'm trying to solve looks like this:
__________RESULT_________
^ ^ ^
A B C
a1 ^ a2 b1 ^ b2 c1 ^ c2...
To get my result I want to calculate A B and C in parallel, and each of those pieces will have to calculate a1, a2, a3, .... in parallel. I'm breaking it into threads so I can request multiple values simultaneously so that spark can run the computation in parallel.
I created the sample above simply because I want to get the threading correct, I'm not trying to figure out how to count the # of lines with a character in it. But this seemed super simple to vet the threading aspect.
This little change fixes things right up. self.letter was blowing up in the lambda, dereferencing it before the filter call removed the crash
def run(self):
print(f"Starting: {self.letter}")
letter = self.letter
nums[self.letter] = content.filter(lambda s: letter in s).count()
print(f"{self.letter}: {nums[self.letter]}")
Upvotes: 0
Views: 4189
Reputation: 2386
The Exception says
It appears that you are attempting to reference SparkContext from a broadcast variable, action, or transformation
In your case the reference to the SparkContext
is held by the following line:
nums[self.letter] = self.content.filter(lambda s: self.letter in s).count()
in this line, you define a filter (which counts as a transformation) using the following lambda expression:
lambda s: self.letter in s
The Problem with this expression is: You reference the member variable letter
of the object-reference self
. To make this reference available during the execution of your batch, Spark needs to serialize the object self
. But this object holds not only the member letter
, but also content
, which is a Spark-RDD (and every Spark-RDD holds a reference to the SparkContext it was created from).
To make the lambda serializable, you have to ensure not to reference anything that is not serializable inside it. The easiest way to achieve that, given your example, is to define a local variable based on the member letter
:
def run(self):
print(f"Starting: {self.letter}")
letter = self.letter
nums[self.letter] = self.content.filter(lambda s: letter in s).count()
print(f"{self.letter}: {nums[self.letter]}")
To understand why we can't do this, we have to understand what Spark does with every transformation in the background.
Whenever you have some piece of code like this:
sc = SparkContext(<connection information>)
You're creating a "Connection" to the Spark-Master. It may be a simple in-process local Spark-Master or a Spark-Master running on a whole different server.
Given the SparkContext
-Object, we can define where our pipeline should get it's data from. For this example, let's say we want to read our data from a text-file (just like in your question:
rdd = sc.textFile("file:///usr/local/Cellar/apache-spark/3.0.0/README.md")
As I mentioned before, the SparkContext
is more or less a "Connection" to the Spark-Master. The URL we specify as the location of our text-file must be accessable from the Spark-Master, not from the system you're executing the python-script on!
Based on the Spark-RDD we created, we can now define how the data should be processed. Let's say we want to count only lines that contain a given string "Hello World"
:
linesThatContainHelloWorld = rdd.filter(lambda line: "Hello World" in line).count()
What Spark does once we call a terminal function (a computation that yields a result, like count()
in this case) is that it serializes the function we passed to filter
, transfers the serialized data to the Spark-Workers (which may run on a totally different server) and these Spark-Workers deserialize that function to be able to execute the given function.
That means that this piece of code: lambda line: "Hello World" in line
will actually not be executed inside the Python-Process you're currently in, but on the Spark-Workers.
Things start to get trickier (for Spark) whenever we reference a variable from the upper scope inside one of our transformations:
stringThatALineShouldContain = "Hello World"
linesThatContainHelloWorld = rdd.filter(lambda line: stringThatALineShouldContain in line).count()
Now, Spark not only has to serialize the given function, but also the referenced variable stringThatALineShouldContain
from the upper scope. In this simple example, this is no problem, since the variable stringThatALineShouldContain
is serializable.
But whenever we try to access something that is not serializable or simply holds a reference to something that is not serialize, Spark will complain.
For example:
stringThatALineShouldContain = "Hello World"
badExample = (sc, stringThatALineShouldContain) # tuple holding a reference to the SparkContext
linesThatContainHelloWorld = rdd.filter(lambda line: badExample[1] in line).count()
Since the function now references badExample
, Spark tries to serialize this variable and complains that it holds a reference to the SparkContext
.
This not only applies to the SparkContext
, but to everything that is not serializable, such as Connection-Objects to Databases, File-Handles and many more.
If, for any reason, you have to do something like this, you should only reference an object that contains information of how to create that unserializable object.
dbConnection = MySQLConnection("mysql.example.com") # Not sure if this class exists, only for the example
rdd.filter(lambda line: dbConnection.insertIfNotExists("INSERT INTO table (col) VALUES (?)", line)
# note that this is still "bad code", since the connection is never cleared. But I hope you get the idea
class LazyMySQLConnection:
connectionString = None
actualConnection = None
def __init__(self, connectionString):
self.connectionString = connectionString
def __getstate__(self):
# tell pickle (the serialization library Spark uses for transformations) that the actualConnection member is not part of the state
state = dict(self.__dict__)
del state["actualConnection"]
return state
def getOrCreateConnection(self):
if not self.actualConnection:
self.actualConnection = MySQLConnection(self.connectionString)
return self.actualConnection
lazyDbConnection = LazyMySQLConnection("mysql.example.com")
rdd.filter(lambda line: lazyDbConnection.getOrCreateConnection().insertIfNotExists("INSERT INTO table (col) VALUES (?)", line)
# remember, the lambda we supplied for the filter will be executed on the Spark-Workers, so the connection will be etablished from each Spark-Worker!
Upvotes: 2
Reputation: 2386
You're trying to use (Py)Spark in a way it is not intended to be used. You're mixing up plain-python data processing with spark-processing where you could completely realy on spark.
The Idea with Spark (and other Data Processing Frameworks) is, that you define how your data should be processed and all the multithreading + distribution stuff is just a independent "configuration".
Also, I don't really see what you would like to gain by using multiple threads. Every Thread would:
This would (if it worked) yield a correct result, sure, but is inefficient, since there would be many threads fighting for those read operations on that file (remember, every thread would have to read the COMPLETE file in the first place, the be able to filter based on its assigned letter).
Work with spark, not against it, to get the most out of it.
# imports and so on
content_file = "file:///usr/local/Cellar/apache-spark/3.0.0/README.md"
sc = SparkContext("local", "first app")
rdd = sc.textFile(content_file) # read from this file
rdd = rdd.flatMap(lambda line: [letter for letter in line]) # forward every letter of each line to the next operator
# initialize the letterRange "outside" of spark so we reduce the runtime-overhead
relevantLetterRange = [chr(char) for char in range(ord('a'), ord('z') + 1)]
rdd = rdd.filter(lambda letter: letter in relevantLetterRange)
rdd = rdd.keyBy(lambda letter: letter) # key by the letter itself
countsByKey = rdd.countByKey() # count by key
You can of course simply write this in one chain:
# imports and so on
content_file = "file:///usr/local/Cellar/apache-spark/3.0.0/README.md"
sc = SparkContext("local", "first app")
relevantLetterRange = [chr(char) for char in range(ord('a'), ord('z') + 1)]
countsByKey = sc.textFile(content_file)\
.flatMap(lambda line: [letter for letter in line])\
.filter(lambda letter: letter in relevantLetterRange)\
.keyBy(lambda letter: letter)
.countByKey()
Upvotes: 1