Reputation: 11
I'm getting a AssertionError for this fuction ... how do i resolve this
def one_hot_matrix(label, depth=6):
one_hot = tf.one_hot(label, depth, axis = 0)
one_hot = tf.reshape(one_hot, (-1,1))
return one_hot
def one_hot_matrix_test(target):
label = tf.constant(1)
depth = 4
result = target(label, depth)
print("Test 1:",result)
assert result.shape[0] == depth, "Use the parameter depth"
assert np.allclose(result, [0., 1. ,0., 0.] ), "Wrong output. Use tf.one_hot"
label_2 = [2]
result = target(label_2, depth)
print("Test 2:", result)
assert result.shape[0] == depth, "Use the parameter depth"
assert np.allclose(result, [0., 0. ,1., 0.] ), "Wrong output. Use tf.reshape as instructed"
print("\033[92mAll test passed")
Upvotes: 1
Views: 828
Reputation: 4073
Pay attention to the shape attribute, no value after the decimal point
one_hot = tf.reshape(tf.one_hot(label, depth, axis=0), shape=[depth,])
Upvotes: 0
Reputation: 1
This is the answer:
one_hot=tf.reshape(tf.one_hot(label, depth, axis = 0), [4,])
Upvotes: 0
Reputation: 21
Mostly you got an assertion error according to the shape of the result.
For that insure you use
one_hot = tf.reshape(one_hot, (depth,))
Upvotes: 2