from typing import List, cast, Dict, Callable import numpy as np from tinygrad.ops import ScheduleItem, LazyOp, LoadOps, Device, BufferOps from tinygrad.graph import log_schedule_item, print_tree from tinygrad.lazy import LazyBuffer from tinygrad.helpers import DEBUG, prod, all_int, getenv, IMAGE from tinygrad.runtime.lib import RawBufferMapped, RawBufferTransfer from tinygrad.runtime.ops_disk import RawDiskBuffer from tinygrad.features.image import fix_schedule_for_images def run_schedule(schedule:List[ScheduleItem], disable_logging=False): # HACK: images can be not usable due to shape if IMAGE >= 2: schedule = fix_schedule_for_images(schedule) # NOTE: if you for loop the schedule it's slow because nothing frees while len(schedule): si = schedule.pop(0) if not disable_logging: log_schedule_item(si) assert all(x.realized for x in si.inputs), "can't run schedule, some inputs aren't realized" if DEBUG >= 3: print_tree(si.ast) if si.ast.op in LoadOps: # confirm the LoadOps are contiguous and in order for i,s in enumerate(si.ast.src): assert isinstance(s, LazyOp) and s.op == BufferOps.MEM and s.arg.idx == i+1 and s.arg.st.contiguous, f"bad LoadOps src {i}: {s}" LOAD_OPS_DISPATCHER[cast(LoadOps, si.ast.op)](si.out, *si.inputs) else: si.out.realized = Device[si.out.device].exec_ast(si.ast, output=si.out, inputs=si.inputs, var_vals=si.var_vals, **si.out._device_extra_args()) del si.out.op for v in si.out.views: del v.op assert si.out.realized and isinstance(si.out.realized, Device[si.out.device].buffer), f"device mismatch on realized got {type(si.out.realized)} expected {si.out.device}" assert si.out.realized.dtype == si.out.dtype, "realized dtype is incorrect" # *** zero op LoadOps *** def _realize_empty(buffer: LazyBuffer) -> None: assert all_int(buffer.shape), "does not support symbolic shape" if DEBUG >= 2: print(f"*** empty {buffer.device} shape {str(buffer.shape):23s} dtype {buffer.dtype}") buffer.realized = Device[buffer.device].buffer(prod(buffer.shape), buffer.dtype, **buffer._device_extra_args()) def _realize_rand(buffer: LazyBuffer) -> None: assert all_int(buffer.shape), "does not support symbolic shape" if DEBUG >= 2: print(f"*** rand {buffer.device} seed {buffer.op.arg:<10d} shape {str(buffer.shape):23s} dtype {buffer.dtype}") rng = np.random.default_rng(buffer.op.arg) buffer.realized = Device[buffer.device].buffer.fromCPU(rng.random(size=prod(buffer.shape), dtype=np.float32).astype(dtype=buffer.dtype.np, copy=False), **buffer._device_extra_args()) # *** one op LoadOps *** def _realize_from(buffer: LazyBuffer, src: LazyBuffer) -> None: assert src.realized.size == buffer.st.size(), f"size mismatch on FROM {src.realized.size} != {buffer.st.size()}" assert src.st.contiguous and buffer.st.contiguous, "all must be contiguous for from" if DEBUG >= 2: print(f"*** copy {buffer.device} <- {src.device} size {src.realized.size:<16d} shape {str(buffer.shape):23s} dtype {src.realized.dtype}") # TODO: make this generic if isinstance(src.realized, RawDiskBuffer) and issubclass(Device[buffer.device].buffer, RawBufferMapped): assert all_int(buffer.shape), "does not support symbolic shape" buffer.realized = Device[buffer.device].buffer(prod(buffer.shape), buffer.dtype, **buffer._device_extra_args()) src.realized.readinto(cast(RawBufferMapped, buffer.realized)._buffer()) elif isinstance(src.realized, RawBufferTransfer) and issubclass(Device[buffer.device].buffer, RawBufferTransfer) and getenv("P2P", 0) >= 1: buffer.realized = cast(RawBufferTransfer, Device[buffer.device].buffer).transfer(src.realized, buffer.shape, buffer.dtype, **buffer._device_extra_args()) else: # TODO: schedule this as FROM to go to CPU, and a FROM to go to device buffer.realized = Device[buffer.device].buffer.fromCPU(src.realized.toCPU(), **buffer._device_extra_args()) # *** n op LoadOps *** def _realize_custom(buffer: LazyBuffer, *inputs: LazyBuffer) -> None: if DEBUG >= 2: print(f"*** custom {buffer.device} shape {str(buffer.shape):23s} dtype {buffer.dtype}") buffer.realized = buffer.op.arg(buffer, *inputs) LOAD_OPS_DISPATCHER: Dict[LoadOps, Callable] = { LoadOps.EMPTY: _realize_empty, LoadOps.RAND: _realize_rand, LoadOps.FROM: _realize_from, LoadOps.CUSTOM: _realize_custom, }