私は31のクラス(Officeデータセット)を持つ画像分類器に取り組んでいます。クラスごとに1つのフォルダーがあります。 PyTorchを使用して記述されたpythonスクリプトで、datasets.ImageFolder
を使用してデータセットをロードし、各画像にラベルを割り当ててトレーニングします。データをロードするためのコードスニペットは次のとおりです。
from torchvision import datasets, transforms
import torch
def load_training(root_path, dir, batch_size, kwargs):
transform = transforms.Compose(
[transforms.Resize([256, 256]),
transforms.RandomCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor()])
data = datasets.ImageFolder(root=root_path + dir, transform=transform)
train_loader = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=True, drop_last=True, **kwargs)
return train_loader
コードは各フォルダーを取得し、そのフォルダー内のすべての画像に同じラベルを割り当てます。どのラベルがどの画像/画像フォルダーに割り当てられているかを見つける方法はありますか?
クラスImageFolderには属性class_to_idx
があり、これはクラスの名前をインデックス(ラベル)にマッピングする辞書です。したがって、data.classes
を使用してクラスにアクセスし、各クラスについてdata.class_to_idx
を使用してラベルを取得できます。
参照: https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py