openpilot is an open source driver assistance system. openpilot performs the functions of Automated Lane Centering and Adaptive Cruise Control for over 200 supported car makes and models.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

338 lines
16 KiB

import json, pathlib, zipfile, pickle, tarfile, struct, functools, io
from collections import OrderedDict
from typing import Union, Optional, Any, Callable, BinaryIO, Iterable
from tinygrad.tensor import Tensor
from tinygrad.dtype import dtypes
from tinygrad.helpers import prod, argsort, DEBUG, Timing, CI, unwrap, GlobalCounters, tqdm, round_up, T
from tinygrad.shape.view import strides_for_shape
class TensorIO(io.RawIOBase, BinaryIO):
def __init__(self, t: Tensor):
if t.ndim != 1 or t.dtype != dtypes.uint8: raise ValueError("Tensor must be 1d and of dtype uint8!")
self._position, self._tensor = 0, t
def readable(self) -> bool: return True
def read(self, size: int = -1) -> bytes:
if (buf:=super().read(size)) is None: raise ValueError("io.RawIOBase.read returned None") # only happens if readinto returns None (never)
return buf
def readinto(self, buffer: Any) -> int:
data = self._tensor[self._position:self._position+len(buffer)].data()
buffer[:len(data)] = data
self._position += len(data)
return len(data)
def seekable(self) -> bool: return True
def seek(self, offset: int, whence: int = 0) -> int:
self._position = min(len(self._tensor), max(0, [offset, self._position+offset, len(self._tensor)+offset][whence]))
return self._position
# required to correctly implement BinaryIO
def __enter__(self): return self
def write(self, s: Any): raise io.UnsupportedOperation("TensorIO.write not supported")
def writelines(self, lines: Iterable[Any]): raise io.UnsupportedOperation("TensorIO.writelines not supported")
safe_dtypes = {"BOOL":dtypes.bool, "I8":dtypes.int8, "U8":dtypes.uint8, "I16":dtypes.int16, "U16":dtypes.uint16, "I32":dtypes.int, "U32":dtypes.uint,
"I64":dtypes.int64, "U64":dtypes.uint64, "F16":dtypes.float16, "BF16":dtypes.bfloat16, "F32":dtypes.float32, "F64":dtypes.float64}
inverse_safe_dtypes = {v:k for k,v in safe_dtypes.items()}
def accept_filename(func: Callable[[Tensor], T]) -> Callable[[Union[Tensor, str, pathlib.Path]], T]:
@functools.wraps(func)
def wrapper(fn: Union[Tensor, str, pathlib.Path]) -> T: return func(Tensor(pathlib.Path(fn)) if not isinstance(fn, Tensor) else fn)
return wrapper
@accept_filename
def safe_load_metadata(t:Tensor) -> tuple[Tensor, int, dict[str, Any]]:
"""
Loads a .safetensor file, returning the source tensor, data start position, and metadata.
"""
data_start = int.from_bytes(t[0:8].data(), "little") + 8
return t, data_start, json.loads(t[8:data_start].data().tobytes())
def safe_load(fn:Union[Tensor, str, pathlib.Path]) -> dict[str, Tensor]:
"""
Loads a .safetensor file, returning the `state_dict`.
```python
state_dict = nn.state.safe_load("test.safetensor")
```
"""
t, data_start, metadata = safe_load_metadata(fn)
data = t[data_start:]
return { k: data[v['data_offsets'][0]:v['data_offsets'][1]].bitcast(safe_dtypes[v['dtype']]).reshape(v['shape'])
for k, v in metadata.items() if k != "__metadata__" }
def safe_save(tensors:dict[str, Tensor], fn:str, metadata:Optional[dict[str, Any]]=None):
"""
Saves a `state_dict` to disk in a .safetensor file with optional metadata.
```python
t = Tensor([1, 2, 3])
nn.state.safe_save({'t':t}, "test.safetensor")
```
"""
headers, offset = {}, 0
if metadata: headers['__metadata__'] = metadata
for k,v in tensors.items():
headers[k] = {'dtype': inverse_safe_dtypes[v.dtype], 'shape': list(v.shape), 'data_offsets':[offset, offset+v.nbytes()]}
offset += v.nbytes()
j = json.dumps(headers, separators=(',', ':'))
j += "\x20"*(round_up(len(j),8)-len(j))
pathlib.Path(fn).unlink(missing_ok=True)
t = Tensor.empty(8+len(j)+offset, dtype=dtypes.uint8, device=f"disk:{fn}")
t[0:8].bitcast(dtypes.int64).assign([len(j)])
t[8:8+len(j)].assign(list(j.encode('utf-8')))
for k,v in safe_load(t).items(): v.assign(tensors[k])
# state dict
def get_state_dict(obj, prefix:str='', tensor_type=Tensor) -> dict[str, Tensor]:
"""
Returns a `state_dict` of the object, with optional prefix.
```python exec="true" source="above" session="tensor" result="python"
class Net:
def __init__(self):
self.l1 = nn.Linear(4, 5)
self.l2 = nn.Linear(5, 6)
net = Net()
print(nn.state.get_state_dict(net).keys())
```
"""
if isinstance(obj, tensor_type): return {prefix.strip('.'):obj}
if hasattr(obj, '_asdict'): return get_state_dict(obj._asdict(), prefix, tensor_type) # namedtuple
if isinstance(obj, OrderedDict): return get_state_dict(dict(obj), prefix, tensor_type)
if hasattr(obj, '__dict__'): return get_state_dict(obj.__dict__, prefix, tensor_type)
state_dict = {}
if isinstance(obj, (list, tuple)):
for i,x in enumerate(obj): state_dict.update(get_state_dict(x, f"{prefix}{str(i)}.", tensor_type))
elif isinstance(obj, dict):
for k,v in obj.items(): state_dict.update(get_state_dict(v, f"{prefix}{str(k)}.", tensor_type))
return state_dict
def get_parameters(obj) -> list[Tensor]:
"""
```python exec="true" source="above" session="tensor" result="python"
class Net:
def __init__(self):
self.l1 = nn.Linear(4, 5)
self.l2 = nn.Linear(5, 6)
net = Net()
print(len(nn.state.get_parameters(net)))
```
"""
return list(get_state_dict(obj).values())
def load_state_dict(model, state_dict:dict[str, Tensor], strict=True, verbose=True, consume=False, realize=True) -> None:
"""
Loads a `state_dict` into a model.
```python
class Net:
def __init__(self):
self.l1 = nn.Linear(4, 5)
self.l2 = nn.Linear(5, 6)
net = Net()
state_dict = nn.state.get_state_dict(net)
nn.state.load_state_dict(net, state_dict)
```
"""
start_mem_used = GlobalCounters.mem_used
with Timing("loaded weights in ",
lambda et_ns: f", {(B:=(GlobalCounters.mem_used-start_mem_used))/1e9:.2f} GB loaded at {B/et_ns:.2f} GB/s", enabled=verbose):
model_state_dict = get_state_dict(model)
if DEBUG >= 1 and len(state_dict) > len(model_state_dict):
print("WARNING: unused weights in state_dict", sorted(list(state_dict.keys() - model_state_dict.keys())))
for k,v in (t := tqdm(model_state_dict.items(), disable=CI or not verbose)):
t.desc = f"ram used: {GlobalCounters.mem_used/1e9:5.2f} GB, {k:50s}: "
if k not in state_dict and not strict:
if DEBUG >= 1: print(f"WARNING: not loading {k}")
continue
if v.shape != state_dict[k].shape:
raise ValueError(f'Shape mismatch in layer `{k}`: Expected shape {v.shape}, but found {state_dict[k].shape} in state dict.')
if isinstance(v.device, tuple):
if isinstance(state_dict[k].device, tuple): v.replace(state_dict[k])
else: v.replace(state_dict[k].shard(v.device, v.lazydata.axis))
else: v.replace(state_dict[k].to(v.device))
if realize: v.realize()
if consume: del state_dict[k]
@accept_filename
def tar_extract(t: Tensor) -> dict[str, Tensor]:
"""
```python
tar_extract(fn: Tensor | str | Path) -> dict[str, Tensor]
```
Extracts files from a tar archive and returns them as a dictionary of names (keys) and tensors (values).
```python
tensors = nn.state.tar_extract(Tensor(pathlib.Path("archive.tar")))
```
"""
with tarfile.open(fileobj=TensorIO(t), mode="r") as tar:
return {member.name:t[member.offset_data:member.offset_data+member.size] for member in tar if member.type == tarfile.REGTYPE}
# torch support!
@accept_filename
def torch_load(t:Tensor) -> dict[str, Tensor]:
"""
```python
torch_load(fn: Tensor | str | Path) -> dict[str, Tensor]
```
Loads a torch .pth file, returning the `state_dict`.
```python
state_dict = nn.state.torch_load("test.pth")
```
"""
offsets: dict[Union[str, int], int] = {}
lens: dict[Union[str, int], int] = {}
def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad=None, backward_hooks=None, metadata=None):
#print(storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata)
lens[storage[2]] = storage[4] * storage[1].itemsize
if storage[2] not in offsets: return None
byte_offset = offsets[storage[2]]+storage_offset*storage[1].itemsize
ret = t[byte_offset:byte_offset+prod(size)*storage[1].itemsize].bitcast(storage[1])
# 7 lines to deal with permuted tensors. NOTE: this currently requires reading off the disk
shape_strides = [(s, st) for s,st in zip(size, stride) if s != 1]
permute_indexes = [len(shape_strides)-1-y for y in argsort([x[1] for x in shape_strides])]
if tuple(permute_indexes) != tuple(range(len(permute_indexes))):
intermediate_shape = tuple([shape_strides[x][0] for x in argsort(permute_indexes)])
assert tuple([shape_strides[i][1] for i in argsort(permute_indexes)]) == strides_for_shape(intermediate_shape), "nonpermutable strides"
if DEBUG >= 3: print(f"WARNING: this torch load is slow. to permute {intermediate_shape} with {permute_indexes}")
assert storage[1] != dtypes.bfloat16, "can't permute BF16"
# TODO: find a nice way to support all shapetracker on disktensors
ret = ret.to(None).reshape(intermediate_shape).permute(permute_indexes)
return ret.reshape(size)
class Parameter:
def __setstate__(self, state): self.tensor = state[0]
deserialized_objects: dict[str, Any] = {}
intercept = {"HalfStorage": dtypes.float16, "FloatStorage": dtypes.float32, "BFloat16Storage": dtypes.bfloat16,
"IntStorage": dtypes.int32, "BoolStorage": dtypes.bool,
"LongStorage": dtypes.int64, "_rebuild_tensor_v2": _rebuild_tensor_v2, "FloatTensor": None, "Parameter": Parameter}
whitelist = {"torch", "collections", "numpy", "_codecs"} # NOTE: this is not for security, only speed
class Dummy: pass
class TorchPickle(pickle.Unpickler):
def find_class(self, module, name):
module_root = module.split(".")[0]
if module_root not in whitelist:
if DEBUG >= 2: print(f"WARNING: returning Dummy for {module} {name}")
return Dummy
return intercept[name] if module_root == "torch" else super().find_class(module, name)
def persistent_load(self, pid): return deserialized_objects.get(pid, pid)
fobj = io.BufferedReader(TensorIO(t))
def passthrough_reset(v: bool): return fobj.seek(0, 0) or v
if passthrough_reset(zipfile.is_zipfile(fobj)): # NOTE: passthrough_reset required to support python < 3.14
myzip = zipfile.ZipFile(fobj, 'r')
base_name = myzip.namelist()[0].split('/', 1)[0]
for n in myzip.namelist():
if n.startswith(f'{base_name}/data/'):
with myzip.open(n) as myfile:
offsets[n.split("/")[-1]] = myfile._orig_compress_start # type: ignore
with myzip.open(f'{base_name}/data.pkl') as myfile:
return TorchPickle(myfile).load()
elif passthrough_reset(tarfile.is_tarfile(fobj)): # NOTE: passthrough_reset required to support python < 3.11
with tarfile.open(fileobj=fobj, mode="r") as tar:
storages_offset = tar.getmember('storages').offset_data
f = unwrap(tar.extractfile('storages'))
for i in range(TorchPickle(f).load()): # num_storages
(key, _, storage_type), sz = TorchPickle(f).load(), struct.unpack('<q', f.read(8))[0]
offsets[key] = storages_offset + f.tell()
f.seek(sz*storage_type.itemsize, 1)
f = unwrap(tar.extractfile('tensors'))
for _ in range(TorchPickle(f).load()): # num_tensors
(key, storage_id, _), ndim, _ = TorchPickle(f).load(), struct.unpack('<i', f.read(4))[0], f.read(4)
size, stride = struct.unpack(f'<{ndim}q', f.read(8 * ndim)), struct.unpack(f'<{ndim}q', f.read(8 * ndim))
storage_offset = struct.unpack('<q', f.read(8))[0]
deserialized_objects[str(key)] = _rebuild_tensor_v2((None, storage_type, storage_id, None, -1), storage_offset, size, stride)
return {k:v.tensor if isinstance(v, Parameter) else v for k,v in TorchPickle(unwrap(tar.extractfile('pickle'))).load().items()}
else:
pkl = TorchPickle(fobj)
_, _, _, rwd, _, ids, base_offset = pkl.load(), pkl.load(), pkl.load(), fobj.tell(), pkl.load(), pkl.load(), fobj.tell()
for i in ids:
offsets[i] = base_offset + 8
base_offset += 8 + lens[i]
fobj.seek(rwd)
return TorchPickle(fobj).load()
def ggml_data_to_tensor(t: Tensor, n: int, ggml_type: int) -> Tensor:
"""
Converts ggml tensor data to a tinygrad tensor.
Supported native types: float32 (id: 0), float16 (id: 1), int8 (id: 16), int16 (id: 17), int32 (id: 18)
Supported quantized types: Q4_0 (id: 2), Q4_1 (id: 3), Q8_0 (id: 8), Q6_K (id: 14)
"""
# https://github.com/ggerganov/ggml/blob/6dccc647264f5429df2624f36138f601e7ce23e5/include/ggml.h#L356
# native types
if (dtype := { 0: dtypes.float32, 1: dtypes.float16, 16: dtypes.int8, 17: dtypes.int16, 18: dtypes.int32 }.get(ggml_type)) is not None:
return t[:dtype.itemsize * n].bitcast(dtype)
def q_to_uint8(t: Tensor, b: int) -> Tensor:
# TODO: rewrite with arange?
shift_tensor, bitmask = Tensor.stack(*[ Tensor(2**(i*b), device=t.device, dtype=t.dtype) for i in range(8//b) ]), 0xff >> (8 - b)
return t.unsqueeze(-1).expand((*t.shape,8//b)).idiv(shift_tensor).bitwise_and(bitmask).transpose(-1, -2).flatten(-2)
# map to (number of elements, number of bytes)
if (nelements_nbytes := { 2: (32, 18), 3: (32, 20), 14: (256, 210), 8: (32, 34) }.get(ggml_type)) is not None:
blocks = t[:(n//nelements_nbytes[0])*nelements_nbytes[1]].reshape((-1, nelements_nbytes[1]))
if ggml_type == 2: return (q_to_uint8(blocks[:,2:], 4).bitcast(dtypes.int8) - 8) * blocks[:,:2].bitcast(dtypes.float16).cast(dtypes.float32)
if ggml_type == 3:
d, m = (blocks[:,s:s+2].bitcast(dtypes.float16).cast(dtypes.float32) for s in [ 0, 2 ])
return q_to_uint8(blocks[:,4:], 4).bitcast(dtypes.int8) * d + m
if ggml_type == 8: return blocks[:,:2].bitcast(dtypes.float16).cast(dtypes.float32) * blocks[:,2:].bitcast(dtypes.int8)
if ggml_type == 14:
xl, xh = q_to_uint8(blocks[:,:128].reshape((-1, 2, 64)), 4), q_to_uint8(blocks[:,128:192].reshape((-1, 2, 32)), 2).lshift(4)
scales = blocks[:,192:208].bitcast(dtypes.int8).unsqueeze(-1).expand((-1, 16, 16)).reshape((-1, 256))
d = blocks[:,-2:].bitcast(dtypes.float16).cast(dtypes.float32).expand((-1, 256))
return d * (xl.bitwise_or(xh).bitcast(dtypes.int8) - 32).flatten(-2) * scales
raise ValueError(f"GGML type '{ggml_type}' is not supported!")
@accept_filename
def gguf_load(tensor: Tensor) -> tuple[dict, dict[str, Tensor]]:
"""
Loads a .gguf file, returning the `kv_data` and `state_dict`.
```python
gguf_tensor = Tensor(pathlib.Path("Meta-Llama-3-8B-Instruct.Q4_0.gguf")).to(Device.DEFAULT)
kv_data, state_dict = nn.state.gguf_load(gguf_tensor)
```
NOTE: The provided tensor must be on a device that supports execution.
"""
reader, kv_data, state_dict = io.BufferedReader(TensorIO(tensor), 1_000_000), {}, {}
def read_unpack(fmt: str, n: int): return struct.unpack(fmt, reader.read(n))[0]
def read_str(): return str(reader.read(read_uint64()), "utf-8")
def read_arr():
reader, n = readers[read_int32()], read_uint64()
return [ reader() for _ in range(n) ]
readers: dict[int, Callable[[], Any]] = { 8: read_str, 9: read_arr, **{ t: functools.partial(read_unpack, "<"+f, nb) for t,f,nb in \
[ (0,"c",1), (1,"b",1), (2,"H",2), (3,"h",2), (4,"I",4), (5,"i",4), (6,"f",4), (7,"?",1), (10,"Q",8), (11,"q",8), (12,"d",8) ] } }
read_uint32, read_int32, read_uint64, read_int64 = readers[4], readers[5], readers[10], readers[11]
magic, version, n_tensors, n_kv = reader.read(4), read_int32(), read_int64(), read_int64()
if magic != b"GGUF" or version not in [2, 3]: raise ValueError("Invalid GGUF format!")
for _ in range(n_kv):
k, typ = read_str(), read_int32()
kv_data[k] = readers[typ]()
t_infos = [ (read_str(), tuple(read_uint64() for _ in range(read_uint32())), read_int32(), read_uint64()) for _ in range(n_tensors) ]
alignment, pos = kv_data.get("general.alignment", 32), reader.tell()
data_start = round_up(pos, alignment)
for name, dims, typ, off in t_infos: state_dict[name] = ggml_data_to_tensor(tensor[data_start + off:], prod(dims), typ).reshape(*reversed(dims))
return kv_data, state_dict