Reputation: 31
I am training a deep learning model on stacks of images with variable dimensions. (Shape = [Batch, None, 256, 256, 1])
, where None can be variable.
I use tf.RaggedTensor.merge_dimsions(0,1)
to convert the ragged Tensor to a shape of [None, 256, 256, 1]
to run into a pretrained keras CNN model.
However, using the KerasLayer API results in the following error: TypeError: the object of type 'RaggedTensor' has no len()
When I apply .merge_dimsions
outside of the KerasLayer and pass the tensors to the same pretrained model I do not get this error.
import tensorflow as tf
# Synthetic Data Pipeline
def synthetic_gen():
varShape = tf.random.uniform((), minval=1, maxval=12, dtype=tf.int32)
image = tf.random.normal((varShape, 256, 256, 1))
image = tf.RaggedTensor.from_tensor(image, ragged_rank=1)
yield image
ds = tf.data.Dataset.from_generator(synthetic_gen, output_signature=(tf.RaggedTensorSpec(shape=(None, 256, 256, 1), dtype=tf.float32, ragged_rank=1)))
ds = ds.repeat().batch(8)
print(next(iter(ds)).shape)
# Build Model
inputs = tf.keras.Input(
type_spec=tf.RaggedTensorSpec(
shape=(8, None, 256, 256, 1),
dtype=tf.float32,
ragged_rank=1))
ResNet50 = tf.keras.applications.ResNet50(
include_top=True,
input_shape=(256, 256, 1),
weights=None)
def merge(x):
x = x.merge_dims(0, 1)
return x
x = tf.keras.layers.Lambda(merge)(inputs)
merged_inputs = x
# x = ResNet50(x) # Uncommenting this will result in `model` producing an error when run for inference.
model = tf.keras.Model(inputs, x)
# Run inference
data = next(iter(ds))
model(data).shape # Will be an error if ResNet50 is used
Here is a colab notebook that demonstrates the problem. https://colab.research.google.com/drive/1kN78mf4_oNqxWOluV054NlqmakC5msli?usp=sharing
Upvotes: 3
Views: 2131
Reputation: 17219
Not sure if the following answer or workaround is stable for complex network design. But here are some pointers. The reason you got
Ragged Tensors have no len()
is because of ResNet models, as it expects tensor
and not ragged_tensor
. I'm not sure however if the ResNet(weights=None) is able to take ragged_tensor
or not directly. So, if we can convert the ragged data right before the ResNet gets fed, maybe it won't complain. Below is the full working code according to this. But please note, there is probably some efficient approach maybe possible.
Data
import tensorflow as tf
# Synthetic Data Pipeline
def synthetic_gen():
varShape = tf.random.uniform((), minval=1, maxval=12, dtype=tf.int32)
image = tf.random.normal((varShape, 256, 256, 1))
image = tf.RaggedTensor.from_tensor(image, ragged_rank=1)
yield image
ds = tf.data.Dataset.from_generator(synthetic_gen,
output_signature=(tf.RaggedTensorSpec(
shape=(None, 256, 256, 1),
dtype=tf.float32, ragged_rank=1
)
)
)
ds = ds.repeat().batch(8)
# Build Model
inputs = tf.keras.Input(
type_spec=tf.RaggedTensorSpec(
shape=(8, None, 256, 256, 1),
dtype=tf.float32,
ragged_rank=1))
ResNet50 = tf.keras.applications.ResNet50(
include_top=True,
input_shape=(256, 256, 1),
weights=None)
def merge(x):
x = x.merge_dims(0, 1)
return x
Here we convert ragged_tensor
to tensor
before passing the data to ResNet.
class RagModel(tf.keras.Model):
def __init__(self):
super(RagModel, self).__init__()
# base models
self.a = tf.keras.layers.Lambda(merge)
# convert: tensor = ragged_tensor.to_tensor()
self.b = tf.keras.layers.Lambda(lambda x: x.to_tensor())
self.c = ResNet50
def call(self, inputs, training=None, plot=False, **kwargs):
x = self.a(inputs)
x = self.b(x) if not plot else x
x = self.c(x)
return x
# a helper function to plot
def build_graph(self):
x = tf.keras.Input(type_spec=tf.RaggedTensorSpec(
shape=(8, None, 256, 256, 1),
dtype=tf.float32, ragged_rank=1)
)
return tf.keras.Model(inputs=[x],
outputs=self.call(x, plot=True))
x_model = RagModel()
data = next(iter(ds)); print(data.shape)
x_model(data).shape
(8, None, 256, 256, 1)
TensorShape([39, 1000])
tf.keras.utils.plot_model(x_model.build_graph(),
show_shapes=True, show_layer_names=True)
x_model.build_graph().summary()
Model: "model_1"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_4 (InputLayer) [(8, None, 256, 256, 1)] 0
_________________________________________________________________
lambda_2 (Lambda) (None, 256, 256, 1) 0
_________________________________________________________________
resnet50 (Functional) (None, 1000) 25630440
=================================================================
Total params: 25,630,440
Trainable params: 25,577,320
Non-trainable params: 53,120
_________________________________________________________________
Upvotes: 2