Reputation: 9321
I'm working on this image classification problem with keras. I'm trying to use subclassing API's
to do almost everything. I've created my custom
conv blocks which looks as follows:
class ConvBlock(keras.layers.Layer):
def __init__(self, in_features, kernel_size=(3, 3)):
super(ConvBlock, self).__init__()
self.conv = keras.layers.Conv2D(in_features, kernel_size, padding="same")
self.bn = keras.layers.BatchNormalization()
self.relu = keras.activations.relu
def call(self, x, training=False):
x = self.conv(x)
x = self.bn(x, training=training)
return self.relu(x)
After that i've created my simple Sequential
Model for testing which looks as follows:
seq_model = keras.Sequential([
ConvBlock(64),
ConvBlock(128),
ConvBlock(64),
keras.layers.Flatten(),
keras.layers.Dense(64, activation='relu'),
keras.layers.Dense(128, activation='relu'),
keras.layers.Dense(64, activation='relu'),
keras.layers.Dense(5, activation='softmax'),
], name="seq_model")
seq_model.build((None, 96, 96, 1))
seq_model.summary()
So far so good, if i call .compile()
, .train()
and .evaluate()
on this seq_model
it works. The problem comes when i try to call .compile()
, .train()
and .evaluate()
using my custom .compile()
, .train()
and .evaluate()
. The following code shows how i created them:
class Model(keras.Model):
def __init__(self, model):
super().__init__()
self.model = model
# .compile()
def compile(self, loss, optimizer, metrics):
super().compile()
self.loss = loss
self.optimizer = optimizer
self.custom_metrics = metrics
# .fit()
def train_step(self, data):
x, y = data
with tf.GradientTape() as tape:
pred = self.model(x, training=True)
loss = self.loss(y, pred)
gradients = tape.gradient(loss, self.trainable_variables)
optimizer.apply_gradients(zip(gradients, self.trainable_variables))
self.custom_metrics.update_state(y, pred)
return {"loss": loss, "accuracy": self.custom_metrics.result()}
# .evaluate()
def test_step(self, data):
x, y = data
pred = self.model(x, training=False)
loss = self.loss(y, pred)
self.custom_metrics.update_state(y, pred)
return {"loss": loss, "accuracy": self.custom_metrics.result()}
This is how I'm calling it.
yoga_model = Model(seq_model)
yoga_model.compile(
loss = keras.losses.CategoricalCrossentropy(from_logits=False),
optimizer = keras.optimizers.Adam(lr=0.001),
metrics = keras.metrics.CategoricalAccuracy(name="acc")
)
yoga_model.fit(train_ds, epochs=1, verbose=1)
Please help. A help input will be appreciated.
Upvotes: 1
Views: 2166
Reputation: 17219
In your custom model with subclassed API, implement the call
method as follows:
from tensorflow import keras
class Model(keras.Model):
def __init__:
self.model = model
def train_step:
def test_step:
def compile:
# implement the call method
def call(self, inputs, *args, **kwargs):
return self.model(inputs)
Upvotes: 3