You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
354 lines
19 KiB
354 lines
19 KiB
from __future__ import annotations
|
|
import os, functools, platform, time, re, contextlib, operator, hashlib, pickle, sqlite3, tempfile, pathlib, string, ctypes, sys, gzip, getpass
|
|
import urllib.request, subprocess, shutil, math, contextvars, types, copyreg, inspect, importlib
|
|
from dataclasses import dataclass
|
|
from typing import Union, ClassVar, Optional, Iterable, Any, TypeVar, Callable, Sequence, TypeGuard, Iterator, Generic
|
|
|
|
T = TypeVar("T")
|
|
U = TypeVar("U")
|
|
# NOTE: it returns int 1 if x is empty regardless of the type of x
|
|
def prod(x:Iterable[T]) -> Union[T,int]: return functools.reduce(operator.mul, x, 1)
|
|
|
|
# NOTE: helpers is not allowed to import from anything else in tinygrad
|
|
OSX = platform.system() == "Darwin"
|
|
CI = os.getenv("CI", "") != ""
|
|
|
|
# fix colors on Windows, https://stackoverflow.com/questions/12492810/python-how-can-i-make-the-ansi-escape-codes-to-work-also-in-windows
|
|
if sys.platform == "win32": os.system("")
|
|
|
|
def dedup(x:Iterable[T]): return list(dict.fromkeys(x)) # retains list order
|
|
def argfix(*x):
|
|
if x and x[0].__class__ in (tuple, list):
|
|
if len(x) != 1: raise ValueError(f"bad arg {x}")
|
|
return tuple(x[0])
|
|
return x
|
|
def argsort(x): return type(x)(sorted(range(len(x)), key=x.__getitem__)) # https://stackoverflow.com/questions/3382352/equivalent-of-numpy-argsort-in-basic-python
|
|
def all_same(items:Union[tuple[T, ...], list[T]]): return all(x == items[0] for x in items)
|
|
def all_int(t: Sequence[Any]) -> TypeGuard[tuple[int, ...]]: return all(isinstance(s, int) for s in t)
|
|
def colored(st, color:Optional[str], background=False): return f"\u001b[{10*background+60*(color.upper() == color)+30+['black', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white'].index(color.lower())}m{st}\u001b[0m" if color is not None else st # replace the termcolor library with one line # noqa: E501
|
|
def colorize_float(x: float): return colored(f"{x:7.2f}x", 'green' if x < 0.75 else 'red' if x > 1.15 else 'yellow')
|
|
def time_to_str(t:float, w=8) -> str: return next((f"{t * d:{w}.2f}{pr}" for d,pr in [(1, "s "),(1e3, "ms")] if t > 10/d), f"{t * 1e6:{w}.2f}us")
|
|
def ansistrip(s:str): return re.sub('\x1b\\[(K|.*?m)', '', s)
|
|
def ansilen(s:str): return len(ansistrip(s))
|
|
def make_tuple(x:Union[int, Sequence[int]], cnt:int) -> tuple[int, ...]: return (x,)*cnt if isinstance(x, int) else tuple(x)
|
|
def flatten(l:Iterable[Iterable[T]]): return [item for sublist in l for item in sublist]
|
|
def fully_flatten(l):
|
|
if hasattr(l, "__len__") and hasattr(l, "__getitem__") and not isinstance(l, str):
|
|
if hasattr(l, "shape") and l.shape == (): return [l[()]]
|
|
flattened = []
|
|
for li in l: flattened.extend(fully_flatten(li))
|
|
return flattened
|
|
return [l]
|
|
def fromimport(mod, frm): return getattr(__import__(mod, fromlist=[frm]), frm)
|
|
def strip_parens(fst:str): return fst[1:-1] if fst[0] == '(' and fst[-1] == ')' and fst[1:-1].find('(') <= fst[1:-1].find(')') else fst
|
|
def ceildiv(num, amt): return int(ret) if isinstance((ret:=-(num//-amt)), float) else ret
|
|
def round_up(num:int, amt:int) -> int: return (num+amt-1)//amt * amt
|
|
# cstyle div and mod
|
|
def cdiv(x:int, y:int) -> int: return abs(x)//abs(y)*(1,-1)[x*y<0] if y != 0 else 0
|
|
def cmod(x:int, y:int) -> int: return x-cdiv(x,y)*y
|
|
def lo32(x:Any) -> Any: return x & 0xFFFFFFFF # Any is sint
|
|
def hi32(x:Any) -> Any: return x >> 32 # Any is sint
|
|
def data64(data:Any) -> tuple[Any, Any]: return (data >> 32, data & 0xFFFFFFFF) # Any is sint
|
|
def data64_le(data:Any) -> tuple[Any, Any]: return (data & 0xFFFFFFFF, data >> 32) # Any is sint
|
|
def getbits(value: int, start: int, end: int): return (value >> start) & ((1 << end-start+1) - 1)
|
|
def i2u(bits: int, value: int): return value if value >= 0 else (1<<bits)+value
|
|
def merge_dicts(ds:Iterable[dict[T,U]]) -> dict[T,U]:
|
|
kvs = set([(k,v) for d in ds for k,v in d.items()])
|
|
assert len(kvs) == len(set(kv[0] for kv in kvs)), f"cannot merge, {kvs} contains different values for the same key"
|
|
return {k:v for d in ds for k,v in d.items()}
|
|
def partition(itr:Iterable[T], fxn:Callable[[T],bool]) -> tuple[list[T], list[T]]:
|
|
ret:tuple[list[T], list[T]] = ([], [])
|
|
for s in itr: (ret[0] if fxn(s) else ret[1]).append(s)
|
|
return ret
|
|
def unwrap(x:Optional[T]) -> T:
|
|
assert x is not None
|
|
return x
|
|
def get_single_element(x:list[T]) -> T:
|
|
assert len(x) == 1, f"list {x} must only have 1 element"
|
|
return x[0]
|
|
def get_child(obj, key):
|
|
for k in key.split('.'):
|
|
if k.isnumeric(): obj = obj[int(k)]
|
|
elif isinstance(obj, dict): obj = obj[k]
|
|
else: obj = getattr(obj, k)
|
|
return obj
|
|
def word_wrap(x, wrap=80): return x if len(x) <= wrap or '\n' in x[0:wrap] else (x[0:wrap] + "\n" + word_wrap(x[wrap:], wrap))
|
|
def pluralize(st:str, cnt:int): return f"{cnt} {st}"+('' if cnt == 1 else 's')
|
|
|
|
# for length N coefficients `p`, returns p[0] * x**(N-1) + p[1] * x**(N-2) + ... + p[-2] * x + p[-1]
|
|
def polyN(x:T, p:list[float]) -> T: return functools.reduce(lambda acc,c: acc*x+c, p, 0.0) # type: ignore
|
|
|
|
@functools.lru_cache(maxsize=None)
|
|
def to_function_name(s:str): return ''.join([c if c in (string.ascii_letters+string.digits+'_') else f'{ord(c):02X}' for c in ansistrip(s)])
|
|
@functools.lru_cache(maxsize=None)
|
|
def getenv(key:str, default=0): return type(default)(os.getenv(key, default))
|
|
def temp(x:str, append_user:bool=False) -> str:
|
|
return (pathlib.Path(tempfile.gettempdir()) / (f"{x}.{getpass.getuser()}" if append_user else x)).as_posix()
|
|
|
|
class Context(contextlib.ContextDecorator):
|
|
def __init__(self, **kwargs): self.kwargs = kwargs
|
|
def __enter__(self):
|
|
self.old_context:dict[str, int] = {k:v.value for k,v in ContextVar._cache.items()}
|
|
for k,v in self.kwargs.items(): ContextVar._cache[k].value = v
|
|
def __exit__(self, *args):
|
|
for k,v in self.old_context.items(): ContextVar._cache[k].value = v
|
|
|
|
class ContextVar:
|
|
_cache: ClassVar[dict[str, ContextVar]] = {}
|
|
value: int
|
|
key: str
|
|
def __init__(self, key, default_value):
|
|
if key in ContextVar._cache: raise RuntimeError(f"attempt to recreate ContextVar {key}")
|
|
ContextVar._cache[key] = self
|
|
self.value, self.key = getenv(key, default_value), key
|
|
def __bool__(self): return bool(self.value)
|
|
def __ge__(self, x): return self.value >= x
|
|
def __gt__(self, x): return self.value > x
|
|
def __lt__(self, x): return self.value < x
|
|
|
|
DEBUG, IMAGE, BEAM, NOOPT = ContextVar("DEBUG", 0), ContextVar("IMAGE", 0), ContextVar("BEAM", 0), ContextVar("NOOPT", 0)
|
|
JIT = ContextVar("JIT", 2 if platform.system() == 'Darwin' and ('Intel' in platform.processor() or 'i386' in platform.processor()) else 1)
|
|
WINO, CAPTURING, TRACEMETA = ContextVar("WINO", 0), ContextVar("CAPTURING", 1), ContextVar("TRACEMETA", 1)
|
|
USE_TC, TC_SELECT, TC_OPT, AMX = ContextVar("TC", 1), ContextVar("TC_SELECT", -1), ContextVar("TC_OPT", 0), ContextVar("AMX", 0)
|
|
TRANSCENDENTAL, TC_SEARCH_OVER_SHAPE = ContextVar("TRANSCENDENTAL", 1), ContextVar("TC_SEARCH_OVER_SHAPE", 1)
|
|
FUSE_ARANGE, FUSE_CONV_BW = ContextVar("FUSE_ARANGE", 0), ContextVar("FUSE_CONV_BW", 0)
|
|
SPLIT_REDUCEOP, NO_MEMORY_PLANNER, RING = ContextVar("SPLIT_REDUCEOP", 1), ContextVar("NO_MEMORY_PLANNER", 0), ContextVar("RING", 1)
|
|
PICKLE_BUFFERS, PROFILE, LRU = ContextVar("PICKLE_BUFFERS", 1), ContextVar("PROFILE", getenv("VIZ")), ContextVar("LRU", 1)
|
|
CACHELEVEL, IGNORE_BEAM_CACHE, DEVECTORIZE = ContextVar("CACHELEVEL", 2), ContextVar("IGNORE_BEAM_CACHE", 0), ContextVar("DEVECTORIZE", 1)
|
|
DONT_REALIZE_EXPAND, DONT_GROUP_REDUCES = ContextVar("DONT_REALIZE_EXPAND", 0), ContextVar("DONT_GROUP_REDUCES", 0)
|
|
QUANTIZE, VALIDATE_WITH_CPU = ContextVar("QUANTIZE", 0), ContextVar("VALIDATE_WITH_CPU", 0)
|
|
|
|
@dataclass(frozen=True)
|
|
class Metadata:
|
|
name: str
|
|
caller: str
|
|
backward: bool = False
|
|
def __hash__(self): return hash(self.name)
|
|
def __repr__(self): return str(self) + (f" - {self.caller}" if self.caller else "")
|
|
def __str__(self): return self.name + (" bw" if self.backward else "")
|
|
_METADATA: contextvars.ContextVar[Optional[Metadata]] = contextvars.ContextVar("_METADATA", default=None)
|
|
|
|
# **************** global state Counters ****************
|
|
|
|
class GlobalCounters:
|
|
global_ops: ClassVar[int] = 0
|
|
global_mem: ClassVar[int] = 0
|
|
time_sum_s: ClassVar[float] = 0.0
|
|
kernel_count: ClassVar[int] = 0
|
|
mem_used: ClassVar[int] = 0 # NOTE: this is not reset
|
|
@staticmethod
|
|
def reset(): GlobalCounters.global_ops, GlobalCounters.global_mem, GlobalCounters.time_sum_s, GlobalCounters.kernel_count = 0,0,0.0,0
|
|
|
|
# **************** timer and profiler ****************
|
|
|
|
class Timing(contextlib.ContextDecorator):
|
|
def __init__(self, prefix="", on_exit=None, enabled=True): self.prefix, self.on_exit, self.enabled = prefix, on_exit, enabled
|
|
def __enter__(self): self.st = time.perf_counter_ns()
|
|
def __exit__(self, *exc):
|
|
self.et = time.perf_counter_ns() - self.st
|
|
if self.enabled: print(f"{self.prefix}{self.et*1e-6:6.2f} ms"+(self.on_exit(self.et) if self.on_exit else ""))
|
|
|
|
def _format_fcn(fcn): return f"{fcn[0]}:{fcn[1]}:{fcn[2]}"
|
|
class Profiling(contextlib.ContextDecorator):
|
|
def __init__(self, enabled=True, sort='cumtime', frac=0.2, fn=None, ts=1):
|
|
self.enabled, self.sort, self.frac, self.fn, self.time_scale = enabled, sort, frac, fn, 1e3/ts
|
|
def __enter__(self):
|
|
import cProfile
|
|
self.pr = cProfile.Profile()
|
|
if self.enabled: self.pr.enable()
|
|
def __exit__(self, *exc):
|
|
if self.enabled:
|
|
self.pr.disable()
|
|
if self.fn: self.pr.dump_stats(self.fn)
|
|
import pstats
|
|
stats = pstats.Stats(self.pr).strip_dirs().sort_stats(self.sort)
|
|
for fcn in stats.fcn_list[0:int(len(stats.fcn_list)*self.frac)]: # type: ignore[attr-defined]
|
|
(_primitive_calls, num_calls, tottime, cumtime, callers) = stats.stats[fcn] # type: ignore[attr-defined]
|
|
scallers = sorted(callers.items(), key=lambda x: -x[1][2])
|
|
print(f"n:{num_calls:8d} tm:{tottime*self.time_scale:7.2f}ms tot:{cumtime*self.time_scale:7.2f}ms",
|
|
colored(_format_fcn(fcn).ljust(50), "yellow"),
|
|
colored(f"<- {(scallers[0][1][2]/tottime)*100:3.0f}% {_format_fcn(scallers[0][0])}", "BLACK") if scallers else '')
|
|
|
|
# *** universal database cache ***
|
|
|
|
cache_dir: str = os.path.join(getenv("XDG_CACHE_HOME", os.path.expanduser("~/Library/Caches" if OSX else "~/.cache")), "tinygrad")
|
|
CACHEDB: str = getenv("CACHEDB", os.path.abspath(os.path.join(cache_dir, "cache.db")))
|
|
|
|
VERSION = 19
|
|
_db_connection = None
|
|
def db_connection():
|
|
global _db_connection
|
|
if _db_connection is None:
|
|
os.makedirs(CACHEDB.rsplit(os.sep, 1)[0], exist_ok=True)
|
|
_db_connection = sqlite3.connect(CACHEDB, timeout=60, isolation_level="IMMEDIATE")
|
|
# another connection has set it already or is in the process of setting it
|
|
# that connection will lock the database
|
|
with contextlib.suppress(sqlite3.OperationalError): _db_connection.execute("PRAGMA journal_mode=WAL").fetchone()
|
|
if DEBUG >= 7: _db_connection.set_trace_callback(print)
|
|
return _db_connection
|
|
|
|
def diskcache_clear():
|
|
cur = db_connection().cursor()
|
|
drop_tables = cur.execute("SELECT 'DROP TABLE IF EXISTS ' || quote(name) || ';' FROM sqlite_master WHERE type = 'table';").fetchall()
|
|
cur.executescript("\n".join([s[0] for s in drop_tables] + ["VACUUM;"]))
|
|
|
|
def diskcache_get(table:str, key:Union[dict, str, int]) -> Any:
|
|
if CACHELEVEL < 1: return None
|
|
if isinstance(key, (str,int)): key = {"key": key}
|
|
cur = db_connection().cursor()
|
|
try:
|
|
res = cur.execute(f"SELECT val FROM '{table}_{VERSION}' WHERE {' AND '.join([f'{x}=?' for x in key.keys()])}", tuple(key.values()))
|
|
except sqlite3.OperationalError:
|
|
return None # table doesn't exist
|
|
if (val:=res.fetchone()) is not None: return pickle.loads(val[0])
|
|
return None
|
|
|
|
_db_tables = set()
|
|
def diskcache_put(table:str, key:Union[dict, str, int], val:Any, prepickled=False):
|
|
if CACHELEVEL < 1: return val
|
|
if isinstance(key, (str,int)): key = {"key": key}
|
|
conn = db_connection()
|
|
cur = conn.cursor()
|
|
if table not in _db_tables:
|
|
TYPES = {str: "text", bool: "integer", int: "integer", float: "numeric", bytes: "blob"}
|
|
ltypes = ', '.join(f"{k} {TYPES[type(key[k])]}" for k in key.keys())
|
|
cur.execute(f"CREATE TABLE IF NOT EXISTS '{table}_{VERSION}' ({ltypes}, val blob, PRIMARY KEY ({', '.join(key.keys())}))")
|
|
_db_tables.add(table)
|
|
cur.execute(f"REPLACE INTO '{table}_{VERSION}' ({', '.join(key.keys())}, val) VALUES ({', '.join(['?']*len(key))}, ?)", tuple(key.values()) + (val if prepickled else pickle.dumps(val), )) # noqa: E501
|
|
conn.commit()
|
|
cur.close()
|
|
return val
|
|
|
|
def diskcache(func):
|
|
def wrapper(*args, **kwargs) -> bytes:
|
|
table, key = f"cache_{func.__name__}", hashlib.sha256(pickle.dumps((args, kwargs))).hexdigest()
|
|
if (ret:=diskcache_get(table, key)) is not None: return ret
|
|
return diskcache_put(table, key, func(*args, **kwargs))
|
|
return wrapper
|
|
|
|
# *** process replay ***
|
|
|
|
CAPTURE_PROCESS_REPLAY = getenv("RUN_PROCESS_REPLAY") or getenv("CAPTURE_PROCESS_REPLAY")
|
|
|
|
# *** http support ***
|
|
|
|
def _ensure_downloads_dir() -> pathlib.Path:
|
|
# if we are on a tinybox, use the raid array
|
|
if pathlib.Path("/etc/tinybox-release").is_file():
|
|
# try creating dir with sudo
|
|
if not (downloads_dir := pathlib.Path("/raid/downloads")).exists():
|
|
subprocess.run(["sudo", "mkdir", "-p", downloads_dir], check=True)
|
|
subprocess.run(["sudo", "chown", "tiny:root", downloads_dir], check=True)
|
|
subprocess.run(["sudo", "chmod", "775", downloads_dir], check=True)
|
|
return downloads_dir
|
|
return pathlib.Path(cache_dir) / "downloads"
|
|
|
|
def fetch(url:str, name:Optional[Union[pathlib.Path, str]]=None, subdir:Optional[str]=None, gunzip:bool=False,
|
|
allow_caching=not getenv("DISABLE_HTTP_CACHE")) -> pathlib.Path:
|
|
if url.startswith(("/", ".")): return pathlib.Path(url)
|
|
if name is not None and (isinstance(name, pathlib.Path) or '/' in name): fp = pathlib.Path(name)
|
|
else: fp = _ensure_downloads_dir() / (subdir or "") / ((name or hashlib.md5(url.encode('utf-8')).hexdigest()) + (".gunzip" if gunzip else ""))
|
|
if not fp.is_file() or not allow_caching:
|
|
(_dir := fp.parent).mkdir(parents=True, exist_ok=True)
|
|
with urllib.request.urlopen(url, timeout=10) as r:
|
|
assert r.status == 200, r.status
|
|
length = int(r.headers.get('content-length', 0)) if not gunzip else None
|
|
readfile = gzip.GzipFile(fileobj=r) if gunzip else r
|
|
progress_bar:tqdm = tqdm(total=length, unit='B', unit_scale=True, desc=f"{url}", disable=CI)
|
|
with tempfile.NamedTemporaryFile(dir=_dir, delete=False) as f:
|
|
while chunk := readfile.read(16384): progress_bar.update(f.write(chunk))
|
|
f.close()
|
|
pathlib.Path(f.name).rename(fp)
|
|
progress_bar.update(close=True)
|
|
if length and (file_size:=os.stat(fp).st_size) < length: raise RuntimeError(f"fetch size incomplete, {file_size} < {length}")
|
|
return fp
|
|
|
|
# *** Exec helpers
|
|
|
|
def cpu_time_execution(cb, enable):
|
|
if enable: st = time.perf_counter()
|
|
cb()
|
|
if enable: return time.perf_counter()-st
|
|
|
|
def cpu_objdump(lib, objdump_tool='objdump'):
|
|
with tempfile.NamedTemporaryFile(delete=True) as f:
|
|
pathlib.Path(f.name).write_bytes(lib)
|
|
print(subprocess.check_output([objdump_tool, '-d', f.name]).decode('utf-8'))
|
|
|
|
def capstone_flatdump(lib: bytes):
|
|
import capstone
|
|
match platform.machine():
|
|
case 'x86_64' | 'AMD64': cs = capstone.Cs(capstone.CS_ARCH_X86, capstone.CS_MODE_64)
|
|
case 'aarch64' | 'arm64': cs = capstone.Cs(capstone.CS_ARCH_ARM64, capstone.CS_MODE_ARM)
|
|
case machine: raise NotImplementedError(f"Capstone disassembly isn't supported for {machine}")
|
|
cs.skipdata = True
|
|
for instr in cs.disasm(lib, 0):
|
|
print(f"{instr.address:#08x}: {instr.mnemonic}\t{instr.op_str}")
|
|
sys.stdout.flush()
|
|
|
|
# *** ctypes helpers
|
|
|
|
# TODO: make this work with read only memoryviews (if possible)
|
|
def from_mv(mv:memoryview, to_type=ctypes.c_char):
|
|
return ctypes.cast(ctypes.addressof(to_type.from_buffer(mv)), ctypes.POINTER(to_type * len(mv))).contents
|
|
def to_mv(ptr:int, sz:int) -> memoryview: return memoryview(ctypes.cast(ptr, ctypes.POINTER(ctypes.c_uint8 * sz)).contents).cast("B")
|
|
def mv_address(mv): return ctypes.addressof(ctypes.c_char.from_buffer(mv))
|
|
def to_char_p_p(options: list[bytes], to_type=ctypes.c_char):
|
|
return (ctypes.POINTER(to_type) * len(options))(*[ctypes.cast(ctypes.create_string_buffer(o), ctypes.POINTER(to_type)) for o in options])
|
|
@functools.lru_cache(maxsize=None)
|
|
def init_c_struct_t(fields: tuple[tuple[str, ctypes._SimpleCData], ...]):
|
|
class CStruct(ctypes.Structure):
|
|
_pack_, _fields_ = 1, fields
|
|
return CStruct
|
|
def init_c_var(ctypes_var, creat_cb): return (creat_cb(ctypes_var), ctypes_var)[1]
|
|
def flat_mv(mv:memoryview): return mv if len(mv) == 0 else mv.cast("B", shape=(mv.nbytes,))
|
|
|
|
# *** tqdm
|
|
|
|
class tqdm(Generic[T]):
|
|
def __init__(self, iterable:Iterable[T]|None=None, desc:str='', disable:bool=False,
|
|
unit:str='it', unit_scale=False, total:Optional[int]=None, rate:int=100):
|
|
self.iterable, self.disable, self.unit, self.unit_scale, self.rate = iterable, disable, unit, unit_scale, rate
|
|
self.st, self.i, self.n, self.skip, self.t = time.perf_counter(), -1, 0, 1, getattr(iterable, "__len__", lambda:0)() if total is None else total
|
|
self.set_description(desc)
|
|
self.update(0)
|
|
def __iter__(self) -> Iterator[T]:
|
|
assert self.iterable is not None, "need an iterable to iterate"
|
|
for item in self.iterable:
|
|
yield item
|
|
self.update(1)
|
|
self.update(close=True)
|
|
def __enter__(self): return self
|
|
def __exit__(self, *_): self.update(close=True)
|
|
def set_description(self, desc:str): self.desc = f"{desc}: " if desc else ""
|
|
def update(self, n:int=0, close:bool=False):
|
|
self.n, self.i = self.n+n, self.i+1
|
|
if self.disable or (not close and self.i % self.skip != 0): return
|
|
prog, elapsed, ncols = self.n/self.t if self.t else 0, time.perf_counter()-self.st, shutil.get_terminal_size().columns
|
|
if elapsed and self.i/elapsed > self.rate and self.i: self.skip = max(int(self.i/elapsed)//self.rate,1)
|
|
def HMS(t): return ':'.join(f'{x:02d}' if i else str(x) for i,x in enumerate([int(t)//3600,int(t)%3600//60,int(t)%60]) if i or x)
|
|
def SI(x):
|
|
return (f"{x/1000**int(g:=round(math.log(x,1000),6)):.{int(3-3*math.fmod(g,1))}f}"[:4].rstrip('.')+' kMGTPEZY'[int(g)].strip()) if x else '0.00'
|
|
prog_text = f'{SI(self.n)}{f"/{SI(self.t)}" if self.t else self.unit}' if self.unit_scale else f'{self.n}{f"/{self.t}" if self.t else self.unit}'
|
|
est_text = f'<{HMS(elapsed/prog-elapsed) if self.n else "?"}' if self.t else ''
|
|
it_text = (SI(self.n/elapsed) if self.unit_scale else f"{self.n/elapsed:5.2f}") if self.n else "?"
|
|
suf = f'{prog_text} [{HMS(elapsed)}{est_text}, {it_text}{self.unit}/s]'
|
|
sz = max(ncols-len(self.desc)-3-2-2-len(suf), 1)
|
|
bar = '\r' + self.desc + (f'{100*prog:3.0f}%|{("█"*int(num:=sz*prog)+" ▏▎▍▌▋▊▉"[int(8*num)%8].strip()).ljust(sz," ")}| ' if self.t else '') + suf
|
|
print(bar[:ncols+1], flush=True, end='\n'*close, file=sys.stderr)
|
|
@classmethod
|
|
def write(cls, s:str): print(f"\r\033[K{s}", flush=True, file=sys.stderr)
|
|
|
|
class trange(tqdm):
|
|
def __init__(self, n:int, **kwargs): super().__init__(iterable=range(n), total=n, **kwargs)
|
|
|
|
# *** universal support for code object pickling
|
|
|
|
def _reconstruct_code(*args): return types.CodeType(*args)
|
|
def _serialize_code(code:types.CodeType):
|
|
args = inspect.signature(types.CodeType).parameters # NOTE: this works in Python 3.10 and up
|
|
return _reconstruct_code, tuple(code.__getattribute__('co_'+x.replace('codestring', 'code').replace('constants', 'consts')) for x in args)
|
|
copyreg.pickle(types.CodeType, _serialize_code)
|
|
|
|
def _serialize_module(module:types.ModuleType): return importlib.import_module, (module.__name__,)
|
|
copyreg.pickle(types.ModuleType, _serialize_module)
|
|
|