Wayne
Wayne

Reputation: 31

Ragged Tensors have no len() after conversion to Tensor

I am training a deep learning model on stacks of images with variable dimensions. (Shape = [Batch, None, 256, 256, 1]), where None can be variable.

I use tf.RaggedTensor.merge_dimsions(0,1) to convert the ragged Tensor to a shape of [None, 256, 256, 1] to run into a pretrained keras CNN model.

However, using the KerasLayer API results in the following error: TypeError: the object of type 'RaggedTensor' has no len()

When I apply .merge_dimsions outside of the KerasLayer and pass the tensors to the same pretrained model I do not get this error.

import tensorflow as tf

# Synthetic Data Pipeline
def synthetic_gen():
  varShape = tf.random.uniform((), minval=1, maxval=12, dtype=tf.int32)
  image = tf.random.normal((varShape, 256, 256, 1))
  image = tf.RaggedTensor.from_tensor(image, ragged_rank=1)
  yield image

ds = tf.data.Dataset.from_generator(synthetic_gen, output_signature=(tf.RaggedTensorSpec(shape=(None, 256, 256, 1), dtype=tf.float32, ragged_rank=1)))
ds = ds.repeat().batch(8)
print(next(iter(ds)).shape)

# Build Model
inputs = tf.keras.Input(
    type_spec=tf.RaggedTensorSpec(
        shape=(8, None, 256, 256, 1), 
        dtype=tf.float32, 
        ragged_rank=1))

ResNet50 = tf.keras.applications.ResNet50(
    include_top=True, 
    input_shape=(256, 256, 1),
    weights=None)

def merge(x):
  x = x.merge_dims(0, 1)
  return x
x = tf.keras.layers.Lambda(merge)(inputs)
merged_inputs = x
# x = ResNet50(x) # Uncommenting this will result in `model` producing an error when run for inference.

model = tf.keras.Model(inputs, x)

# Run inference
data = next(iter(ds))
model(data).shape # Will be an error if ResNet50 is used

Here is a colab notebook that demonstrates the problem. https://colab.research.google.com/drive/1kN78mf4_oNqxWOluV054NlqmakC5msli?usp=sharing

Upvotes: 3

Views: 2131

Answers (1)

Innat
Innat

Reputation: 17219

Not sure if the following answer or workaround is stable for complex network design. But here are some pointers. The reason you got

Ragged Tensors have no len()

is because of ResNet models, as it expects tensor and not ragged_tensor. I'm not sure however if the ResNet(weights=None) is able to take ragged_tensor or not directly. So, if we can convert the ragged data right before the ResNet gets fed, maybe it won't complain. Below is the full working code according to this. But please note, there is probably some efficient approach maybe possible.


Data

import tensorflow as tf

# Synthetic Data Pipeline
def synthetic_gen():
  varShape = tf.random.uniform((), minval=1, maxval=12, dtype=tf.int32)
  image = tf.random.normal((varShape, 256, 256, 1))
  image = tf.RaggedTensor.from_tensor(image, ragged_rank=1)
  yield image

ds = tf.data.Dataset.from_generator(synthetic_gen, 
                                    output_signature=(tf.RaggedTensorSpec(
                                        shape=(None, 256, 256, 1), 
                                        dtype=tf.float32, ragged_rank=1
                                        )
                                    )
                                )
ds = ds.repeat().batch(8)

Basic Model

# Build Model
inputs = tf.keras.Input(
    type_spec=tf.RaggedTensorSpec(
        shape=(8, None, 256, 256, 1), 
        dtype=tf.float32, 
        ragged_rank=1))

ResNet50 = tf.keras.applications.ResNet50(
    include_top=True, 
    input_shape=(256, 256, 1),
    weights=None)

def merge(x):
  x = x.merge_dims(0, 1)
  return x

Ragged Model

Here we convert ragged_tensor to tensor before passing the data to ResNet.

class RagModel(tf.keras.Model):
    def __init__(self):
        super(RagModel, self).__init__()
        # base models 
        self.a = tf.keras.layers.Lambda(merge)
        # convert: tensor = ragged_tensor.to_tensor()
        self.b = tf.keras.layers.Lambda(lambda x: x.to_tensor())
        self.c = ResNet50
    
    def call(self, inputs, training=None, plot=False, **kwargs):
        x = self.a(inputs)
        x = self.b(x) if not plot else x
        x = self.c(x)
        return x
    
    # a helper function to plot 
    def build_graph(self):
        x = tf.keras.Input(type_spec=tf.RaggedTensorSpec(
            shape=(8, None, 256, 256, 1),
            dtype=tf.float32, ragged_rank=1)
        )
        return tf.keras.Model(inputs=[x],
                              outputs=self.call(x, plot=True))
   
x_model = RagModel()

Run

data = next(iter(ds)); print(data.shape)
x_model(data).shape 
(8, None, 256, 256, 1)
TensorShape([39, 1000])

Plot

tf.keras.utils.plot_model(x_model.build_graph(), 
              show_shapes=True, show_layer_names=True)

enter image description here

x_model.build_graph().summary()

Model: "model_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_4 (InputLayer)         [(8, None, 256, 256, 1)]  0         
_________________________________________________________________
lambda_2 (Lambda)            (None, 256, 256, 1)       0         
_________________________________________________________________
resnet50 (Functional)        (None, 1000)              25630440  
=================================================================
Total params: 25,630,440
Trainable params: 25,577,320
Non-trainable params: 53,120
_________________________________________________________________

Upvotes: 2

Related Questions