openpilot is an open source driver assistance system. openpilot performs the functions of Automated Lane Centering and Adaptive Cruise Control for over 200 supported car makes and models.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

124 lines
6.9 KiB

import os, json, pathlib, zipfile, pickle
from tqdm import tqdm
from typing import Dict, Union, List, Optional, Any, Tuple
from tinygrad.tensor import Tensor
from tinygrad.helpers import dtypes, prod, argsort, DEBUG, Timing, GlobalCounters, CI
from tinygrad.shape.view import strides_for_shape
from tinygrad.ops import Device
safe_dtypes = {"F16": dtypes.float16, "F32": dtypes.float32, "U8": dtypes.uint8, "I8": dtypes.int8, "I32": dtypes.int32, "I64": dtypes.int64}
inverse_safe_dtypes = {v:k for k,v in safe_dtypes.items()}
def safe_load_metadata(fn:Union[Tensor,str]) -> Tuple[Tensor, int, Any]:
t = fn if isinstance(fn, Tensor) else Tensor.empty(os.stat(fn).st_size, dtype=dtypes.uint8, device=f"disk:{fn}")
json_len = t[0:1].cast(dtypes.int64).numpy()[0]
return (t, json_len, json.loads(t[8:8+json_len].numpy().tobytes()))
def safe_load(fn:Union[Tensor,str]) -> Dict[str, Tensor]:
t, json_len, metadata = safe_load_metadata(fn)
return {k:t[8+json_len+v['data_offsets'][0]:].cast(safe_dtypes[v['dtype']])[:prod(v['shape'])].reshape(v['shape']) for k,v in metadata.items() if k != "__metadata__"}
def safe_save(tensors:Dict[str, Tensor], fn:str, metadata:Optional[Dict[str, Any]]=None):
headers, offset = {}, 0
if metadata: headers['__metadata__'] = metadata
for k,v in tensors.items():
headers[k] = {'dtype': inverse_safe_dtypes[v.dtype], 'shape': list(v.shape), 'data_offsets':[offset, offset+v.nbytes()]}
offset += v.nbytes()
j = json.dumps(headers, separators=(',', ':'))
j += "\x20"*((8-len(j)%8)%8)
pathlib.Path(fn).unlink(missing_ok=True)
t = Tensor.empty(8+len(j)+offset, dtype=dtypes.uint8, device=f"disk:{fn}")
t[0:1].cast(dtypes.int64).assign([len(j)])
t[8:8+len(j)].assign(Tensor(list(j.encode('utf-8')), dtype=dtypes.uint8, device="cpu"))
for k,v in safe_load(t).items(): v.assign(tensors[k])
# state dict
from collections import OrderedDict
def get_state_dict(obj, prefix:str='', tensor_type=Tensor) -> Dict[str, Tensor]:
if isinstance(obj, tensor_type): return {prefix.strip('.'):obj}
if hasattr(obj, '_asdict'): return get_state_dict(obj._asdict(), prefix, tensor_type) # namedtuple
if isinstance(obj, OrderedDict): return get_state_dict(dict(obj), prefix, tensor_type)
if hasattr(obj, '__dict__'): return get_state_dict(obj.__dict__, prefix, tensor_type)
state_dict = {}
if isinstance(obj, (list, tuple)):
for i,x in enumerate(obj): state_dict.update(get_state_dict(x, f"{prefix}{str(i)}.", tensor_type))
elif isinstance(obj, dict):
for k,v in obj.items(): state_dict.update(get_state_dict(v, f"{prefix}{str(k)}.", tensor_type))
return state_dict
def get_parameters(obj) -> List[Tensor]: return list(get_state_dict(obj).values())
def load_state_dict(model, state_dict, strict=True, verbose=True):
with Timing("loaded weights in ", lambda et_ns: f", {GlobalCounters.mem_used/1e9:.2f} GB loaded at {GlobalCounters.mem_used/et_ns:.2f} GB/s"):
model_state_dict = get_state_dict(model)
if DEBUG >= 1 and len(state_dict) > len(model_state_dict): print("WARNING: unused weights in state_dict", sorted(list(state_dict.keys() - model_state_dict.keys())))
for k,v in (t := tqdm(model_state_dict.items(), disable=CI or not verbose)):
t.set_description(f"ram used: {GlobalCounters.mem_used/1e9:5.2f} GB, {k:50s}")
if k not in state_dict and not strict:
if DEBUG >= 1: print(f"WARNING: not loading {k}")
continue
v.assign(state_dict[k].to(v.device)).realize()
# torch support!
def torch_load(fn:str):
t = Tensor.empty(os.stat(fn).st_size, dtype=dtypes.uint8, device=f"disk:{fn}")
offsets: Dict[str, int] = {}
lens: Dict[str, int] = {}
def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata=None):
#print(storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata)
lens[storage[2]] = storage[4] * storage[1].itemsize
if storage[2] not in offsets: return None
byte_offset = offsets[storage[2]]+storage_offset*storage[1].itemsize
ret = t[byte_offset:byte_offset+prod(size)].cast(storage[1])
# convert bfloat16 -> float16 using LLVM for Llama 2
# upstream LLaMA also does this conversion:
# https://github.com/facebookresearch/llama/blob/6c7fe276574e78057f917549435a2554000a876d/llama/generation.py#L95
# TODO: should this be done in the example instead? or maybe we don't need this anymore with better bfloat16 support
if storage[1] == dtypes.bfloat16:
ret = ret.bitcast(dtypes.uint16).to("CPU").cast(dtypes.uint32).mul(1<<16).bitcast(dtypes.float32).to(Device.DEFAULT).half()
#ret = ret.to("LLVM").half().to(Device.DEFAULT)
# 7 lines to deal with permuted tensors. NOTE: this currently requires reading off the disk
shape_strides = [(s, st) for s,st in zip(size, stride) if s != 1]
permute_indexes = [len(shape_strides)-1-y for y in argsort([x[1] for x in shape_strides])]
if tuple(permute_indexes) != tuple(range(len(permute_indexes))):
intermediate_shape = tuple([shape_strides[x][0] for x in argsort(permute_indexes)])
assert tuple([shape_strides[i][1] for i in argsort(permute_indexes)]) == strides_for_shape(intermediate_shape), "nonpermutable strides"
if DEBUG >= 2: print(f"WARNING: this torch load is slow. CPU to permute {intermediate_shape} with {permute_indexes}")
# TODO: find a nice way to support all shapetracker on disktensors
ret = ret.cpu().reshape(intermediate_shape).permute(permute_indexes)
return ret.reshape(size)
intercept = {"HalfStorage": dtypes.float16, "FloatStorage": dtypes.float32, "BFloat16Storage": dtypes.bfloat16, "IntStorage": dtypes.int32, "LongStorage": dtypes.int64, "_rebuild_tensor_v2": _rebuild_tensor_v2}
whitelist = {"torch", "collections", "numpy", "_codecs"} # NOTE: this is not for security, only speed
class Dummy: pass
class TorchPickle(pickle.Unpickler):
def find_class(self, module, name):
module_root = module.split(".")[0]
if module_root not in whitelist:
if DEBUG >= 2: print(f"WARNING: returning Dummy for {module} {name}")
return Dummy
return intercept[name] if module_root == "torch" else super().find_class(module, name)
def persistent_load(self, pid): return pid
if tuple(t[0:2].numpy()) == (0x50, 0x4b):
myzip = zipfile.ZipFile(fn, 'r')
base_name = myzip.namelist()[0].split('/', 1)[0]
for n in myzip.namelist():
if n.startswith(f'{base_name}/data/'):
with myzip.open(n) as myfile:
offsets[n.split("/")[-1]] = myfile._orig_compress_start # type: ignore
with myzip.open(f'{base_name}/data.pkl') as myfile:
return TorchPickle(myfile).load()
else:
with open(fn, "rb") as f:
pkl = TorchPickle(f)
_, _, _, rwd, _, ids, base_offset = pkl.load(), pkl.load(), pkl.load(), f.tell(), pkl.load(), pkl.load(), f.tell()
for i in ids:
offsets[i] = base_offset + 8
base_offset += 8 + lens[i]
f.seek(rwd)
return TorchPickle(f).load()