LSTM/GRU和Flatten使用时出现维度不兼容错误

我想在我案例研究中使用我在towardsdatascience上发现的一个有前景的神经网络。

我拥有数据的形状如下:

X_train:(1200,18,15)y_train:(1200,18,1)

这里的神经网络包含GRU、Flatten和Dense等层。

def twds_model(layer1=32, layer2=32, layer3=16, dropout_rate=0.5, optimizer='Adam'                    , learning_rate=0.001, activation='relu', loss='mse'):         model = Sequential()    model.add(Bidirectional(GRU(layer1, return_sequences=True),input_shape=(X_train.shape[1],X_train.shape[2])))    model.add(AveragePooling1D(2))    model.add(Conv1D(layer2, 3, activation=activation, padding='same',                name='extractor'))    model.add(Flatten())    model.add(Dense(layer3,activation=activation))    model.add(Dropout(dropout_rate))    model.add(Dense(1))    model.compile(optimizer=optimizer,loss=loss)    return modeltwds_model=twds_model()print(twds_model.summary())
_________________________________________________________________Layer (type)                 Output Shape              Param #   =================================================================bidirectional_4 (Bidirection (None, 18, 64)            9216      _________________________________________________________________average_pooling1d_1 (Average (None, 9, 64)             0         _________________________________________________________________extractor (Conv1D)           (None, 9, 32)             6176      _________________________________________________________________flatten_1 (Flatten)          (None, 288)               0         _________________________________________________________________dense_3 (Dense)              (None, 16)                4624      _________________________________________________________________dropout_4 (Dropout)          (None, 16)                0         _________________________________________________________________dense_4 (Dense)              (None, 1)                 17        =================================================================Total params: 20,033Trainable params: 20,033Non-trainable params: 0_________________________________________________________________None

不幸的是,我遇到了一个矛盾的错误陷阱,输入和输出形状不匹配。这里是在上述情况下出现的错误。

InvalidArgumentError: Incompatible shapes: [144,1] vs. [144,18,1]     [[{{node loss_2/dense_4_loss/sub}}]]     [[{{node loss_2/mul}}]]
Train on 10420 samples, validate on 1697 samplesEpoch 1/8---------------------------------------------------------------------------InvalidArgumentError                      Traceback (most recent call last)<ipython-input-30-3f5256ff03ec> in <module>----> 1 Test_tdws=twds_model.fit(X_train, y_train, epochs=8, batch_size=144, verbose=2, validation_split=(0.14), shuffle=False) #callbacks=[tensorboard])~\Anaconda3\envs\Tensorflow\lib\site-packages\tensorflow\python\keras\engine\training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, max_queue_size, workers, use_multiprocessing, **kwargs)    878           initial_epoch=initial_epoch,    879           steps_per_epoch=steps_per_epoch,--> 880           validation_steps=validation_steps)    881     882   def evaluate(self,~\Anaconda3\envs\Tensorflow\lib\site-packages\tensorflow\python\keras\engine\training_arrays.py in model_iteration(model, inputs, targets, sample_weights, batch_size, epochs, verbose, callbacks, val_inputs, val_targets, val_sample_weights, shuffle, initial_epoch, steps_per_epoch, validation_steps, mode, validation_in_fit, **kwargs)    327     328         # Get outputs.--> 329         batch_outs = f(ins_batch)    330         if not isinstance(batch_outs, list):    331           batch_outs = [batch_outs]~\Anaconda3\envs\Tensorflow\lib\site-packages\tensorflow\python\keras\backend.py in __call__(self, inputs)   3074    3075     fetched = self._callable_fn(*array_vals,-> 3076                                 run_metadata=self.run_metadata)   3077     self._call_fetch_callbacks(fetched[-len(self._fetches):])   3078     return nest.pack_sequence_as(self._outputs_structure,~\Anaconda3\envs\Tensorflow\lib\site-packages\tensorflow\python\client\session.py in __call__(self, *args, **kwargs)   1437           ret = tf_session.TF_SessionRunCallable(   1438               self._session._session, self._handle, args, status,-> 1439               run_metadata_ptr)   1440         if run_metadata:   1441           proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)~\Anaconda3\envs\Tensorflow\lib\site-packages\tensorflow\python\framework\errors_impl.py in __exit__(self, type_arg, value_arg, traceback_arg)    526             None, None,    527             compat.as_text(c_api.TF_Message(self.status.status)),--> 528             c_api.TF_GetCode(self.status.status))    529     # Delete the underlying status object from memory otherwise it stays alive    530     # as there is a reference to status from this from the traceback due toInvalidArgumentError: Incompatible shapes: [144,1] vs. [144,18,1]     [[{{node loss_2/dense_4_loss/sub}}]]     [[{{node loss_2/mul}}]]

