You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
163 lines
7.3 KiB
163 lines
7.3 KiB
from typing import Any, Callable
|
|
import itertools, inspect, functools, types
|
|
from tinygrad.helpers import partition, dedup, Context
|
|
from tinygrad.ops import UPat, UPatAny, UOp, Ops, PatternMatcher, graph_rewrite, deconstruct_function
|
|
|
|
class UPatCompileError(Exception): pass
|
|
|
|
# **** UPat compiled ****
|
|
|
|
def _get_clause(self:UPat, base:UOp, depth=0) -> UOp:
|
|
if isinstance(self, UPatAny):
|
|
assert len(self.src) == 1
|
|
return UOp(Ops.AND, src=(UOp(Ops.OR, src=tuple(_get_clause(s, base, depth) for s in self.src[0])),))
|
|
# build the and_clause for acceptance
|
|
and_clause:list[UOp] = []
|
|
if self.op is not None:
|
|
if len(self.op) > 1: and_clause.append(UOp(Ops.CUSTOM, src=(base, UOp(Ops.BIND, arg=tuple(int(x) for x in self.op))), arg="{0}.op in {1}"))
|
|
else: and_clause.append(UOp(Ops.CUSTOM, src=(base,), arg="{0}.op == "+str(self.op[0].value)))
|
|
if self.arg is not None:
|
|
if isinstance(self.arg, int): and_clause.append(UOp(Ops.CUSTOM, src=(base,), arg="{0}.arg == "+str(int(self.arg))))
|
|
else: and_clause.append(UOp(Ops.CUSTOM, src=(base, UOp(Ops.BIND, arg=self.arg)), arg="{0}.arg == {1}"))
|
|
if self.strict_length or self.required_len > 0:
|
|
and_clause.append(UOp(Ops.CUSTOM, src=(base,), arg=("len({0}.src)"+(" == " if self.strict_length else " >= ")+str(self.required_len))))
|
|
if self.name is not None: and_clause.append(UOp(Ops.ASSIGN, src=(UOp(Ops.DEFINE_VAR, arg=self.name), base)))
|
|
if self.dtype is not None:
|
|
if len(self.dtype) > 1:
|
|
and_clause.append(UOp(Ops.CUSTOM, src=(base, UOp(Ops.BIND, arg=tuple(self.dtype))), arg="({0}.dtype in {1} or {0}.dtype._scalar in {1})"))
|
|
else: and_clause.append(UOp(Ops.CUSTOM, src=(base, UOp(Ops.BIND, arg=self.dtype[0])), arg="({0}.dtype == {1} or {0}.dtype._scalar == {1})"))
|
|
if self.src is not None:
|
|
# single match
|
|
if len(self.src) == 1 and isinstance(self.src[0], tuple):
|
|
and_clause += [_get_clause(s, base.gep(i), depth) for i,s in enumerate(self.src[0])]
|
|
# repeat match
|
|
elif len(self.src) == 1 and isinstance(self.src[0], itertools.repeat):
|
|
it = UOp(Ops.NOOP, arg=f"ituop{depth}")
|
|
match = _get_clause(next(self.src[0]), it, depth+1)
|
|
and_clause.append(UOp(Ops.RANGE, src=(match, it, base), arg="all([{0} for {1} in {2}.src])"))
|
|
# multi match (fork)
|
|
elif len(self.src) > 1 and all(isinstance(x, tuple) for x in self.src):
|
|
fork_cond = [UOp(Ops.AND, src=tuple([_get_clause(s, base.gep(i), depth) for i,s in enumerate(ss)])) for ss in self.src]
|
|
and_clause.append(UOp(Ops.OR, src=tuple(fork_cond)))
|
|
else: raise RuntimeError("broken")
|
|
return UOp(Ops.AND, src=tuple(and_clause))
|
|
|
|
# *** pattern matcher ***
|
|
|
|
def do_process_and(a:UOp) -> UOp|None:
|
|
found = False
|
|
new_src:list[UOp] = []
|
|
or_clause:list[UOp] = []
|
|
|
|
# remove any nested ANDs, extract or clauses
|
|
for x in a.src:
|
|
if x.op is Ops.AND:
|
|
new_src.extend(x.src)
|
|
found = True
|
|
elif x.op is Ops.OR: or_clause.append(x)
|
|
else: new_src.append(x)
|
|
|
|
# too big to compile
|
|
if len(or_clause) >= 4: raise UPatCompileError("too big to compile")
|
|
|
|
# one or clause max
|
|
if len(or_clause) > 1:
|
|
# need the product of the or clauses
|
|
or_clause = [UOp(Ops.OR, src=tuple([UOp(Ops.AND, src=x) for x in itertools.product(*[x.src for x in or_clause])]))]
|
|
found = True
|
|
|
|
# handle assigns
|
|
assigns, new_src = partition(new_src, lambda x: x.op is Ops.ASSIGN)
|
|
if len(assigns):
|
|
if len(or_clause):
|
|
# push assigns to the top if we have an or_clause
|
|
assert len(or_clause) == 1 and all(x.op is Ops.AND for x in or_clause[0].src)
|
|
or_clause = [UOp(Ops.OR, src=tuple([x.replace(src=x.src+tuple(assigns)) for x in or_clause[0].src]))]
|
|
found = True
|
|
else:
|
|
# check for duplicate assigns
|
|
dict_assigns: dict[UOp, UOp] = {}
|
|
for a in assigns:
|
|
if a.src[0] in dict_assigns:
|
|
# duplicate assign is a compare
|
|
new_src.append(UOp(Ops.CMPNE, src=(dict_assigns[a.src[0]], a.src[1])))
|
|
found = True
|
|
else:
|
|
dict_assigns[a.src[0]] = a.src[1]
|
|
# put the assigns back
|
|
for k,v in dict_assigns.items(): new_src.append(UOp(Ops.ASSIGN, src=(k,v)))
|
|
|
|
# reassemble, if there's any deduping to do, do it
|
|
if len(dretand:=dedup(new_src+or_clause)) != len(new_src)+len(or_clause): found = True
|
|
return UOp(Ops.AND, src=tuple(dretand)) if found else None
|
|
|
|
# processor
|
|
pm_proc = PatternMatcher([(UPat(Ops.AND, name="a"), do_process_and)], compiled=False)
|
|
|
|
# renderer
|
|
def wrap(ctx, x) -> UOp:
|
|
ctx[ret:=f"a{len(ctx)}"] = x.arg
|
|
return UOp(Ops.NOOP, arg=ret)
|
|
|
|
pm_renderer = PatternMatcher([
|
|
(UPat(Ops.BIND, name="x"), wrap),
|
|
|
|
# CMPNE is actually equal
|
|
(UPat(Ops.CMPNE, name="x"), lambda x: UOp(Ops.CUSTOM, src=x.src, arg="{0} is {1}")),
|
|
|
|
# RANGE can't have OR inside it
|
|
(UPat(Ops.RANGE, src=(UPat(Ops.AND, src=UPat(Ops.NOOP), name="x"), UPat(), UPat()), name="r"),
|
|
lambda r,x: r.replace(op=Ops.CUSTOM, src=(UOp(Ops.NOOP, arg="(" + ' and '.join(y.arg for y in x.src) + ")"),)+r.src[1:])),
|
|
|
|
(UPat(Ops.CUSTOM, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=x.arg.format(*[y.arg for y in x.src]))),
|
|
(UPat(Ops.GEP, src=UPat(Ops.NOOP, name="x"), name="g"), lambda x,g: x.replace(arg=x.arg+f".src[{g.arg[0]}]"))
|
|
], compiled=False)
|
|
|
|
def _final_render(x:UOp, has_ctx:bool, depth=1) -> list[str]:
|
|
assert x.op is Ops.AND
|
|
and_pieces, assign_pieces = [], []
|
|
or_pieces: list[str] = []
|
|
for s in x.src:
|
|
if s.op is Ops.OR:
|
|
assert len(or_pieces) == 0 and len(s.src) >= 1
|
|
for ss in s.src: or_pieces.extend(_final_render(ss, has_ctx, depth+1))
|
|
elif s.op is Ops.ASSIGN:
|
|
assert s.src[0].op is Ops.DEFINE_VAR and s.src[1].op is Ops.NOOP
|
|
assign_pieces.append(f"{s.src[0].arg}={s.src[1].arg}")
|
|
elif s.op is Ops.NOOP: and_pieces.append(s.arg)
|
|
else: raise UPatCompileError(f"can't compile this {s}")
|
|
# if we have an or, render it
|
|
if len(or_pieces):
|
|
assert len(assign_pieces) == 0
|
|
and_clause = ' and '.join(and_pieces)
|
|
return [f"{' '*depth}if {and_clause if len(and_clause) else 'True'}:"] + or_pieces
|
|
# if we don't, this is a final return
|
|
assign_clause = ', '.join((["ctx=ctx"] if has_ctx else [])+assign_pieces)
|
|
and_clause = ' and '.join(and_pieces + [f"(_ret:=_fxn({assign_clause})) is not None"])
|
|
return [f"{' '*depth}if {and_clause}: return _ret"]
|
|
|
|
def _get_code(self:UPat, has_ctx:bool):
|
|
ret = _get_clause(self, UOp(Ops.NOOP, arg="uop"))
|
|
try:
|
|
# TODO: this should be tracked in a "system" rewrite, not untracked or tracked with kernel
|
|
with Context(TRACK_MATCH_STATS=0):
|
|
ret = graph_rewrite(ret, pm_proc, name="process UPat")
|
|
dyn_lookup: dict[str, Any] = {}
|
|
out = graph_rewrite(ret, pm_renderer, ctx=dyn_lookup, name="compile UPat")
|
|
rendered = _final_render(out, has_ctx)
|
|
except UPatCompileError:
|
|
#print("FAILED", self, self.location)
|
|
return None
|
|
return '\n'.join([f"# match for {self.location}", "def compiled_match(uop, ctx):"] + rendered + [" return None"]), dyn_lookup
|
|
|
|
@functools.cache
|
|
def upat_compile(self:UPat, fxn) -> Callable|None:
|
|
real_fxn = types.FunctionType(*deconstruct_function(fxn))
|
|
code = _get_code(self, 'ctx' in inspect.signature(real_fxn).parameters)
|
|
if code is None: return None
|
|
code_str, dyn_lookup = code
|
|
globs = dyn_lookup.copy()
|
|
globs["_fxn"] = real_fxn
|
|
namespace: dict = {}
|
|
exec(code_str, globs, namespace) # pylint: disable=W0122
|
|
return namespace["compiled_match"]
|
|
|