openpilot is an open source driver assistance system. openpilot performs the functions of Automated Lane Centering and Adaptive Cruise Control for over 200 supported car makes and models.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

41 lines
3.0 KiB

import torch
import numpy as np
from typing import Dict, Callable, Optional
from tinygrad.ops import BufferOps, UnaryOps, BinaryOps, MovementOps, TernaryOps, Op, Interpreted
from tinygrad.helpers import getenv, dtypes, prod, DType
from tinygrad.runtime.ops_cpu import base_fxn_for_op, einsum_mulacc
from tinygrad.runtime.lib import RawBuffer
device = torch.device("cuda:0" if torch.cuda.is_available() else ("mps" if getenv("MPS", 0) else "cpu"))
type_map = {torch.float64: dtypes.float64, torch.float16: dtypes.float16, torch.float32: dtypes.float32, torch.int8: dtypes.int8, torch.int32: dtypes.int32, torch.int64: dtypes.int64, torch.uint8: dtypes.uint8, torch.bool: dtypes.bool}
inverse_type_map = {v:k for k,v in type_map.items()}
def as_strided(x, arg):
if any(i < 0 for i in arg[1]):
return torch.as_strided(x.contiguous(), arg[0], tuple(abs(i) for i in arg[1]),
arg[2] + sum((s-1)*a if a < 0 else 0 for (s,a) in zip(arg[0], arg[1]))).flip([i for i,a in enumerate(arg[1]) if a < 0])
return torch.as_strided(x.contiguous(), arg[0], arg[1], arg[2])
torch_fxn_for_op: Dict[Op, Callable] = {**base_fxn_for_op, **{
# TODO: torch.tensor should work here
#BufferOps.CONST: lambda val, dtype: torch.tensor(val, dtype=inverse_type_map[dtype]),
BufferOps.CONST: lambda val, dtype: torch.from_numpy(np.array(val, dtype=dtype.np)),
UnaryOps.NOOP: lambda x: x.contiguous(), UnaryOps.SQRT: lambda x: x.sqrt(), UnaryOps.EXP2: lambda x: x.exp2(), UnaryOps.LOG2: lambda x: x.log2(), UnaryOps.SIN: torch.sin,
UnaryOps.CAST: lambda x,y: (x.view if y[1] else x.type)(next(k for k,v in type_map.items() if v==y[0])),
BinaryOps.MAX: torch.maximum, BinaryOps.CMPLT: lambda x,y: (x<y).type(torch.promote_types(x.dtype, y.dtype)), BinaryOps.SUB: lambda x,y: torch.logical_xor(x, y) if y.dtype is torch.bool else torch.sub(x, y),
MovementOps.PAD: lambda x, padding: torch.nn.functional.pad(x, [item for sublist in padding[::-1] for item in sublist]), # pylint: disable=E1102
TernaryOps.MULACC: einsum_mulacc(lambda s,a,b: torch.einsum(s, a.float(), b.float()).type(torch.promote_types(a.dtype, b.dtype)), lambda x: x.stride(), lambda x,s: x.expand(s)),
TernaryOps.WHERE: lambda x, y, z: torch.where(x != 0, y, z),
MovementOps.STRIDE: lambda x, arg: x[tuple(slice(None, None, abs(i)) for i in arg)].flip([i for i,a in enumerate(arg) if a < 0]),
MovementOps.EXPAND: lambda x, arg: x.expand(arg), MovementOps.PERMUTE: lambda x, arg: x.permute(arg),
MovementOps.AS_STRIDED: as_strided
}}
class RawTorchBuffer(RawBuffer):
def __init__(self, size:int, dtype:DType, buf:Optional[torch.Tensor]=None): super().__init__(size, dtype, buf if buf is not None else torch.empty([size], dtype=inverse_type_map[dtype]))
@classmethod
def fromCPU(cls, x):
buf = torch.from_numpy(x if all(s>=0 for s in x.strides) else x.copy()).requires_grad_(False).to(device)
return cls(prod(x.shape), type_map[buf.dtype], buf)
def toCPU(self): return self._buf.cpu().numpy()
TorchBuffer = Interpreted(RawTorchBuffer, torch_fxn_for_op, from_underlying=lambda x: RawTorchBuffer(prod(x.shape), type_map[x.dtype], x))