我想在我案例研究中使用我在towardsdatascience上发现的一个有前景的神经网络。
我拥有数据的形状如下:
X_train:(1200,18,15)y_train:(1200,18,1)
这里的神经网络包含GRU、Flatten和Dense等层。
def twds_model(layer1=32, layer2=32, layer3=16, dropout_rate=0.5, optimizer='Adam' , learning_rate=0.001, activation='relu', loss='mse'): model = Sequential() model.add(Bidirectional(GRU(layer1, return_sequences=True),input_shape=(X_train.shape[1],X_train.shape[2]))) model.add(AveragePooling1D(2)) model.add(Conv1D(layer2, 3, activation=activation, padding='same', name='extractor')) model.add(Flatten()) model.add(Dense(layer3,activation=activation)) model.add(Dropout(dropout_rate)) model.add(Dense(1)) model.compile(optimizer=optimizer,loss=loss) return modeltwds_model=twds_model()print(twds_model.summary())
_________________________________________________________________Layer (type) Output Shape Param # =================================================================bidirectional_4 (Bidirection (None, 18, 64) 9216 _________________________________________________________________average_pooling1d_1 (Average (None, 9, 64) 0 _________________________________________________________________extractor (Conv1D) (None, 9, 32) 6176 _________________________________________________________________flatten_1 (Flatten) (None, 288) 0 _________________________________________________________________dense_3 (Dense) (None, 16) 4624 _________________________________________________________________dropout_4 (Dropout) (None, 16) 0 _________________________________________________________________dense_4 (Dense) (None, 1) 17 =================================================================Total params: 20,033Trainable params: 20,033Non-trainable params: 0_________________________________________________________________None
不幸的是,我遇到了一个矛盾的错误陷阱,输入和输出形状不匹配。这里是在上述情况下出现的错误。
InvalidArgumentError: Incompatible shapes: [144,1] vs. [144,18,1] [[{{node loss_2/dense_4_loss/sub}}]] [[{{node loss_2/mul}}]]
Train on 10420 samples, validate on 1697 samplesEpoch 1/8---------------------------------------------------------------------------InvalidArgumentError Traceback (most recent call last)<ipython-input-30-3f5256ff03ec> in <module>----> 1 Test_tdws=twds_model.fit(X_train, y_train, epochs=8, batch_size=144, verbose=2, validation_split=(0.14), shuffle=False) #callbacks=[tensorboard])~\Anaconda3\envs\Tensorflow\lib\site-packages\tensorflow\python\keras\engine\training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, max_queue_size, workers, use_multiprocessing, **kwargs) 878 initial_epoch=initial_epoch, 879 steps_per_epoch=steps_per_epoch,--> 880 validation_steps=validation_steps) 881 882 def evaluate(self,~\Anaconda3\envs\Tensorflow\lib\site-packages\tensorflow\python\keras\engine\training_arrays.py in model_iteration(model, inputs, targets, sample_weights, batch_size, epochs, verbose, callbacks, val_inputs, val_targets, val_sample_weights, shuffle, initial_epoch, steps_per_epoch, validation_steps, mode, validation_in_fit, **kwargs) 327 328 # Get outputs.--> 329 batch_outs = f(ins_batch) 330 if not isinstance(batch_outs, list): 331 batch_outs = [batch_outs]~\Anaconda3\envs\Tensorflow\lib\site-packages\tensorflow\python\keras\backend.py in __call__(self, inputs) 3074 3075 fetched = self._callable_fn(*array_vals,-> 3076 run_metadata=self.run_metadata) 3077 self._call_fetch_callbacks(fetched[-len(self._fetches):]) 3078 return nest.pack_sequence_as(self._outputs_structure,~\Anaconda3\envs\Tensorflow\lib\site-packages\tensorflow\python\client\session.py in __call__(self, *args, **kwargs) 1437 ret = tf_session.TF_SessionRunCallable( 1438 self._session._session, self._handle, args, status,-> 1439 run_metadata_ptr) 1440 if run_metadata: 1441 proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)~\Anaconda3\envs\Tensorflow\lib\site-packages\tensorflow\python\framework\errors_impl.py in __exit__(self, type_arg, value_arg, traceback_arg) 526 None, None, 527 compat.as_text(c_api.TF_Message(self.status.status)),--> 528 c_api.TF_GetCode(self.status.status)) 529 # Delete the underlying status object from memory otherwise it stays alive 530 # as there is a reference to status from this from the traceback due toInvalidArgumentError: Incompatible shapes: [144,1] vs. [144,18,1] [[{{node loss_2/dense_4_loss/sub}}]] [[{{node loss_2/mul}}]]
为了补充完整,这里是将y_train重塑为(1200*18,1)后出现的预期错误:
---------------------------------------------------------------------------ValueError Traceback (most recent call last)<ipython-input-47-2a6d0761b794> in <module>----> 1 Test_tdws=twds_model.fit(X_train, y_train_flat, epochs=8, batch_size=144, verbose=2, validation_split=(0.14), shuffle=False) #callbacks=[tensorboard])~\Anaconda3\envs\Tensorflow\lib\site-packages\tensorflow\python\keras\engine\training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, max_queue_size, workers, use_multiprocessing, **kwargs) 774 steps=steps_per_epoch, 775 validation_split=validation_split,--> 776 shuffle=shuffle) 777 778 # Prepare validation data.~\Anaconda3\envs\Tensorflow\lib\site-packages\tensorflow\python\keras\engine\training.py in _standardize_user_data(self, x, y, sample_weight, class_weight, batch_size, check_steps, steps_name, steps, validation_split, shuffle) 2434 # Check that all arrays have the same length. 2435 if not self._distribution_strategy:-> 2436 training_utils.check_array_lengths(x, y, sample_weights) 2437 if self._is_graph_network and not self.run_eagerly: 2438 # Additional checks to avoid users mistakenly using improper loss fns.~\Anaconda3\envs\Tensorflow\lib\site-packages\tensorflow\python\keras\engine\training_utils.py in check_array_lengths(inputs, targets, weights) 454 'the same number of samples as target arrays. ' 455 'Found ' + str(list(set_x)[0]) + ' input samples '--> 456 'and ' + str(list(set_y)[0]) + ' target samples.') 457 if len(set_w) > 1: 458 raise ValueError('All sample_weight arrays should have 'ValueError: Input arrays should have the same number of samples as target arrays. Found 12117 input samples and 218106 target samples
使用的版本如下:
Package Version---------------------- --------------------- nsorflow-gpu-ensorflow-gpu 1.13.1-rotobuf 3.11.3-umpy 1.18.1absl-py 0.9.0antlr4-python3-runtime 4.8asn1crypto 1.3.0astor 0.7.1astropy 3.2.1astunparse 1.6.3attrs 19.3.0audioread 2.1.8autopep8 1.5.3backcall 0.1.0beautifulsoup4 4.9.0bezier 0.8.0bkcharts 0.2bleach 3.1.4blis 0.2.4bokeh 1.1.0boto3 1.9.253botocore 1.12.253Bottleneck 1.3.2cachetools 4.1.0certifi 2020.4.5.1cffi 1.14.0chardet 3.0.4click 6.7cloudpickle 0.5.3cmdstanpy 0.4.0color 0.1colorama 0.4.3colorcet 0.9.1convertdate 2.2.1copulas 0.2.5cryptography 2.8ctgan 0.2.1cycler 0.10.0cymem 2.0.2Cython 0.29.17dash 0.26.0dash-core-components 0.27.2dash-html-components 0.11.0dash-renderer 0.13.2dask 0.18.1dataclasses 0.6datashader 0.7.0datashape 0.5.2datawig 0.1.10deap 1.3.0decorator 4.4.2defusedxml 0.6.0deltapy 0.1.1dill 0.2.9distributed 1.22.1docutils 0.14entrypoints 0.3ephem 3.7.7.1et-xmlfile 1.0.1exrex 0.10.5Faker 4.0.3fastai 1.0.60fastprogress 0.2.2fbprophet 0.6fire 0.3.1Flask 1.0.2Flask-Compress 1.4.0future 0.17.1gast 0.3.3geojson 2.4.1geomet 0.2.0.post2google-auth 1.14.0google-auth-oauthlib 0.4.1google-pasta 0.2.0gplearn 0.4.1graphviz 0.13.2grpcio 1.29.0h5py 2.10.0HeapDict 1.0.0holidays 0.10.2holoviews 1.12.1html2text 2018.1.9hyperas 0.4.1hyperopt 0.1.2idna 2.6imageio 2.5.0imbalanced-learn 0.3.3imblearn 0.0importlib-metadata 1.5.0impyute 0.0.8ipykernel 5.1.4ipython 7.13.0ipython-genutils 0.2.0ipywidgets 7.5.1itsdangerous 0.24jdcal 1.4jedi 0.16.0Jinja2 2.11.1jmespath 0.9.5joblib 0.13.2jsonschema 3.2.0jupyter 1.0.0jupyter-client 6.1.2jupyter-console 6.0.0jupyter-core 4.6.3Keras 2.2.5Keras-Applications 1.0.8Keras-Preprocessing 1.1.2keras-rectified-adam 0.17.0kiwisolver 1.2.0korean-lunar-calendar 0.2.1librosa 0.7.2llvmlite 0.32.1lml 0.0.1locket 0.2.0LunarCalendar 0.0.9Markdown 2.6.11MarkupSafe 1.1.1matplotlib 3.2.1missingpy 0.2.0mistune 0.8.4mkl-fft 1.0.15mkl-random 1.1.0mkl-service 2.3.0mock 4.0.2msgpack 0.5.6multipledispatch 0.6.0murmurhash 1.0.2mxnet 1.4.1nb-conda 2.2.1nb-conda-kernels 2.2.3nbconvert 5.6.1nbformat 5.0.4nbstripout 0.3.7networkx 2.1notebook 6.0.3numba 0.49.1numexpr 2.7.1numpy 1.19.0oauthlib 3.1.0olefile 0.46opencv-python 4.2.0.34openpyxl 2.5.5opt-einsum 3.2.1packaging 20.3pandas 1.0.3pandasvault 0.0.3pandocfilters 1.4.2param 1.9.0parso 0.6.2partd 0.3.8patsy 0.5.1pbr 5.1.3pickleshare 0.7.5Pillow 7.0.0pip 20.0.2plac 0.9.6plotly 4.7.1plotly-express 0.4.1preshed 2.0.1prometheus-client 0.7.1prompt-toolkit 3.0.4protobuf 3.11.3psutil 5.4.7py 1.8.0pyasn1 0.4.8pyasn1-modules 0.2.8pycodestyle 2.6.0pycparser 2.20pyct 0.4.5pyensae 1.3.839pyexcel 0.5.8pyexcel-io 0.5.7Pygments 2.6.1pykalman 0.9.5PyMeeus 0.3.7pymongo 3.8.0pyOpenSSL 19.1.0pyparsing 2.4.7pypi 2.1pyquickhelper 1.9.3418pyrsistent 0.16.0PySocks 1.7.1pystan 2.19.1.1python-dateutil 2.8.1pytz 2019.3pyviz-comms 0.7.2PyWavelets 0.5.2pywin32 227pywinpty 0.5.7PyYAML 5.3.1pyzmq 18.1.1qtconsole 4.4.4rdt 0.2.1RegscorePy 1.1requests 2.23.0requests-oauthlib 1.3.0resampy 0.2.2retrying 1.3.3rsa 4.0s3transfer 0.2.1scikit-image 0.15.0scikit-learn 0.23.2scipy 1.4.1sdv 0.3.2seaborn 0.9.0seasonal 0.3.1Send2Trash 1.5.0sentinelsat 0.12.2setuptools 46.3.0setuptools-git 1.2six 1.14.0sklearn 0.0sortedcontainers 2.0.4SoundFile 0.10.3.post1soupsieve 2.0spacy 2.1.8srsly 0.1.0statsmodels 0.9.0stopit 1.1.2sugartensor 1.0.0.2ta 0.5.25tb-nightly 1.14.0a20190603tblib 1.3.2tensorboard 1.13.1tensorboard-plugin-wit 1.6.0.post3tensorflow-estimator 1.13.0tensorflow-gpu 1.13.1termcolor 1.1.0terminado 0.8.3testpath 0.4.4text-unidecode 1.3texttable 1.4.0tf-estimator-nightly 1.14.0.dev2019060501Theano 1.0.4thinc 7.0.8threadpoolctl 2.1.0toml 0.10.1toolz 0.10.0torch 1.4.0torchvision 0.5.0tornado 6.0.4TPOT 0.10.2tqdm 4.45.0traitlets 4.3.3transforms3d 0.3.1tsaug 0.2.1typeguard 2.7.1typing 3.6.6update-checker 0.16urllib3 1.22utm 0.4.2wasabi 0.2.2wcwidth 0.1.9webencodings 0.5.1Werkzeug 1.0.1wheel 0.34.2widgetsnbextension 3.5.1win-inet-pton 1.1.0wincertstore 0.2wrapt 1.11.2xarray 0.10.8xlrd 1.1.0yahoo-historical 0.3.2zict 0.1.3zipp 2.2.0
非常感谢提前提供任何有助于代码运行的提示!
EDITEDITEDIT
在将tensorflow和keras更新到最新版本后,我收到了下面的错误。尽管完全删除并重新安装了tensorflow、CUDA 10.1和cudnn 8.0.2,但错误仍然存在。这个错误在我原来的代码和某人的示例代码中都出现了。
UnknownError: Fail to find the dnn implementation. [[{{node CudnnRNN}}]] [[sequential/bidirectional/forward_gru/PartitionedCall]] [Op:__inference_train_function_5731]Function call stack:train_function -> train_function -> train_function
NoneEpoch 1/4---------------------------------------------------------------------------UnknownError Traceback (most recent call last)<ipython-input-1-64eb8afffe02> in <module> 27 print(twds_model.summary()) 28 ---> 29 twds_model.fit(X_train, y_train, epochs=4)~\Anaconda3\envs\Tensorflow\lib\site-packages\tensorflow\python\keras\engine\training.py in _method_wrapper(self, *args, **kwargs) 106 def _method_wrapper(self, *args, **kwargs): 107 if not self._in_multi_worker_mode(): # pylint: disable=protected-access--> 108 return method(self, *args, **kwargs) 109 110 # Running inside `run_distribute_coordinator` already.~\Anaconda3\envs\Tensorflow\lib\site-packages\tensorflow\python\keras\engine\training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq, max_queue_size, workers, use_multiprocessing) 1096 batch_size=batch_size): 1097 callbacks.on_train_batch_begin(step)-> 1098 tmp_logs = train_function(iterator) 1099 if data_handler.should_sync: 1100 context.async_wait()~\Anaconda3\envs\Tensorflow\lib\site-packages\tensorflow\python\eager\def_function.py in __call__(self, *args, **kwds) 778 else: 779 compiler = "nonXla"--> 780 result = self._call(*args, **kwds) 781 782 new_tracing_count = self._get_tracing_count()~\Anaconda3\envs\Tensorflow\lib\site-packages\tensorflow\python\eager\def_function.py in _call(self, *args, **kwds) 838 # Lifting succeeded, so variables are initialized and we can run the 839 # stateless function.--> 840 return self._stateless_fn(*args, **kwds) 841 else: 842 canon_args, canon_kwds = \~\Anaconda3\envs\Tensorflow\lib\site-packages\tensorflow\python\eager\function.py in __call__(self, *args, **kwargs) 2827 with self._lock: 2828 graph_function, args, kwargs = self._maybe_define_function(args, kwargs)-> 2829 return graph_function._filtered_call(args, kwargs) # pylint: disable=protected-access 2830 2831 @property~\Anaconda3\envs\Tensorflow\lib\site-packages\tensorflow\python\eager\function.py in _filtered_call(self, args, kwargs, cancellation_manager) 1846 resource_variable_ops.BaseResourceVariable))], 1847 captured_inputs=self.captured_inputs,-> 1848 cancellation_manager=cancellation_manager) 1849 1850 def _call_flat(self, args, captured_inputs, cancellation_manager=None):~\Anaconda3\envs\Tensorflow\lib\site-packages\tensorflow\python\eager\function.py in _call_flat(self, args, captured_inputs, cancellation_manager) 1922 # No tape is watching; skip to running the function. 1923 return self._build_call_outputs(self._inference_function.call(-> 1924 ctx, args, cancellation_manager=cancellation_manager)) 1925 forward_backward = self._select_forward_and_backward_functions( 1926 args,~\Anaconda3\envs\Tensorflow\lib\site-packages\tensorflow\python\eager\function.py in call(self, ctx, args, cancellation_manager) 548 inputs=args, 549 attrs=attrs,--> 550 ctx=ctx) 551 else: 552 outputs = execute.execute_with_cancellation(~\Anaconda3\envs\Tensorflow\lib\site-packages\tensorflow\python\eager\execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name) 58 ctx.ensure_initialized() 59 tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,---> 60 inputs, attrs, num_outputs) 61 except core._NotOkStatusException as e: 62 if name is not None:UnknownError: Fail to find the dnn implementation. [[{{node CudnnRNN}}]] [[sequential/bidirectional/forward_gru/PartitionedCall]] [Op:__inference_train_function_5731]Function call stack:train_function -> train_function -> train_function
相应的版本列表如下:
Package Version------------------------ ---------------- nsorflow-gpu-ensorflow-gpu 2.3.0-rotobuf 3.11.3absl-py 0.9.0antlr4-python3-runtime 4.8asn1crypto 1.3.0astor 0.7.1astropy 3.2.1astunparse 1.6.3attrs 19.3.0audioread 2.1.8autopep8 1.5.3backcall 0.1.0beautifulsoup4 4.9.0bezier 0.8.0bkcharts 0.2bleach 3.1.4blis 0.2.4bokeh 1.1.0boto3 1.9.253botocore 1.12.253Bottleneck 1.3.2cachetools 4.1.0certifi 2020.4.5.1cffi 1.14.0chardet 3.0.4click 6.7cloudpickle 0.5.3cmdstanpy 0.4.0color 0.1colorama 0.4.3colorcet 0.9.1convertdate 2.2.1copulas 0.2.5cryptography 2.8ctgan 0.2.1cycler 0.10.0cymem 2.0.2Cython 0.29.17dash 0.26.0dash-core-components 0.27.2dash-html-components 0.11.0dash-renderer 0.13.2dask 0.18.1dataclasses 0.6datashader 0.7.0datashape 0.5.2datawig 0.1.10deap 1.3.0decorator 4.4.2defusedxml 0.6.0deltapy 0.1.1dill 0.2.9distributed 1.22.1docutils 0.14entrypoints 0.3ephem 3.7.7.1et-xmlfile 1.0.1exrex 0.10.5Faker 4.0.3fastai 1.0.60fastprogress 0.2.2fbprophet 0.6fire 0.3.1Flask 1.0.2Flask-Compress 1.4.0future 0.17.1gast 0.3.3geojson 2.4.1geomet 0.2.0.post2google-auth 1.14.0google-auth-oauthlib 0.4.1google-pasta 0.2.0gplearn 0.4.1graphviz 0.13.2grpcio 1.29.0h5py 2.10.0HeapDict 1.0.0holidays 0.10.2holoviews 1.12.1html2text 2018.1.9hyperas 0.4.1hyperopt 0.1.2idna 2.6imageio 2.5.0imbalanced-learn 0.3.3imblearn 0.0importlib-metadata 1.5.0impyute 0.0.8ipykernel 5.1.4ipython 7.13.0ipython-genutils 0.2.0ipywidgets 7.5.1itsdangerous 0.24jdcal 1.4jedi 0.16.0Jinja2 2.11.1jmespath 0.9.5joblib 0.13.2jsonschema 3.2.0jupyter 1.0.0jupyter-client 6.1.2jupyter-console 6.0.0jupyter-core 4.6.3Keras 2.4.3Keras-Applications 1.0.8Keras-Preprocessing 1.1.2keras-rectified-adam 0.17.0kiwisolver 1.2.0korean-lunar-calendar 0.2.1librosa 0.7.2llvmlite 0.32.1lml 0.0.1locket 0.2.0LunarCalendar 0.0.9Markdown 2.6.11MarkupSafe 1.1.1matplotlib 3.2.1missingpy 0.2.0mistune 0.8.4mkl-fft 1.0.15mkl-random 1.1.0mkl-service 2.3.0mock 4.0.2msgpack 0.5.6multipledispatch 0.6.0murmurhash 1.0.2mxnet 1.4.1nb-conda 2.2.1nb-conda-kernels 2.2.3nbconvert 5.6.1nbformat 5.0.4nbstripout 0.3.7networkx 2.1notebook 6.0.3numba 0.49.1numexpr 2.7.1numpy 1.18.5oauthlib 3.1.0olefile 0.46opencv-python 4.2.0.34openpyxl 2.5.5opt-einsum 3.2.1packaging 20.3pandas 1.0.3pandasvault 0.0.3pandocfilters 1.4.2param 1.9.0parso 0.6.2partd 0.3.8patsy 0.5.1pbr 5.1.3pickleshare 0.7.5Pillow 7.0.0pip 20.2.2plac 0.9.6plotly 4.7.1plotly-express 0.4.1preshed 2.0.1prometheus-client 0.7.1prompt-toolkit 3.0.4protobuf 3.11.3psutil 5.4.7py 1.8.0pyasn1 0.4.8pyasn1-modules 0.2.8pycodestyle 2.6.0pycparser 2.20pyct 0.4.5pyensae 1.3.839pyexcel 0.5.8pyexcel-io 0.5.7Pygments 2.6.1pykalman 0.9.5PyMeeus 0.3.7pymongo 3.8.0pyOpenSSL 19.1.0pyparsing 2.4.7pypi 2.1pyquickhelper 1.9.3418pyrsistent 0.16.0PySocks 1.7.1pystan 2.19.1.1python-dateutil 2.8.1pytz 2019.3pyviz-comms 0.7.2PyWavelets 0.5.2pywin32 227pywinpty 0.5.7PyYAML 5.3.1pyzmq 18.1.1qtconsole 4.4.4rdt 0.2.1RegscorePy 1.1requests 2.23.0requests-oauthlib 1.3.0resampy 0.2.2retrying 1.3.3rsa 4.0s3transfer 0.2.1scikit-image 0.15.0scikit-learn 0.23.2scipy 1.4.1sdv 0.3.2seaborn 0.9.0seasonal 0.3.1Send2Trash 1.5.0sentinelsat 0.12.2setuptools 46.3.0setuptools-git 1.2six 1.14.0sklearn 0.0sortedcontainers 2.0.4SoundFile 0.10.3.post1soupsieve 2.0spacy 2.1.8srsly 0.1.0statsmodels 0.9.0stopit 1.1.2sugartensor 1.0.0.2ta 0.5.25tb-nightly 1.14.0a20190603tblib 1.3.2tensorboard 2.3.0tensorboard-plugin-wit 1.7.0tensorflow-gpu 2.3.0tensorflow-gpu-estimator 2.3.0termcolor 1.1.0terminado 0.8.3testpath 0.4.4text-unidecode 1.3texttable 1.4.0Theano 1.0.4thinc 7.0.8threadpoolctl 2.1.0toml 0.10.1toolz 0.10.0torch 1.4.0torchvision 0.5.0tornado 6.0.4TPOT 0.10.2tqdm 4.45.0traitlets 4.3.3transforms3d 0.3.1tsaug 0.2.1typeguard 2.7.1typing 3.6.6update-checker 0.16urllib3 1.22utm 0.4.2wasabi 0.2.2wcwidth 0.1.9webencodings 0.5.1Werkzeug 1.0.1wheel 0.34.2widgetsnbextension 3.5.1win-inet-pton 1.1.0wincertstore 0.2wrapt 1.11.2xarray 0.10.8xlrd 1.1.0yahoo-historical 0.3.2zict 0.1.3zipp 2.2.0
回答:
好的,以下是我成功的方法:
Tensorflow 2.3.0Keras 2.4.2CUDA 10.1cuDNN 7.6.5
以及从这个github issue中获取的代码片段
import tensorflow as tfimport osos.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"os.environ["CUDA_VISIBLE_DEVICES"] = '0' # Set to -1 if CPU should be used CPU = -1 , GPU = 0gpus = tf.config.experimental.list_physical_devices('GPU')cpus = tf.config.experimental.list_physical_devices('CPU')if gpus: try: # Currently, memory growth needs to be the same across GPUs for gpu in gpus: tf.config.experimental.set_memory_growth(gpu, True) logical_gpus = tf.config.experimental.list_logical_devices('GPU') print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs") except RuntimeError as e: # Memory growth must be set before GPUs have been initialized print(e)elif cpus: try: # Currently, memory growth needs to be the same across GPUs logical_cpus= tf.config.experimental.list_logical_devices('CPU') print(len(cpus), "Physical CPU,", len(logical_cpus), "Logical CPU") except RuntimeError as e: # Memory growth must be set before GPUs have been initialized print(e)
非常感谢某人一直帮助我。如果你对此感到好奇,你也可以在这里简要浏览一下我的后续问题;-)。