You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
18 lines
712 B
18 lines
712 B
from dataclasses import replace
|
|
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, KernelInfo
|
|
from tinygrad.helpers import colored
|
|
from tinygrad.codegen.opt.kernel import axis_colors
|
|
|
|
def rename_sink(s:UOp):
|
|
if s.arg is not None and s.arg.name != "test": return None
|
|
|
|
# get all ranges (sorted)
|
|
rngs = sorted([u for u in s.parents if u.op is Ops.RANGE], key=lambda x: x.arg[0:-1])
|
|
|
|
# add name to kernel
|
|
name = "k" + colored('_', 'BLACK').join(['']+[colored(x.src[0].render(), axis_colors[x.arg[-1]]) for x in rngs])
|
|
return s.replace(arg=KernelInfo(name=name) if s.arg is None else replace(s.arg, name=name))
|
|
|
|
pm_postrange_opt = PatternMatcher([
|
|
(UPat(Ops.SINK, name="s"), rename_sink),
|
|
])
|
|
|