为了补充完整,这里是将y_train重塑为(1200*18,1)后出现的预期错误:

---------------------------------------------------------------------------ValueError                                Traceback (most recent call last)<ipython-input-47-2a6d0761b794> in <module>----> 1 Test_tdws=twds_model.fit(X_train, y_train_flat, epochs=8, batch_size=144, verbose=2, validation_split=(0.14), shuffle=False) #callbacks=[tensorboard])~\Anaconda3\envs\Tensorflow\lib\site-packages\tensorflow\python\keras\engine\training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, max_queue_size, workers, use_multiprocessing, **kwargs)    774         steps=steps_per_epoch,    775         validation_split=validation_split,--> 776         shuffle=shuffle)    777     778     # Prepare validation data.~\Anaconda3\envs\Tensorflow\lib\site-packages\tensorflow\python\keras\engine\training.py in _standardize_user_data(self, x, y, sample_weight, class_weight, batch_size, check_steps, steps_name, steps, validation_split, shuffle)   2434       # Check that all arrays have the same length.   2435       if not self._distribution_strategy:-> 2436         training_utils.check_array_lengths(x, y, sample_weights)   2437         if self._is_graph_network and not self.run_eagerly:   2438           # Additional checks to avoid users mistakenly using improper loss fns.~\Anaconda3\envs\Tensorflow\lib\site-packages\tensorflow\python\keras\engine\training_utils.py in check_array_lengths(inputs, targets, weights)    454                      'the same number of samples as target arrays. '    455                      'Found ' + str(list(set_x)[0]) + ' input samples '--> 456                      'and ' + str(list(set_y)[0]) + ' target samples.')    457   if len(set_w) > 1:    458     raise ValueError('All sample_weight arrays should have 'ValueError: Input arrays should have the same number of samples as target arrays. Found 12117 input samples and 218106 target samples

使用的版本如下:

