import json, pathlib, zipfile, pickle, tarfile, struct, functools, io from collections import OrderedDict from typing import Union, Optional, Any, Callable, BinaryIO, Iterable from tinygrad.tensor import Tensor from tinygrad.dtype import dtypes from tinygrad.helpers import prod, argsort, DEBUG, Timing, CI, unwrap, GlobalCounters, tqdm, round_up, T from tinygrad.shape.view import strides_for_shape class TensorIO(io.RawIOBase, BinaryIO): def __init__(self, t: Tensor): if t.ndim != 1 or t.dtype != dtypes.uint8: raise ValueError("Tensor must be 1d and of dtype uint8!") self._position, self._tensor = 0, t def readable(self) -> bool: return True def read(self, size: int = -1) -> bytes: if (buf:=super().read(size)) is None: raise ValueError("io.RawIOBase.read returned None") # only happens if readinto returns None (never) return buf def readinto(self, buffer: Any) -> int: data = self._tensor[self._position:self._position+len(buffer)].data() buffer[:len(data)] = data self._position += len(data) return len(data) def seekable(self) -> bool: return True def seek(self, offset: int, whence: int = 0) -> int: self._position = min(len(self._tensor), max(0, [offset, self._position+offset, len(self._tensor)+offset][whence])) return self._position # required to correctly implement BinaryIO def __enter__(self): return self def write(self, s: Any): raise io.UnsupportedOperation("TensorIO.write not supported") def writelines(self, lines: Iterable[Any]): raise io.UnsupportedOperation("TensorIO.writelines not supported") safe_dtypes = {"BOOL":dtypes.bool, "I8":dtypes.int8, "U8":dtypes.uint8, "I16":dtypes.int16, "U16":dtypes.uint16, "I32":dtypes.int, "U32":dtypes.uint, "I64":dtypes.int64, "U64":dtypes.uint64, "F16":dtypes.float16, "BF16":dtypes.bfloat16, "F32":dtypes.float32, "F64":dtypes.float64} inverse_safe_dtypes = {v:k for k,v in safe_dtypes.items()} def accept_filename(func: Callable[[Tensor], T]) -> Callable[[Union[Tensor, str, pathlib.Path]], T]: @functools.wraps(func) def wrapper(fn: Union[Tensor, str, pathlib.Path]) -> T: return func(Tensor(pathlib.Path(fn)) if not isinstance(fn, Tensor) else fn) return wrapper @accept_filename def safe_load_metadata(t:Tensor) -> tuple[Tensor, int, dict[str, Any]]: """ Loads a .safetensor file, returning the source tensor, data start position, and metadata. """ data_start = int.from_bytes(t[0:8].data(), "little") + 8 return t, data_start, json.loads(t[8:data_start].data().tobytes()) def safe_load(fn:Union[Tensor, str, pathlib.Path]) -> dict[str, Tensor]: """ Loads a .safetensor file, returning the `state_dict`. ```python state_dict = nn.state.safe_load("test.safetensor") ``` """ t, data_start, metadata = safe_load_metadata(fn) data = t[data_start:] return { k: data[v['data_offsets'][0]:v['data_offsets'][1]].bitcast(safe_dtypes[v['dtype']]).reshape(v['shape']) for k, v in metadata.items() if k != "__metadata__" } def safe_save(tensors:dict[str, Tensor], fn:str, metadata:Optional[dict[str, Any]]=None): """ Saves a `state_dict` to disk in a .safetensor file with optional metadata. ```python t = Tensor([1, 2, 3]) nn.state.safe_save({'t':t}, "test.safetensor") ``` """ headers, offset = {}, 0 if metadata: headers['__metadata__'] = metadata for k,v in tensors.items(): headers[k] = {'dtype': inverse_safe_dtypes[v.dtype], 'shape': list(v.shape), 'data_offsets':[offset, offset+v.nbytes()]} offset += v.nbytes() j = json.dumps(headers, separators=(',', ':')) j += "\x20"*(round_up(len(j),8)-len(j)) pathlib.Path(fn).unlink(missing_ok=True) t = Tensor.empty(8+len(j)+offset, dtype=dtypes.uint8, device=f"disk:{fn}") t[0:8].bitcast(dtypes.int64).assign([len(j)]) t[8:8+len(j)].assign(list(j.encode('utf-8'))) for k,v in safe_load(t).items(): v.assign(tensors[k]) # state dict def get_state_dict(obj, prefix:str='', tensor_type=Tensor) -> dict[str, Tensor]: """ Returns a `state_dict` of the object, with optional prefix. ```python exec="true" source="above" session="tensor" result="python" class Net: def __init__(self): self.l1 = nn.Linear(4, 5) self.l2 = nn.Linear(5, 6) net = Net() print(nn.state.get_state_dict(net).keys()) ``` """ if isinstance(obj, tensor_type): return {prefix.strip('.'):obj} if hasattr(obj, '_asdict'): return get_state_dict(obj._asdict(), prefix, tensor_type) # namedtuple if isinstance(obj, OrderedDict): return get_state_dict(dict(obj), prefix, tensor_type) if hasattr(obj, '__dict__'): return get_state_dict(obj.__dict__, prefix, tensor_type) state_dict = {} if isinstance(obj, (list, tuple)): for i,x in enumerate(obj): state_dict.update(get_state_dict(x, f"{prefix}{str(i)}.", tensor_type)) elif isinstance(obj, dict): for k,v in obj.items(): state_dict.update(get_state_dict(v, f"{prefix}{str(k)}.", tensor_type)) return state_dict def get_parameters(obj) -> list[Tensor]: """ ```python exec="true" source="above" session="tensor" result="python" class Net: def __init__(self): self.l1 = nn.Linear(4, 5) self.l2 = nn.Linear(5, 6) net = Net() print(len(nn.state.get_parameters(net))) ``` """ return list(get_state_dict(obj).values()) def load_state_dict(model, state_dict:dict[str, Tensor], strict=True, verbose=True, consume=False, realize=True) -> None: """ Loads a `state_dict` into a model. ```python class Net: def __init__(self): self.l1 = nn.Linear(4, 5) self.l2 = nn.Linear(5, 6) net = Net() state_dict = nn.state.get_state_dict(net) nn.state.load_state_dict(net, state_dict) ``` """ start_mem_used = GlobalCounters.mem_used with Timing("loaded weights in ", lambda et_ns: f", {(B:=(GlobalCounters.mem_used-start_mem_used))/1e9:.2f} GB loaded at {B/et_ns:.2f} GB/s", enabled=verbose): model_state_dict = get_state_dict(model) if DEBUG >= 1 and len(state_dict) > len(model_state_dict): print("WARNING: unused weights in state_dict", sorted(list(state_dict.keys() - model_state_dict.keys()))) for k,v in (t := tqdm(model_state_dict.items(), disable=CI or not verbose)): t.desc = f"ram used: {GlobalCounters.mem_used/1e9:5.2f} GB, {k:50s}: " if k not in state_dict and not strict: if DEBUG >= 1: print(f"WARNING: not loading {k}") continue if v.shape != state_dict[k].shape: raise ValueError(f'Shape mismatch in layer `{k}`: Expected shape {v.shape}, but found {state_dict[k].shape} in state dict.') if isinstance(v.device, tuple): if isinstance(state_dict[k].device, tuple): v.replace(state_dict[k]) else: v.replace(state_dict[k].shard(v.device, v.lazydata.axis)) else: v.replace(state_dict[k].to(v.device)) if realize: v.realize() if consume: del state_dict[k] @accept_filename def tar_extract(t: Tensor) -> dict[str, Tensor]: """ ```python tar_extract(fn: Tensor | str | Path) -> dict[str, Tensor] ``` Extracts files from a tar archive and returns them as a dictionary of names (keys) and tensors (values). ```python tensors = nn.state.tar_extract(Tensor(pathlib.Path("archive.tar"))) ``` """ with tarfile.open(fileobj=TensorIO(t), mode="r") as tar: return {member.name:t[member.offset_data:member.offset_data+member.size] for member in tar if member.type == tarfile.REGTYPE} # torch support! @accept_filename def torch_load(t:Tensor) -> dict[str, Tensor]: """ ```python torch_load(fn: Tensor | str | Path) -> dict[str, Tensor] ``` Loads a torch .pth file, returning the `state_dict`. ```python state_dict = nn.state.torch_load("test.pth") ``` """ offsets: dict[Union[str, int], int] = {} lens: dict[Union[str, int], int] = {} def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad=None, backward_hooks=None, metadata=None): #print(storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata) lens[storage[2]] = storage[4] * storage[1].itemsize if storage[2] not in offsets: return None byte_offset = offsets[storage[2]]+storage_offset*storage[1].itemsize ret = t[byte_offset:byte_offset+prod(size)*storage[1].itemsize].bitcast(storage[1]) # 7 lines to deal with permuted tensors. NOTE: this currently requires reading off the disk shape_strides = [(s, st) for s,st in zip(size, stride) if s != 1] permute_indexes = [len(shape_strides)-1-y for y in argsort([x[1] for x in shape_strides])] if tuple(permute_indexes) != tuple(range(len(permute_indexes))): intermediate_shape = tuple([shape_strides[x][0] for x in argsort(permute_indexes)]) assert tuple([shape_strides[i][1] for i in argsort(permute_indexes)]) == strides_for_shape(intermediate_shape), "nonpermutable strides" if DEBUG >= 3: print(f"WARNING: this torch load is slow. to permute {intermediate_shape} with {permute_indexes}") assert storage[1] != dtypes.bfloat16, "can't permute BF16" # TODO: find a nice way to support all shapetracker on disktensors ret = ret.to(None).reshape(intermediate_shape).permute(permute_indexes) return ret.reshape(size) class Parameter: def __setstate__(self, state): self.tensor = state[0] deserialized_objects: dict[str, Any] = {} intercept = {"HalfStorage": dtypes.float16, "FloatStorage": dtypes.float32, "BFloat16Storage": dtypes.bfloat16, "IntStorage": dtypes.int32, "BoolStorage": dtypes.bool, "LongStorage": dtypes.int64, "_rebuild_tensor_v2": _rebuild_tensor_v2, "FloatTensor": None, "Parameter": Parameter} whitelist = {"torch", "collections", "numpy", "_codecs"} # NOTE: this is not for security, only speed class Dummy: pass class TorchPickle(pickle.Unpickler): def find_class(self, module, name): module_root = module.split(".")[0] if module_root not in whitelist: if DEBUG >= 2: print(f"WARNING: returning Dummy for {module} {name}") return Dummy return intercept[name] if module_root == "torch" else super().find_class(module, name) def persistent_load(self, pid): return deserialized_objects.get(pid, pid) fobj = io.BufferedReader(TensorIO(t)) def passthrough_reset(v: bool): return fobj.seek(0, 0) or v if passthrough_reset(zipfile.is_zipfile(fobj)): # NOTE: passthrough_reset required to support python < 3.14 myzip = zipfile.ZipFile(fobj, 'r') base_name = myzip.namelist()[0].split('/', 1)[0] for n in myzip.namelist(): if n.startswith(f'{base_name}/data/'): with myzip.open(n) as myfile: offsets[n.split("/")[-1]] = myfile._orig_compress_start # type: ignore with myzip.open(f'{base_name}/data.pkl') as myfile: return TorchPickle(myfile).load() elif passthrough_reset(tarfile.is_tarfile(fobj)): # NOTE: passthrough_reset required to support python < 3.11 with tarfile.open(fileobj=fobj, mode="r") as tar: storages_offset = tar.getmember('storages').offset_data f = unwrap(tar.extractfile('storages')) for i in range(TorchPickle(f).load()): # num_storages (key, _, storage_type), sz = TorchPickle(f).load(), struct.unpack(' Tensor: """ Converts ggml tensor data to a tinygrad tensor. Supported native types: float32 (id: 0), float16 (id: 1), int8 (id: 16), int16 (id: 17), int32 (id: 18) Supported quantized types: Q4_0 (id: 2), Q4_1 (id: 3), Q8_0 (id: 8), Q6_K (id: 14) """ # https://github.com/ggerganov/ggml/blob/6dccc647264f5429df2624f36138f601e7ce23e5/include/ggml.h#L356 # native types if (dtype := { 0: dtypes.float32, 1: dtypes.float16, 16: dtypes.int8, 17: dtypes.int16, 18: dtypes.int32 }.get(ggml_type)) is not None: return t[:dtype.itemsize * n].bitcast(dtype) def q_to_uint8(t: Tensor, b: int) -> Tensor: # TODO: rewrite with arange? shift_tensor, bitmask = Tensor.stack(*[ Tensor(2**(i*b), device=t.device, dtype=t.dtype) for i in range(8//b) ]), 0xff >> (8 - b) return t.unsqueeze(-1).expand((*t.shape,8//b)).idiv(shift_tensor).bitwise_and(bitmask).transpose(-1, -2).flatten(-2) # map to (number of elements, number of bytes) if (nelements_nbytes := { 2: (32, 18), 3: (32, 20), 14: (256, 210), 8: (32, 34) }.get(ggml_type)) is not None: blocks = t[:(n//nelements_nbytes[0])*nelements_nbytes[1]].reshape((-1, nelements_nbytes[1])) if ggml_type == 2: return (q_to_uint8(blocks[:,2:], 4).bitcast(dtypes.int8) - 8) * blocks[:,:2].bitcast(dtypes.float16).cast(dtypes.float32) if ggml_type == 3: d, m = (blocks[:,s:s+2].bitcast(dtypes.float16).cast(dtypes.float32) for s in [ 0, 2 ]) return q_to_uint8(blocks[:,4:], 4).bitcast(dtypes.int8) * d + m if ggml_type == 8: return blocks[:,:2].bitcast(dtypes.float16).cast(dtypes.float32) * blocks[:,2:].bitcast(dtypes.int8) if ggml_type == 14: xl, xh = q_to_uint8(blocks[:,:128].reshape((-1, 2, 64)), 4), q_to_uint8(blocks[:,128:192].reshape((-1, 2, 32)), 2).lshift(4) scales = blocks[:,192:208].bitcast(dtypes.int8).unsqueeze(-1).expand((-1, 16, 16)).reshape((-1, 256)) d = blocks[:,-2:].bitcast(dtypes.float16).cast(dtypes.float32).expand((-1, 256)) return d * (xl.bitwise_or(xh).bitcast(dtypes.int8) - 32).flatten(-2) * scales raise ValueError(f"GGML type '{ggml_type}' is not supported!") @accept_filename def gguf_load(tensor: Tensor) -> tuple[dict, dict[str, Tensor]]: """ Loads a .gguf file, returning the `kv_data` and `state_dict`. ```python gguf_tensor = Tensor(pathlib.Path("Meta-Llama-3-8B-Instruct.Q4_0.gguf")).to(Device.DEFAULT) kv_data, state_dict = nn.state.gguf_load(gguf_tensor) ``` NOTE: The provided tensor must be on a device that supports execution. """ reader, kv_data, state_dict = io.BufferedReader(TensorIO(tensor), 1_000_000), {}, {} def read_unpack(fmt: str, n: int): return struct.unpack(fmt, reader.read(n))[0] def read_str(): return str(reader.read(read_uint64()), "utf-8") def read_arr(): reader, n = readers[read_int32()], read_uint64() return [ reader() for _ in range(n) ] readers: dict[int, Callable[[], Any]] = { 8: read_str, 9: read_arr, **{ t: functools.partial(read_unpack, "<"+f, nb) for t,f,nb in \ [ (0,"c",1), (1,"b",1), (2,"H",2), (3,"h",2), (4,"I",4), (5,"i",4), (6,"f",4), (7,"?",1), (10,"Q",8), (11,"q",8), (12,"d",8) ] } } read_uint32, read_int32, read_uint64, read_int64 = readers[4], readers[5], readers[10], readers[11] magic, version, n_tensors, n_kv = reader.read(4), read_int32(), read_int64(), read_int64() if magic != b"GGUF" or version not in [2, 3]: raise ValueError("Invalid GGUF format!") for _ in range(n_kv): k, typ = read_str(), read_int32() kv_data[k] = readers[typ]() t_infos = [ (read_str(), tuple(read_uint64() for _ in range(read_uint32())), read_int32(), read_uint64()) for _ in range(n_tensors) ] alignment, pos = kv_data.get("general.alignment", 32), reader.tell() data_start = round_up(pos, alignment) for name, dims, typ, off in t_infos: state_dict[name] = ggml_data_to_tensor(tensor[data_start + off:], prod(dims), typ).reshape(*reversed(dims)) return kv_data, state_dict