MessiSkillz
MessiSkillz

Reputation: 23

tf.math.segment_mean operation not working with XLA_GPU_JIT in google colab

I am using Google Colab pro with a T4 GPU for testing, my model is a GNN that is used to find relational data between chess pieces on a board. When I ran my code with a CPU it ran perfectly fine, but when I ran it on a GPU that is where i got the error.

This is the layer where the problem is coming from:

class ScatterAndAggregateLayer(tf.keras.layers.Layer):
    """
    OP: e -> v'
    IN: V_set, E_set, node_ids

    Scatter E-set entries into buckets defined by node_ids and aggregate accordingly. V-set is passed only for shape
    preservation.
    """
    def __init__(self, agg_method='sum', **kwargs):
        super(ScatterAndAggregateLayer, self).__init__(**kwargs)
        self.agg_method = agg_method
        self.agg_method_func = None
        self.exdtended_pad = None

    def build(self, input_shape):
        self.agg_method_func = {
            'sum': tf.math.segment_sum,
            'mean': tf.math.segment_mean,
            'max': tf.math.segment_max,
        }[self.agg_method]

        self.exdtended_pad = [[0, 0] for _ in range(max(0, len(input_shape[1]) - 3))]

        self.exdtended_pad = [[0, 0] for _ in range(max(0, len(input_shape[1]) - 3))]


    def call(self, inputs):
        with tf.device('/CPU:0'):
            V_set, E_set, node_ids = inputs

            residue_pad = (
                    tf.reduce_sum(tf.ones_like(V_set[0, ..., :1], dtype=tf.int32))
                    - tf.maximum(tf.reduce_max(node_ids[0]) + 1, 0)
            )
            return tf.pad(
                self.agg_method_func(E_set[0], node_ids[0]),
                paddings=[[0, residue_pad], [0, 0]] + self.exdtended_pad,
            )[tf.newaxis, ...]

I used a very simple model.fit function in order to train the model:

model.fit(dataset,
          steps_per_epoch = 40000,
          callbacks=[tensorboard_callback])

and I got this error message during training:

Traceback (most recent call last):
  File "/teamspace/studios/this_studio/hybridmodel/pytorch/combined_model.py", line 495, in <module>
    model.fit(dataset,
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py", line 122, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/tensorflow/python/eager/execute.py", line 53, in quick_execute
    tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
tensorflow.python.framework.errors_impl.InvalidArgumentError: Graph execution error:

Detected at node functional_12_1/scatter_and_aggregate_layer_1/SegmentMean defined at (most recent call last):
<stack traces unavailable>
Detected at node functional_12_1/scatter_and_aggregate_layer_1/SegmentMean defined at (most recent call last):
<stack traces unavailable>
Detected unsupported operations when trying to compile graph __inference_one_step_on_data_153843[] on XLA_GPU_JIT: SegmentMean (No registered 'SegmentMean' OpKernel for XLA_GPU_JIT devices compatible with node {{node functional_12_1/scatter_and_aggregate_layer_1/SegmentMean}}){{node functional_12_1/scatter_and_aggregate_layer_1/SegmentMean}}

It seems like tf.math.segment operations aren't supported on XLA_GPU_JIT, could someone please explain why this is and how I can fix it?

Upvotes: 0

Views: 38

Answers (0)

Related Questions