Package                Version---------------------- ---------------------                      nsorflow-gpu-ensorflow-gpu         1.13.1-rotobuf               3.11.3-umpy                  1.18.1absl-py                0.9.0antlr4-python3-runtime 4.8asn1crypto             1.3.0astor                  0.7.1astropy                3.2.1astunparse             1.6.3attrs                  19.3.0audioread              2.1.8autopep8               1.5.3backcall               0.1.0beautifulsoup4         4.9.0bezier                 0.8.0bkcharts               0.2bleach                 3.1.4blis                   0.2.4bokeh                  1.1.0boto3                  1.9.253botocore               1.12.253Bottleneck             1.3.2cachetools             4.1.0certifi                2020.4.5.1cffi                   1.14.0chardet                3.0.4click                  6.7cloudpickle            0.5.3cmdstanpy              0.4.0color                  0.1colorama               0.4.3colorcet               0.9.1convertdate            2.2.1copulas                0.2.5cryptography           2.8ctgan                  0.2.1cycler                 0.10.0cymem                  2.0.2Cython                 0.29.17dash                   0.26.0dash-core-components   0.27.2dash-html-components   0.11.0dash-renderer          0.13.2dask                   0.18.1dataclasses            0.6datashader             0.7.0datashape              0.5.2datawig                0.1.10deap                   1.3.0decorator              4.4.2defusedxml             0.6.0deltapy                0.1.1dill                   0.2.9distributed            1.22.1docutils               0.14entrypoints            0.3ephem                  3.7.7.1et-xmlfile             1.0.1exrex                  0.10.5Faker                  4.0.3fastai                 1.0.60fastprogress           0.2.2fbprophet              0.6fire                   0.3.1Flask                  1.0.2Flask-Compress         1.4.0future                 0.17.1gast                   0.3.3geojson                2.4.1geomet                 0.2.0.post2google-auth            1.14.0google-auth-oauthlib   0.4.1google-pasta           0.2.0gplearn                0.4.1graphviz               0.13.2grpcio                 1.29.0h5py                   2.10.0HeapDict               1.0.0holidays               0.10.2holoviews              1.12.1html2text              2018.1.9hyperas                0.4.1hyperopt               0.1.2idna                   2.6imageio                2.5.0imbalanced-learn       0.3.3imblearn               0.0importlib-metadata     1.5.0impyute                0.0.8ipykernel              5.1.4ipython                7.13.0ipython-genutils       0.2.0ipywidgets             7.5.1itsdangerous           0.24jdcal                  1.4jedi                   0.16.0Jinja2                 2.11.1jmespath               0.9.5joblib                 0.13.2jsonschema             3.2.0jupyter                1.0.0jupyter-client         6.1.2jupyter-console        6.0.0jupyter-core           4.6.3Keras                  2.2.5Keras-Applications     1.0.8Keras-Preprocessing    1.1.2keras-rectified-adam   0.17.0kiwisolver             1.2.0korean-lunar-calendar  0.2.1librosa                0.7.2llvmlite               0.32.1lml                    0.0.1locket                 0.2.0LunarCalendar          0.0.9Markdown               2.6.11MarkupSafe             1.1.1matplotlib             3.2.1missingpy              0.2.0mistune                0.8.4mkl-fft                1.0.15mkl-random             1.1.0mkl-service            2.3.0mock                   4.0.2msgpack                0.5.6multipledispatch       0.6.0murmurhash             1.0.2mxnet                  1.4.1nb-conda               2.2.1nb-conda-kernels       2.2.3nbconvert              5.6.1nbformat               5.0.4nbstripout             0.3.7networkx               2.1notebook               6.0.3numba                  0.49.1numexpr                2.7.1numpy                  1.19.0oauthlib               3.1.0olefile                0.46opencv-python          4.2.0.34openpyxl               2.5.5opt-einsum             3.2.1packaging              20.3pandas                 1.0.3pandasvault            0.0.3pandocfilters          1.4.2param                  1.9.0parso                  0.6.2partd                  0.3.8patsy                  0.5.1pbr                    5.1.3pickleshare            0.7.5Pillow                 7.0.0pip                    20.0.2plac                   0.9.6plotly                 4.7.1plotly-express         0.4.1preshed                2.0.1prometheus-client      0.7.1prompt-toolkit         3.0.4protobuf               3.11.3psutil                 5.4.7py                     1.8.0pyasn1                 0.4.8pyasn1-modules         0.2.8pycodestyle            2.6.0pycparser              2.20pyct                   0.4.5pyensae                1.3.839pyexcel                0.5.8pyexcel-io             0.5.7Pygments               2.6.1pykalman               0.9.5PyMeeus                0.3.7pymongo                3.8.0pyOpenSSL              19.1.0pyparsing              2.4.7pypi                   2.1pyquickhelper          1.9.3418pyrsistent             0.16.0PySocks                1.7.1pystan                 2.19.1.1python-dateutil        2.8.1pytz                   2019.3pyviz-comms            0.7.2PyWavelets             0.5.2pywin32                227pywinpty               0.5.7PyYAML                 5.3.1pyzmq                  18.1.1qtconsole              4.4.4rdt                    0.2.1RegscorePy             1.1requests               2.23.0requests-oauthlib      1.3.0resampy                0.2.2retrying               1.3.3rsa                    4.0s3transfer             0.2.1scikit-image           0.15.0scikit-learn           0.23.2scipy                  1.4.1sdv                    0.3.2seaborn                0.9.0seasonal               0.3.1Send2Trash             1.5.0sentinelsat            0.12.2setuptools             46.3.0setuptools-git         1.2six                    1.14.0sklearn                0.0sortedcontainers       2.0.4SoundFile              0.10.3.post1soupsieve              2.0spacy                  2.1.8srsly                  0.1.0statsmodels            0.9.0stopit                 1.1.2sugartensor            1.0.0.2ta                     0.5.25tb-nightly             1.14.0a20190603tblib                  1.3.2tensorboard            1.13.1tensorboard-plugin-wit 1.6.0.post3tensorflow-estimator   1.13.0tensorflow-gpu         1.13.1termcolor              1.1.0terminado              0.8.3testpath               0.4.4text-unidecode         1.3texttable              1.4.0tf-estimator-nightly   1.14.0.dev2019060501Theano                 1.0.4thinc                  7.0.8threadpoolctl          2.1.0toml                   0.10.1toolz                  0.10.0torch                  1.4.0torchvision            0.5.0tornado                6.0.4TPOT                   0.10.2tqdm                   4.45.0traitlets              4.3.3transforms3d           0.3.1tsaug                  0.2.1typeguard              2.7.1typing                 3.6.6update-checker         0.16urllib3                1.22utm                    0.4.2wasabi                 0.2.2wcwidth                0.1.9webencodings           0.5.1Werkzeug               1.0.1wheel                  0.34.2widgetsnbextension     3.5.1win-inet-pton          1.1.0wincertstore           0.2wrapt                  1.11.2xarray                 0.10.8xlrd                   1.1.0yahoo-historical       0.3.2zict                   0.1.3zipp                   2.2.0

