You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
78 lines
3.3 KiB
78 lines
3.3 KiB
from extra.datasets.kits19 import iterate, preprocess
|
|
from examples.mlperf.dataloader import batch_load_unet3d
|
|
from test.external.mlperf_unet3d.kits19 import PytTrain, PytVal
|
|
from tinygrad.helpers import temp
|
|
from pathlib import Path
|
|
|
|
import nibabel as nib
|
|
import numpy as np
|
|
import os
|
|
import random
|
|
import tempfile
|
|
import unittest
|
|
|
|
class ExternalTestDatasets(unittest.TestCase):
|
|
def _set_seed(self):
|
|
np.random.seed(42)
|
|
random.seed(42)
|
|
|
|
def _create_samples(self, val, num_samples=2):
|
|
self._set_seed()
|
|
|
|
img, lbl = np.random.rand(190, 392, 392).astype(np.float32), np.random.randint(0, 100, size=(190, 392, 392)).astype(np.uint8)
|
|
img, lbl = nib.Nifti1Image(img, np.eye(4)), nib.Nifti1Image(lbl, np.eye(4))
|
|
dataset = "val" if val else "train"
|
|
preproc_pth = Path(tempfile.gettempdir() + f"/{dataset}")
|
|
|
|
for i in range(num_samples):
|
|
os.makedirs(tempfile.gettempdir() + f"/case_000{i}", exist_ok=True)
|
|
nib.save(img, temp(f"case_000{i}/imaging.nii.gz"))
|
|
nib.save(lbl, temp(f"case_000{i}/segmentation.nii.gz"))
|
|
|
|
preproc_img, preproc_lbl = preprocess(Path(tempfile.gettempdir()) / f"case_000{i}")
|
|
preproc_img_pth, preproc_lbl_pth = temp(f"{dataset}/case_000{i}_x.npy"), temp(f"{dataset}/case_000{i}_y.npy")
|
|
|
|
os.makedirs(preproc_pth, exist_ok=True)
|
|
np.save(preproc_img_pth, preproc_img, allow_pickle=False)
|
|
np.save(preproc_lbl_pth, preproc_lbl, allow_pickle=False)
|
|
|
|
return preproc_pth, list(preproc_pth.glob("*_x.npy")), list(preproc_pth.glob("*_y.npy"))
|
|
|
|
def _create_kits19_ref_dataloader(self, preproc_img_pths, preproc_lbl_pths, val):
|
|
if val:
|
|
dataset = PytVal(preproc_img_pths, preproc_lbl_pths)
|
|
else:
|
|
dataset = PytTrain(preproc_img_pths, preproc_lbl_pths, patch_size=(128, 128, 128), oversampling=0.4)
|
|
|
|
return iter(dataset)
|
|
|
|
def _create_kits19_tinygrad_dataloader(self, preproc_pth, val, batch_size=1, shuffle=False, seed=42, use_old_dataloader=False):
|
|
if use_old_dataloader:
|
|
dataset = iterate(list(Path(tempfile.gettempdir()).glob("case_*")), preprocessed_dir=preproc_pth, val=val, shuffle=shuffle, bs=batch_size)
|
|
else:
|
|
dataset = iter(batch_load_unet3d(preproc_pth, batch_size=batch_size, val=val, shuffle=shuffle, seed=seed))
|
|
|
|
return iter(dataset)
|
|
|
|
def test_kits19_training_set(self):
|
|
preproc_pth, preproc_img_pths, preproc_lbl_pths = self._create_samples(False)
|
|
ref_dataset = self._create_kits19_ref_dataloader(preproc_img_pths, preproc_lbl_pths, False)
|
|
tinygrad_dataset = self._create_kits19_tinygrad_dataloader(preproc_pth, False)
|
|
|
|
for ref_sample, tinygrad_sample in zip(ref_dataset, tinygrad_dataset):
|
|
self._set_seed()
|
|
|
|
np.testing.assert_equal(tinygrad_sample[0][:, 0].numpy(), ref_sample[0])
|
|
np.testing.assert_equal(tinygrad_sample[1][:, 0].numpy(), ref_sample[1])
|
|
|
|
def test_kits19_validation_set(self):
|
|
preproc_pth, preproc_img_pths, preproc_lbl_pths = self._create_samples(True)
|
|
ref_dataset = self._create_kits19_ref_dataloader(preproc_img_pths, preproc_lbl_pths, True)
|
|
tinygrad_dataset = self._create_kits19_tinygrad_dataloader(preproc_pth, True, use_old_dataloader=True)
|
|
|
|
for ref_sample, tinygrad_sample in zip(ref_dataset, tinygrad_dataset):
|
|
np.testing.assert_equal(tinygrad_sample[0][:, 0], ref_sample[0])
|
|
np.testing.assert_equal(tinygrad_sample[1], ref_sample[1])
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|
|
|