Reputation: 620
I am building a model in tensorflow version 2.0 (upgrading is not an option due to compatibility with my version of cuda, which I do not have permission to change). I am using tf.strategy.MirroredStrategy()
to train my model on 2 GPUs. However, I am trying to instantiate a custom dense layer whose weights are the transpose of the weights of a different dense layer. My code involves this line to build the custom layer:
from tensorflow.keras import backend as K
class DenseTied(Layer):
# Really long class, full code can be found at link below
def build(self, input_shape):
self.kernel = K.transpose(self.tied_to.kernel)
I am then using this in a model as follows:
from tensorflow.keras.layers import Input, Dense
def build_model(input_shape):
model_input = Input(shape=input_shape)
dense1 = Dense(6144, activation='relu')
dense_tied1 = DenseTied(49152, tied_to=dense1)
x = dense1(model_input)
model_output = dense_tied1(x)
model = Model(model_input, model_output)
model.compile(optimizer='adam', loss='mse')
return model
When trying to build this model I get an error: AttributeError: 'tensorflow.python.framework.ops.EagerTensor' object has no attribute '_distribute_strategy'
.
I have been tracking this down for a while now and I pinpointed that the issue is in the line
self.kernel = K.transpose(self.tied_to.kernel)
It seems that self.tied_to.kernel
is of type <class 'tensorflow.python.distribute.values.MirroredVariable'>
but after calling K.transpose()
on it the resulting output is of type <class 'tensorflow.python.framework.ops.EagerTensor'>
. I tried following the instructions here but it did not work. I get AttributeError: 'MirroredStrategy' object has no attribute 'run'
when in the docs it does. So I think maybe my version of Tensorflow is too old for that method.
How can I update a mirrored variable in Tensorflow 2.0?
Also if you want to see the full custom layer code, I am trying to implement the dense tied layer described here.
Upvotes: 2
Views: 1778
Reputation: 1087
As of now, the documentation is for tensorflow 2.3. If you are using 2.0 it should be
strategy.experimental_run_v2
instead of strategy.run
.
Upvotes: 3