125 lines
6.9 KiB
125 lines
6.9 KiB
1 year ago
|
import os, json, pathlib, zipfile, pickle
|
||
|
from tqdm import tqdm
|
||
|
from typing import Dict, Union, List, Optional, Any, Tuple
|
||
|
from tinygrad.tensor import Tensor
|
||
|
from tinygrad.helpers import dtypes, prod, argsort, DEBUG, Timing, GlobalCounters, CI
|
||
|
from tinygrad.shape.view import strides_for_shape
|
||
|
from tinygrad.ops import Device
|
||
|
|
||
|
safe_dtypes = {"F16": dtypes.float16, "F32": dtypes.float32, "U8": dtypes.uint8, "I8": dtypes.int8, "I32": dtypes.int32, "I64": dtypes.int64}
|
||
|
inverse_safe_dtypes = {v:k for k,v in safe_dtypes.items()}
|
||
|
|
||
|
def safe_load_metadata(fn:Union[Tensor,str]) -> Tuple[Tensor, int, Any]:
|
||
|
t = fn if isinstance(fn, Tensor) else Tensor.empty(os.stat(fn).st_size, dtype=dtypes.uint8, device=f"disk:{fn}")
|
||
|
json_len = t[0:1].cast(dtypes.int64).numpy()[0]
|
||
|
return (t, json_len, json.loads(t[8:8+json_len].numpy().tobytes()))
|
||
|
|
||
|
def safe_load(fn:Union[Tensor,str]) -> Dict[str, Tensor]:
|
||
|
t, json_len, metadata = safe_load_metadata(fn)
|
||
|
return {k:t[8+json_len+v['data_offsets'][0]:].cast(safe_dtypes[v['dtype']])[:prod(v['shape'])].reshape(v['shape']) for k,v in metadata.items() if k != "__metadata__"}
|
||
|
|
||
|
def safe_save(tensors:Dict[str, Tensor], fn:str, metadata:Optional[Dict[str, Any]]=None):
|
||
|
headers, offset = {}, 0
|
||
|
if metadata: headers['__metadata__'] = metadata
|
||
|
for k,v in tensors.items():
|
||
|
headers[k] = {'dtype': inverse_safe_dtypes[v.dtype], 'shape': list(v.shape), 'data_offsets':[offset, offset+v.nbytes()]}
|
||
|
offset += v.nbytes()
|
||
|
j = json.dumps(headers, separators=(',', ':'))
|
||
|
j += "\x20"*((8-len(j)%8)%8)
|
||
|
pathlib.Path(fn).unlink(missing_ok=True)
|
||
|
t = Tensor.empty(8+len(j)+offset, dtype=dtypes.uint8, device=f"disk:{fn}")
|
||
|
t[0:1].cast(dtypes.int64).assign([len(j)])
|
||
|
t[8:8+len(j)].assign(Tensor(list(j.encode('utf-8')), dtype=dtypes.uint8, device="cpu"))
|
||
|
for k,v in safe_load(t).items(): v.assign(tensors[k])
|
||
|
|
||
|
# state dict
|
||
|
|
||
|
from collections import OrderedDict
|
||
|
def get_state_dict(obj, prefix:str='', tensor_type=Tensor) -> Dict[str, Tensor]:
|
||
|
if isinstance(obj, tensor_type): return {prefix.strip('.'):obj}
|
||
|
if hasattr(obj, '_asdict'): return get_state_dict(obj._asdict(), prefix, tensor_type) # namedtuple
|
||
|
if isinstance(obj, OrderedDict): return get_state_dict(dict(obj), prefix, tensor_type)
|
||
|
if hasattr(obj, '__dict__'): return get_state_dict(obj.__dict__, prefix, tensor_type)
|
||
|
state_dict = {}
|
||
|
if isinstance(obj, (list, tuple)):
|
||
|
for i,x in enumerate(obj): state_dict.update(get_state_dict(x, f"{prefix}{str(i)}.", tensor_type))
|
||
|
elif isinstance(obj, dict):
|
||
|
for k,v in obj.items(): state_dict.update(get_state_dict(v, f"{prefix}{str(k)}.", tensor_type))
|
||
|
return state_dict
|
||
|
def get_parameters(obj) -> List[Tensor]: return list(get_state_dict(obj).values())
|
||
|
|
||
|
def load_state_dict(model, state_dict, strict=True, verbose=True):
|
||
|
with Timing("loaded weights in ", lambda et_ns: f", {GlobalCounters.mem_used/1e9:.2f} GB loaded at {GlobalCounters.mem_used/et_ns:.2f} GB/s"):
|
||
|
model_state_dict = get_state_dict(model)
|
||
|
if DEBUG >= 1 and len(state_dict) > len(model_state_dict): print("WARNING: unused weights in state_dict", sorted(list(state_dict.keys() - model_state_dict.keys())))
|
||
|
for k,v in (t := tqdm(model_state_dict.items(), disable=CI or not verbose)):
|
||
|
t.set_description(f"ram used: {GlobalCounters.mem_used/1e9:5.2f} GB, {k:50s}")
|
||
|
if k not in state_dict and not strict:
|
||
|
if DEBUG >= 1: print(f"WARNING: not loading {k}")
|
||
|
continue
|
||
|
v.assign(state_dict[k].to(v.device)).realize()
|
||
|
|
||
|
# torch support!
|
||
|
|
||
|
def torch_load(fn:str):
|
||
|
t = Tensor.empty(os.stat(fn).st_size, dtype=dtypes.uint8, device=f"disk:{fn}")
|
||
|
|
||
|
offsets: Dict[str, int] = {}
|
||
|
lens: Dict[str, int] = {}
|
||
|
def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata=None):
|
||
|
#print(storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata)
|
||
|
lens[storage[2]] = storage[4] * storage[1].itemsize
|
||
|
if storage[2] not in offsets: return None
|
||
|
byte_offset = offsets[storage[2]]+storage_offset*storage[1].itemsize
|
||
|
ret = t[byte_offset:byte_offset+prod(size)].cast(storage[1])
|
||
|
# convert bfloat16 -> float16 using LLVM for Llama 2
|
||
|
# upstream LLaMA also does this conversion:
|
||
|
# https://github.com/facebookresearch/llama/blob/6c7fe276574e78057f917549435a2554000a876d/llama/generation.py#L95
|
||
|
# TODO: should this be done in the example instead? or maybe we don't need this anymore with better bfloat16 support
|
||
|
if storage[1] == dtypes.bfloat16:
|
||
|
ret = ret.bitcast(dtypes.uint16).to("CPU").cast(dtypes.uint32).mul(1<<16).bitcast(dtypes.float32).to(Device.DEFAULT).half()
|
||
|
#ret = ret.to("LLVM").half().to(Device.DEFAULT)
|
||
|
|
||
|
# 7 lines to deal with permuted tensors. NOTE: this currently requires reading off the disk
|
||
|
shape_strides = [(s, st) for s,st in zip(size, stride) if s != 1]
|
||
|
permute_indexes = [len(shape_strides)-1-y for y in argsort([x[1] for x in shape_strides])]
|
||
|
if tuple(permute_indexes) != tuple(range(len(permute_indexes))):
|
||
|
intermediate_shape = tuple([shape_strides[x][0] for x in argsort(permute_indexes)])
|
||
|
assert tuple([shape_strides[i][1] for i in argsort(permute_indexes)]) == strides_for_shape(intermediate_shape), "nonpermutable strides"
|
||
|
if DEBUG >= 2: print(f"WARNING: this torch load is slow. CPU to permute {intermediate_shape} with {permute_indexes}")
|
||
|
# TODO: find a nice way to support all shapetracker on disktensors
|
||
|
ret = ret.cpu().reshape(intermediate_shape).permute(permute_indexes)
|
||
|
|
||
|
return ret.reshape(size)
|
||
|
|
||
|
intercept = {"HalfStorage": dtypes.float16, "FloatStorage": dtypes.float32, "BFloat16Storage": dtypes.bfloat16, "IntStorage": dtypes.int32, "LongStorage": dtypes.int64, "_rebuild_tensor_v2": _rebuild_tensor_v2}
|
||
|
whitelist = {"torch", "collections", "numpy", "_codecs"} # NOTE: this is not for security, only speed
|
||
|
class Dummy: pass
|
||
|
class TorchPickle(pickle.Unpickler):
|
||
|
def find_class(self, module, name):
|
||
|
module_root = module.split(".")[0]
|
||
|
if module_root not in whitelist:
|
||
|
if DEBUG >= 2: print(f"WARNING: returning Dummy for {module} {name}")
|
||
|
return Dummy
|
||
|
return intercept[name] if module_root == "torch" else super().find_class(module, name)
|
||
|
def persistent_load(self, pid): return pid
|
||
|
|
||
|
if tuple(t[0:2].numpy()) == (0x50, 0x4b):
|
||
|
myzip = zipfile.ZipFile(fn, 'r')
|
||
|
base_name = myzip.namelist()[0].split('/', 1)[0]
|
||
|
for n in myzip.namelist():
|
||
|
if n.startswith(f'{base_name}/data/'):
|
||
|
with myzip.open(n) as myfile:
|
||
|
offsets[n.split("/")[-1]] = myfile._orig_compress_start # type: ignore
|
||
|
with myzip.open(f'{base_name}/data.pkl') as myfile:
|
||
|
return TorchPickle(myfile).load()
|
||
|
else:
|
||
|
with open(fn, "rb") as f:
|
||
|
pkl = TorchPickle(f)
|
||
|
_, _, _, rwd, _, ids, base_offset = pkl.load(), pkl.load(), pkl.load(), f.tell(), pkl.load(), pkl.load(), f.tell()
|
||
|
for i in ids:
|
||
|
offsets[i] = base_offset + 8
|
||
|
base_offset += 8 + lens[i]
|
||
|
f.seek(rwd)
|
||
|
return TorchPickle(f).load()
|