web-dev-qa-db-ja.com

Pytorch:画像ラベル

私は31のクラス(Officeデータセット)を持つ画像分類器に取り組んでいます。クラスごとに1つのフォルダーがあります。 PyTorchを使用して記述されたpythonスクリプトで、datasets.ImageFolderを使用してデータセットをロードし、各画像にラベルを割り当ててトレーニングします。データをロードするためのコードスニペットは次のとおりです。

from torchvision import datasets, transforms
import torch

def load_training(root_path, dir, batch_size, kwargs):
    transform = transforms.Compose(
        [transforms.Resize([256, 256]),
         transforms.RandomCrop(224),
         transforms.RandomHorizontalFlip(),
         transforms.ToTensor()])
    data = datasets.ImageFolder(root=root_path + dir, transform=transform)
    train_loader = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=True, drop_last=True, **kwargs)
    return train_loader

コードは各フォルダーを取得し、そのフォルダー内のすべての画像に同じラベルを割り当てます。どのラベルがどの画像/画像フォルダーに割り当てられているかを見つける方法はありますか?

6
tahsin314

クラスImageFolderには属性class_to_idxがあり、これはクラスの名前をインデックス(ラベル)にマッピングする辞書です。したがって、data.classesを使用してクラスにアクセスし、各クラスについてdata.class_to_idxを使用してラベルを取得できます。

参照: https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py

7
Jan