You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
337 lines
16 KiB
337 lines
16 KiB
import json, pathlib, zipfile, pickle, tarfile, struct, functools, io
|
|
from collections import OrderedDict
|
|
from typing import Union, Optional, Any, Callable, BinaryIO, Iterable
|
|
from tinygrad.tensor import Tensor
|
|
from tinygrad.dtype import dtypes
|
|
from tinygrad.helpers import prod, argsort, DEBUG, Timing, CI, unwrap, GlobalCounters, tqdm, round_up, T
|
|
from tinygrad.shape.view import strides_for_shape
|
|
|
|
class TensorIO(io.RawIOBase, BinaryIO):
|
|
def __init__(self, t: Tensor):
|
|
if t.ndim != 1 or t.dtype != dtypes.uint8: raise ValueError("Tensor must be 1d and of dtype uint8!")
|
|
self._position, self._tensor = 0, t
|
|
|
|
def readable(self) -> bool: return True
|
|
def read(self, size: int = -1) -> bytes:
|
|
if (buf:=super().read(size)) is None: raise ValueError("io.RawIOBase.read returned None") # only happens if readinto returns None (never)
|
|
return buf
|
|
def readinto(self, buffer: Any) -> int:
|
|
data = self._tensor[self._position:self._position+len(buffer)].data()
|
|
buffer[:len(data)] = data
|
|
self._position += len(data)
|
|
return len(data)
|
|
|
|
def seekable(self) -> bool: return True
|
|
def seek(self, offset: int, whence: int = 0) -> int:
|
|
self._position = min(len(self._tensor), max(0, [offset, self._position+offset, len(self._tensor)+offset][whence]))
|
|
return self._position
|
|
|
|
# required to correctly implement BinaryIO
|
|
def __enter__(self): return self
|
|
def write(self, s: Any): raise io.UnsupportedOperation("TensorIO.write not supported")
|
|
def writelines(self, lines: Iterable[Any]): raise io.UnsupportedOperation("TensorIO.writelines not supported")
|
|
|
|
safe_dtypes = {"BOOL":dtypes.bool, "I8":dtypes.int8, "U8":dtypes.uint8, "I16":dtypes.int16, "U16":dtypes.uint16, "I32":dtypes.int, "U32":dtypes.uint,
|
|
"I64":dtypes.int64, "U64":dtypes.uint64, "F16":dtypes.float16, "BF16":dtypes.bfloat16, "F32":dtypes.float32, "F64":dtypes.float64}
|
|
inverse_safe_dtypes = {v:k for k,v in safe_dtypes.items()}
|
|
|
|
def accept_filename(func: Callable[[Tensor], T]) -> Callable[[Union[Tensor, str, pathlib.Path]], T]:
|
|
@functools.wraps(func)
|
|
def wrapper(fn: Union[Tensor, str, pathlib.Path]) -> T: return func(Tensor(pathlib.Path(fn)) if not isinstance(fn, Tensor) else fn)
|
|
return wrapper
|
|
|
|
@accept_filename
|
|
def safe_load_metadata(t:Tensor) -> tuple[Tensor, int, dict[str, Any]]:
|
|
"""
|
|
Loads a .safetensor file, returning the source tensor, data start position, and metadata.
|
|
"""
|
|
data_start = int.from_bytes(t[0:8].data(), "little") + 8
|
|
return t, data_start, json.loads(t[8:data_start].data().tobytes())
|
|
|
|
def safe_load(fn:Union[Tensor, str, pathlib.Path]) -> dict[str, Tensor]:
|
|
"""
|
|
Loads a .safetensor file, returning the `state_dict`.
|
|
|
|
```python
|
|
state_dict = nn.state.safe_load("test.safetensor")
|
|
```
|
|
"""
|
|
t, data_start, metadata = safe_load_metadata(fn)
|
|
data = t[data_start:]
|
|
return { k: data[v['data_offsets'][0]:v['data_offsets'][1]].bitcast(safe_dtypes[v['dtype']]).reshape(v['shape'])
|
|
for k, v in metadata.items() if k != "__metadata__" }
|
|
|
|
def safe_save(tensors:dict[str, Tensor], fn:str, metadata:Optional[dict[str, Any]]=None):
|
|
"""
|
|
Saves a `state_dict` to disk in a .safetensor file with optional metadata.
|
|
|
|
```python
|
|
t = Tensor([1, 2, 3])
|
|
nn.state.safe_save({'t':t}, "test.safetensor")
|
|
```
|
|
"""
|
|
headers, offset = {}, 0
|
|
if metadata: headers['__metadata__'] = metadata
|
|
for k,v in tensors.items():
|
|
headers[k] = {'dtype': inverse_safe_dtypes[v.dtype], 'shape': list(v.shape), 'data_offsets':[offset, offset+v.nbytes()]}
|
|
offset += v.nbytes()
|
|
j = json.dumps(headers, separators=(',', ':'))
|
|
j += "\x20"*(round_up(len(j),8)-len(j))
|
|
pathlib.Path(fn).unlink(missing_ok=True)
|
|
t = Tensor.empty(8+len(j)+offset, dtype=dtypes.uint8, device=f"disk:{fn}")
|
|
t[0:8].bitcast(dtypes.int64).assign([len(j)])
|
|
t[8:8+len(j)].assign(list(j.encode('utf-8')))
|
|
for k,v in safe_load(t).items(): v.assign(tensors[k])
|
|
|
|
# state dict
|
|
|
|
def get_state_dict(obj, prefix:str='', tensor_type=Tensor) -> dict[str, Tensor]:
|
|
"""
|
|
Returns a `state_dict` of the object, with optional prefix.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
class Net:
|
|
def __init__(self):
|
|
self.l1 = nn.Linear(4, 5)
|
|
self.l2 = nn.Linear(5, 6)
|
|
|
|
net = Net()
|
|
print(nn.state.get_state_dict(net).keys())
|
|
```
|
|
"""
|
|
if isinstance(obj, tensor_type): return {prefix.strip('.'):obj}
|
|
if hasattr(obj, '_asdict'): return get_state_dict(obj._asdict(), prefix, tensor_type) # namedtuple
|
|
if isinstance(obj, OrderedDict): return get_state_dict(dict(obj), prefix, tensor_type)
|
|
if hasattr(obj, '__dict__'): return get_state_dict(obj.__dict__, prefix, tensor_type)
|
|
state_dict = {}
|
|
if isinstance(obj, (list, tuple)):
|
|
for i,x in enumerate(obj): state_dict.update(get_state_dict(x, f"{prefix}{str(i)}.", tensor_type))
|
|
elif isinstance(obj, dict):
|
|
for k,v in obj.items(): state_dict.update(get_state_dict(v, f"{prefix}{str(k)}.", tensor_type))
|
|
return state_dict
|
|
|
|
def get_parameters(obj) -> list[Tensor]:
|
|
"""
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
class Net:
|
|
def __init__(self):
|
|
self.l1 = nn.Linear(4, 5)
|
|
self.l2 = nn.Linear(5, 6)
|
|
|
|
net = Net()
|
|
print(len(nn.state.get_parameters(net)))
|
|
```
|
|
"""
|
|
return list(get_state_dict(obj).values())
|
|
|
|
def load_state_dict(model, state_dict:dict[str, Tensor], strict=True, verbose=True, consume=False, realize=True) -> None:
|
|
"""
|
|
Loads a `state_dict` into a model.
|
|
|
|
```python
|
|
class Net:
|
|
def __init__(self):
|
|
self.l1 = nn.Linear(4, 5)
|
|
self.l2 = nn.Linear(5, 6)
|
|
|
|
net = Net()
|
|
state_dict = nn.state.get_state_dict(net)
|
|
nn.state.load_state_dict(net, state_dict)
|
|
```
|
|
"""
|
|
start_mem_used = GlobalCounters.mem_used
|
|
with Timing("loaded weights in ",
|
|
lambda et_ns: f", {(B:=(GlobalCounters.mem_used-start_mem_used))/1e9:.2f} GB loaded at {B/et_ns:.2f} GB/s", enabled=verbose):
|
|
model_state_dict = get_state_dict(model)
|
|
if DEBUG >= 1 and len(state_dict) > len(model_state_dict):
|
|
print("WARNING: unused weights in state_dict", sorted(list(state_dict.keys() - model_state_dict.keys())))
|
|
for k,v in (t := tqdm(model_state_dict.items(), disable=CI or not verbose)):
|
|
t.desc = f"ram used: {GlobalCounters.mem_used/1e9:5.2f} GB, {k:50s}: "
|
|
if k not in state_dict and not strict:
|
|
if DEBUG >= 1: print(f"WARNING: not loading {k}")
|
|
continue
|
|
if v.shape != state_dict[k].shape:
|
|
raise ValueError(f'Shape mismatch in layer `{k}`: Expected shape {v.shape}, but found {state_dict[k].shape} in state dict.')
|
|
if isinstance(v.device, tuple):
|
|
if isinstance(state_dict[k].device, tuple): v.replace(state_dict[k])
|
|
else: v.replace(state_dict[k].shard(v.device, v.lazydata.axis))
|
|
else: v.replace(state_dict[k].to(v.device))
|
|
if realize: v.realize()
|
|
if consume: del state_dict[k]
|
|
|
|
@accept_filename
|
|
def tar_extract(t: Tensor) -> dict[str, Tensor]:
|
|
"""
|
|
```python
|
|
tar_extract(fn: Tensor | str | Path) -> dict[str, Tensor]
|
|
```
|
|
|
|
Extracts files from a tar archive and returns them as a dictionary of names (keys) and tensors (values).
|
|
|
|
```python
|
|
tensors = nn.state.tar_extract(Tensor(pathlib.Path("archive.tar")))
|
|
```
|
|
"""
|
|
with tarfile.open(fileobj=TensorIO(t), mode="r") as tar:
|
|
return {member.name:t[member.offset_data:member.offset_data+member.size] for member in tar if member.type == tarfile.REGTYPE}
|
|
|
|
# torch support!
|
|
|
|
@accept_filename
|
|
def torch_load(t:Tensor) -> dict[str, Tensor]:
|
|
"""
|
|
```python
|
|
torch_load(fn: Tensor | str | Path) -> dict[str, Tensor]
|
|
```
|
|
|
|
Loads a torch .pth file, returning the `state_dict`.
|
|
|
|
```python
|
|
state_dict = nn.state.torch_load("test.pth")
|
|
```
|
|
"""
|
|
offsets: dict[Union[str, int], int] = {}
|
|
lens: dict[Union[str, int], int] = {}
|
|
def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad=None, backward_hooks=None, metadata=None):
|
|
#print(storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata)
|
|
lens[storage[2]] = storage[4] * storage[1].itemsize
|
|
if storage[2] not in offsets: return None
|
|
byte_offset = offsets[storage[2]]+storage_offset*storage[1].itemsize
|
|
ret = t[byte_offset:byte_offset+prod(size)*storage[1].itemsize].bitcast(storage[1])
|
|
|
|
# 7 lines to deal with permuted tensors. NOTE: this currently requires reading off the disk
|
|
shape_strides = [(s, st) for s,st in zip(size, stride) if s != 1]
|
|
permute_indexes = [len(shape_strides)-1-y for y in argsort([x[1] for x in shape_strides])]
|
|
if tuple(permute_indexes) != tuple(range(len(permute_indexes))):
|
|
intermediate_shape = tuple([shape_strides[x][0] for x in argsort(permute_indexes)])
|
|
assert tuple([shape_strides[i][1] for i in argsort(permute_indexes)]) == strides_for_shape(intermediate_shape), "nonpermutable strides"
|
|
if DEBUG >= 3: print(f"WARNING: this torch load is slow. to permute {intermediate_shape} with {permute_indexes}")
|
|
assert storage[1] != dtypes.bfloat16, "can't permute BF16"
|
|
# TODO: find a nice way to support all shapetracker on disktensors
|
|
ret = ret.to(None).reshape(intermediate_shape).permute(permute_indexes)
|
|
|
|
return ret.reshape(size)
|
|
|
|
class Parameter:
|
|
def __setstate__(self, state): self.tensor = state[0]
|
|
|
|
deserialized_objects: dict[str, Any] = {}
|
|
intercept = {"HalfStorage": dtypes.float16, "FloatStorage": dtypes.float32, "BFloat16Storage": dtypes.bfloat16,
|
|
"IntStorage": dtypes.int32, "BoolStorage": dtypes.bool,
|
|
"LongStorage": dtypes.int64, "_rebuild_tensor_v2": _rebuild_tensor_v2, "FloatTensor": None, "Parameter": Parameter}
|
|
whitelist = {"torch", "collections", "numpy", "_codecs"} # NOTE: this is not for security, only speed
|
|
class Dummy: pass
|
|
class TorchPickle(pickle.Unpickler):
|
|
def find_class(self, module, name):
|
|
module_root = module.split(".")[0]
|
|
if module_root not in whitelist:
|
|
if DEBUG >= 2: print(f"WARNING: returning Dummy for {module} {name}")
|
|
return Dummy
|
|
return intercept[name] if module_root == "torch" else super().find_class(module, name)
|
|
def persistent_load(self, pid): return deserialized_objects.get(pid, pid)
|
|
|
|
fobj = io.BufferedReader(TensorIO(t))
|
|
def passthrough_reset(v: bool): return fobj.seek(0, 0) or v
|
|
|
|
if passthrough_reset(zipfile.is_zipfile(fobj)): # NOTE: passthrough_reset required to support python < 3.14
|
|
myzip = zipfile.ZipFile(fobj, 'r')
|
|
base_name = myzip.namelist()[0].split('/', 1)[0]
|
|
for n in myzip.namelist():
|
|
if n.startswith(f'{base_name}/data/'):
|
|
with myzip.open(n) as myfile:
|
|
offsets[n.split("/")[-1]] = myfile._orig_compress_start # type: ignore
|
|
with myzip.open(f'{base_name}/data.pkl') as myfile:
|
|
return TorchPickle(myfile).load()
|
|
elif passthrough_reset(tarfile.is_tarfile(fobj)): # NOTE: passthrough_reset required to support python < 3.11
|
|
with tarfile.open(fileobj=fobj, mode="r") as tar:
|
|
storages_offset = tar.getmember('storages').offset_data
|
|
f = unwrap(tar.extractfile('storages'))
|
|
for i in range(TorchPickle(f).load()): # num_storages
|
|
(key, _, storage_type), sz = TorchPickle(f).load(), struct.unpack('<q', f.read(8))[0]
|
|
offsets[key] = storages_offset + f.tell()
|
|
f.seek(sz*storage_type.itemsize, 1)
|
|
f = unwrap(tar.extractfile('tensors'))
|
|
for _ in range(TorchPickle(f).load()): # num_tensors
|
|
(key, storage_id, _), ndim, _ = TorchPickle(f).load(), struct.unpack('<i', f.read(4))[0], f.read(4)
|
|
size, stride = struct.unpack(f'<{ndim}q', f.read(8 * ndim)), struct.unpack(f'<{ndim}q', f.read(8 * ndim))
|
|
storage_offset = struct.unpack('<q', f.read(8))[0]
|
|
deserialized_objects[str(key)] = _rebuild_tensor_v2((None, storage_type, storage_id, None, -1), storage_offset, size, stride)
|
|
return {k:v.tensor if isinstance(v, Parameter) else v for k,v in TorchPickle(unwrap(tar.extractfile('pickle'))).load().items()}
|
|
else:
|
|
pkl = TorchPickle(fobj)
|
|
_, _, _, rwd, _, ids, base_offset = pkl.load(), pkl.load(), pkl.load(), fobj.tell(), pkl.load(), pkl.load(), fobj.tell()
|
|
for i in ids:
|
|
offsets[i] = base_offset + 8
|
|
base_offset += 8 + lens[i]
|
|
fobj.seek(rwd)
|
|
return TorchPickle(fobj).load()
|
|
|
|
def ggml_data_to_tensor(t: Tensor, n: int, ggml_type: int) -> Tensor:
|
|
"""
|
|
Converts ggml tensor data to a tinygrad tensor.
|
|
|
|
Supported native types: float32 (id: 0), float16 (id: 1), int8 (id: 16), int16 (id: 17), int32 (id: 18)
|
|
Supported quantized types: Q4_0 (id: 2), Q4_1 (id: 3), Q8_0 (id: 8), Q6_K (id: 14)
|
|
"""
|
|
# https://github.com/ggerganov/ggml/blob/6dccc647264f5429df2624f36138f601e7ce23e5/include/ggml.h#L356
|
|
|
|
# native types
|
|
if (dtype := { 0: dtypes.float32, 1: dtypes.float16, 16: dtypes.int8, 17: dtypes.int16, 18: dtypes.int32 }.get(ggml_type)) is not None:
|
|
return t[:dtype.itemsize * n].bitcast(dtype)
|
|
|
|
def q_to_uint8(t: Tensor, b: int) -> Tensor:
|
|
# TODO: rewrite with arange?
|
|
shift_tensor, bitmask = Tensor.stack(*[ Tensor(2**(i*b), device=t.device, dtype=t.dtype) for i in range(8//b) ]), 0xff >> (8 - b)
|
|
return t.unsqueeze(-1).expand((*t.shape,8//b)).idiv(shift_tensor).bitwise_and(bitmask).transpose(-1, -2).flatten(-2)
|
|
|
|
# map to (number of elements, number of bytes)
|
|
if (nelements_nbytes := { 2: (32, 18), 3: (32, 20), 14: (256, 210), 8: (32, 34) }.get(ggml_type)) is not None:
|
|
blocks = t[:(n//nelements_nbytes[0])*nelements_nbytes[1]].reshape((-1, nelements_nbytes[1]))
|
|
if ggml_type == 2: return (q_to_uint8(blocks[:,2:], 4).bitcast(dtypes.int8) - 8) * blocks[:,:2].bitcast(dtypes.float16).cast(dtypes.float32)
|
|
if ggml_type == 3:
|
|
d, m = (blocks[:,s:s+2].bitcast(dtypes.float16).cast(dtypes.float32) for s in [ 0, 2 ])
|
|
return q_to_uint8(blocks[:,4:], 4).bitcast(dtypes.int8) * d + m
|
|
if ggml_type == 8: return blocks[:,:2].bitcast(dtypes.float16).cast(dtypes.float32) * blocks[:,2:].bitcast(dtypes.int8)
|
|
if ggml_type == 14:
|
|
xl, xh = q_to_uint8(blocks[:,:128].reshape((-1, 2, 64)), 4), q_to_uint8(blocks[:,128:192].reshape((-1, 2, 32)), 2).lshift(4)
|
|
scales = blocks[:,192:208].bitcast(dtypes.int8).unsqueeze(-1).expand((-1, 16, 16)).reshape((-1, 256))
|
|
d = blocks[:,-2:].bitcast(dtypes.float16).cast(dtypes.float32).expand((-1, 256))
|
|
return d * (xl.bitwise_or(xh).bitcast(dtypes.int8) - 32).flatten(-2) * scales
|
|
raise ValueError(f"GGML type '{ggml_type}' is not supported!")
|
|
|
|
@accept_filename
|
|
def gguf_load(tensor: Tensor) -> tuple[dict, dict[str, Tensor]]:
|
|
"""
|
|
Loads a .gguf file, returning the `kv_data` and `state_dict`.
|
|
|
|
```python
|
|
gguf_tensor = Tensor(pathlib.Path("Meta-Llama-3-8B-Instruct.Q4_0.gguf")).to(Device.DEFAULT)
|
|
kv_data, state_dict = nn.state.gguf_load(gguf_tensor)
|
|
```
|
|
|
|
NOTE: The provided tensor must be on a device that supports execution.
|
|
"""
|
|
reader, kv_data, state_dict = io.BufferedReader(TensorIO(tensor), 1_000_000), {}, {}
|
|
def read_unpack(fmt: str, n: int): return struct.unpack(fmt, reader.read(n))[0]
|
|
def read_str(): return str(reader.read(read_uint64()), "utf-8")
|
|
def read_arr():
|
|
reader, n = readers[read_int32()], read_uint64()
|
|
return [ reader() for _ in range(n) ]
|
|
|
|
readers: dict[int, Callable[[], Any]] = { 8: read_str, 9: read_arr, **{ t: functools.partial(read_unpack, "<"+f, nb) for t,f,nb in \
|
|
[ (0,"c",1), (1,"b",1), (2,"H",2), (3,"h",2), (4,"I",4), (5,"i",4), (6,"f",4), (7,"?",1), (10,"Q",8), (11,"q",8), (12,"d",8) ] } }
|
|
read_uint32, read_int32, read_uint64, read_int64 = readers[4], readers[5], readers[10], readers[11]
|
|
|
|
magic, version, n_tensors, n_kv = reader.read(4), read_int32(), read_int64(), read_int64()
|
|
if magic != b"GGUF" or version not in [2, 3]: raise ValueError("Invalid GGUF format!")
|
|
for _ in range(n_kv):
|
|
k, typ = read_str(), read_int32()
|
|
kv_data[k] = readers[typ]()
|
|
|
|
t_infos = [ (read_str(), tuple(read_uint64() for _ in range(read_uint32())), read_int32(), read_uint64()) for _ in range(n_tensors) ]
|
|
alignment, pos = kv_data.get("general.alignment", 32), reader.tell()
|
|
data_start = round_up(pos, alignment)
|
|
|
|
for name, dims, typ, off in t_infos: state_dict[name] = ggml_data_to_tensor(tensor[data_start + off:], prod(dims), typ).reshape(*reversed(dims))
|
|
|
|
return kv_data, state_dict
|
|
|