使用keras创建带有自定义过滤器的卷积网络

我正在尝试使用keras创建一个卷积网络,其中包括以下代码:

from keras.layers import Input, LSTM, concatenatefrom keras.models import Modelfrom keras.utils.vis_utils import model_to_dotfrom IPython.display import display, SVGinputs = Input(shape=(None, 4))filter_unit = LSTM(1)conv = concatenate([filter_unit(inputs[..., 0:2]),                    filter_unit(inputs[..., 2:4])])model = Model(inputs=inputs, outputs=conv)SVG(model_to_dot(model, show_shapes=True).create(prog='dot', format='svg'))

我尝试沿特征维度切片输入张量,以便将(人为缩小的)输入分成两部分,分别用于两个过滤单元。在这个例子中,过滤器是一个单一的LSTM单元。我希望能够用任意的模型替换LSTM单元。

然而,在model = ...这一行上出现了错误:

---------------------------------------------------------------------------AttributeError                            Traceback (most recent call last)<ipython-input-6-a9f7f2ffbe17> in <module>()      9 conv = concatenate([filter_unit(inputs[..., 0:2]),     10                     filter_unit(inputs[..., 2:4])])---> 11 model = Model(inputs=inputs, outputs=conv)     12 SVG(model_to_dot(model, show_shapes=True).create(prog='dot', format='svg'))~/.local/opt/anaconda3/envs/trafficprediction/lib/python3.6/site-packages/keras/legacy/interfaces.py in wrapper(*args, **kwargs)     86                 warnings.warn('Update your `' + object_name +     87                               '` call to the Keras 2 API: ' + signature, stacklevel=2)---> 88             return func(*args, **kwargs)     89         wrapper._legacy_support_signature = inspect.getargspec(func)     90         return wrapper~/.local/opt/anaconda3/envs/trafficprediction/lib/python3.6/site-packages/keras/engine/topology.py in __init__(self, inputs, outputs, name)   1703         nodes_in_progress = set()   1704         for x in self.outputs:-> 1705             build_map_of_graph(x, finished_nodes, nodes_in_progress)   1706    1707         for node in reversed(nodes_in_decreasing_depth):~/.local/opt/anaconda3/envs/trafficprediction/lib/python3.6/site-packages/keras/engine/topology.py in build_map_of_graph(tensor, finished_nodes, nodes_in_progress, layer, node_index, tensor_index)   1693                 tensor_index = node.tensor_indices[i]   1694                 build_map_of_graph(x, finished_nodes, nodes_in_progress,-> 1695                                    layer, node_index, tensor_index)   1696    1697             finished_nodes.add(node)~/.local/opt/anaconda3/envs/trafficprediction/lib/python3.6/site-packages/keras/engine/topology.py in build_map_of_graph(tensor, finished_nodes, nodes_in_progress, layer, node_index, tensor_index)   1693                 tensor_index = node.tensor_indices[i]   1694                 build_map_of_graph(x, finished_nodes, nodes_in_progress,-> 1695                                    layer, node_index, tensor_index)   1696    1697             finished_nodes.add(node)~/.local/opt/anaconda3/envs/trafficprediction/lib/python3.6/site-packages/keras/engine/topology.py in build_map_of_graph(tensor, finished_nodes, nodes_in_progress, layer, node_index, tensor_index)   1663             """   1664             if not layer or node_index is None or tensor_index is None:-> 1665                 layer, node_index, tensor_index = tensor._keras_history   1666             node = layer.inbound_nodes[node_index]   1667 AttributeError: 'Tensor' object has no attribute '_keras_history'

如果将LSTM替换为Dense,也会出现相同的问题。我不太明白这个错误信息的含义。我做错了什么?

关于相同错误有一个问题(链接如下),但我不清楚如何使用Lambda层,或者这是否是正确的解决方案。

AttributeError: ‘Tensor’ object has no attribute ‘_keras_history’


回答:

问题出在输入的切片方式上。LSTM层期望接收一个Layer对象作为输入,而您提供的是一个Tensor对象。您可以尝试添加一个lambda层(在这个例子中可能需要两个)来切片输入,以便为LSTM层提供输入。类似于:

y = Lambda(lambda x: x[:,0,:,:], output_shape=(1,) + input_shape[2:])(x)

然后这个y层将作为后续层的(切片后的)输入。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注