openpilot is an open source driver assistance system. openpilot performs the functions of Automated Lane Centering and Adaptive Cruise Control for over 200 supported car makes and models.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

133 lines
5.3 KiB

from extra.models.resnet import ResNet50
from extra.mcts_search import mcts_search
from examples.mlperf.helpers import get_mlperf_bert_model
from tinygrad import Tensor, Device, dtypes, nn
from tinygrad.codegen.kernel import Kernel
from tinygrad.ops import Ops, sym_infer
from tinygrad.device import Compiled
from tinygrad.engine.search import beam_search, bufs_from_lin
from tinygrad.helpers import DEBUG, ansilen, getenv, colored, TRACEMETA
from extra.optimization.helpers import time_linearizer
def get_sched_resnet():
mdl = ResNet50()
optim = (nn.optim.LARS if getenv("LARS") else nn.optim.SGD)(nn.state.get_parameters(mdl))
BS = getenv("BS", 64)
# run model twice to get only what changes, these are the kernels of the model
for _ in range(2):
out = mdl(Tensor.empty(BS, 3, 224, 224))
targets = [out]
if getenv("BACKWARD"):
optim.zero_grad()
out.sparse_categorical_crossentropy(Tensor.empty(BS, dtype=dtypes.int)).backward()
targets += [x for x in optim.schedule_step()]
sched = Tensor.schedule(*targets)
print(f"schedule length {len(sched)}")
return sched
def get_sched_bert():
mdl = get_mlperf_bert_model()
optim = nn.optim.LAMB(nn.state.get_parameters(mdl))
# fake data
BS = getenv("BS", 9)
input_ids = Tensor.empty((BS, 512), dtype=dtypes.float32)
segment_ids = Tensor.empty((BS, 512), dtype=dtypes.float32)
attention_mask = Tensor.empty((BS, 512), dtype=dtypes.default_float)
masked_positions = Tensor.empty((BS, 76), dtype=dtypes.float32)
masked_lm_ids = Tensor.empty((BS, 76), dtype=dtypes.float32)
masked_lm_weights = Tensor.empty((BS, 76), dtype=dtypes.float32)
next_sentence_labels = Tensor.empty((BS, 1), dtype=dtypes.float32)
# run model twice to get only what changes, these are the kernels of the model
for _ in range(2):
lm_logits, seq_relationship_logits = mdl(input_ids, attention_mask, masked_positions, segment_ids)
targets = [lm_logits, seq_relationship_logits]
if getenv("BACKWARD"):
optim.zero_grad()
loss = mdl.loss(lm_logits, seq_relationship_logits, masked_lm_ids, masked_lm_weights, next_sentence_labels)
# ignore grad norm and loss scaler for now
loss.backward()
targets += [x for x in optim.schedule_step()]
sched = Tensor.schedule(*targets)
print(f"schedule length {len(sched)}")
return sched
if __name__ == "__main__":
if getenv("HALF", 1):
dtypes.default_float = dtypes.half
# the device we are optimizing for
device: Compiled = Device[Device.DEFAULT]
if getenv("BACKWARD"): Tensor.training = True
print(f"optimizing for {Device.DEFAULT}")
sched = globals()[f"get_sched_{getenv('MODEL', 'resnet')}"]()
sched = [x for x in sched if x.ast.op is Ops.SINK]
# focus on one kernel
if getenv("KERNEL", -1) >= 0: sched = sched[getenv("KERNEL", -1):getenv("KERNEL", -1)+1]
# work with the schedule
total_tm = 0
running_gflops = 0
usage = {}
for i,si in enumerate(sched):
if DEBUG >= 3: print(si.ast)
rawbufs = bufs_from_lin(Kernel(si.ast))
# "linearize" the op into uops in different ways
lins: list[tuple[Kernel, str]] = []
# always try hand coded opt
lin = Kernel(si.ast, opts=device.renderer)
lin.hand_coded_optimizations()
lins.append((lin, "HC"))
# maybe try tensor cores
lin = Kernel(si.ast, opts=device.renderer)
if lin.apply_tensor_cores():
lins.append((lin, "TC"))
# try a beam search
if beam:=getenv("BEAM"):
lin = Kernel(si.ast, opts=device.renderer)
lin = beam_search(lin, rawbufs, beam, bool(getenv("BEAM_ESTIMATE", 1)))
lins.append((lin, "BEAM"))
# try MCTS
if mcts:=getenv("MCTS"):
lin = Kernel(si.ast, opts=device.renderer)
lin = mcts_search(lin, rawbufs, mcts)
lins.append((lin, "MCTS"))
# benchmark the programs
choices = []
for lin, nm in lins:
tm = time_linearizer(lin, rawbufs, allow_test_size=False, cnt=10, disable_cache=True)
ops = (prg:=lin.to_program()).estimates.ops
gflops = sym_infer(ops, {k:k.min for k in lin.ast.variables()})*1e-9/tm
choices.append((tm, gflops, lin, prg, nm))
sorted_choices = sorted(choices, key=lambda x: x[0])
if DEBUG >= 1: # print all kernels
for tm, gflops, lin, prg, nm in choices:
print(f" kernel {i:2d} {lin.name+' '*(37-ansilen(lin.name))} {str(prg.global_size):18s} {str(prg.local_size):12s} takes {tm*1000:7.2f} ms, {gflops:6.0f} GFLOPS -- {colored(nm, 'green') if lin is sorted_choices[0][2] else nm}")
tm, gflops, lin, prg, nm = sorted_choices[0]
if getenv("SRC"):
print(si.ast)
print(lin.applied_opts)
print(lin.to_program().src)
total_tm += tm
running_gflops += gflops * tm
if (key := str([str(m) for m in si.metadata])) not in usage: usage[key] = (0, 0)
usage[key] = (usage[key][0] + tm, usage[key][1] + 1)
print(f"*** {total_tm*1000:7.2f} ms : kernel {i:2d} {lin.name+' '*(37-ansilen(lin.name))} {str(prg.global_size):18s} {str(prg.local_size):12s} takes {tm*1000:7.2f} ms, {gflops:6.0f} GFLOPS {[repr(m) if TRACEMETA >= 2 else str(m) for m in si.metadata]}")
print(f"******* total {total_tm*1000:.2f} ms, {running_gflops/total_tm:6.0f} GFLOPS")
print("usage:")
for k in sorted(usage, key=lambda x: -usage[x][0])[:10]:
print(f"{usage[k][0]*1000:.2f} ms: {k} ({usage[k][1]} times)")