You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
111 lines
3.2 KiB
111 lines
3.2 KiB
import ast, pathlib, unittest
|
|
|
|
import numpy as np
|
|
from PIL import Image
|
|
|
|
from tinygrad import Tensor
|
|
from tinygrad.helpers import getenv, CI
|
|
from extra.models.efficientnet import EfficientNet
|
|
from extra.models.vit import ViT
|
|
from extra.models.resnet import ResNet50
|
|
|
|
def _load_labels():
|
|
labels_filename = pathlib.Path(__file__).parent / 'efficientnet/imagenet1000_clsidx_to_labels.txt'
|
|
return ast.literal_eval(labels_filename.read_text())
|
|
|
|
_LABELS = _load_labels()
|
|
|
|
def preprocess(img, new=False):
|
|
# preprocess image
|
|
aspect_ratio = img.size[0] / img.size[1]
|
|
img = img.resize((int(224*max(aspect_ratio,1.0)), int(224*max(1.0/aspect_ratio,1.0))))
|
|
|
|
img = np.array(img)
|
|
y0, x0 =(np.asarray(img.shape)[:2] - 224) // 2
|
|
img = img[y0: y0 + 224, x0: x0 + 224]
|
|
|
|
# low level preprocess
|
|
if new:
|
|
img = img.astype(np.float32)
|
|
img -= [127.0, 127.0, 127.0]
|
|
img /= [128.0, 128.0, 128.0]
|
|
img = img[None]
|
|
else:
|
|
img = np.moveaxis(img, [2, 0, 1], [0, 1, 2])
|
|
img = img.astype(np.float32)[:3].reshape(1, 3, 224, 224)
|
|
img /= 255.0
|
|
img -= np.array([0.485, 0.456, 0.406]).reshape((1, -1, 1, 1))
|
|
img /= np.array([0.229, 0.224, 0.225]).reshape((1, -1, 1, 1))
|
|
return img
|
|
|
|
def _infer(model: EfficientNet, img):
|
|
with Tensor.train(False):
|
|
out = model.forward(Tensor(img)).argmax(axis=-1)
|
|
return out.tolist()
|
|
|
|
chicken_img = preprocess(Image.open(pathlib.Path(__file__).parent / 'efficientnet/Chicken.jpg'))
|
|
car_img = preprocess(Image.open(pathlib.Path(__file__).parent / 'efficientnet/car.jpg'))
|
|
|
|
class TestEfficientNet(unittest.TestCase):
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
cls.model = EfficientNet(number=getenv("NUM"))
|
|
cls.model.load_from_pretrained()
|
|
|
|
@classmethod
|
|
def tearDownClass(cls):
|
|
del cls.model
|
|
|
|
@unittest.skipIf(CI, "covered by test_chicken_car")
|
|
def test_chicken(self):
|
|
labels = _infer(self.model, chicken_img)
|
|
self.assertEqual(_LABELS[labels[0]], "hen")
|
|
|
|
@unittest.skipIf(CI, "covered by test_chicken_car")
|
|
def test_car(self):
|
|
labels = _infer(self.model, car_img)
|
|
self.assertEqual(_LABELS[labels[0]], "sports car, sport car")
|
|
|
|
def test_chicken_car(self):
|
|
labels = _infer(self.model, np.concat([chicken_img, car_img], axis=0))
|
|
self.assertEqual(_LABELS[labels[0]], "hen")
|
|
self.assertEqual(_LABELS[labels[1]], "sports car, sport car")
|
|
|
|
class TestViT(unittest.TestCase):
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
cls.model = ViT()
|
|
cls.model.load_from_pretrained()
|
|
|
|
@classmethod
|
|
def tearDownClass(cls):
|
|
del cls.model
|
|
|
|
def test_chicken(self):
|
|
labels = _infer(self.model, chicken_img)
|
|
self.assertEqual(_LABELS[labels[0]], "cock")
|
|
|
|
def test_car(self):
|
|
labels = _infer(self.model, car_img)
|
|
self.assertEqual(_LABELS[labels[0]], "racer, race car, racing car")
|
|
|
|
class TestResNet(unittest.TestCase):
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
cls.model = ResNet50()
|
|
cls.model.load_from_pretrained()
|
|
|
|
@classmethod
|
|
def tearDownClass(cls):
|
|
del cls.model
|
|
|
|
def test_chicken(self):
|
|
labels = _infer(self.model, chicken_img)
|
|
self.assertEqual(_LABELS[labels[0]], "hen")
|
|
|
|
def test_car(self):
|
|
labels = _infer(self.model, car_img)
|
|
self.assertEqual(_LABELS[labels[0]], "sports car, sport car")
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|
|
|