如何使用PyTorch从本地目录导入MNIST数据集

我正在用PyTorch编写一个著名的MNIST手写数字数据库问题的代码。我从主网站下载了训练和测试数据集,包括带标签的数据集。数据集的格式是t10k-images-idx3-ubyte.gz,解压后是t10k-images-idx3-ubyte。我的数据集文件夹结构如下

MINST Data  train-images-idx3-ubyte.gz  train-labels-idx1-ubyte.gz  t10k-images-idx3-ubyte.gz  t10k-labels-idx1-ubyte.gz

现在,我编写了如下加载数据的代码

def load_dataset():    data_path = "/home/MNIST/Data/"    xy_trainPT = torchvision.datasets.ImageFolder(        root=data_path, transform=torchvision.transforms.ToTensor()    )    train_loader = torch.utils.data.DataLoader(        xy_trainPT, batch_size=64, num_workers=0, shuffle=True    )    return train_loader

我的代码显示支持的扩展名有:.jpg,.jpeg,.png,.ppm,.bmp,.pgm,.tif,.tiff,.webp

如何解决这个问题?我还想检查我的图像是否已经从数据集中加载(只需显示包含前5张图像的图形)?


回答:

阅读此文通过Python从.idx3-ubyte文件或GZIP中提取图像

更新

你可以使用以下格式导入数据

xy_trainPT = torchvision.datasets.MNIST(    root="~/Handwritten_Deep_L/",    train=True,    download=True,    transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor()]),)

现在,关于download=True,你的代码首先会检查根目录(你提供的路径)是否包含任何数据集。

如果没有,数据集将从网络下载。

如果,此路径已经包含一个数据集,那么你的代码将使用现有数据集运行,并且不会从互联网下载。

你可以测试一下,先提供一个不包含任何数据集的路径(数据将从网络下载),然后提供另一个已经包含数据集的路径,数据将不会被下载。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注