非常感谢提前提供任何有助于代码运行的提示!

EDITEDITEDIT

在将tensorflow和keras更新到最新版本后,我收到了下面的错误。尽管完全删除并重新安装了tensorflow、CUDA 10.1和cudnn 8.0.2,但错误仍然存在。这个错误在我原来的代码和某人的示例代码中都出现了。

UnknownError:    Fail to find the dnn implementation.     [[{{node CudnnRNN}}]]     [[sequential/bidirectional/forward_gru/PartitionedCall]] [Op:__inference_train_function_5731]Function call stack:train_function -> train_function -> train_function
NoneEpoch 1/4---------------------------------------------------------------------------UnknownError                              Traceback (most recent call last)<ipython-input-1-64eb8afffe02> in <module>     27     print(twds_model.summary())     28 ---> 29     twds_model.fit(X_train, y_train, epochs=4)~\Anaconda3\envs\Tensorflow\lib\site-packages\tensorflow\python\keras\engine\training.py in _method_wrapper(self, *args, **kwargs)    106   def _method_wrapper(self, *args, **kwargs):    107     if not self._in_multi_worker_mode():  # pylint: disable=protected-access--> 108       return method(self, *args, **kwargs)    109     110     # Running inside `run_distribute_coordinator` already.~\Anaconda3\envs\Tensorflow\lib\site-packages\tensorflow\python\keras\engine\training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq, max_queue_size, workers, use_multiprocessing)   1096                 batch_size=batch_size):   1097               callbacks.on_train_batch_begin(step)-> 1098               tmp_logs = train_function(iterator)   1099               if data_handler.should_sync:   1100                 context.async_wait()~\Anaconda3\envs\Tensorflow\lib\site-packages\tensorflow\python\eager\def_function.py in __call__(self, *args, **kwds)    778       else:    779         compiler = "nonXla"--> 780         result = self._call(*args, **kwds)    781     782       new_tracing_count = self._get_tracing_count()~\Anaconda3\envs\Tensorflow\lib\site-packages\tensorflow\python\eager\def_function.py in _call(self, *args, **kwds)    838         # Lifting succeeded, so variables are initialized and we can run the    839         # stateless function.--> 840         return self._stateless_fn(*args, **kwds)    841     else:    842       canon_args, canon_kwds = \~\Anaconda3\envs\Tensorflow\lib\site-packages\tensorflow\python\eager\function.py in __call__(self, *args, **kwargs)   2827     with self._lock:   2828       graph_function, args, kwargs = self._maybe_define_function(args, kwargs)-> 2829     return graph_function._filtered_call(args, kwargs)  # pylint: disable=protected-access   2830    2831   @property~\Anaconda3\envs\Tensorflow\lib\site-packages\tensorflow\python\eager\function.py in _filtered_call(self, args, kwargs, cancellation_manager)   1846                            resource_variable_ops.BaseResourceVariable))],   1847         captured_inputs=self.captured_inputs,-> 1848         cancellation_manager=cancellation_manager)   1849    1850   def _call_flat(self, args, captured_inputs, cancellation_manager=None):~\Anaconda3\envs\Tensorflow\lib\site-packages\tensorflow\python\eager\function.py in _call_flat(self, args, captured_inputs, cancellation_manager)   1922       # No tape is watching; skip to running the function.   1923       return self._build_call_outputs(self._inference_function.call(-> 1924           ctx, args, cancellation_manager=cancellation_manager))   1925     forward_backward = self._select_forward_and_backward_functions(   1926         args,~\Anaconda3\envs\Tensorflow\lib\site-packages\tensorflow\python\eager\function.py in call(self, ctx, args, cancellation_manager)    548               inputs=args,    549               attrs=attrs,--> 550               ctx=ctx)    551         else:    552           outputs = execute.execute_with_cancellation(~\Anaconda3\envs\Tensorflow\lib\site-packages\tensorflow\python\eager\execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)     58     ctx.ensure_initialized()     59     tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,---> 60                                         inputs, attrs, num_outputs)     61   except core._NotOkStatusException as e:     62     if name is not None:UnknownError:    Fail to find the dnn implementation.     [[{{node CudnnRNN}}]]     [[sequential/bidirectional/forward_gru/PartitionedCall]] [Op:__inference_train_function_5731]Function call stack:train_function -> train_function -> train_function

