sagar_acharya
sagar_acharya

Reputation: 366

How to save and load selected and all variables in tensorflow 2.0 using tf.train.Checkpoint?

How do I save selected variables in tensorflow 2.0 shown below in a file and load them into some defined variables in another code using tf.train.Checkpoint?

class manyVariables:
    def __init__(self):
        self.initList = [None]*100
        for i in range(100):
            self.initList[i] = tf.Variable(tf.random.normal([5,5]))
        self.makeSomeMoreVariables()

    def makeSomeMoreVariables(self):
        self.moreList = [None]*10
        for i in range(10):
            self.moreList[i] = tf.Variable(tf.random.normal([3,3]))

    def saveVariables(self):
        # how to save self.initList's 3,55 and 60th elements and self.moreList's 4th element

Also, please show how to save all the variables and reload using tf.train.Checkpoint. Thanks in advance.

Upvotes: 3

Views: 5542

Answers (2)

javidcf
javidcf

Reputation: 59731

I'm not sure if this is what you mean, but you can create a tf.train.Checkpoint object specifically for the variables that you want to save and restore. See the following example:

import tensorflow as tf

class manyVariables:
    def __init__(self):
        self.initList = [None]*100
        for i in range(100):
            self.initList[i] = tf.Variable(tf.random.normal([5,5]))
        self.makeSomeMoreVariables()
        self.ckpt = self.makeCheckpoint()

    def makeSomeMoreVariables(self):
        self.moreList = [None]*10
        for i in range(10):
            self.moreList[i] = tf.Variable(tf.random.normal([3,3]))

    def makeCheckpoint(self):
        return tf.train.Checkpoint(
            init3=self.initList[3], init55=self.initList[55],
            init60=self.initList[60], more4=self.moreList[4])

    def saveVariables(self):
        self.ckpt.save('./ckpt')

    def restoreVariables(self):
        status = self.ckpt.restore(tf.train.latest_checkpoint('.'))
        status.assert_consumed()  # Optional check

# Create variables
v1 = manyVariables()
# Assigned fixed values
for i, v in enumerate(v1.initList):
    v.assign(i * tf.ones_like(v))
for i, v in enumerate(v1.moreList):
    v.assign(100 + i * tf.ones_like(v))
# Save them
v1.saveVariables()

# Create new variables
v2 = manyVariables()
# Check initial values
print(v2.initList[2].numpy())
# [[-1.9110833   0.05956204 -1.1753829  -0.3572553  -0.95049495]
#  [ 0.31409055  1.1262076   0.47890127 -0.1699607   0.4409122 ]
#  [-0.75385517 -0.13847834  0.97012395  0.42515194 -1.4371008 ]
#  [ 0.44205236  0.86158335  0.6919655  -2.5156968   0.16496429]
#  [-1.241602   -0.15177743  0.5603795  -0.3560254  -0.18536267]]
print(v2.initList[3].numpy())
# [[-3.3441594  -0.18425298 -0.4898144  -1.2330629   0.08798431]
#  [ 1.5002227   0.99475247  0.7817361   0.3849587  -0.59548247]
#  [-0.57121766 -1.277224    0.6957546  -0.67618763  0.0510064 ]
#  [ 0.85491985  0.13310803 -0.93152267  0.10205163  0.57520276]
#  [-1.0606447  -0.16966362 -1.0448577   0.56799036 -0.90726566]]

# Restore them
v2.restoreVariables()
# Check values after restoring
print(v2.initList[2].numpy())
# [[-1.9110833   0.05956204 -1.1753829  -0.3572553  -0.95049495]
#  [ 0.31409055  1.1262076   0.47890127 -0.1699607   0.4409122 ]
#  [-0.75385517 -0.13847834  0.97012395  0.42515194 -1.4371008 ]
#  [ 0.44205236  0.86158335  0.6919655  -2.5156968   0.16496429]
#  [-1.241602   -0.15177743  0.5603795  -0.3560254  -0.18536267]]
print(v2.initList[3].numpy())
# [[3. 3. 3. 3. 3.]
#  [3. 3. 3. 3. 3.]
#  [3. 3. 3. 3. 3.]
#  [3. 3. 3. 3. 3.]
#  [3. 3. 3. 3. 3.]]

If you want to save all the variables in the lists, you could replace makeCheckpoint with something like this:

def makeCheckpoint(self):
    return tf.train.Checkpoint(
        **{f'init{i}': v for i, v in enumerate(self.initList)},
        **{f'more{i}': v for i, v in enumerate(self.moreList)})

Note that you can have "nested" checkpoints so, more generally, you could have a function that makes a checkpoint for a list of variables, for example like this:

def listCheckpoint(varList):
    # Use 'item{}'.format(i) if using Python <3.6
    return tf.train.Checkpoint(**{f'item{i}': v for i, v in enumerate(varList)})

Then you could just have this:

def makeCheckpoint(self):
    return tf.train.Checkpoint(init=listCheckpoint(self.initList),
                               more=listCheckpoint(self.moreList))

Upvotes: 4

Eric
Eric

Reputation: 1213

In the following code I save an array called variables to a .txt file with a name of your choosing. This file will be in the same folder as your python file. The 'wb' in the open function means write with truncation(so removing everything that was previously in the file) and uses byte format. I use pickle to handle saving/parsing the list.

import pickle

    def saveVariables(self, variables): #where 'variables' is a list of variables
        with open("nameOfYourFile.txt", 'wb+') as file:
           pickle.dump(variables, file)

    def retrieveVariables(self, filename):
        variables = []
        with open(str(filename), 'rb') as file:
            variables = pickle.load(file)
        return variables

To save specific stuff to your file just add it as the variables argument in saveVariables like so:

myVariables = [initList[2], initList[54], initList[59], moreList[3]]
saveVariables(myVariables)

To retrieve variables from text file with a certain name:

myVariables = retrieveVariables("theNameOfYourFile.txt")
thirdEl = myVariables[0]
fiftyFifthEl = myVariables[1]
SixtiethEl = myVariables[2]
fourthEl = myVariables[3]

You could add these functions anywhere in the class.

To be able to access the initList/moreList in your example however, you should either return them from their functions(like I do with the variables list) or make them global.

Upvotes: 2

Related Questions