Thomas Simonini
Thomas Simonini

Reputation: 93

Incompatible shapes when output * actions_one_hot

I'm trying to implement a Deep Q Network that plays Doom (vizdoom)

However I'm stuck (since yesterday) with the problem of one hot encoding and its consequences: in fact, I have 3 possible actions that are encoded like that

[[True, False, False], [False, True, False], [False, False, True]] size = [Batch_size, 3]

When I one_hot encode this action array I obtain an array of this size [BatchSize, 3, 3]

As consequence when I want to calculate my Q-value estimation:

Q = tf.reduce_sum(tf.multiply(self.output, self.actions_one_hot), axis=1)

The tf.multiply(self.output, self.actions_one_hot) produces an error:

InvalidArgumentError: Incompatible shapes: [10,3] vs. [10,3,3] [[Node: DQNetwork/Mul = Mul[T=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:CPU:0"](DQNetwork/dense/BiasAdd, DQNetwork/one_hot)]]

I understand that these 2 have incompatible shapes to be multiplied but I don't understand what I must do to make them compatible.

To be more clear this is the notebook with each part explained:

I'm sure that I made a really stupid mistake but I don't see it.

Thanks for your help!

Upvotes: 0

Views: 218

Answers (1)

silgon
silgon

Reputation: 7191

You have to make the shapes compatible for tf.multiply because the function is an element-wise multiplication.

However, I think you're probably doing something wrong about the one_hot. Usually, a one_hot function will transform for example from a number to a one hot matrix. Let's say you have 3 possible actions in your action space which are (0,1,2), the one hot function will translate that to [[1,0,0],[0,1,0],[0,0,1]]. The problem is that you are sending the one_hot vectors to another one_hot function. If you send directly the actions, you would have the same shape for both tensors.

Long story short, you're doing using the one_hot function twice. If you already have a vector of type [True, False, False], you already have a one_hot.

Upvotes: 2

Related Questions