Tensorflow. 从BasicRNNCell切换到LSTMCell

我已经构建了一个使用BasicRNN的RNN,现在我想使用LSTMCell,但这个转换似乎并不简单。我应该更改什么?

首先,我定义了所有占位符和变量:

X_placeholder = tf.placeholder(tf.float32, [batch_size, truncated_backprop_length, embedding_size])Y_placeholder = tf.placeholder(tf.int32, [batch_size, truncated_backprop_length])init_state = tf.placeholder(tf.float32, [batch_size, state_size])W = tf.Variable(np.random.rand(state_size, num_classes),dtype=tf.float32)b = tf.Variable(np.zeros((batch_size, num_classes)), dtype=tf.float32)W2 = tf.Variable(np.random.rand(state_size, num_classes),dtype=tf.float32)b2 = tf.Variable(np.zeros((batch_size, num_classes)), dtype=tf.float32)

然后我拆分标签:

labels_series = tf.transpose(batchY_placeholder)labels_series = tf.unstack(batchY_placeholder, axis=1)inputs_series = X_placeholder

然后我定义我的RNN:

cell = tf.contrib.rnn.BasicLSTMCell(state_size, state_is_tuple = False)states_series, current_state = tf.nn.dynamic_rnn(cell, inputs_series, initial_state = init_state)

我得到的错误是:

