openpilot is an open source driver assistance system. openpilot performs the functions of Automated Lane Centering and Adaptive Cruise Control for over 200 supported car makes and models.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

288 lines
11 KiB

# this can be constructed from a cl_cache or loaded from a thneed file
import time
import struct
import json
import traceback
import numpy as np
from tinygrad.runtime.ops_gpu import CLProgram, compile_gpu
from tinygrad.device import Device
from tinygrad.helpers import DEBUG, getenv
from collections import defaultdict
import pyopencl as cl
from tinygrad.runtime.ops_gpu import OSX_TIMING_RATIO
CL = Device["GPU"]
DEBUGCL = getenv("DEBUGCL", 0)
FLOAT16 = getenv("FLOAT16", 0)
class Thneed:
def __init__(self, cl_cache=[], inputs={}):
self.cl_cache, self.inputs = cl_cache[:], inputs
self.gobj = 0
# build graph
# NOTE: if CLCACHE=1, this is wrong!
nodes = defaultdict(lambda: {'in_edges': [], 'out_edges': []})
for _, args in self.cl_cache:
# output is always the first parameter
for a in args[3:]:
nodes[a]['out_edges'].append(args[2])
nodes[args[2]]['in_edges'].append(a)
# get buffers to save
self.buffers_to_save = set()
self.outputs = []
for n in nodes.keys():
if len(nodes[n]['in_edges']) == 0:
self.buffers_to_save.add(n)
if len(nodes[n]['out_edges']) == 0:
self.outputs.append(n)
fake_inputs = []
for k,n in self.inputs.items():
if n in self.buffers_to_save:
self.buffers_to_save.remove(n)
else:
print(f"WARNING: {k} was not a used input, removing it")
fake_inputs.append(k)
for k in fake_inputs:
del self.inputs[k]
def load(self, input_fn):
float32 = not FLOAT16
mf = cl.mem_flags
image_fmt = cl.ImageFormat(cl.channel_order.RGBA, cl.channel_type.FLOAT if float32 else cl.channel_type.HALF_FLOAT)
image_fmt_32 = cl.ImageFormat(cl.channel_order.RGBA, cl.channel_type.FLOAT)
with open(input_fn, "rb") as f:
json_len = struct.unpack("I", f.read(4))[0]
jdat = json.loads(f.read(json_len).decode('latin_1'))
weights = f.read()
# load in the buffers
bufs = {'\x00\x00\x00\x00\x00\x00\x00\x00': None}
bufs_loaded = {}
ptr = 0
for o in jdat['objects']:
#print(o)
if o['needs_load']:
nptr = ptr + o['size']
o['data'] = weights[ptr:nptr]
ptr = nptr
if o['arg_type'] == "image2d_t" or o['arg_type'] == "image1d_t":
tfmt = image_fmt_32 if 'float32' in o and o['float32'] else image_fmt
if o['arg_type'] == "image2d_t":
if 'buffer_id' in o and o['height'] == 1 and not bufs_loaded[o['buffer_id']]:
# hack: use a image1d since we can back that with a buffer
buf = cl.Image(CL.ctx, mf.READ_WRITE, tfmt, shape=(o['width'],), buffer=bufs[o['buffer_id']])
else:
# buffer isn't supported in image2d, copy buffer into image
if 'buffer_id' in o and bufs_loaded[o['buffer_id']]:
arr = np.zeros(bufs[o['buffer_id']].size // 2, dtype=np.float16)
cl.enqueue_copy(CL.queue, arr, bufs[o['buffer_id']])
buf = cl.Image(CL.ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, tfmt,
shape=(o['width'], o['height']), pitches=(o['row_pitch'],), hostbuf=arr)
elif o['needs_load']:
buf = cl.Image(CL.ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, tfmt,
shape=(o['width'], o['height']), pitches=(o['row_pitch'],), hostbuf=o['data'])
else:
buf = cl.Image(CL.ctx, mf.READ_WRITE, tfmt, shape=(o['width'], o['height']))
if o['arg_type'] == "image1d_t":
assert not o['needs_load']
assert not bufs_loaded[o['buffer_id']]
buf = cl.Image(CL.ctx, mf.READ_WRITE, tfmt, shape=(o['width'],), buffer=bufs[o['buffer_id']])
else:
if 'data' in o:
buf = cl.Buffer(CL.ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=o['data'])
else:
# zero out buffers
buf = cl.Buffer(CL.ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=b'\x00'*o['size'])
bufs[o['id']] = buf
bufs_loaded[o['id']] = 'data' in o
# if it's loaded, it's saved
if 'data' in o:
self.buffers_to_save.add(buf)
# load binaries
prgs = {}
for o in jdat['binaries']:
nptr = ptr + o['length']
prgs[o['name']] = CLProgram(Device["GPU"], o['name'], weights[ptr:nptr])
ptr = nptr
# populate the cl_cache
for i,k in enumerate(jdat['kernels']):
kernel = prgs[k['name']]
aaa = []
for j,(a,sz) in enumerate(zip(k['args'], k['args_size'])):
if len(a) == 0:
aa = cl.LocalMemory(sz)
elif len(a) == 4:
a = a.encode('latin_1')
aa = np.uint32(struct.unpack("I", a)[0])
elif len(a) == 2:
a = a.encode('latin_1')
aa = np.uint16(struct.unpack("H", a)[0])
elif len(a) == 8:
#print(i,j,struct.unpack("Q", a.encode('latin_1'))[0])
aa = bufs[a]
aaa.append(aa)
self.cl_cache.append((kernel, [k['global_work_size'], k['local_work_size'], *aaa]))
if DEBUG >= 1: print(f"thneed: total bufs loaded: {len(bufs.keys())}")
# load inputs
for k in jdat['inputs']:
self.inputs[k['name']] = bufs[k['buffer_id']]
# load outputs
for k in jdat['outputs']:
self.outputs.append(bufs[k['buffer_id']])
def save(self, output_fn):
# this is the struct that will be saved
jdat = {"binaries": [], "programs": {}, "kernels": [], "objects": []}
# build the pieces of this struct
weights = []
binaries = []
saved_objs = set()
saved_binaries = set()
for prg, args in self.cl_cache:
# get binaries for saving
if prg.name not in saved_binaries:
binary = prg.clprogram.get_info(cl.program_info.BINARIES)
assert len(binary) == 1
jdat['binaries'].append({"name":prg.name, "length":len(binary[0])})
binaries.append(binary[0])
saved_binaries.add(prg.name)
# get the args from the kernel, some need the data saved
targs, args_size = [], []
argdtypes = [None]*(len(args)-2)
for a,d in zip(args[2:], argdtypes):
if d == np.int16:
targs.append(struct.pack("H", a).decode("latin_1"))
args_size.append(2)
elif d == np.int32:
targs.append(struct.pack("I", a).decode("latin_1"))
args_size.append(4)
elif isinstance(a, cl.LocalMemory):
targs.append("")
args_size.append(a.size)
elif d is None:
if getattr(a, "global_id", None) is None:
setattr(a, "global_id", self.gobj)
self.gobj += 1
ptr = struct.pack("Q", a.global_id).decode("latin_1")
if ptr not in saved_objs:
if isinstance(a, cl.Buffer):
needs_load = a in self.buffers_to_save
jdat['objects'].append({
"id": ptr, "arg_type": "float*", "needs_load": needs_load, "size": a.size,
})
if needs_load:
data = np.empty(a.size//4, dtype=np.float32)
cl.enqueue_copy(CL.queue, data, a, is_blocking=True)
weights.append(data.tobytes())
elif isinstance(a, cl.Image):
assert a.format == cl.ImageFormat(cl.channel_order.RGBA, cl.channel_type.HALF_FLOAT if FLOAT16 else cl.channel_type.FLOAT), "wrong type"
needs_load = a in self.buffers_to_save
row_pitch = (a.shape[0]*4*(2 if FLOAT16 else 4) + 63)//64 * 64
size = row_pitch * a.shape[1]
# this is *2 if float16 and *4 if float32
buf = cl.Buffer(CL.ctx, cl.mem_flags.READ_WRITE, size=size * (2 if FLOAT16 else 1))
# zero out the buffer
cl.enqueue_copy(CL.queue, buf, b'\x00'*buf.size, is_blocking=True)
CLProgram(CL, "from_image_strided", compile_gpu("""
__kernel void from_image_strided(read_only image2d_t in, __global float4 *out, int row_pitch) {
const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
int2 l;
l.y = get_global_id(1);
l.x = get_global_id(0);
out[l.y*row_pitch + l.x] = read_imagef(in, smp, l);
}
"""), bufs=2, vars=1)(a, buf, row_pitch//(4*(2 if FLOAT16 else 4)), global_size=a.shape)
# multiple of 32 isn't enough
jdat['objects'].append({
"id": ptr, "needs_load": needs_load, "size": size, "arg_type": "image2d_t",
"width": a.shape[0], "height": a.shape[1], "row_pitch": row_pitch, "float32": not FLOAT16,
})
if needs_load:
data = np.empty(size//(2 if FLOAT16 else 4), dtype=np.float32)
cl.enqueue_copy(CL.queue, data, buf, is_blocking=True)
if FLOAT16: data = data.astype(np.float16)
weights.append(data.tobytes())
else:
raise Exception("unknown object", a)
#print(jdat['objects'][-1])
saved_objs.add(ptr)
targs.append(ptr)
args_size.append(8)
else:
raise Exception("idk this type")
# save the kernel itself
jdat['kernels'].append({
"name": prg.name,
"work_dim": len(args[0]),
"global_work_size": args[0],
# TODO: C++ thneed requires a local_work_size, so we fill it with ones
"local_work_size": [1 for _ in args[0]] if args[1] is None else args[1],
"num_args": len(args)-2,
"args": targs,
"args_size": args_size
})
jdat['outputs'] = [{
"buffer_id": struct.pack("Q", x.global_id).decode("latin_1"),
"size": x.size,
} for x in self.outputs]
jdat['inputs'] = [{
"buffer_id": struct.pack("Q", v.global_id).decode("latin_1"),
"size": v.size,
"name": k
} for k,v in self.inputs.items()][::-1]
print(f"saving thneed to {output_fn}")
with open(output_fn, "wb") as f:
j = json.dumps(jdat, ensure_ascii=False).encode('latin_1')
f.write(struct.pack("I", len(j)))
f.write(j)
f.write(b''.join(weights))
f.write(b''.join(binaries))
def run(self):
events = []
st = time.monotonic()
for prg, args in self.cl_cache:
events.append(prg.clprg(CL.queue, *args))
mt = time.monotonic()
Device["GPU"].synchronize()
et = time.monotonic() - st
print(f"submit in {(mt-st)*1000.0:.2f} ms, total runtime is {et*1000.0:.2f} ms")
if DEBUGCL >= 2:
for i, ((prg, args), e) in enumerate(zip(self.cl_cache, events)):
print(f"{i:3d} {prg.name:25s} " + "queued @ %5.2f ms, submit @ %5.2fms, start @ %5.2f ms, end @ %5.2f ms" % tuple((x*OSX_TIMING_RATIO - st*1e9)/1e6 for x in [e.profile.queued, e.profile.submit, e.profile.start, e.profile.end]))
if DEBUGCL >= 1:
total_runtime = 0
for i, ((prg, args), e) in enumerate(zip(self.cl_cache, events)):
runtime = (e.profile.end - e.profile.start) * OSX_TIMING_RATIO
print(f"{i:3d} time {total_runtime/1e6:5.2f} ms running {prg.name:25s} with {str(args[0]):15s} {str(args[1]):15s} count {len(args)-2:2d} runtime {runtime/1e3:7.2f} us {(getattr(prg, 'op_estimate', float('nan')))/runtime:9.2f} GFLOPS -> {args[2].shape if hasattr(args[2], 'shape') else args[2].size}")
if hasattr(prg, 'prg') and ((DEBUGCL >= 2 and getenv("PRINT_KERNEL", -1) == i) or DEBUGCL >= 3):
print(prg.prg)
total_runtime += runtime
print(f"total runtime: {total_runtime/1e6:.2f} ms wall time: {et*1000.0:.2f} ms")
return total_runtime/1e9
return et