openpilot is an open source driver assistance system. openpilot performs the functions of Automated Lane Centering and Adaptive Cruise Control for over 200 supported car makes and models.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

171 lines
6.6 KiB

import tinygrad.nn as nn
from tinygrad import Tensor, dtypes
from tinygrad.nn.state import torch_load
from tinygrad.helpers import fetch, get_child
# allow monkeypatching in layer implementations
BatchNorm = nn.BatchNorm2d
Conv2d = nn.Conv2d
Linear = nn.Linear
class BasicBlock:
expansion = 1
def __init__(self, in_planes, planes, stride=1, groups=1, base_width=64):
assert groups == 1 and base_width == 64, "BasicBlock only supports groups=1 and base_width=64"
self.conv1 = Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn1 = BatchNorm(planes)
self.conv2 = Conv2d(planes, planes, kernel_size=3, padding=1, stride=1, bias=False)
self.bn2 = BatchNorm(planes)
self.downsample = []
if stride != 1 or in_planes != self.expansion*planes:
self.downsample = [
Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
BatchNorm(self.expansion*planes)
]
def __call__(self, x):
out = self.bn1(self.conv1(x)).relu()
out = self.bn2(self.conv2(out))
out = out + x.sequential(self.downsample)
out = out.relu()
return out
class Bottleneck:
# NOTE: stride_in_1x1=False, this is the v1.5 variant
expansion = 4
def __init__(self, in_planes, planes, stride=1, stride_in_1x1=False, groups=1, base_width=64):
width = int(planes * (base_width / 64.0)) * groups
# NOTE: the original implementation places stride at the first convolution (self.conv1), control with stride_in_1x1
self.conv1 = Conv2d(in_planes, width, kernel_size=1, stride=stride if stride_in_1x1 else 1, bias=False)
self.bn1 = BatchNorm(width)
self.conv2 = Conv2d(width, width, kernel_size=3, padding=1, stride=1 if stride_in_1x1 else stride, groups=groups, bias=False)
self.bn2 = BatchNorm(width)
self.conv3 = Conv2d(width, self.expansion*planes, kernel_size=1, bias=False)
self.bn3 = BatchNorm(self.expansion*planes)
self.downsample = []
if stride != 1 or in_planes != self.expansion*planes:
self.downsample = [
Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
BatchNorm(self.expansion*planes)
]
def __call__(self, x):
out = self.bn1(self.conv1(x)).relu()
out = self.bn2(self.conv2(out)).relu()
out = self.bn3(self.conv3(out))
out = out + x.sequential(self.downsample)
out = out.relu()
return out
class ResNet:
def __init__(self, num, num_classes=None, groups=1, width_per_group=64, stride_in_1x1=False):
self.num = num
self.block = {
18: BasicBlock,
34: BasicBlock,
50: Bottleneck,
101: Bottleneck,
152: Bottleneck
}[num]
self.num_blocks = {
18: [2,2,2,2],
34: [3,4,6,3],
50: [3,4,6,3],
101: [3,4,23,3],
152: [3,8,36,3]
}[num]
self.in_planes = 64
self.groups = groups
self.base_width = width_per_group
self.conv1 = Conv2d(3, 64, kernel_size=7, stride=2, bias=False, padding=3)
self.bn1 = BatchNorm(64)
self.layer1 = self._make_layer(self.block, 64, self.num_blocks[0], stride=1, stride_in_1x1=stride_in_1x1)
self.layer2 = self._make_layer(self.block, 128, self.num_blocks[1], stride=2, stride_in_1x1=stride_in_1x1)
self.layer3 = self._make_layer(self.block, 256, self.num_blocks[2], stride=2, stride_in_1x1=stride_in_1x1)
self.layer4 = self._make_layer(self.block, 512, self.num_blocks[3], stride=2, stride_in_1x1=stride_in_1x1)
self.fc = Linear(512 * self.block.expansion, num_classes) if num_classes is not None else None
def _make_layer(self, block, planes, num_blocks, stride, stride_in_1x1):
strides = [stride] + [1] * (num_blocks-1)
layers = []
for stride in strides:
if block == Bottleneck:
layers.append(block(self.in_planes, planes, stride, stride_in_1x1, self.groups, self.base_width))
else:
layers.append(block(self.in_planes, planes, stride, self.groups, self.base_width))
self.in_planes = planes * block.expansion
return layers
def forward(self, x):
is_feature_only = self.fc is None
if is_feature_only: features = []
out = self.bn1(self.conv1(x)).relu()
out = out.pad([1,1,1,1]).max_pool2d((3,3), 2)
out = out.sequential(self.layer1)
if is_feature_only: features.append(out)
out = out.sequential(self.layer2)
if is_feature_only: features.append(out)
out = out.sequential(self.layer3)
if is_feature_only: features.append(out)
out = out.sequential(self.layer4)
if is_feature_only: features.append(out)
if not is_feature_only:
out = out.mean([2,3])
out = self.fc(out.cast(dtypes.float32))
return out
return features
def __call__(self, x:Tensor) -> Tensor:
return self.forward(x)
def load_from_pretrained(self):
model_urls = {
(18, 1, 64): 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
(34, 1, 64): 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
(50, 1, 64): 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
(50, 32, 4): 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
(101, 1, 64): 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
(152, 1, 64): 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
}
self.url = model_urls[(self.num, self.groups, self.base_width)]
for k, dat in torch_load(fetch(self.url)).items():
try:
obj: Tensor = get_child(self, k)
except AttributeError as e:
if 'fc.' in k and self.fc is None:
continue
raise e
if 'fc.' in k and obj.shape != dat.shape:
print("skipping fully connected layer")
continue # Skip FC if transfer learning
if 'bn' not in k and 'downsample' not in k: assert obj.shape == dat.shape, (k, obj.shape, dat.shape)
obj.assign(dat.to(obj.device).reshape(obj.shape))
ResNet18 = lambda num_classes=1000: ResNet(18, num_classes=num_classes)
ResNet34 = lambda num_classes=1000: ResNet(34, num_classes=num_classes)
ResNet50 = lambda num_classes=1000: ResNet(50, num_classes=num_classes)
ResNet101 = lambda num_classes=1000: ResNet(101, num_classes=num_classes)
ResNet152 = lambda num_classes=1000: ResNet(152, num_classes=num_classes)
ResNeXt50_32X4D = lambda num_classes=1000: ResNet(50, num_classes=num_classes, groups=32, width_per_group=4)
if __name__ == "__main__":
model = ResNet18()
model.load_from_pretrained()
from tinygrad import Context, GlobalCounters, TinyJit
jmodel = TinyJit(model)
jmodel(Tensor.rand(1, 3, 224, 224)).realize()
GlobalCounters.reset()
jmodel(Tensor.rand(1, 3, 224, 224)).realize()
for i in range(10): jmodel(Tensor.rand(1, 3, 224, 224))