Siladittya
Siladittya

Reputation: 1205

tf.keras.backend.clip not giving correct results

tf.keras.backend.clip is not clipping the tensors

When I use tf.keras.backend.clip inside this function

def grads_ds(model_ds, ds_inputs,y_true,cw):
    print(y_true)
    with tf.GradientTape() as ds_tape:
        y_pred = model_ds(ds_inputs)
        print(y_pred.numpy())
        logits_1 = -1*y_true*K.log(y_pred)*cw[:,0]
        logits_0 = -1*(1-y_true)*K.log(1-y_pred)*cw[:,1]
        loss = logits_1 + logits_0
        loss_value_ds = K.sum(loss)

    ds_grads = ds_tape.gradient(loss_value_ds,model_ds.trainable_variables,unconnected_gradients=tf.UnconnectedGradients.NONE)
    for g in ds_grads:
        g = tf.keras.backend.clip(g,min_grad,max_grad)
    return loss_value_ds, ds_grads

THe value of the gradients remain the same (unclipped).

When I use tf.keras.backend.clip inside the custom training loop, same way

for g in ds_grads:
    g = tf.keras.backend.clip(g,min_grad,max_grad)

it doesn't work. The gradient applied to the variables are not clipped.

However, if I print g within the loop, then it shows the clipped value.

Can't understand where the problem is.

Upvotes: 0

Views: 534

Answers (1)

strider0160
strider0160

Reputation: 549

This is because the g in your example is a reference to the value in the list. When you assign to it you are merely changing the value to which it points (ie you are not modifying the current value it points to). Consider this example, I want to set all the values in lst to 5. Guess what happens when you run this code sample?

lst = [1,2,3,4]
for ele in lst:
    ele = 5
print(lst)

Nothing! You get the exact same list back. However within the loop you will see that ele is now 5, as you have already found out in your case. This was the case where the values in the list are immutable (tensors are immutable).

However, you can modify mutable objects in place:

lst = [[2], [2], [2]]
for ele in lst:
    ele.append(3)
print(lst)

The above code will make each element [2, 3] as expected.

One way of solving your problem is:

lst = [1,2,3,4]
for itr in range(len(lst)):
    lst[itr] = 5
print(lst)

Upvotes: 1

Related Questions