Reputation: 561
I am trying to implement an LSTM Model as a model_fn input to an Estimator. My X is only a .txt with a time series of prices. Before going into my first hidden layer, I try to define the lstm cell as:
def lstm_cell():
return tf.contrib.rnn.BasicLSTMCell(
size, forget_bias=0.0, state_is_tuple=True)
attn_cell = lstm_cell
if is_training and keep_prob < 1:
def attn_cell():
return tf.contrib.rnn.DropoutWrapper(
lstm_cell(), output_keep_prob=keep_prob)
cell = tf.contrib.rnn.MultiRNNCell([attn_cell() for _ in range(num_layers)], state_is_tuple=True)
initial_state = cell.zero_state(batch_size, data_type())
inputs = tf.unstack(X, num=num_steps, axis=0)
outputs = []
outputs, state = tf.nn.dynamic_rnn(cell, inputs,
initial_state=initial_state)
This then is supposed to go into:
first_hidden_layer = tf.contrib.layers.relu(outputs, 1000)
Unfortunately, it throws an error idicating that "ValueError: Dimension must be 1 but is 3 for 'transpose' (op: 'Transpose') with input shapes: [1], [3]." I gather that my problem is the "inputs" tensor. In its description, the inputs variable is supposed to be a tensor with form [batch_size,max_time,...], but Ihave no idea how to translate this into above structure since, through the estimator, only input values X and target values y are fed to the system. So my question would be how to create a tensor that can serve as an inputs variable to the dynamic_rnn class.
Thanks a lot.
Upvotes: 1
Views: 1402
Reputation: 1206
I believe you don't need the line:
inputs = tf.unstack(X, num=num_steps, axis=0)
you can supply X
directly to dynamic_rnn
since dynamic_rnn
doesn't take a list of tensors; It takes one tensor where the time axis is dimension 0 (if time_major == True
) or dimension 1 (if time_major == False
).
Actually, it seems that X
has 2 dimensions only, since inputs
is list of 1 dimensional tensors (as indicated by the error message). so you should replace the unstack
line with:
inputs = tf.expand_dims(X, axis=2)
This will add a 3rd dimension of size 1 that is needed by dynamic_rnn
Upvotes: 1