from tinygrad.dtype import DType, PtrDType, dtypes from tinygrad.ops import UOp, Ops, PatternMatcher, UPat from tinygrad.renderer.cstyle import CStyleLanguage, base_rewrite, extra_pm from tinygrad.helpers import strip_parens import math def sign_extend(val:UOp, sext_am:int): return (UOp.where((val >> (sext_am - 1)) > 0, UOp.const(dtypes.uint32, 0xffffffff) << sext_am, UOp.const(dtypes.uint32, 0)) \ | val.bitcast(dtypes.uint32)).bitcast(dtypes.int) # store for char: buf[idx/4] <- (var << (idx%4)*8)) def packed_store(bidx:UOp, var:UOp): shift_am = (bidx.src[1].cast(dtypes.uint32)%UOp.const(dtypes.uint32, 4//var.dtype.itemsize))*UOp.const(dtypes.uint32, 8*var.dtype.itemsize) new_v = (var & (0xFF if var.dtype.itemsize == 1 else 0xFFFF)).cast(dtypes.uint32) << shift_am mask = (((0xFF if var.dtype.itemsize == 1 else 0xFFFF) << shift_am) ^ 0xFFFFFFFF).cast(dtypes.uint32) buf = UOp.load(UOp(Ops.INDEX, bidx.dtype, (bidx.src[0], bidx.src[1]//(4//var.dtype.itemsize))), dtype=dtypes.uint32) return UOp.store(UOp(Ops.INDEX, bidx.dtype, (bidx.src[0], bidx.src[1]//(4//var.dtype.itemsize))), ((buf & mask) | new_v.cast(dtypes.uint32))) # load for char: sign_extend(buf[idx/4] >> ((idx%4)*8)) def packed_load(root:UOp, bidx:UOp, dtype:DType, var:UOp|None=None): div_idx = bidx.src[1]//(4//dtype.itemsize) shift_am = (bidx.src[1].cast(dtypes.uint32)%UOp.const(dtypes.uint32, 4//dtype.itemsize))*UOp.const(dtypes.uint32, 8*dtype.itemsize) if var is not None: load = UOp.load(UOp(Ops.INDEX, bidx.dtype, (bidx.src[0], div_idx, bidx.src[2])), var, dtype=dtypes.uint32, arg=root.arg) else: load = UOp.load(UOp(Ops.INDEX, bidx.dtype, (bidx.src[0], div_idx)), *root.src[1:], dtype=dtypes.uint32, arg=root.arg) val = (load.cast(dtypes.uint32) >> shift_am) & (0xFF if dtype.itemsize == 1 else 0xFFFF) return sign_extend(val, 8*dtype.itemsize).cast(dtype) if dtype in [dtypes.char, dtypes.short] else val.cast(dtype) def is_packed(dt:DType) -> bool: return dt.itemsize < 4 and dt.base != dtypes.half wgsl_matcher = PatternMatcher([ (UPat((Ops.CMPLT, Ops.XOR), src=(UPat(name="a", dtype=dtypes.bool), UPat.var("b")), name="c"), lambda a,b,c: a.cast(dtypes.int).alu(c.op, b.cast(dtypes.int)).cast(dtypes.bool)), (UPat(Ops.LOAD, name="l", src=(UPat.var("b"),)), lambda l,b: packed_load(l, b, l.dtype) if is_packed(l.dtype) else None), (UPat(Ops.LOAD, name="l", src=(UPat.var("b"), UPat.cvar("c"))), lambda l,b,c: packed_load(l,b,l.dtype,c.cast(dtypes.uint32)) if is_packed(l.dtype) else None), (UPat.store(UPat.var("bidx"), UPat.var("var"), allow_any_len=True), lambda bidx,var: packed_store(bidx,var) if is_packed(var.dtype) else None), # TODO: why is this needed, and only for this MUL order (UPat(Ops.MUL, src=(UPat.var("a"), UPat.var("g").where(UPat.cvar("c1"), UPat.cvar("c2")))), lambda a,g,c1,c2: g.where(c1, a) if math.isnan(c1.arg) and c2.arg == 1.0 else None), (UPat.var("a") << UPat.var("b"),lambda a,b:(a.bitcast(dtypes.uint32)<({x.arg})" if x.arg < 0 else f"{x.arg&0xFFFFFFFF}u"), (UPat(Ops.DEFINE_LOCAL, name="x"), lambda ctx,x: f"var {ctx[x]}: array<{ctx.buf_map(x.dtype.base)}, {x.dtype.size}>;"), (UPat(Ops.BITCAST, dtype=dtypes.half, name="x", src=(UPat(dtype=(dtypes.short, dtypes.ushort, dtypes.uint32),),)), lambda ctx,x: f"bitcast>({ctx[x.src[0]]})[0]"), (UPat(Ops.BITCAST, dtype=(dtypes.char, dtypes.uchar), name="x"), lambda ctx,x: f"bitcast<{ctx.type_map[x.dtype]}>({ctx[x.src[0]]}&0xFF)"), (UPat(Ops.BITCAST, dtype=(dtypes.short, dtypes.ushort), name="x"),lambda ctx,x:f"bitcast<{ctx.type_map[x.dtype]}>(vec2({ctx[x.src[0]]},0))" \ if x.src[0].dtype == dtypes.half else f"bitcast<{ctx.type_map[x.dtype]}>({ctx[x.src[0]]}&0xFFFF)"), (UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"bitcast<{ctx.type_map[x.dtype]}>({ctx[x.src[0]]})"), (UPat.load(UPat.var("b"), UPat.cvar("v")),lambda ctx,b,v: f"select({ctx[v]}, {ctx.render_load(ctx[b],b.src[0].dtype)}, {ctx[b.src[2]]})"), (UPat.load(UPat.var("b"), allow_any_len=True), lambda ctx, b: ctx.render_load(ctx[b], b.dtype)), (UPat.store(UPat.var("b"), UPat.var("v"), allow_any_len=True),lambda ctx,b,v:\ # (load & mask) | var -> mask = v.src[0].src[1], var = v.src[1] f"atomicAnd(&{ctx[b]},{ctx[v.src[0].src[1]]});\n atomicAdd(&{ctx[b]},{ctx[v.src[1]]});" if is_packed(b.src[0].dtype) \ else f"{ctx[b]} = {ctx[v]};"), (UPat(Ops.INDEX, src=(UPat.var("b"), UPat.var("idx")), allow_any_len=True), lambda ctx,b,idx: f"{ctx[b]}[{strip_parens(ctx[idx]) if idx.arg is Ops.ADD else ctx[idx]}]"), # fix nan check: 'a != a -> is_nan()' (UPat.var("a") != UPat.var("a"), lambda ctx,a: f"(min({ctx[a]}, 1.0) == 1.0 && max({ctx[a]}, -1.0) == -1.0)"), ]) + base_rewrite def render_cast(self, dt:DType, val: str) -> str: return f"{self.type_map[dt]}({val})" def render_dtype(self, dt:DType, mutable=True) -> str: return "var" def render_load(self, x:str, dt:DType) -> str: return f"atomicLoad(&{x})" if is_packed(dt) else x def buf_map(self, dt:DType) -> str: return "atomic" if is_packed(dt) else self.type_map[dt.base] def render_kernel(self, function_name:str, kernel:list[str], bufs:list[tuple[str,tuple[DType,bool]]], uops:list[UOp], prefix=None) -> str: local_size = [num for _, num in sorted([u.arg for u in uops if u.op is Ops.SPECIAL and u.arg[0][0] == 'l'], key=lambda x: x[0])] if not local_size: local_size = [1] bind_it = iter(range(len(bufs))) external_local_bufs = [line.lstrip() for line in kernel if "var" in line] kernel[:] = [line for line in kernel if "var" not in line] prg = "enable f16;\n" if any(uop.dtype.base == dtypes.half for uop in uops) else "" prg += "fn nan() -> f32 { let bits = 0xffffffffu; return bitcast(bits); }\n" prg += "@group(0) @binding(0)\nvar INFINITY : f32;\n" prg += "\n".join((external_local_bufs or [])+[f"@group(0) @binding({next(bind_it)+1})" + f"{'var' if isinstance(dtype, PtrDType) else 'var'}" + f"{name}:{f'array<{self.buf_map(dtype.base)}>' if isinstance(dtype,PtrDType) else self.buf_map(dtype)};" for name,(dtype,_) in bufs]) prg += f"\n@compute @workgroup_size({','.join([str(x) for x in local_size])}) fn {function_name}(@builtin(workgroup_id) gindex: vec3," return prg + "@builtin(local_invocation_id) lindex: vec3) {\n" + "\n".join(kernel) + "\n}"