You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
334 lines
19 KiB
334 lines
19 KiB
from __future__ import annotations
|
|
import math, itertools
|
|
from collections import defaultdict
|
|
from typing import cast, Final
|
|
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, KernelInfo, graph_rewrite, AxisType, ssimplify, can_pad, GroupOp
|
|
from tinygrad.device import Buffer
|
|
from tinygrad.dtype import AddrSpace, dtypes, ImageDType
|
|
from tinygrad.helpers import colored, BEAM, getenv, DEBUG, to_function_name, NOOPT, argsort, round_up, prod
|
|
from tinygrad.codegen.opt import axis_colors, Opt, OptOps, KernelOptError, check, axis_letters
|
|
from tinygrad.codegen.simplify import pm_flatten_range
|
|
from tinygrad.renderer import Renderer
|
|
|
|
remove_tags = PatternMatcher([(UPat(GroupOp.All, name="x"), lambda x: x.replace(tag=None) if x.tag is not None else None)])
|
|
|
|
# NOTE: LOCAL and GROUP_REDUCE have the same priority. the order here matters
|
|
axis_to_pos = {AxisType.LOOP: -1, AxisType.THREAD: 0, AxisType.GLOBAL: 0, AxisType.WARP: 1, AxisType.LOCAL: 2, AxisType.UPCAST: 3,
|
|
AxisType.GROUP_REDUCE: 2, AxisType.REDUCE: 4, AxisType.UNROLL: 5}
|
|
|
|
class Scheduler:
|
|
def __init__(self, ast:UOp, opts:Renderer):
|
|
self.ast, self.opts = ast, opts
|
|
self.dont_use_locals = self.ast.arg.dont_use_locals if self.ast.arg is not None else False
|
|
self.applied_opts = list(self.ast.arg.applied_opts) if self.ast.arg is not None else []
|
|
|
|
@property
|
|
def rngs(self):
|
|
# always in order by axistype
|
|
return sorted([u for u in self.ast.parents if u.op is Ops.RANGE and u.vmax > 0], key=lambda x: (axis_to_pos[x.arg[-1]],) + x.arg[0:-1])
|
|
@property
|
|
def shape_len(self): return len(self.rngs)
|
|
@property
|
|
def full_shape(self): return [ssimplify(x.src[0]) for x in self.rngs]
|
|
@property
|
|
def axis_types(self): return [x.arg[-1] for x in self.rngs]
|
|
@property
|
|
def maxarg(self): return max([x.arg[0] for x in self.rngs], default=0)
|
|
|
|
# strings like ['g0', 'g1', 'l0', 'l1', 'l2', 'l3', 'l4', 'l5', 'R0', 'r0', 'r1', 'r2', 'u0', 'u1', 'u2']
|
|
def shape_str(self) -> list[str]:
|
|
ret: list[str] = []
|
|
cnt: dict[AxisType, int] = {}
|
|
for x in self.axis_types:
|
|
cnt[x] = (cnt[x] + 1) if x in cnt else 0
|
|
ret.append(f"{axis_letters[x]}{cnt[x]}")
|
|
return ret
|
|
def shape_str_to_axis(self, nms:list[str]) -> tuple[int, ...]: return tuple([self.shape_str().index(x) for x in nms])
|
|
|
|
def copy(self):
|
|
ret = Scheduler(self.ast, self.opts)
|
|
ret.dont_use_locals = self.dont_use_locals
|
|
ret.applied_opts = self.applied_opts[:]
|
|
return ret
|
|
|
|
kernel_cnt: Final[defaultdict[str, int]] = defaultdict(int)
|
|
def get_optimized_ast(self, name_override:str|None=None):
|
|
if name_override is not None: name = name_override
|
|
else:
|
|
kernel_type = "r" if self.reduceop is not None else "E"
|
|
name = kernel_type + colored('_', 'BLACK').join(['']+[colored(x.src[0].render(), color) for x,color in zip(self.rngs, self.colors())])
|
|
Scheduler.kernel_cnt[(function_name := to_function_name(name))] += 1
|
|
num = f"n{Scheduler.kernel_cnt[function_name]-1}" if Scheduler.kernel_cnt[function_name] > 1 else ""
|
|
name += colored(num, 'BLACK')
|
|
self.ast = graph_rewrite(self.ast, pm_flatten_range, name="flatten range")
|
|
return self.ast.replace(arg=KernelInfo(name=name, applied_opts=tuple(self.applied_opts), dont_use_locals=self.dont_use_locals), tag=1)
|
|
|
|
def _globalizable_rngs(self) -> list[UOp]:
|
|
store_rngs = self.ast.src[0].src[2:]
|
|
|
|
# filter any not in local stores
|
|
local_store_rngs = [x.ranges for x in self.ast.toposort() if (x.op is Ops.STORE and x.src[0].ptrdtype.addrspace == AddrSpace.LOCAL) \
|
|
or (x.op is Ops.BUFFERIZE and x.arg == AddrSpace.LOCAL)]
|
|
for ls in local_store_rngs: store_rngs = tuple([x for x in store_rngs if x in ls])
|
|
|
|
return [x for x in UOp.sink(*store_rngs).toposort() if x.op is Ops.RANGE and x.arg[1] == AxisType.LOOP] if store_rngs else []
|
|
|
|
def convert_loop_to_global(self):
|
|
if not self.opts.has_local: return None
|
|
|
|
globalizible_rngs = self._globalizable_rngs()
|
|
rng = [x.replace(arg=(x.arg[0], AxisType.GLOBAL)) if x in globalizible_rngs else x for x in self.rngs]
|
|
|
|
self.ast = self.ast.substitute(dict(zip(self.rngs, rng)))
|
|
|
|
def colors(self) -> list[str]: return [axis_colors[x] if not self.dont_use_locals or not x == AxisType.GLOBAL else "BLUE" for x in self.axis_types]
|
|
def colored_shape(self) -> str: return ' '.join([colored(f'{x.src[0].render():>4s}', color) for x,color in zip(self.rngs, self.colors())])
|
|
|
|
def shift_to(self, rng:UOp, amount:int, new_type:AxisType, top:bool=False, input_new_rng=None):
|
|
if (old_sz:=rng.src[0].divides(amount)) is None:
|
|
raise KernelOptError(f"{amount} can't divide {rng.src[0]} in {self.colored_shape()}")
|
|
new_rng = UOp.range(amount, self.maxarg+1, new_type) if input_new_rng is None else input_new_rng
|
|
replaced_rng = rng.replace(src=(UOp.const(dtypes.int, old_sz),))
|
|
sub_axis = (new_rng * old_sz + replaced_rng) if top else (replaced_rng * amount + new_rng)
|
|
self.ast = self.ast.substitute({rng:sub_axis}, name=f"shift {rng.arg[0]} {amount} {str(new_type).split('.')[1].lower()}")
|
|
return replaced_rng, new_rng
|
|
|
|
def ranges_of(self, *axis_type:AxisType) -> list[UOp]: return [r for r in self.rngs if r.arg[-1] in axis_type]
|
|
def axes_of(self, *axis_type:AxisType) -> list[int]: return [i for i,t in enumerate(self.axis_types) if t in axis_type]
|
|
|
|
# copied from kernel.py
|
|
@property
|
|
def upcastable_dims(self) -> list[int]: return [i for i in self.axes_of(AxisType.GLOBAL, AxisType.LOCAL, AxisType.LOOP) \
|
|
if isinstance(s:=self.full_shape[i], int) and s > 1]
|
|
@property
|
|
def unrollable_dims(self) -> list[int]: return [i for i in self.axes_of(AxisType.GROUP_REDUCE, AxisType.REDUCE) \
|
|
if isinstance(s:=self.full_shape[i], int) and s > 1]
|
|
|
|
def real_axis(self, op:OptOps, axis:int|None):
|
|
try:
|
|
if axis is None or op is OptOps.TC: return -1
|
|
if op is OptOps.UNROLL: return self.unrollable_dims[axis]
|
|
if op in {OptOps.GROUP, OptOps.GROUPTOP}: return self.axes_of(AxisType.REDUCE)[axis]
|
|
check(axis < self.shape_len, f"invalid axis on {axis=} {op=} {self.shape_len=}")
|
|
return axis
|
|
except IndexError as e: raise KernelOptError from e
|
|
|
|
def apply_opt(self, opt:Opt, append_opt:bool=True):
|
|
if opt.op is OptOps.NOLOCALS:
|
|
check(all(x not in {AxisType.WARP, AxisType.LOCAL, AxisType.GROUP_REDUCE} for x in self.axis_types), "no locals can't have locals")
|
|
if append_opt: self.applied_opts.append(opt)
|
|
self.dont_use_locals = True
|
|
return
|
|
|
|
if opt.op in {OptOps.LOCAL, OptOps.GROUP, OptOps.GROUPTOP}:
|
|
check(self.opts.has_local, "locals needed for opt")
|
|
|
|
rng = self.rngs[real_axis] if (real_axis:=self.real_axis(opt.op, opt.axis)) >= 0 else UOp(Ops.NOOP)
|
|
|
|
opt_to_at = {
|
|
OptOps.LOCAL: AxisType.LOCAL, OptOps.UPCAST: AxisType.UPCAST,
|
|
OptOps.UNROLL: AxisType.UNROLL, OptOps.GROUP: AxisType.GROUP_REDUCE,
|
|
OptOps.GROUPTOP: AxisType.GROUP_REDUCE, OptOps.THREAD: AxisType.THREAD}
|
|
|
|
ret = None
|
|
if opt.op in opt_to_at:
|
|
amt:int = int(rng.vmax+1) if opt.arg == 0 else cast(int, opt.arg)
|
|
|
|
# copied from kernel.py. prevents METAL compiler hangs
|
|
if self.reduceop is not None and (opt.op in {OptOps.GROUP, OptOps.GROUPTOP} or \
|
|
(self.group_for_reduces and opt.op not in {OptOps.NOLOCALS, OptOps.PADTO})):
|
|
upcast_local_sz = prod([self.full_shape[a] for a in self.axes_of(AxisType.UPCAST, AxisType.WARP, AxisType.LOCAL, AxisType.GROUP_REDUCE)])
|
|
smem_sz = amt*upcast_local_sz*self.reduceop.dtype.itemsize
|
|
check(smem_sz <= self.opts.shared_max, f"exceeds maximum shared memory size: needs {smem_sz}, max {self.opts.shared_max}")
|
|
|
|
if opt.op is OptOps.UNROLL:
|
|
check(amt <= 32, "don't unroll more than 32")
|
|
check(rng.arg[-1] in {AxisType.GROUP_REDUCE, AxisType.REDUCE}, "unroll is for GROUP_REDUCE/REDUCE")
|
|
if opt.op is OptOps.UPCAST:
|
|
check((self.opts is not None and self.opts.device == "DSP") or amt <= 16, "don't upcast more than 16")
|
|
check(rng.arg[-1] in {AxisType.GLOBAL, AxisType.LOCAL, AxisType.LOOP}, f"upcast is for GLOBAL/LOCAL/LOOP, not {rng.arg[-1]}")
|
|
if opt.op is OptOps.LOCAL:
|
|
check(not self.dont_use_locals, "can't use locals")
|
|
check(rng.arg[-1] in {AxisType.GLOBAL, AxisType.LOOP}, "local is for globals")
|
|
if opt.op is OptOps.THREAD:
|
|
check(self.opts is not None and self.opts.has_threads, "target does not support threads")
|
|
check(self.opts is not None and self.opts.global_max is not None and amt <= self.opts.global_max[0], "too many threads")
|
|
check(all(x is not AxisType.THREAD for x in self.axis_types), "already threaded")
|
|
check(rng in self._globalizable_rngs(), "can't apply range to this dim")
|
|
if opt.op in {OptOps.GROUP, OptOps.GROUPTOP}:
|
|
check(all(x.op is not OptOps.TC for x in self.applied_opts), "no grouping with tensor cores") # TODO: why is this wrong?
|
|
check(not self.dont_use_locals, "can't use locals")
|
|
check(rng.arg[-1] == AxisType.REDUCE, "group is for reduce")
|
|
ret = self.shift_to(rng, amt, opt_to_at[opt.op], top=opt.op in {OptOps.GROUPTOP, OptOps.THREAD})
|
|
elif opt.op is OptOps.TC:
|
|
check(len(self.applied_opts) == 0, "tensor core opts must be first") # TODO: remove the need for this by having warps
|
|
check(opt.axis is not None, "tensor core opts must have an axis")
|
|
check(opt.arg is not None and isinstance(opt.arg, tuple) and len(opt.arg) == 3, "tensor core opts must have valid arg")
|
|
check(-1 <= (tc_select:=cast(tuple, opt.arg)[0]) < len(self.opts.tensor_cores), "tensor core opts must have valid tc_select")
|
|
check(0 <= (tc_opt:=cast(tuple, opt.arg)[1]) <= 2, "tensor core opts must have valid tc_opt")
|
|
check(0 < (use_tensor_cores:=cast(tuple, opt.arg)[2]) <= 2, "use_tensor_cores value is not valid")
|
|
try: ret = self._apply_tc_opt(use_tensor_cores, cast(int, opt.axis), tc_select, tc_opt)
|
|
except ValueError as e: raise KernelOptError(str(e))
|
|
check(ret is not None, "no tensor core available")
|
|
elif opt.op is OptOps.PADTO:
|
|
check(rng.src[0].op is Ops.CONST, "only pad const axes")
|
|
check(rng.arg[-1] not in {AxisType.UPCAST, AxisType.UNROLL}, "cannot pad upcasted") # TODO: why is this wrong?
|
|
check(rng.arg[-1] is not AxisType.THREAD, "cannot pad thread")
|
|
# ok to pad SUM if all parent ALU ops have f(0) = 0
|
|
if (r:=self.reduceop) is not None and rng.arg[-1] in (AxisType.GROUP_REDUCE, AxisType.REDUCE):
|
|
check(r.arg[0] is Ops.ADD and can_pad(r, {}), f"cannot pad {r}")
|
|
new_sz = round_up(int(rng.vmax+1), cast(int, opt.arg))
|
|
check(rng.vmax+1 > new_sz//4, "pad adds more than quadruple the work")
|
|
replaced_rng = UOp.range(new_sz, *rng.arg)
|
|
replaces = {rng:replaced_rng}
|
|
valid = replaced_rng < rng.vmax+1
|
|
for b in self.bufs:
|
|
if rng in (i:=b.src[1].get_idx()).sparents:
|
|
replaces[b] = b.replace(src=(b.src[0],(valid&b.src[1].get_valid()).where(i, UOp.invalid())))
|
|
self.ast = self.ast.substitute(replaces, f"padto {rng.arg[:-1]} {opt.arg}")
|
|
elif opt.op is OptOps.SWAP:
|
|
try:
|
|
altrng = self.rngs[opt.arg]
|
|
except IndexError:
|
|
raise KernelOptError
|
|
check(rng.arg[-1] == AxisType.GLOBAL and altrng.arg[-1] == AxisType.GLOBAL, "swap only for globals")
|
|
self.ast = self.ast.substitute({rng:rng.replace(arg=(*altrng.arg[0:-1], rng.arg[-1]), tag=1),
|
|
altrng:altrng.replace(arg=(*rng.arg[0:-1], altrng.arg[-1]), tag=1)})
|
|
self.ast = graph_rewrite(self.ast, remove_tags)
|
|
else:
|
|
raise KernelOptError(f"unsupported opt {opt.op}")
|
|
|
|
if append_opt: self.applied_opts.append(opt)
|
|
return ret
|
|
|
|
def _apply_tc_opt(self, use_tensor_cores:int, axis:int, tc_select:int, opt_level:int) -> None|list[UOp]:
|
|
reduceops = [x for x in self.ast.toposort() if x.op is Ops.REDUCE]
|
|
if not len(reduceops): raise KernelOptError("no reduce ops for TensorCore")
|
|
reduceop = reduceops[0]
|
|
if use_tensor_cores and reduceop is not None and reduceop.arg is Ops.ADD:
|
|
mul = reduceop.src[0] if reduceop.src[0].op is not Ops.CAST else reduceop.src[0].src[0]
|
|
if mul.op is not Ops.MUL: return None
|
|
in0, in1 = mul.src
|
|
try:
|
|
tensor_cores = self.opts.tensor_cores if tc_select == -1 else [self.opts.tensor_cores[tc_select]]
|
|
except IndexError:
|
|
raise KernelOptError(f"invalid tensor core choice {tc_select}")
|
|
for tc in tensor_cores:
|
|
if tc.dtype_in == in0.dtype.scalar() and tc.dtype_in == in1.dtype.scalar() and tc.dtype_out == reduceop.dtype.scalar():
|
|
# tensor cores have three ranges. X, Y, and REDUCE
|
|
in0_ranges = sorted([u for u in in0.ranges if u not in in1.ranges], key=lambda x: -x.arg[0])
|
|
in1_ranges = sorted([u for u in in1.ranges if u not in in0.ranges], key=lambda x: -x.arg[0])
|
|
red_ranges = sorted(reduceop.src[1:], key=lambda x: -x.arg[0])
|
|
if DEBUG >= 3:
|
|
print(f"TC({axis}): {[(x.arg[0],x.vmax+1) for x in in0_ranges]}",
|
|
f"{[(x.arg[0],x.vmax+1) for x in in1_ranges]} {[(x.arg[0],x.vmax+1) for x in red_ranges]}")
|
|
if not len(in0_ranges) or not len(in1_ranges) or not len(red_ranges): continue
|
|
|
|
# pick ranges
|
|
# NOTE: why are in1 and in0 switched?
|
|
axis_choices = list(itertools.product(in1_ranges, in0_ranges, red_ranges))
|
|
if not (axis < len(axis_choices)): continue
|
|
axes = list(axis_choices[axis])
|
|
|
|
# do optimizations and save the ranges
|
|
try:
|
|
for i,a in enumerate(axes):
|
|
idx = self.rngs.index(a)
|
|
if (a.vmax+1) % tc.dims[i] != 0:
|
|
if opt_level < 2: raise KernelOptError("tc padding requires opt_level >= 2")
|
|
# apply_opt should return the updated range?
|
|
self.apply_opt(Opt(OptOps.PADTO, idx, tc.dims[i]), append_opt=False) # PADTO might fail
|
|
axes[i] = self.rngs[idx]
|
|
except KernelOptError: continue
|
|
|
|
# we create the warp as a whole thing, in case some of these ranges are moved/removed later
|
|
warp = UOp.range(tc.threads, -1, AxisType.WARP)
|
|
ne: list[UOp] = []
|
|
for opt in tc.opts:
|
|
if opt[0] == "l":
|
|
axes[int(opt[1])], new_range = self.shift_to(axes[int(opt[1])], 2, AxisType.LOCAL, input_new_rng=warp%2)
|
|
warp //= 2
|
|
elif opt[0] == "u":
|
|
axes[int(opt[1])], new_range = self.shift_to(axes[int(opt[1])], 2, AxisType.UPCAST)
|
|
else: raise RuntimeError(f"unsupported opt {opt[0]} in tensor cores")
|
|
ne.append(new_range)
|
|
|
|
for _, amt in tc.get_reduce_axes():
|
|
axes[2], new_range = self.shift_to(axes[2], amt, AxisType.UNROLL)
|
|
ne.append(new_range)
|
|
|
|
if use_tensor_cores != 2:
|
|
# fix the srcs
|
|
reduceop = [x for x in self.ast.toposort() if x.op is Ops.REDUCE][0]
|
|
tne = [x.replace(tag=1) for x in ne]
|
|
ret = reduceop.substitute(dict(zip(ne, tne)))
|
|
srcs = list((ret.src[0] if ret.src[0].op is not Ops.CAST else ret.src[0].src[0]).src)
|
|
srcs = [x.substitute(dict(zip(tne, [ne[i] for i in argsort(p)]))) for x,p in zip(srcs, tc.permutes_for_shape_str(tc.base_shape_str()))]
|
|
|
|
# get reduce/upcast axes for the tensor cores
|
|
tc_reduce_axes = self.shape_str_to_axis([f"r{i}" for i in range(len(tc.get_reduce_axes()))])
|
|
base_upcast_axes = tuple([(s,2) for s in self.shape_str_to_axis(tc.base_upcast_axes())])
|
|
tc_upcast_axes = tuple([base_upcast_axes[:int(math.log2(tc.elements_per_thread[i]))] for i in range(3)])
|
|
|
|
# axes to range number (was done in lowerer)
|
|
tc_upcast_axes = tuple([tuple([(self.rngs[a].arg[0], sz) for a,sz in v]) for v in tc_upcast_axes])
|
|
tc_reduce_axes = tuple([self.rngs[a].arg[0] for a in tc_reduce_axes])
|
|
|
|
# construct the op
|
|
# TODO: remove tc_upcast_axes from the arg
|
|
# do the reduce_axes always disappear? i think they don't
|
|
# they need to be moved into the WMMA srcs
|
|
wmma_arg = (str(tc), tc.dims, tc.dtype_in, tc.dtype_out, self.opts.device, tc.threads, tc_upcast_axes, ()) #, tc_reduce_axes)
|
|
wmma = UOp(Ops.WMMA, dtype=tc.dtype_out.vec(tc.elements_per_thread[2]), src=(
|
|
UOp(Ops.CONTRACT, dtype=srcs[0].dtype.vec(tc.elements_per_thread[0]), src=(srcs[0],), arg=tc_upcast_axes[0], tag=1),
|
|
UOp(Ops.CONTRACT, dtype=srcs[1].dtype.vec(tc.elements_per_thread[1]), src=(srcs[1],), arg=tc_upcast_axes[1], tag=1),
|
|
UOp.const(tc.dtype_out.vec(tc.elements_per_thread[2]), 0.0)), arg=wmma_arg, tag=1)
|
|
tc_uop = UOp(Ops.UNROLL, tc.dtype_out, (wmma,), arg=tc_upcast_axes[2], tag=1)
|
|
|
|
# preserve extra reduces
|
|
reduce_ranges = [x for x in UOp.sink(*reduceop.src[1:]).toposort() if x.op is Ops.RANGE and x.arg[0] not in tc_reduce_axes]
|
|
if len(reduce_ranges): tc_uop = UOp(Ops.REDUCE, tc_uop.dtype, (tc_uop,)+tuple(reduce_ranges), Ops.ADD)
|
|
self.ast = self.ast.substitute({reduceop: tc_uop})
|
|
return axes
|
|
return None
|
|
|
|
# helpers for hand_coded_optimizations
|
|
@property
|
|
def reduceop(self) -> UOp|None:
|
|
red = [x for x in self.ast.parents if x.op is Ops.REDUCE]
|
|
if not len(red): return None
|
|
return UOp(Ops.REDUCE_AXIS, red[0].dtype, red[0].src, (red[0].arg, ()))
|
|
@property
|
|
def bufs(self) -> list[UOp]: return [x for x in self.ast.toposort() if x.op is Ops.INDEX][::-1]
|
|
@property
|
|
def output_shape(self):
|
|
return [s if at not in {AxisType.REDUCE, AxisType.UNROLL, AxisType.GROUP_REDUCE} else 1 for s,at in zip(self.full_shape, self.axis_types)]
|
|
@property
|
|
def upcasted(self) -> int: return len(self.axes_of(AxisType.UPCAST, AxisType.UNROLL))
|
|
@property
|
|
def group_for_reduces(self) -> int: return len(self.axes_of(AxisType.GROUP_REDUCE))
|
|
|
|
def bufs_from_ast(ast:UOp, dname:str) -> list[Buffer]:
|
|
glbls = sorted([x for x in ast.parents if x.op is Ops.DEFINE_GLOBAL], key=lambda x: x.arg)
|
|
return [Buffer(dname, x.ptrdtype.size, x.dtype.base if not isinstance(x.dtype, ImageDType) else x.dtype) for x in glbls]
|
|
|
|
def apply_opts(ctx:Renderer, ast:UOp):
|
|
if ast.tag is not None: return None
|
|
k = Scheduler(ast, ctx)
|
|
k.convert_loop_to_global()
|
|
if ast.arg is not None and ast.arg.opts_to_apply is not None:
|
|
for opt in ast.arg.opts_to_apply: k.apply_opt(opt)
|
|
elif BEAM >= 1:
|
|
from tinygrad.codegen.opt.search import beam_search
|
|
rawbufs = bufs_from_ast(ast, ctx.device)
|
|
k = beam_search(k, rawbufs, BEAM.value, bool(getenv("BEAM_ESTIMATE", 1)))
|
|
elif not NOOPT and (ast.arg is None or ast.arg.applied_opts == ()):
|
|
from tinygrad.codegen.opt.heuristic import hand_coded_optimizations
|
|
# NOTE: hand_coded_optimizations doesn't support multiblock opts yet
|
|
if all(len(u.src) == 1 for u in ast.parents if u.op is Ops.LOAD):
|
|
k = hand_coded_optimizations(k)
|
|
return k.get_optimized_ast(name_override=ast.arg.name if ast.arg is not None and ast.arg.name != "test" else None)
|
|
|
|
pm_postrange_opt = PatternMatcher([
|
|
(UPat(Ops.SINK, name="ast"), apply_opts),
|
|
])
|
|
|