相应的版本列表如下:

Package                  Version------------------------ ----------------                        nsorflow-gpu-ensorflow-gpu           2.3.0-rotobuf                 3.11.3absl-py                  0.9.0antlr4-python3-runtime   4.8asn1crypto               1.3.0astor                    0.7.1astropy                  3.2.1astunparse               1.6.3attrs                    19.3.0audioread                2.1.8autopep8                 1.5.3backcall                 0.1.0beautifulsoup4           4.9.0bezier                   0.8.0bkcharts                 0.2bleach                   3.1.4blis                     0.2.4bokeh                    1.1.0boto3                    1.9.253botocore                 1.12.253Bottleneck               1.3.2cachetools               4.1.0certifi                  2020.4.5.1cffi                     1.14.0chardet                  3.0.4click                    6.7cloudpickle              0.5.3cmdstanpy                0.4.0color                    0.1colorama                 0.4.3colorcet                 0.9.1convertdate              2.2.1copulas                  0.2.5cryptography             2.8ctgan                    0.2.1cycler                   0.10.0cymem                    2.0.2Cython                   0.29.17dash                     0.26.0dash-core-components     0.27.2dash-html-components     0.11.0dash-renderer            0.13.2dask                     0.18.1dataclasses              0.6datashader               0.7.0datashape                0.5.2datawig                  0.1.10deap                     1.3.0decorator                4.4.2defusedxml               0.6.0deltapy                  0.1.1dill                     0.2.9distributed              1.22.1docutils                 0.14entrypoints              0.3ephem                    3.7.7.1et-xmlfile               1.0.1exrex                    0.10.5Faker                    4.0.3fastai                   1.0.60fastprogress             0.2.2fbprophet                0.6fire                     0.3.1Flask                    1.0.2Flask-Compress           1.4.0future                   0.17.1gast                     0.3.3geojson                  2.4.1geomet                   0.2.0.post2google-auth              1.14.0google-auth-oauthlib     0.4.1google-pasta             0.2.0gplearn                  0.4.1graphviz                 0.13.2grpcio                   1.29.0h5py                     2.10.0HeapDict                 1.0.0holidays                 0.10.2holoviews                1.12.1html2text                2018.1.9hyperas                  0.4.1hyperopt                 0.1.2idna                     2.6imageio                  2.5.0imbalanced-learn         0.3.3imblearn                 0.0importlib-metadata       1.5.0impyute                  0.0.8ipykernel                5.1.4ipython                  7.13.0ipython-genutils         0.2.0ipywidgets               7.5.1itsdangerous             0.24jdcal                    1.4jedi                     0.16.0Jinja2                   2.11.1jmespath                 0.9.5joblib                   0.13.2jsonschema               3.2.0jupyter                  1.0.0jupyter-client           6.1.2jupyter-console          6.0.0jupyter-core             4.6.3Keras                    2.4.3Keras-Applications       1.0.8Keras-Preprocessing      1.1.2keras-rectified-adam     0.17.0kiwisolver               1.2.0korean-lunar-calendar    0.2.1librosa                  0.7.2llvmlite                 0.32.1lml                      0.0.1locket                   0.2.0LunarCalendar            0.0.9Markdown                 2.6.11MarkupSafe               1.1.1matplotlib               3.2.1missingpy                0.2.0mistune                  0.8.4mkl-fft                  1.0.15mkl-random               1.1.0mkl-service              2.3.0mock                     4.0.2msgpack                  0.5.6multipledispatch         0.6.0murmurhash               1.0.2mxnet                    1.4.1nb-conda                 2.2.1nb-conda-kernels         2.2.3nbconvert                5.6.1nbformat                 5.0.4nbstripout               0.3.7networkx                 2.1notebook                 6.0.3numba                    0.49.1numexpr                2.7.1numpy                    1.18.5oauthlib                 3.1.0olefile                  0.46opencv-python            4.2.0.34openpyxl                 2.5.5opt-einsum               3.2.1packaging                20.3pandas                   1.0.3pandasvault              0.0.3pandocfilters            1.4.2param                    1.9.0parso                    0.6.2partd                    0.3.8patsy                    0.5.1pbr                      5.1.3pickleshare              0.7.5Pillow                   7.0.0pip                      20.2.2plac                     0.9.6plotly                   4.7.1plotly-express           0.4.1preshed                  2.0.1prometheus-client        0.7.1prompt-toolkit           3.0.4protobuf                 3.11.3psutil                   5.4.7py                       1.8.0pyasn1                   0.4.8pyasn1-modules           0.2.8pycodestyle              2.6.0pycparser                2.20pyct                     0.4.5pyensae                  1.3.839pyexcel                  0.5.8pyexcel-io               0.5.7Pygments                 2.6.1pykalman                 0.9.5PyMeeus                  0.3.7pymongo                  3.8.0pyOpenSSL                19.1.0pyparsing                2.4.7pypi                     2.1pyquickhelper            1.9.3418pyrsistent               0.16.0PySocks                  1.7.1pystan                   2.19.1.1python-dateutil          2.8.1pytz                     2019.3pyviz-comms              0.7.2PyWavelets               0.5.2pywin32                  227pywinpty                 0.5.7PyYAML                   5.3.1pyzmq                    18.1.1qtconsole                4.4.4rdt                      0.2.1RegscorePy               1.1requests                 2.23.0requests-oauthlib        1.3.0resampy                  0.2.2retrying                 1.3.3rsa                      4.0s3transfer               0.2.1scikit-image             0.15.0scikit-learn             0.23.2scipy                    1.4.1sdv                      0.3.2seaborn                  0.9.0seasonal                 0.3.1Send2Trash               1.5.0sentinelsat              0.12.2setuptools               46.3.0setuptools-git           1.2six                      1.14.0sklearn                  0.0sortedcontainers         2.0.4SoundFile                0.10.3.post1soupsieve                2.0spacy                    2.1.8srsly                    0.1.0statsmodels              0.9.0stopit                   1.1.2sugartensor              1.0.0.2ta                       0.5.25tb-nightly               1.14.0a20190603tblib                    1.3.2tensorboard              2.3.0tensorboard-plugin-wit   1.7.0tensorflow-gpu           2.3.0tensorflow-gpu-estimator 2.3.0termcolor                1.1.0terminado                0.8.3testpath                 0.4.4text-unidecode           1.3texttable                1.4.0Theano                   1.0.4thinc                    7.0.8threadpoolctl            2.1.0toml                     0.10.1toolz                    0.10.0torch                    1.4.0torchvision              0.5.0tornado                  6.0.4TPOT                     0.10.2tqdm                     4.45.0traitlets                4.3.3transforms3d             0.3.1tsaug                    0.2.1typeguard                2.7.1typing                   3.6.6update-checker           0.16urllib3                  1.22utm                      0.4.2wasabi                   0.2.2wcwidth                  0.1.9webencodings             0.5.1Werkzeug                 1.0.1wheel                    0.34.2widgetsnbextension       3.5.1win-inet-pton            1.1.0wincertstore             0.2wrapt                    1.11.2xarray                   0.10.8xlrd                     1.1.0yahoo-historical         0.3.2zict                     0.1.3zipp                     2.2.0

