web-dev-qa-db-ja.com

データローダーからサンプルのファイル名を取得するにはどうすればよいですか?

トレーニングしたたたみ込みニューラルネットワークのデータテストの結果をファイルに書き込む必要があります。データには音声データの収集が含まれます。ファイル形式は「ファイル名、予測」である必要がありますが、ファイル名の抽出に苦労しています。私はこのようにデータをロードします:

_import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader

TEST_DATA_PATH = ...

trans = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

test_dataset = torchvision.datasets.MNIST(
    root=TEST_DATA_PATH,
    train=False,
    transform=trans,
    download=True
)

test_loader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False)
_

そして私は次のようにファイルに書き込もうとしています:

_f = open("test_y", "w")
with torch.no_grad():
    for i, (images, labels) in enumerate(test_loader, 0):
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        file = os.listdir(TEST_DATA_PATH + "/all")[i]
        format = file + ", " + str(predicted.item()) + '\n'
        f.write(format)
f.close()
_

os.listdir(TESTH_DATA_PATH + "/all")[i]の問題は、_test_loader_のロードされたファイルの順序と同期されないことです。私に何ができる?

5
Almog Levi

一般的に、 DataLoader は、内部にあるデータセットからバッチを提供するためのものです。

単一/マルチラベル分類の問題の場合に言及したAS @Barrielの場合、DataLoaderには画像ファイル名がなく、画像を表すテンソルとクラス/ラベルのみが含まれます。

ただし、オブジェクトをロードするときのDataLoaderコンストラクターは、データフレームを含め、小さなもの(データセットと共にターゲット/ラベルとファイル名をパックすることができます)をとることができます

このようにして、DataLoaderは何とかして必要なものを取得する場合があります。

1
prosti

まあ、それはあなたのDatasetの実装方法に依存します。たとえば、torchvision.datasets.MNIST(...)の場合、単一のサンプルのファイル名などがないため、ファイル名を取得できません(MNISTサンプルは 別の方法でロードされます )。

Dataset実装を示していないので、torchvision.datasets.ImageFolder(...)(または任意の torchvision.datasets.DatasetFolder(...) )を使用してこれを行う方法を説明します。

_f = open("test_y", "w")
with torch.no_grad():
    for i, (images, labels) in enumerate(test_loader, 0):
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        sample_fname, _ = test_loader.dataset.samples[i]
        f.write("{}, {}\n".format(sample_fname, predicted.item()))
f.close()
_

__getitem__(self, index) 、具体的には here の間にファイルのパスが取得されていることがわかります。

独自のDatasetを実装した場合(およびshuffleと_batch_size > 1_をサポートしたい場合)、__getitem__(...)で_sample_fname_を返します呼び出して、次のようにします。

_for i, (images, labels, sample_fname) in enumerate(test_loader, 0):
    # [...]
_

これにより、shuffleを気にする必要がなくなります。また、_batch_size_が1より大きい場合は、ループの内容をより一般的なものに変更する必要があります。例:

_f = open("test_y", "w")
for i, (images, labels, samples_fname) in enumerate(test_loader, 0):
    outputs = model(images)
    pred = torch.max(outputs, 1)[1]
    f.write("\n".join([
        ", ".join(x)
        for x in Zip(map(str, pred.cpu().tolist()), samples_fname)
    ]) + "\n")
f.close()
_
1
Berriel