InvalidArgumentError                      Traceback (most recent call last)/home/deepnlp2017/.local/lib/python3.5/site-packages/tensorflow/python/framework/common_shapes.py in _call_cpp_shape_fn_impl(op, input_tensors_needed, input_tensors_as_shapes_needed, debug_python_shape_fn, require_shape_fn)    669           node_def_str, input_shapes, input_tensors, input_tensors_as_shapes,--> 670           status)    671   except errors.InvalidArgumentError as err:/home/deepnlp2017/anaconda3/lib/python3.5/contextlib.py in __exit__(self, type, value, traceback)     65             try:---> 66                 next(self.gen)     67             except StopIteration:/home/deepnlp2017/.local/lib/python3.5/site-packages/tensorflow/python/framework/errors_impl.py in raise_exception_on_not_ok_status()    468           compat.as_text(pywrap_tensorflow.TF_Message(status)),--> 469           pywrap_tensorflow.TF_GetCode(status))    470   finally:InvalidArgumentError: Dimensions must be equal, but are 50 and 100 for 'rnn/while/basic_lstm_cell/mul' (op: 'Mul') with input shapes: [32,50], [32,100].During handling of the above exception, another exception occurred:ValueError                                Traceback (most recent call last)<ipython-input-19-2ac617f4dde4> in <module>()      4 #cell = tf.contrib.rnn.BasicRNNCell(state_size)      5 cell = tf.contrib.rnn.BasicLSTMCell(state_size, state_is_tuple = False)----> 6 states_series, current_state = tf.nn.dynamic_rnn(cell, inputs_series, initial_state = init_state)/home/deepnlp2017/.local/lib/python3.5/site-packages/tensorflow/python/ops/rnn.py in dynamic_rnn(cell, inputs, sequence_length, initial_state, dtype, parallel_iterations, swap_memory, time_major, scope)    543         swap_memory=swap_memory,    544         sequence_length=sequence_length,--> 545         dtype=dtype)    546     547     # Outputs of _dynamic_rnn_loop are always shaped [time, batch, depth]./home/deepnlp2017/.local/lib/python3.5/site-packages/tensorflow/python/ops/rnn.py in _dynamic_rnn_loop(cell, inputs, initial_state, parallel_iterations, swap_memory, sequence_length, dtype)    710       loop_vars=(time, output_ta, state),    711       parallel_iterations=parallel_iterations,--> 712       swap_memory=swap_memory)    713     714   # Unpack final output if not using output tuples./home/deepnlp2017/.local/lib/python3.5/site-packages/tensorflow/python/ops/control_flow_ops.py in while_loop(cond, body, loop_vars, shape_invariants, parallel_iterations, back_prop, swap_memory, name)   2624     context = WhileContext(parallel_iterations, back_prop, swap_memory, name)   2625     ops.add_to_collection(ops.GraphKeys.WHILE_CONTEXT, context)-> 2626     result = context.BuildLoop(cond, body, loop_vars, shape_invariants)   2627     return result   2628 /home/deepnlp2017/.local/lib/python3.5/site-packages/tensorflow/python/ops/control_flow_ops.py in BuildLoop(self, pred, body, loop_vars, shape_invariants)   2457       self.Enter()   2458       original_body_result, exit_vars = self._BuildLoop(-> 2459           pred, body, original_loop_vars, loop_vars, shape_invariants)   2460     finally:   2461       self.Exit()/home/deepnlp2017/.local/lib/python3.5/site-packages/tensorflow/python/ops/control_flow_ops.py in _BuildLoop(self, pred, body, original_loop_vars, loop_vars, shape_invariants)   2407         structure=original_loop_vars,   2408         flat_sequence=vars_for_body_with_tensor_arrays)-> 2409     body_result = body(*packed_vars_for_body)   2410     if not nest.is_sequence(body_result):   2411       body_result = [body_result]/home/deepnlp2017/.local/lib/python3.5/site-packages/tensorflow/python/ops/rnn.py in _time_step(time, output_ta_t, state)    695           skip_conditionals=True)    696     else:--> 697       (output, new_state) = call_cell()    698     699     # Pack state if using state tuples/home/deepnlp2017/.local/lib/python3.5/site-packages/tensorflow/python/ops/rnn.py in <lambda>()    681     682     input_t = nest.pack_sequence_as(structure=inputs, flat_sequence=input_t)--> 683     call_cell = lambda: cell(input_t, state)    684     685     if sequence_length is not None:/home/deepnlp2017/.local/lib/python3.5/site-packages/tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py in __call__(self, inputs, state, scope)    182       i, j, f, o = array_ops.split(value=concat, num_or_size_splits=4, axis=1)    183 --> 184       new_c = (c * sigmoid(f + self._forget_bias) + sigmoid(i) *    185                self._activation(j))    186       new_h = self._activation(new_c) * sigmoid(o)/home/deepnlp2017/.local/lib/python3.5/site-packages/tensorflow/python/ops/math_ops.py in binary_op_wrapper(x, y)    882       if not isinstance(y, sparse_tensor.SparseTensor):    883         y = ops.convert_to_tensor(y, dtype=x.dtype.base_dtype, name="y")--> 884       return func(x, y, name=name)    885     886   def binary_op_wrapper_sparse(sp_x, y):/home/deepnlp2017/.local/lib/python3.5/site-packages/tensorflow/python/ops/math_ops.py in _mul_dispatch(x, y, name)   1103   is_tensor_y = isinstance(y, ops.Tensor)   1104   if is_tensor_y:-> 1105     return gen_math_ops._mul(x, y, name=name)   1106   else:   1107     assert isinstance(y, sparse_tensor.SparseTensor)  # Case: Dense * Sparse./home/deepnlp2017/.local/lib/python3.5/site-packages/tensorflow/python/ops/gen_math_ops.py in _mul(x, y, name)   1623     A `Tensor`. Has the same type as `x`.   1624   """-> 1625   result = _op_def_lib.apply_op("Mul", x=x, y=y, name=name)   1626   return result   1627 /home/deepnlp2017/.local/lib/python3.5/site-packages/tensorflow/python/framework/op_def_library.py in apply_op(self, op_type_name, name, **keywords)    761         op = g.create_op(op_type_name, inputs, output_types, name=scope,    762                          input_types=input_types, attrs=attr_protos,--> 763                          op_def=op_def)    764         if output_structure:    765           outputs = op.outputs/home/deepnlp2017/.local/lib/python3.5/site-packages/tensorflow/python/framework/ops.py in create_op(self, op_type, inputs, dtypes, input_types, name, attrs, op_def, compute_shapes, compute_device)   2395                     original_op=self._default_original_op, op_def=op_def)   2396     if compute_shapes:-> 2397       set_shapes_for_outputs(ret)   2398     self._add_op(ret)   2399     self._record_op_seen_by_control_dependencies(ret)/home/deepnlp2017/.local/lib/python3.5/site-packages/tensorflow/python/framework/ops.py in set_shapes_for_outputs(op)   1755       shape_func = _call_cpp_shape_fn_and_require_op   1756 -> 1757   shapes = shape_func(op)   1758   if shapes is None:   1759     raise RuntimeError(/home/deepnlp2017/.local/lib/python3.5/site-packages/tensorflow/python/framework/ops.py in call_with_requiring(op)   1705    1706   def call_with_requiring(op):-> 1707     return call_cpp_shape_fn(op, require_shape_fn=True)   1708    1709   _call_cpp_shape_fn_and_require_op = call_with_requiring/home/deepnlp2017/.local/lib/python3.5/site-packages/tensorflow/python/framework/common_shapes.py in call_cpp_shape_fn(op, input_tensors_needed, input_tensors_as_shapes_needed, debug_python_shape_fn, require_shape_fn)    608     res = _call_cpp_shape_fn_impl(op, input_tensors_needed,    609                                   input_tensors_as_shapes_needed,--> 610                                   debug_python_shape_fn, require_shape_fn)    611     if not isinstance(res, dict):    612       # Handles the case where _call_cpp_shape_fn_impl calls unknown_shape(op)./home/deepnlp2017/.local/lib/python3.5/site-packages/tensorflow/python/framework/common_shapes.py in _call_cpp_shape_fn_impl(op, input_tensors_needed, input_tensors_as_shapes_needed, debug_python_shape_fn, require_shape_fn)    673       missing_shape_fn = True    674     else:--> 675       raise ValueError(err.message)    676     677   if missing_shape_fn:ValueError: Dimensions must be equal, but are 50 and 100 for 'rnn/while/basic_lstm_cell/mul' (op: 'Mul') with input shapes: [32,50], [32,100].

