Reputation: 23
I am using Google Colab pro with a T4 GPU for testing, my model is a GNN that is used to find relational data between chess pieces on a board. When I ran my code with a CPU it ran perfectly fine, but when I ran it on a GPU that is where i got the error.
This is the layer where the problem is coming from:
class ScatterAndAggregateLayer(tf.keras.layers.Layer):
"""
OP: e -> v'
IN: V_set, E_set, node_ids
Scatter E-set entries into buckets defined by node_ids and aggregate accordingly. V-set is passed only for shape
preservation.
"""
def __init__(self, agg_method='sum', **kwargs):
super(ScatterAndAggregateLayer, self).__init__(**kwargs)
self.agg_method = agg_method
self.agg_method_func = None
self.exdtended_pad = None
def build(self, input_shape):
self.agg_method_func = {
'sum': tf.math.segment_sum,
'mean': tf.math.segment_mean,
'max': tf.math.segment_max,
}[self.agg_method]
self.exdtended_pad = [[0, 0] for _ in range(max(0, len(input_shape[1]) - 3))]
self.exdtended_pad = [[0, 0] for _ in range(max(0, len(input_shape[1]) - 3))]
def call(self, inputs):
with tf.device('/CPU:0'):
V_set, E_set, node_ids = inputs
residue_pad = (
tf.reduce_sum(tf.ones_like(V_set[0, ..., :1], dtype=tf.int32))
- tf.maximum(tf.reduce_max(node_ids[0]) + 1, 0)
)
return tf.pad(
self.agg_method_func(E_set[0], node_ids[0]),
paddings=[[0, residue_pad], [0, 0]] + self.exdtended_pad,
)[tf.newaxis, ...]
I used a very simple model.fit function in order to train the model:
model.fit(dataset,
steps_per_epoch = 40000,
callbacks=[tensorboard_callback])
and I got this error message during training:
Traceback (most recent call last):
File "/teamspace/studios/this_studio/hybridmodel/pytorch/combined_model.py", line 495, in <module>
model.fit(dataset,
File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py", line 122, in error_handler
raise e.with_traceback(filtered_tb) from None
File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/tensorflow/python/eager/execute.py", line 53, in quick_execute
tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
tensorflow.python.framework.errors_impl.InvalidArgumentError: Graph execution error:
Detected at node functional_12_1/scatter_and_aggregate_layer_1/SegmentMean defined at (most recent call last):
<stack traces unavailable>
Detected at node functional_12_1/scatter_and_aggregate_layer_1/SegmentMean defined at (most recent call last):
<stack traces unavailable>
Detected unsupported operations when trying to compile graph __inference_one_step_on_data_153843[] on XLA_GPU_JIT: SegmentMean (No registered 'SegmentMean' OpKernel for XLA_GPU_JIT devices compatible with node {{node functional_12_1/scatter_and_aggregate_layer_1/SegmentMean}}){{node functional_12_1/scatter_and_aggregate_layer_1/SegmentMean}}
It seems like tf.math.segment
operations aren't supported on XLA_GPU_JIT
, could someone please explain why this is and how I can fix it?
Upvotes: 0
Views: 38