K S
K S

Reputation: 45

Storing pickle object in Google Cloud Storage using Tensorflow.io.gfile

I am trying to store a pickle object in a Google Cloud Storage bucket. This is a part of a machine learning pipeline [tutorial][1] provided by Google that I am following. I broke down the code to a minimal example that still throws the same error. In my actual code, the object is a class instance.

import tensorflow as tf
import dill as pickle
obj = {'foo': 'bar'}
with tf.io.gfile.GFile(filename, 'wb') as f:
    pickle.dump(obj, f, protocol=pickle.HIGHEST_PROTOCOL)

When I run this, I get

TypeError: Expected binary or unicode string, got <memory at 0x7fdc66d7d7c0>

On the tutorial this step worked fine as it used Tensorflow 1 and tf.io.gfile.Open(), which was removed in Tensorflow 2 and replaced by the command above. Simply using open() also works, but of course that doesn't help me writing to a bucket. I also tried

with tf.io.gfile.GFile(filename, 'wb') as f:
   f.write(obj)

but it returns the same error. Please let me know know what I am doing wrong or if there is an alternative approach to store a pickled object directly to a bucket? Many thanks for your help! [1]: https://cloud.google.com/dataflow/docs/samples/molecules-walkthrough#overview

Upvotes: 2

Views: 1217

Answers (2)

tianzong
tianzong

Reputation: 1

you may also need to

import pickle5 as pickle

for Python 3.7.9

Upvotes: 0

kalu
kalu

Reputation: 2682

I don't know why the original code sample doesn't work. But this did the trick for me:

import codecs
import dill as pickle
import tensorflow as tf

def dump(obj, filename):
    """ Wrapper to dump an object to a file."""
    with tf.io.gfile.GFile(filename, "wb") as f:
        pickled = codecs.encode(
            pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL),
            "base64").decode()
        f.write(pickled)


def load(filename):
    """ Wrapper to load an object from a file."""
    with tf.io.gfile.GFile(filename, "rb") as f:
        pickled = f.read()
        return pickle.loads(codecs.decode(pickled, "base64"))

versions: Python 3.7.9, TensorFlow 2.3.0

Upvotes: 1

Related Questions