Simon Delecourt
Simon Delecourt

Reputation: 1599

Avoid tensorflow session extension

I have a function which uses tensorflow backend from keras. In a loop I add operations to the session graph and then run the session. The problem is that the graph seems to grow extensively after multiple call to the function. This lead to having the function evaluation to be 2 times longer after 4/5 calls to the function.

This is the function:

def attack_fgsm(self, x, y, epsilon=1e-2):
    sess = K.get_session()
    nabla_x = np.zeros(x.shape)

    for (weak_classi, alpha) in zip(self.models, self.alphas):
        grads = K.gradients(K.categorical_crossentropy(y, weak_classi.model.output), weak_classi.model.input)[0]
        grads = sess.run(grads, feed_dict={weak_classi.model.input: x})
        nabla_x += alpha*grads

    x_adv = x + epsilon*np.sign(nabla_x)

    return x_adv

So the question is how to optimize this function so that the graph doesn't grow too much ?

After some research it seems that I need to use placeholder to overcome the problem. So I came up with this :

def attack_fgsm(self, x, y, epsilon=1e-2):
    sess = K.get_session()
    nabla_x = np.zeros(x.shape)
    y_ph = K.placeholder(y.shape)
    model_in = K.placeholder(x.shape, dtype="float")

    for (weak_classi, alpha) in zip(self.models, self.alphas):
        grads = K.gradients(K.categorical_crossentropy(y_ph, weak_classi.model.output), weak_classi.model.input)[0]
        grads = sess.run(grads, feed_dict={y_ph:y, model_in:x})
        nabla_x += alpha*grads

    x_adv = x + epsilon*np.sign(nabla_x)
    #K.clear_session()
    return x_adv

Which leads to :

Traceback (most recent call last):
  File "/home/simond/adversarialboosting/src/scripts/robustness_study.py", line 93, in <module>
    x_att_ada = adaboost.attack_fgsm(x_test, y_test, epsilon=eps)
  File "/home/simond/adversarialboosting/src/classes/AdvBoostM1.py", line 308, in attack_fgsm
    grads = sess.run(grads, feed_dict={y_ph:y, model_in:x})
  File "/home/simond/miniconda3/envs/keras/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 950, in run
    run_metadata_ptr)
  File "/home/simond/miniconda3/envs/keras/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1158, in _run
    self._graph, fetches, feed_dict_tensor, feed_handles=feed_handles)
  File "/home/simond/miniconda3/envs/keras/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 474, in __init__
    self._fetch_mapper = _FetchMapper.for_fetch(fetches)
  File "/home/simond/miniconda3/envs/keras/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 261, in for_fetch
    type(fetch)))
TypeError: Fetch argument None has invalid type <class 'NoneType'>

Upvotes: 0

Views: 52

Answers (1)

Dr. Snoopy
Dr. Snoopy

Reputation: 56377

The problem is running this line of code each time you call this function:

grads = K.gradients(K.categorical_crossentropy(y, weak_classi.model.output), weak_classi.model.input)[0]

This adds a symbolic computation of the gradient to your graph, and it is not needed to be run more than once for each weak_classi instance, so you can split this into two parts. This part should be run only once, say, at initialization:

self.weak_classi_grads = []
for (weak_classi, alpha) in zip(self.models, self.alphas):
    grads = K.gradients(K.categorical_crossentropy(y_ph, weak_classi.model.output), weak_classi.model.input)[0]
self.weak_classi_grads.append(grads)

Then you can rewrite your evaluation function as:

def attack_fgsm(self, x, y, epsilon=1e-2):
    sess = K.get_session()
    nabla_x = np.zeros(x.shape)

    for (weak_classi, alpha, grads) in zip(self.models, self.alphas, self.weak_classi_grads):
        grads = sess.run(grads, feed_dict={weak_classi.model.input: x})
        nabla_x += alpha*grads

    x_adv = x + epsilon*np.sign(nabla_x)

    return x_adv

This way the graph only has one instance of gradient computation for each model, and then you just need to run the session to evaluate gradients with different inputs.

Upvotes: 1

Related Questions