You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
340 lines
20 KiB
340 lines
20 KiB
import math, itertools
|
|
from collections import defaultdict
|
|
from typing import Optional, List, Tuple, Dict, Set, Final, NamedTuple
|
|
from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, LazyOp, Op, ASTRunner
|
|
from tinygrad.codegen.ast import ASTKernel, Token, Types
|
|
from tinygrad.shape.symbolic import Node, MulNode, DivNode, SumNode, Variable, render_python
|
|
from tinygrad.shape import ShapeTracker, View
|
|
from tinygrad.helpers import getenv, DEBUG, prod, partition, mnum, all_same, dedup
|
|
|
|
# div is different in cl than python
|
|
render_cl = render_python.copy()
|
|
render_cl[DivNode] = lambda self,ops,ctx: f"({self.a.render(ops)}/{self.b})"
|
|
|
|
VALIDHACKS = getenv("VALIDHACKS", 0) # TODO: remove the need for this
|
|
NATIVE_EXPLOG = getenv("NATIVE_EXPLOG", 0) # this is needed as a switch for the tests to pass
|
|
|
|
class GPULanguage(NamedTuple):
|
|
kernel_prefix : str = ""
|
|
buffer_prefix : str = ""
|
|
buffer_suffix : str = ""
|
|
smem_prefix : str = ""
|
|
barrier : str = ""
|
|
gid : List[str] = []
|
|
lid : List[str] = []
|
|
extra_args : List[str] = []
|
|
float4 : Optional[str] = None
|
|
|
|
def to_image_idx(base_shape:Tuple[int, ...], idxy:Node, valid:Node, validhacks=False) -> Tuple[Node, Node]:
|
|
idy = (idxy//(4*base_shape[1]))
|
|
if validhacks and valid.min == 0:
|
|
idx = (idxy//4) + (idy*-base_shape[1])
|
|
# find the ones in idx that didn't factorize and remove them (TODO: this is not universal)
|
|
if isinstance(idx, SumNode):
|
|
unfactored, idx_nodes = partition(idx.nodes, lambda x: isinstance(x, MulNode) and x.b == -base_shape[1])
|
|
assert len(unfactored) <= 1
|
|
idx = Variable.sum(idx_nodes)
|
|
unfactored = (Variable.sum(unfactored) // base_shape[1])
|
|
idy += unfactored
|
|
# ugh really...handtuned garbage
|
|
if idx.min >= (base_shape[1]*3)//4:
|
|
idx -= base_shape[1]
|
|
idy += 1
|
|
else:
|
|
idx = (idxy//4)%base_shape[1]
|
|
#print(base_shape, idx.min, idx.max, idy.min, idy.max, idx, idy)
|
|
return idx, idy
|
|
|
|
class GPUCodegen(ASTKernel):
|
|
lang : GPULanguage = GPULanguage()
|
|
|
|
# for renaming
|
|
kernel_cnt : Final[Dict[str, int]] = defaultdict(lambda: -1)
|
|
kernel_name_cache : Final[Dict[str, str]] = {}
|
|
|
|
code_for_op : Final[Dict[Op, str]] = {
|
|
UnaryOps.NOOP: "(A)", UnaryOps.NEG: "(-(A))", UnaryOps.NOT: "(1.0f-A)",
|
|
UnaryOps.EXP: "native_exp(A)" if NATIVE_EXPLOG else "exp(A)",
|
|
UnaryOps.LOG: "native_log(A)" if NATIVE_EXPLOG else "log(A)",
|
|
BinaryOps.ADD: "(A+B)", BinaryOps.SUB: "(A-B)", BinaryOps.MUL: "(A*B)",
|
|
BinaryOps.DIV: "(A/B)", BinaryOps.POW: "pow(A,B)", BinaryOps.CMPEQ: "(A==B)",
|
|
BinaryOps.MAX: "max(A,B)", ReduceOps.SUM: "A+=B", ReduceOps.MAX: "A=max(A,B)"
|
|
}
|
|
start_for_op : Final[Dict[Op, str]] = {ReduceOps.SUM: "0.0f", ReduceOps.MAX: "-INFINITY"}
|
|
|
|
def group_float4(self, grp:List[Token]) -> Token:
|
|
if all(g.tok.endswith(e) for g,e in zip(grp, [".x", ".y", ".z", ".w"])) and all_same([g.tok.split(".")[0] for g in grp]): return Token(grp[0].tok.split(".")[0], Types.FLOAT4)
|
|
else: return Token(f"{self.lang.float4}({','.join(g.tok for g in grp)})", Types.FLOAT4)
|
|
|
|
def store(self, buf_index:int, value:List[Token]) -> None:
|
|
assert len(value) == self.buftokens[buf_index].size(), f"size mismatch {len(value)} != {self.buftokens[buf_index].size()}"
|
|
assert len(self.sts[buf_index].views) == 1, "store has more than one view"
|
|
|
|
# all stores can merge, since they have one view and are valid
|
|
should_upcast = self.lang.float4 and self.buftokens[buf_index].can_float4()
|
|
|
|
to_store = {o:v for o,v in zip(self.buftokens[buf_index].offsets(), value)}
|
|
did_store = set()
|
|
for o,v in to_store.items():
|
|
if o in did_store: continue
|
|
idxy, valid = self.sts[buf_index].expr_idxs(o)
|
|
assert valid.min == 1, "store must always be valid"
|
|
if should_upcast:
|
|
for j in range(4): did_store.add(o+j)
|
|
v = self.group_float4([to_store[o+j] for j in range(4)])
|
|
if self.bufs[buf_index] is not None and hasattr(self.bufs[buf_index]._buf, "IMAGE"):
|
|
assert v.typ == Types.FLOAT4, "Image requires upcasting to FLOAT4"
|
|
idx, idy = to_image_idx(self.bufs[buf_index]._base_shape, idxy, valid)
|
|
self.kernel.append(f"write_imagef({self.buftokens[buf_index].tok}, (int2)({idx.render(render_cl)}, {idy.render(render_cl)}), {v.tok}); /* {self.bufs[buf_index]._base_shape} */\n")
|
|
elif v.typ == Types.FLOAT4:
|
|
self.kernel.append(f"(({self.lang.buffer_prefix if self.bufs[buf_index] is not None else self.lang.smem_prefix}float4*){self.buftokens[buf_index].tok})[{(idxy//4).render(render_cl)}] = {v.tok};\n")
|
|
else:
|
|
self.kernel.append(f"{self.buftokens[buf_index].tok}[{(idxy//(4 if v.typ == Types.FLOAT4 else 1)).render(render_cl)}] = {v.tok};\n")
|
|
|
|
def load(self, buf_index:int, idx_override:Optional[str]=None) -> List[Token]:
|
|
# constant folding
|
|
const = None
|
|
if self.bufs[buf_index] is not None and self.bufs[buf_index]._base_shape == (1,) and self.bufs[buf_index]._backing is not None:
|
|
if buf_index != 0: self.bufs_to_delete.add(buf_index)
|
|
val = self.bufs[buf_index]._backing[0]
|
|
assert not math.isnan(val)
|
|
const = Token(f"({val}f)", Types.FLOAT)
|
|
should_upcast = self.lang.float4 and const is None and self.buftokens[buf_index].can_float4()
|
|
tokens = []
|
|
test_idy = []
|
|
for o in self.buftokens[buf_index].offsets():
|
|
key = f"val{mnum(buf_index)}_{mnum(o)}"
|
|
if (buf_index, o) not in self.loaded_keys:
|
|
idxy, valid = self.sts[buf_index].expr_idxs(o) if idx_override is None else self.sts[buf_index].expr_node(idx_override, o)
|
|
if should_upcast:
|
|
float4_index = Variable("FLOAT4_INDEX", 0, 3)
|
|
idxy_test, valid_test = self.sts[buf_index].expr_idxs(float4_index+o) if idx_override is None else self.sts[buf_index].expr_node(idx_override, float4_index+o)
|
|
can_merge = idxy_test == float4_index or (isinstance(idxy_test, SumNode) and any(x == float4_index for x in idxy_test.nodes)) # float4_index must be in there without a multiply
|
|
can_merge = can_merge and "FLOAT4_INDEX" not in (idxy_test//4).render() and "FLOAT4_INDEX" not in valid_test.render() # float4_index must not be in after divide or in valid (TODO: don't check render)
|
|
if const is not None:
|
|
ldr = const
|
|
elif self.bufs[buf_index] is not None and hasattr(self.bufs[buf_index]._buf, "IMAGE"):
|
|
assert should_upcast and can_merge, f"Image requires upcasting to FLOAT4 {self.buftokens[buf_index]}"
|
|
idx, idy = to_image_idx(self.bufs[buf_index]._base_shape, idxy, valid, VALIDHACKS)
|
|
ldr = Token(f"read_imagef({self.buftokens[buf_index].tok}, smp, (int2)({idx.render(render_cl)}, {idy.render(render_cl)})) /* {self.bufs[buf_index]._base_shape} */", Types.FLOAT4)
|
|
test_idy.append(idy.render(render_cl))
|
|
elif should_upcast and can_merge:
|
|
ldr = Token(f"(({self.lang.buffer_prefix if self.bufs[buf_index] is not None else self.lang.smem_prefix}float4*){self.buftokens[buf_index].tok})[{(idxy//4).render(render_cl)}]", Types.FLOAT4)
|
|
else:
|
|
ldr = Token(f"{self.buftokens[buf_index].tok}[{idxy.render(render_cl)}]", Types.FLOAT)
|
|
invalid = self.group_float4([Token("0.0f", Types.FLOAT)]*4) if ldr.typ == Types.FLOAT4 else Token("0.0f", Types.FLOAT)
|
|
ldr = ldr if valid.min == 1 or (VALIDHACKS and hasattr(self.bufs[buf_index]._buf, "IMAGE")) else (Token(f"({valid.render(render_cl)} ? {ldr.tok} : {invalid.tok})", ldr.typ) if valid.max == 1 else invalid)
|
|
if const is not None:
|
|
self.loaded_keys[(buf_index,o)] = ldr
|
|
else:
|
|
self.kernel.append(f"{ldr.decltype()} {key} = {ldr.tok};\n")
|
|
if should_upcast and can_merge:
|
|
for j in range(4):
|
|
self.loaded_keys[(buf_index,o+j)] = Token(key+f'.{"xyzw"[j]}', Types.FLOAT)
|
|
else:
|
|
self.loaded_keys[(buf_index,o)] = Token(key, Types.FLOAT)
|
|
tokens.append(self.loaded_keys[(buf_index,o)])
|
|
assert not VALIDHACKS or all_same(test_idy), f"idy changed! {test_idy}"
|
|
return tokens
|
|
|
|
def ast_parse(self, x, acc:List[Token], do_reduce=False) -> List[Token]:
|
|
if not isinstance(x, LazyOp): return self.load(self.bufs.index(x), "mid" if x is None else None) # hack for local
|
|
if isinstance(x.op, ReduceOps) and not do_reduce: return acc
|
|
values : List[List[Token]] = ([acc] if isinstance(x.op, ReduceOps) else []) + [self.ast_parse(v, acc, do_reduce) for v in x.src]
|
|
code = GPUCodegen.code_for_op[x.op] # TODO: replace this with a function
|
|
if len(values) == 2:
|
|
assert len(values[0]) == len(values[1]) and values[0][0].typ == values[1][0].typ, f"values mismatch {values}"
|
|
return [Token(code.replace("A", a.tok).replace("B", b.tok), a.typ) for a,b in zip(values[0], values[1])]
|
|
else:
|
|
return [Token(code.replace("A", a.tok), a.typ) for a in values[0]]
|
|
|
|
def required_optimizations(self, early_only=False):
|
|
for buf_index,buf in enumerate(self.bufs):
|
|
upcast_strides = [self.sts[buf_index].strides[i] for i in self.upcast_in_mid_reduce_axes]
|
|
if (not early_only or buf in self.earlybufs) and hasattr(buf._buf, "IMAGE") and not (self.buftokens[buf_index].can_float4() or (buf not in self.earlybufs and (1 in upcast_strides))):
|
|
axes = [i for i,x in enumerate(self.sts[buf_index].strides) if x == 1]
|
|
assert len(axes) == 1, f"wrong number of stride 1 axis : {axes}"
|
|
self.shift_to(axes[0], 4)
|
|
self.upcast()
|
|
assert self.buftokens[buf_index].can_float4()
|
|
|
|
def hand_coded_optimizations(self):
|
|
# if there's images in the earlybufs, we have to make an axis the 4 loading one
|
|
self.required_optimizations(early_only=True)
|
|
|
|
# simplify (sets first_reduce)
|
|
self.simplify_ones()
|
|
|
|
# are we grouping? (requires local shape support)
|
|
if len(self.lang.lid) and not self.buftokens[0].can_float4() and self.first_reduce <= 2 and self.first_reduce + 1 <= self.shape_len and prod(self.sts[0].shape[:self.first_reduce]) <= 2048:
|
|
# TODO: use 1024 if it's allowed in a smarter way
|
|
for sz in (([256, 16]) if prod(self.sts[0].shape[:self.first_reduce]) <= 32 else [16]):
|
|
if all([st.shape[self.first_reduce] % sz == 0 or st.shape[self.first_reduce] == 1 for st in self.sts]):
|
|
self.shift_to(self.first_reduce, sz, top=True, insert_before=self.first_reduce)
|
|
self.group_for_reduce.append(sz)
|
|
break
|
|
|
|
# are we upcasting in mid reduce?
|
|
if hasattr(self.bufs[0]._buf, "IMAGE") and not self.buftokens[0].can_float4() and self.group_for_reduce and self.first_reduce <= 2:
|
|
axes = [i for i,x in enumerate(self.sts[0].strides) if x == 1]
|
|
assert len(axes) == 1, f"wrong number of stride 1 axis : {axes}"
|
|
self.shift_to(axes[0], 4, insert_before=self.first_reduce + len(self.group_for_reduce)) # insert at the end of the grouped axis
|
|
self.group_for_reduce.append(4)
|
|
|
|
# now do everything required
|
|
self.required_optimizations()
|
|
|
|
# simplify (sets first_reduce)
|
|
self.simplify_ones()
|
|
|
|
# use more opencl indexing if the output buffer is an image and we have room
|
|
if hasattr(self.bufs[0]._buf, "IMAGE") and self.first_reduce+len(self.group_for_reduce) < 3:
|
|
base_shape = self.bufs[0]._base_shape
|
|
if (base_shape[0]*base_shape[1]) % self.sts[0].shape[0] == 0 and self.sts[0].shape[0]//base_shape[0] != 0:
|
|
if DEBUG >= 4: print("split opencl", base_shape, self.sts[0].shape)
|
|
self.reshape_and_permute(lambda x: [base_shape[0], x[0]//base_shape[0]]+list(x[1:]), None)
|
|
self.simplify_ones()
|
|
|
|
# no more opt if we are grouping
|
|
if self.group_for_reduce: return
|
|
|
|
# **** below this line need to be optional and benchmarked ****
|
|
|
|
# potentially do more upcasts of non reduce axes based on a heuristic
|
|
while prod(self.sts[0].shape[:self.first_reduce]) >= 1024:
|
|
xb_choices = []
|
|
for axis, upcast_amount in itertools.product(range(self.first_reduce), [3,4]): # consider all the non reduce axes, and a 3 or 4 reduce
|
|
# if it mods, and some buffer has stride 0 on axis while having no stride 0 in the buftoken
|
|
if self.full_shape[axis]%upcast_amount == 0 and any(self.sts[buf_index].strides[axis] == 0 and not any(x[1] == 0 for x in self.buftokens[buf_index].axis) for buf_index in range(len(self.sts))):
|
|
xb_choices.append((sum(st.strides[axis]>0 for st in self.sts), sum(st.strides[axis] for st in self.sts), axis, upcast_amount))
|
|
if len(xb_choices):
|
|
xb_choices = sorted(xb_choices)
|
|
if DEBUG >= 4: print(f"float4 merging axis : {xb_choices}")
|
|
self.shift_to(xb_choices[0][2], amount=xb_choices[0][3])
|
|
self.upcast()
|
|
self.simplify_ones()
|
|
else:
|
|
break
|
|
|
|
# if last dim <= 5 and it's a reduce dim, upcast the reduce (loop unrolling). no simplify needed since it's just an upcast. NOTE: careful, this has broken VALIDHACKS
|
|
if self.first_reduce < self.shape_len and self.full_shape[-1] <= 5 and (max([x.size() for i,x in enumerate(self.buftokens) if self.bufs[i] in self.earlybufs]) <= 4 or not any(r for _,_,r in self.buftokens[self.full_buf_index].axis)):
|
|
self.upcast()
|
|
|
|
def get_accumulators(self, name="acc") -> List[Token]:
|
|
assert self.reduceop is not None, "no accumulators if you aren't reducing"
|
|
should_upcast = self.lang.float4 and self.buftokens[0].can_float4()
|
|
accumulators = [Token(f"{name}{i//4}.{'xyzw'[i%4]}" if should_upcast else f"{name}{i}", self.buftokens[0].typ) for i in self.buftokens[0].offsets()]
|
|
if should_upcast:
|
|
self.kernel += [f"float4 {tok} = {self.group_float4([Token(GPUCodegen.start_for_op[self.reduceop.op], Types.FLOAT)]*4).tok};\n" for tok in dedup([x.tok.split('.')[0] for x in accumulators])]
|
|
else:
|
|
self.kernel += [f"float {x.tok} = {GPUCodegen.start_for_op[self.reduceop.op]};\n" for x in accumulators]
|
|
return accumulators
|
|
|
|
# STOP WASTING TIME WITH DOING THE RESHAPES AND PERMUTES BY HAND. KERNEL SEARCH IS THE ONLY WAY IT WILL EVER BE GOOD
|
|
# group_for_reduce will have to be better first
|
|
def codegen(self) -> ASTRunner:
|
|
self.process()
|
|
if DEBUG >= 4: self.printbufs("old:", DEBUG>=5)
|
|
|
|
self.hand_coded_optimizations()
|
|
|
|
# fancy colored shape printer
|
|
if DEBUG >= 3: print(self.colorshape(), end="")
|
|
|
|
# add a local buffer for multistage reduce
|
|
if len(self.group_for_reduce):
|
|
self.bufs.append(None)
|
|
# TODO: the strides of this can be controlled
|
|
st = ShapeTracker(tuple([1] * self.first_reduce + self.group_for_reduce + [1] * (self.shape_len - len(self.group_for_reduce) - self.first_reduce) + [x[0] for x in self.buftokens[0].axis]))
|
|
buftoken = Token("temp", Types.FLOAT, ptr=True)
|
|
# manual upcast of the local
|
|
for _,_,r in self.buftokens[0].axis[::-1]:
|
|
buftoken.array(st.shape[-1], st.views[-1].strides[-1], r)
|
|
st.views[-1] = View(st.shape[0:-1], st.views[-1].strides[0:-1], st.views[-1].offset)
|
|
self.sts.append(st)
|
|
self.buftokens.append(buftoken)
|
|
|
|
self.output_shape : Tuple[int, ...] = self.sts[0].shape[:self.first_reduce] + tuple(self.group_for_reduce)
|
|
assert self.full_shape[:len(self.output_shape)] == self.output_shape, f"output shape mismatch : {self.full_shape[:len(self.output_shape)]} != {self.output_shape}"
|
|
if DEBUG >= 4:
|
|
print("output shape", self.output_shape)
|
|
self.printbufs("new:", DEBUG>=5)
|
|
|
|
self.bufs_to_delete : Set[int] = set()
|
|
self.loaded_keys : Dict[Tuple[int,int], Token] = {}
|
|
self.prekernel : Set[str] = set()
|
|
self.kernel : List[str] = ["const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n"] if any(hasattr(buf._buf, "IMAGE") for buf in self.bufs if buf is not None) else []
|
|
|
|
if len(self.lang.gid) == 0:
|
|
self.kernel += [f"for (int idx{i} = 0; idx{i} < {self.output_shape[i]}; idx{i}++) {{\n" for i in range(0, len(self.output_shape))]
|
|
else:
|
|
# output_shape[-1] is get_global_id(0)
|
|
self.kernel += [f"int idx{len(self.output_shape)-1-i} = {self.lang.gid[i]}; /* {self.output_shape[-1-i]} */\n" for i in range(min(len(self.lang.gid), len(self.output_shape))) if self.output_shape[-1-i] != 1]
|
|
if len(self.output_shape) > len(self.lang.gid):
|
|
# sometimes, there's more dimensions. compact all the dimensions into the first one
|
|
# TODO: these compactions should be searchable (they sort of are with reshapes and permutes)
|
|
final_dimension = len(self.output_shape)-len(self.lang.gid)
|
|
for i in range(final_dimension-1, -1, -1):
|
|
self.kernel += [f"int idx{i} = idx{final_dimension} % {self.output_shape[i]};", f"idx{final_dimension} = idx{final_dimension} / {self.output_shape[i]};\n"]
|
|
self.output_shape = (prod(self.output_shape[0:final_dimension+1]), ) + self.output_shape[final_dimension+1:]
|
|
if DEBUG >= 4: print(f"replaced output shape with {self.output_shape}")
|
|
|
|
# early ast
|
|
accumulators : List[Token] = []
|
|
if self.reduceop is not None:
|
|
accumulators = self.get_accumulators()
|
|
self.kernel += [f"for (int idx{i} = 0; idx{i} < {self.full_shape[i]}; idx{i}++) {{\n" for i in range(self.first_reduce+len(self.group_for_reduce), self.shape_len)]
|
|
self.kernel += [f"{x.tok};\n" for x in self.ast_parse(self.reduceop, [accumulators[off] for off in self.buftokens[self.full_buf_index].acc_offsets()], do_reduce=True)]
|
|
self.kernel += ["}\n"] * (self.shape_len - (self.first_reduce + len(self.group_for_reduce)))
|
|
|
|
# second stage reduce
|
|
if self.group_for_reduce:
|
|
self.kernel.append(self.lang.smem_prefix + f"float {self.buftokens[-1].tok}[{self.sts[-1].size()*self.buftokens[-1].size()}];\n")
|
|
self.store(-1, accumulators) # TODO: this is assuming the local size = global size. should use lidxs
|
|
self.kernel.append(self.lang.barrier+"\n")
|
|
|
|
# this is used to identify the thread doing the reducing (lidx == 0) and is repeated from store
|
|
# must happen before the upcast
|
|
lidx, lvalid = self.sts[-1].expr_idxs()
|
|
assert lvalid.min == 1, "local buffer must always be valid"
|
|
|
|
# if any group_for_reduce items aren't reduces, upcast them here
|
|
for j in self.upcast_in_mid_reduce_axes:
|
|
self.reshape_and_permute(None, [i for i in range(self.shape_len) if i != j] + [j])
|
|
self.upcast()
|
|
if DEBUG >= 4: print("upcast", self.colorshape()) # NOTE: colorshape is wrong here
|
|
|
|
self.kernel.append(f"if ({lidx.render(render_cl)} == 0) {{\n") # lidx.max works here too
|
|
|
|
# second stage reduce with a new set of accumulators. TODO: do we need acc_offsets here?
|
|
accumulators = self.get_accumulators("output")
|
|
self.kernel.append(f"for (int mid = 0; mid < {self.sts[-1].size()}; mid++) {{\n")
|
|
self.kernel += [f"{x.tok};\n" for x in self.ast_parse(LazyOp(self.reduceop.op, (None,), self.sts[0].shape), accumulators, do_reduce=True)]
|
|
self.kernel.append("}\n")
|
|
|
|
# late ast
|
|
self.store(0, self.ast_parse(self.ast, accumulators))
|
|
if self.group_for_reduce: self.kernel.append("}")
|
|
if len(self.lang.gid) == 0: self.kernel += ["}"] * len(self.output_shape)
|
|
self.kernel.append("\n}")
|
|
|
|
# concat kernel into prg
|
|
buftypes = [f"{'read_only' if i > 0 else 'write_only'} image2d_t" if hasattr(x._buf, "IMAGE") else self.lang.buffer_prefix+self.buftokens[i].decltype()+self.lang.buffer_suffix for i,x in enumerate(self.bufs) if x is not None]
|
|
prg = ' '.join(list(self.prekernel) + [f"{self.lang.kernel_prefix} void KERNEL_NAME_PLACEHOLDER(",] +
|
|
[', '.join([f'{t} data{i}' for i,t in enumerate(buftypes) if i not in self.bufs_to_delete] + self.lang.extra_args)] +
|
|
[") {\n"] + self.kernel)
|
|
|
|
# kernel function definition
|
|
function_name = ("re_S" if self.reduceop else "ew_S") + '_'.join([str(x) for x in self.full_shape])
|
|
|
|
# painfully name the function
|
|
if prg in GPUCodegen.kernel_name_cache: function_name = GPUCodegen.kernel_name_cache[prg]
|
|
else:
|
|
GPUCodegen.kernel_cnt[function_name] += 1
|
|
if GPUCodegen.kernel_cnt[function_name]: function_name = f"{function_name}{'_N'+str(GPUCodegen.kernel_cnt[function_name])}"
|
|
GPUCodegen.kernel_name_cache[prg] = function_name
|
|
|
|
return ASTRunner(function_name, prg.replace("KERNEL_NAME_PLACEHOLDER", function_name), self.bufs_to_delete,
|
|
list(self.output_shape[::-1]) if len(self.output_shape) > 0 else [1],
|
|
(self.group_for_reduce[::-1] + [1]*(len(self.output_shape)-len(self.group_for_reduce))) if self.group_for_reduce else None,
|
|
op_estimate=self.info.flops, mem_estimate=sum(prod(x._base_shape) for x in self.bufs if x is not None))
|
|
|