回答:

你应该提供错误跟踪信息。否则很难(或不可能)提供帮助。

我重现了这个问题,发现问题出在状态解包,即c, h = state这一行。

尝试将state_is_tuple设置为false,即

cell = tf.contrib.rnn.BasicLSTMCell(state_size, state_is_tuple=False)

我不确定为什么会发生这种情况。你是在加载之前的模型吗?你的TensorFlow版本是什么?


关于TensorFlow RNN单元的更多信息:

我建议你查看:WildML文章,部分标题为“RNN CELLS, WRAPPERS AND MULTI-LAYER RNNS”。

文中提到:

  • BasicRNNCell – 一个普通的RNN单元。
  • GRUCell – 一个门控循环单元(GRU)单元。
  • BasicLSTMCell – 基于循环神经网络正则化的LSTM单元。没有窥孔连接或单元裁剪。
  • LSTMCell – 一个更复杂的LSTM单元,允许可选的窥孔连接和单元裁剪。
  • MultiRNNCell – 一个包装器,用于将多个单元组合成多层单元。
  • DropoutWrapper – 一个包装器,用于向单元的输入和/或输出连接添加丢弃(dropout)。

鉴于此,我建议你从BasicRNNCell切换到BasicLSTMCell。这里的Basic意味着“除非你知道自己在做什么,否则就使用它”。如果你想尝试LSTM而不想深入细节,这就是你应该做的。这可能很简单,只需替换它,瞧!

如果不行,请分享一些你的代码和错误信息。

希望这对你有帮助

Related Posts

L1-L2正则化的不同系数

我想对网络的权重同时应用L1和L2正则化。然而,我找不…

使用scikit-learn的无监督方法将列表分类成不同组别,有没有办法?

我有一系列实例,每个实例都有一份列表,代表它所遵循的不…

f1_score metric in lightgbm

我想使用自定义指标f1_score来训练一个lgb模型…

通过相关系数矩阵进行特征选择

我在测试不同的算法时,如逻辑回归、高斯朴素贝叶斯、随机…

可以将机器学习库用于流式输入和输出吗?

已关闭。此问题需要更加聚焦。目前不接受回答。 想要改进…

在TensorFlow中,queue.dequeue_up_to()方法的用途是什么?

我对这个方法感到非常困惑,特别是当我发现这个令人费解的…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注