回答:

好的,以下是我成功的方法:

Tensorflow 2.3.0Keras 2.4.2CUDA 10.1cuDNN 7.6.5

以及从这个github issue中获取的代码片段

import tensorflow as tfimport osos.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"os.environ["CUDA_VISIBLE_DEVICES"] = '0' # Set to -1 if CPU should be used CPU = -1 , GPU = 0gpus = tf.config.experimental.list_physical_devices('GPU')cpus = tf.config.experimental.list_physical_devices('CPU')if gpus:    try:        # Currently, memory growth needs to be the same across GPUs        for gpu in gpus:            tf.config.experimental.set_memory_growth(gpu, True)        logical_gpus = tf.config.experimental.list_logical_devices('GPU')        print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")    except RuntimeError as e:        # Memory growth must be set before GPUs have been initialized        print(e)elif cpus:    try:        # Currently, memory growth needs to be the same across GPUs        logical_cpus= tf.config.experimental.list_logical_devices('CPU')        print(len(cpus), "Physical CPU,", len(logical_cpus), "Logical CPU")    except RuntimeError as e:        # Memory growth must be set before GPUs have been initialized        print(e)

非常感谢某人一直帮助我。如果你对此感到好奇,你也可以在这里简要浏览一下我的后续问题;-)。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注