Antoine23
Antoine23

Reputation: 79

Prefetch Dataset dtype format

Getting this error. Looks like this is due to the dtype which isn't the same on the train_step_signature and the dataset? Is there a way to change the format of the Prefetch dataset to tf.int64 from int32? Thanks


ValueError: Python inputs incompatible with input_signature:
  inputs: (
    tf.Tensor(
[[ 77 111 110 ...   0   0   0]
 [ 83 105 110 ...   0   0   0]
 [ 71  97 115 ...   0   0   0]
 ...
 [ 80 114 111 ...   0   0   0]
 [ 70 114  97 ...   0   0   0]
 [ 65 110 233 ...   0   0   0]], shape=(64, 605), dtype=int32),
    tf.Tensor(
[[ 68 101 115 ...   0   0   0]
 [ 76 101  32 ...   0   0   0]
 [ 76 101  32 ...   0   0   0]
 ...
 [ 68  97 110 ...   0   0   0]
 [ 85 110 101 ...   0   0   0]
 [ 68  97 110 ...   0   0   0]], shape=(64, 936), dtype=int32))
  input_signature: (
    TensorSpec(shape=(None, None), dtype=tf.int64, name=None),
    TensorSpec(shape=(None, None), dtype=tf.int64, name=None)).


Upvotes: 2

Views: 46

Answers (1)

AloneTogether
AloneTogether

Reputation: 26708

Maybe try casting:

dataset = dataset.map(lambda x,y (tf.cast(x, dtype=tf.int64), tf.cast(y, dtype=tf.int64)))

Upvotes: 1

Related Questions