Tinygrad runner (#34261)

* squash

* dmonitoringmodeld: use cl transform (#34235)

* needs cleanup

* only if tici

* bump tinygrad

* check width

* base modelframe

* .

* need to be args

* more cleanup

* no _frame in base

* tici only

* its DrivingModelFrame

* .6 is fair

---------

Co-authored-by: Comma Device <device@comma.ai>

* Update tinygrad

* tg upstream

* bump tg

* bump tg

* debug

* attr

* misc cleanup

* whitespace

* remove

* Add TODOs to make python proc for modelrunners

* whitespace

---------

Co-authored-by: ZwX1616 <zwx1616@gmail.com>
Co-authored-by: Comma Device <device@comma.ai>
Co-authored-by: Maxime Desroches <desroches.maxime@gmail.com>
pull/34289/head
Harald Schäfer 4 months ago committed by GitHub
parent ff97a43c50
commit 17ca6389e1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 3
      pyproject.toml
  2. 2
      release/release_files.py
  3. 46
      selfdrive/modeld/SConscript
  4. 6
      selfdrive/modeld/dmonitoringmodeld
  5. 55
      selfdrive/modeld/dmonitoringmodeld.py
  6. 6
      selfdrive/modeld/modeld
  7. 71
      selfdrive/modeld/modeld.py
  8. 26
      selfdrive/modeld/models/commonmodel.cc
  9. 12
      selfdrive/modeld/models/commonmodel.h
  10. 4
      selfdrive/modeld/models/commonmodel.pxd
  11. 24
      selfdrive/modeld/models/commonmodel_pyx.pyx
  12. 27
      selfdrive/modeld/runners/__init__.py
  13. 71
      selfdrive/modeld/runners/onnxmodel.py
  14. 4
      selfdrive/modeld/runners/run.h
  15. 49
      selfdrive/modeld/runners/runmodel.h
  16. 14
      selfdrive/modeld/runners/runmodel.pxd
  17. 6
      selfdrive/modeld/runners/runmodel_pyx.pxd
  18. 37
      selfdrive/modeld/runners/runmodel_pyx.pyx
  19. 116
      selfdrive/modeld/runners/snpemodel.cc
  20. 52
      selfdrive/modeld/runners/snpemodel.h
  21. 9
      selfdrive/modeld/runners/snpemodel.pxd
  22. 17
      selfdrive/modeld/runners/snpemodel_pyx.pyx
  23. 58
      selfdrive/modeld/runners/thneedmodel.cc
  24. 17
      selfdrive/modeld/runners/thneedmodel.h
  25. 9
      selfdrive/modeld/runners/thneedmodel.pxd
  26. 14
      selfdrive/modeld/runners/thneedmodel_pyx.pyx
  27. 8
      selfdrive/modeld/runners/tinygrad_helpers.py
  28. 8
      selfdrive/modeld/thneed/README
  29. 0
      selfdrive/modeld/thneed/__init__.py
  30. 154
      selfdrive/modeld/thneed/serialize.cc
  31. 133
      selfdrive/modeld/thneed/thneed.h
  32. 216
      selfdrive/modeld/thneed/thneed_common.cc
  33. 32
      selfdrive/modeld/thneed/thneed_pc.cc
  34. 258
      selfdrive/modeld/thneed/thneed_qcom2.cc
  35. 11
      selfdrive/test/test_onroad.py
  36. 2
      system/hardware/tici/tests/test_power_draw.py
  37. 2
      system/manager/process_config.py
  38. 2
      tinygrad_repo
  39. 73
      uv.lock

@ -42,8 +42,7 @@ dependencies = [
# modeld
"onnx >= 1.14.0",
"onnxruntime >=1.16.3; platform_system == 'Linux' and platform_machine == 'aarch64'",
"onnxruntime-gpu >=1.16.3; platform_system == 'Linux' and platform_machine == 'x86_64'",
"onnxruntime >=1.16.3",
# logging
"pyzmq",

@ -54,7 +54,7 @@ whitelist = [
"tools/joystick/",
"tools/longitudinal_maneuvers/",
"tinygrad_repo/openpilot/compile2.py",
"tinygrad_repo/examples/openpilot/compile3.py",
"tinygrad_repo/extra/onnx.py",
"tinygrad_repo/extra/onnx_ops.py",
"tinygrad_repo/extra/thneed.py",

@ -13,20 +13,6 @@ common_src = [
"transforms/transform.cc",
]
thneed_src_common = [
"thneed/thneed_common.cc",
"thneed/serialize.cc",
]
thneed_src_qcom = thneed_src_common + ["thneed/thneed_qcom2.cc"]
thneed_src_pc = thneed_src_common + ["thneed/thneed_pc.cc"]
thneed_src = thneed_src_qcom if arch == "larch64" else thneed_src_pc
# SNPE except on Mac and ARM Linux
snpe_lib = []
if arch != "Darwin" and arch != "aarch64":
common_src += ['runners/snpemodel.cc']
snpe_lib += ['SNPE']
# OpenCL is a framework on Mac
if arch == "Darwin":
@ -45,34 +31,24 @@ snpe_rpath_pc = f"{Dir('#').abspath}/third_party/snpe/x86_64-linux-clang"
snpe_rpath = lenvCython['RPATH'] + [snpe_rpath_qcom if arch == "larch64" else snpe_rpath_pc]
cython_libs = envCython["LIBS"] + libs
snpemodel_lib = lenv.Library('snpemodel', ['runners/snpemodel.cc'])
commonmodel_lib = lenv.Library('commonmodel', common_src)
lenvCython.Program('runners/runmodel_pyx.so', 'runners/runmodel_pyx.pyx', LIBS=cython_libs, FRAMEWORKS=frameworks)
lenvCython.Program('runners/snpemodel_pyx.so', 'runners/snpemodel_pyx.pyx', LIBS=[snpemodel_lib, snpe_lib, *cython_libs], FRAMEWORKS=frameworks, RPATH=snpe_rpath)
lenvCython.Program('models/commonmodel_pyx.so', 'models/commonmodel_pyx.pyx', LIBS=[commonmodel_lib, *cython_libs], FRAMEWORKS=frameworks)
tinygrad_files = ["#"+x for x in glob.glob(env.Dir("#tinygrad_repo").relpath + "/**", recursive=True, root_dir=env.Dir("#").abspath)]
tinygrad_files = ["#"+x for x in glob.glob(env.Dir("#tinygrad_repo").relpath + "/**", recursive=True, root_dir=env.Dir("#").abspath) if 'pycache' not in x]
# Get model metadata
fn = File("models/supercombo").abspath
cmd = f'python3 {Dir("#selfdrive/modeld").abspath}/get_model_metadata.py {fn}.onnx'
lenv.Command(fn + "_metadata.pkl", [fn + ".onnx"] + tinygrad_files, cmd)
# Build thneed model
if arch == "larch64" or GetOption('pc_thneed'):
tinygrad_opts = []
if not GetOption('pc_thneed'):
# use FLOAT16 on device for speed + don't cache the CL kernels for space
tinygrad_opts += ["FLOAT16=1", "PYOPENCL_NO_CACHE=1"]
cmd = f"cd {Dir('#').abspath}/tinygrad_repo && " + ' '.join(tinygrad_opts) + f" python3 openpilot/compile2.py {fn}.onnx {fn}.thneed"
lenv.Command(fn + ".thneed", [fn + ".onnx"] + tinygrad_files, cmd)
# Compile tinygrad model
pythonpath_string = 'PYTHONPATH="${PYTHONPATH}:' + env.Dir("#tinygrad_repo").abspath + '"'
if arch == 'larch64':
device_string = 'QCOM=1'
else:
device_string = 'CLANG=1 IMAGE=0'
fn_dm = File("models/dmonitoring_model").abspath
cmd = f"cd {Dir('#').abspath}/tinygrad_repo && " + ' '.join(tinygrad_opts) + f" python3 openpilot/compile2.py {fn_dm}.onnx {fn_dm}.thneed"
lenv.Command(fn_dm + ".thneed", [fn_dm + ".onnx"] + tinygrad_files, cmd)
for model_name in ['supercombo', 'dmonitoring_model']:
fn = File(f"models/{model_name}").abspath
cmd = f'{pythonpath_string} {device_string} python3 {Dir("#tinygrad_repo").abspath}/examples/openpilot/compile3.py {fn}.onnx {fn}_tinygrad.pkl'
lenv.Command(fn + "_tinygrad.pkl", [fn + ".onnx"] + tinygrad_files, cmd)
thneed_lib = env.SharedLibrary('thneed', thneed_src, LIBS=[gpucommon, common, 'OpenCL', 'dl'])
thneedmodel_lib = env.Library('thneedmodel', ['runners/thneedmodel.cc'])
lenvCython.Program('runners/thneedmodel_pyx.so', 'runners/thneedmodel_pyx.pyx', LIBS=envCython["LIBS"]+[thneedmodel_lib, thneed_lib, gpucommon, common, 'dl', 'OpenCL'])

@ -1,10 +1,4 @@
#!/usr/bin/env bash
DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" >/dev/null && pwd)"
cd "$DIR/../../"
if [ -f "$DIR/libthneed.so" ]; then
export LD_PRELOAD="$DIR/libthneed.so"
fi
exec "$DIR/dmonitoringmodeld.py" "$@"

@ -1,8 +1,17 @@
#!/usr/bin/env python3
import os
from openpilot.system.hardware import TICI
if TICI:
from tinygrad.tensor import Tensor
from tinygrad.dtype import dtypes
from openpilot.selfdrive.modeld.runners.tinygrad_helpers import qcom_tensor_from_opencl_address
os.environ['QCOM'] = '1'
else:
from openpilot.selfdrive.modeld.runners.ort_helpers import make_onnx_cpu_runner
import gc
import math
import time
import pickle
import ctypes
import numpy as np
from pathlib import Path
@ -13,21 +22,20 @@ from cereal.messaging import PubMaster, SubMaster
from msgq.visionipc import VisionIpcClient, VisionStreamType, VisionBuf
from openpilot.common.swaglog import cloudlog
from openpilot.common.realtime import set_realtime_priority
from openpilot.common.transformations.model import dmonitoringmodel_intrinsics
from openpilot.common.transformations.model import dmonitoringmodel_intrinsics, DM_INPUT_SIZE
from openpilot.common.transformations.camera import _ar_ox_fisheye, _os_fisheye
from openpilot.selfdrive.modeld.models.commonmodel_pyx import CLContext, MonitoringModelFrame
from openpilot.selfdrive.modeld.runners import ModelRunner, Runtime
from openpilot.selfdrive.modeld.parse_model_outputs import sigmoid
MODEL_WIDTH, MODEL_HEIGHT = DM_INPUT_SIZE
CALIB_LEN = 3
FEATURE_LEN = 512
OUTPUT_SIZE = 84 + FEATURE_LEN
PROCESS_NAME = "selfdrive.modeld.dmonitoringmodeld"
SEND_RAW_PRED = os.getenv('SEND_RAW_PRED')
MODEL_PATHS = {
ModelRunner.THNEED: Path(__file__).parent / 'models/dmonitoring_model.thneed',
ModelRunner.ONNX: Path(__file__).parent / 'models/dmonitoring_model.onnx'}
MODEL_PATH = Path(__file__).parent / 'models/dmonitoring_model.onnx'
MODEL_PKL_PATH = Path(__file__).parent / 'models/dmonitoring_model_tinygrad.pkl'
class DriverStateResult(ctypes.Structure):
_fields_ = [
@ -58,29 +66,42 @@ class DMonitoringModelResult(ctypes.Structure):
class ModelState:
inputs: dict[str, np.ndarray]
output: np.ndarray
model: ModelRunner
def __init__(self, cl_ctx):
assert ctypes.sizeof(DMonitoringModelResult) == OUTPUT_SIZE * ctypes.sizeof(ctypes.c_float)
self.frame = MonitoringModelFrame(cl_ctx)
self.output = np.zeros(OUTPUT_SIZE, dtype=np.float32)
self.inputs = {
'calib': np.zeros(CALIB_LEN, dtype=np.float32)}
self.numpy_inputs = {
'calib': np.zeros((1, CALIB_LEN), dtype=np.float32),
}
self.model = ModelRunner(MODEL_PATHS, self.output, Runtime.GPU, False, cl_ctx)
self.model.addInput("input_img", None)
self.model.addInput("calib", self.inputs['calib'])
if TICI:
self.tensor_inputs = {k: Tensor(v, device='NPY').realize() for k,v in self.numpy_inputs.items()}
with open(MODEL_PKL_PATH, "rb") as f:
self.model_run = pickle.load(f)
else:
self.onnx_cpu_runner = make_onnx_cpu_runner(MODEL_PATH)
def run(self, buf:VisionBuf, calib:np.ndarray, transform:np.ndarray) -> tuple[np.ndarray, float]:
self.inputs['calib'][:] = calib
self.model.setInputBuffer("input_img", self.frame.prepare(buf, transform.flatten(), None).view(np.float32))
self.numpy_inputs['calib'][0,:] = calib
t1 = time.perf_counter()
self.model.execute()
input_img_cl = self.frame.prepare(buf, transform.flatten())
if TICI:
# The imgs tensors are backed by opencl memory, only need init once
if 'input_img' not in self.tensor_inputs:
self.tensor_inputs['input_img'] = qcom_tensor_from_opencl_address(input_img_cl.mem_address, (1, MODEL_WIDTH*MODEL_HEIGHT), dtype=dtypes.uint8)
else:
self.numpy_inputs['input_img'] = self.frame.buffer_from_cl(input_img_cl).reshape((1, MODEL_WIDTH*MODEL_HEIGHT))
if TICI:
output = self.model_run(**self.tensor_inputs).numpy().flatten()
else:
output = self.onnx_cpu_runner.run(None, self.numpy_inputs)[0].flatten()
t2 = time.perf_counter()
return self.output, t2 - t1
return output, t2 - t1
def fill_driver_state(msg, ds_result: DriverStateResult):

@ -1,10 +1,4 @@
#!/usr/bin/env bash
DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" >/dev/null && pwd)"
cd "$DIR/../../"
if [ -f "$DIR/libthneed.so" ]; then
export LD_PRELOAD="$DIR/libthneed.so"
fi
exec "$DIR/modeld.py" "$@"

@ -1,5 +1,15 @@
#!/usr/bin/env python3
import os
from openpilot.system.hardware import TICI
#
if TICI:
from tinygrad.tensor import Tensor
from tinygrad.dtype import dtypes
from openpilot.selfdrive.modeld.runners.tinygrad_helpers import qcom_tensor_from_opencl_address
os.environ['QCOM'] = '1'
else:
from openpilot.selfdrive.modeld.runners.ort_helpers import make_onnx_cpu_runner
import time
import pickle
import numpy as np
@ -18,22 +28,19 @@ from openpilot.common.transformations.camera import DEVICE_CAMERAS
from openpilot.common.transformations.model import get_warp_matrix
from openpilot.system import sentry
from openpilot.selfdrive.controls.lib.desire_helper import DesireHelper
from openpilot.selfdrive.modeld.runners import ModelRunner, Runtime
from openpilot.selfdrive.modeld.parse_model_outputs import Parser
from openpilot.selfdrive.modeld.fill_model_msg import fill_model_msg, fill_pose_msg, PublishState
from openpilot.selfdrive.modeld.constants import ModelConstants
from openpilot.selfdrive.modeld.models.commonmodel_pyx import DrivingModelFrame, CLContext
PROCESS_NAME = "selfdrive.modeld.modeld"
SEND_RAW_PRED = os.getenv('SEND_RAW_PRED')
MODEL_PATHS = {
ModelRunner.THNEED: Path(__file__).parent / 'models/supercombo.thneed',
ModelRunner.ONNX: Path(__file__).parent / 'models/supercombo.onnx'}
MODEL_PATH = Path(__file__).parent / 'models/supercombo.onnx'
MODEL_PKL_PATH = Path(__file__).parent / 'models/supercombo_tinygrad.pkl'
METADATA_PATH = Path(__file__).parent / 'models/supercombo_metadata.pkl'
class FrameMeta:
frame_id: int = 0
timestamp_sof: int = 0
@ -44,40 +51,39 @@ class FrameMeta:
self.frame_id, self.timestamp_sof, self.timestamp_eof = vipc.frame_id, vipc.timestamp_sof, vipc.timestamp_eof
class ModelState:
frame: DrivingModelFrame
wide_frame: DrivingModelFrame
frames: dict[str, DrivingModelFrame]
inputs: dict[str, np.ndarray]
output: np.ndarray
prev_desire: np.ndarray # for tracking the rising edge of the pulse
model: ModelRunner
def __init__(self, context: CLContext):
self.frame = DrivingModelFrame(context)
self.wide_frame = DrivingModelFrame(context)
self.frames = {'input_imgs': DrivingModelFrame(context), 'big_input_imgs': DrivingModelFrame(context)}
self.prev_desire = np.zeros(ModelConstants.DESIRE_LEN, dtype=np.float32)
self.full_features_20Hz = np.zeros((ModelConstants.FULL_HISTORY_BUFFER_LEN, ModelConstants.FEATURE_LEN), dtype=np.float32)
self.desire_20Hz = np.zeros((ModelConstants.FULL_HISTORY_BUFFER_LEN + 1, ModelConstants.DESIRE_LEN), dtype=np.float32)
# img buffers are managed in openCL transform code
self.inputs = {
'desire': np.zeros(ModelConstants.DESIRE_LEN * (ModelConstants.HISTORY_BUFFER_LEN+1), dtype=np.float32),
'traffic_convention': np.zeros(ModelConstants.TRAFFIC_CONVENTION_LEN, dtype=np.float32),
'features_buffer': np.zeros(ModelConstants.HISTORY_BUFFER_LEN * ModelConstants.FEATURE_LEN, dtype=np.float32),
self.numpy_inputs = {
'desire': np.zeros((1, (ModelConstants.HISTORY_BUFFER_LEN+1), ModelConstants.DESIRE_LEN), dtype=np.float32),
'traffic_convention': np.zeros((1, ModelConstants.TRAFFIC_CONVENTION_LEN), dtype=np.float32),
'features_buffer': np.zeros((1, ModelConstants.HISTORY_BUFFER_LEN, ModelConstants.FEATURE_LEN), dtype=np.float32),
}
with open(METADATA_PATH, 'rb') as f:
model_metadata = pickle.load(f)
self.input_shapes = model_metadata['input_shapes']
self.output_slices = model_metadata['output_slices']
net_output_size = model_metadata['output_shapes']['outputs'][1]
self.output = np.zeros(net_output_size, dtype=np.float32)
self.parser = Parser()
self.model = ModelRunner(MODEL_PATHS, self.output, Runtime.GPU, False, context)
self.model.addInput("input_imgs", None)
self.model.addInput("big_input_imgs", None)
for k,v in self.inputs.items():
self.model.addInput(k, v)
if TICI:
self.tensor_inputs = {k: Tensor(v, device='NPY').realize() for k,v in self.numpy_inputs.items()}
with open(MODEL_PKL_PATH, "rb") as f:
self.model_run = pickle.load(f)
else:
self.onnx_cpu_runner = make_onnx_cpu_runner(MODEL_PATH)
def slice_outputs(self, model_outputs: np.ndarray) -> dict[str, np.ndarray]:
parsed_model_outputs = {k: model_outputs[np.newaxis, v] for k,v in self.output_slices.items()}
@ -94,24 +100,36 @@ class ModelState:
self.desire_20Hz[:-1] = self.desire_20Hz[1:]
self.desire_20Hz[-1] = new_desire
self.inputs['desire'][:] = self.desire_20Hz.reshape((25,4,-1)).max(axis=1).flatten()
self.numpy_inputs['desire'][:] = self.desire_20Hz.reshape((1,25,4,-1)).max(axis=2)
self.inputs['traffic_convention'][:] = inputs['traffic_convention']
self.numpy_inputs['traffic_convention'][:] = inputs['traffic_convention']
imgs_cl = {'input_imgs': self.frames['input_imgs'].prepare(buf, transform.flatten()),
'big_input_imgs': self.frames['big_input_imgs'].prepare(wbuf, transform_wide.flatten())}
self.model.setInputBuffer("input_imgs", self.frame.prepare(buf, transform.flatten(), self.model.getCLBuffer("input_imgs")))
self.model.setInputBuffer("big_input_imgs", self.wide_frame.prepare(wbuf, transform_wide.flatten(), self.model.getCLBuffer("big_input_imgs")))
if TICI:
# The imgs tensors are backed by opencl memory, only need init once
for key in imgs_cl:
if key not in self.tensor_inputs:
self.tensor_inputs[key] = qcom_tensor_from_opencl_address(imgs_cl[key].mem_address, self.input_shapes[key], dtype=dtypes.uint8)
else:
for key in imgs_cl:
self.numpy_inputs[key] = self.frames[key].buffer_from_cl(imgs_cl[key]).reshape(self.input_shapes[key])
if prepare_only:
return None
self.model.execute()
if TICI:
self.output = self.model_run(**self.tensor_inputs).numpy().flatten()
else:
self.output = self.onnx_cpu_runner.run(None, self.numpy_inputs)[0].flatten()
outputs = self.parser.parse_outputs(self.slice_outputs(self.output))
self.full_features_20Hz[:-1] = self.full_features_20Hz[1:]
self.full_features_20Hz[-1] = outputs['hidden_state'][0, :]
idxs = np.arange(-4,-100,-4)[::-1]
self.inputs['features_buffer'][:] = self.full_features_20Hz[idxs].flatten()
self.numpy_inputs['features_buffer'][:] = self.full_features_20Hz[idxs]
return outputs
@ -281,7 +299,6 @@ def main(demo=False):
pm.send('modelV2', modelv2_send)
pm.send('drivingModelData', drivingdata_send)
pm.send('cameraOdometry', posenet_send)
last_vipc_frame_id = meta_main.frame_id

@ -7,7 +7,7 @@
DrivingModelFrame::DrivingModelFrame(cl_device_id device_id, cl_context context) : ModelFrame(device_id, context) {
input_frames = std::make_unique<uint8_t[]>(buf_size);
//input_frames_cl = CL_CHECK_ERR(clCreateBuffer(context, CL_MEM_READ_WRITE, buf_size, NULL, &err));
input_frames_cl = CL_CHECK_ERR(clCreateBuffer(context, CL_MEM_READ_WRITE, buf_size, NULL, &err));
img_buffer_20hz_cl = CL_CHECK_ERR(clCreateBuffer(context, CL_MEM_READ_WRITE, 5*frame_size_bytes, NULL, &err));
region.origin = 4 * frame_size_bytes;
region.size = frame_size_bytes;
@ -17,7 +17,7 @@ DrivingModelFrame::DrivingModelFrame(cl_device_id device_id, cl_context context)
init_transform(device_id, context, MODEL_WIDTH, MODEL_HEIGHT);
}
uint8_t* DrivingModelFrame::prepare(cl_mem yuv_cl, int frame_width, int frame_height, int frame_stride, int frame_uv_offset, const mat3& projection, cl_mem* output) {
cl_mem* DrivingModelFrame::prepare(cl_mem yuv_cl, int frame_width, int frame_height, int frame_stride, int frame_uv_offset, const mat3& projection) {
run_transform(yuv_cl, MODEL_WIDTH, MODEL_HEIGHT, frame_width, frame_height, frame_stride, frame_uv_offset, projection);
for (int i = 0; i < 4; i++) {
@ -25,19 +25,12 @@ uint8_t* DrivingModelFrame::prepare(cl_mem yuv_cl, int frame_width, int frame_he
}
loadyuv_queue(&loadyuv, q, y_cl, u_cl, v_cl, last_img_cl);
if (output == NULL) {
CL_CHECK(clEnqueueReadBuffer(q, img_buffer_20hz_cl, CL_TRUE, 0, frame_size_bytes, &input_frames[0], 0, nullptr, nullptr));
CL_CHECK(clEnqueueReadBuffer(q, last_img_cl, CL_TRUE, 0, frame_size_bytes, &input_frames[MODEL_FRAME_SIZE], 0, nullptr, nullptr));
clFinish(q);
return &input_frames[0];
} else {
copy_queue(&loadyuv, q, img_buffer_20hz_cl, *output, 0, 0, frame_size_bytes);
copy_queue(&loadyuv, q, last_img_cl, *output, 0, frame_size_bytes, frame_size_bytes);
copy_queue(&loadyuv, q, img_buffer_20hz_cl, input_frames_cl, 0, 0, frame_size_bytes);
copy_queue(&loadyuv, q, last_img_cl, input_frames_cl, 0, frame_size_bytes, frame_size_bytes);
// NOTE: Since thneed is using a different command queue, this clFinish is needed to ensure the image is ready.
clFinish(q);
return NULL;
}
return &input_frames_cl;
}
DrivingModelFrame::~DrivingModelFrame() {
@ -51,16 +44,15 @@ DrivingModelFrame::~DrivingModelFrame() {
MonitoringModelFrame::MonitoringModelFrame(cl_device_id device_id, cl_context context) : ModelFrame(device_id, context) {
input_frames = std::make_unique<uint8_t[]>(buf_size);
//input_frame_cl = CL_CHECK_ERR(clCreateBuffer(context, CL_MEM_READ_WRITE, buf_size, NULL, &err));
input_frame_cl = CL_CHECK_ERR(clCreateBuffer(context, CL_MEM_READ_WRITE, buf_size, NULL, &err));
init_transform(device_id, context, MODEL_WIDTH, MODEL_HEIGHT);
}
uint8_t* MonitoringModelFrame::prepare(cl_mem yuv_cl, int frame_width, int frame_height, int frame_stride, int frame_uv_offset, const mat3& projection, cl_mem* output) {
cl_mem* MonitoringModelFrame::prepare(cl_mem yuv_cl, int frame_width, int frame_height, int frame_stride, int frame_uv_offset, const mat3& projection) {
run_transform(yuv_cl, MODEL_WIDTH, MODEL_HEIGHT, frame_width, frame_height, frame_stride, frame_uv_offset, projection);
CL_CHECK(clEnqueueReadBuffer(q, y_cl, CL_TRUE, 0, MODEL_FRAME_SIZE * sizeof(uint8_t), input_frames.get(), 0, nullptr, nullptr));
clFinish(q);
//return &y_cl;
return input_frames.get();
return &y_cl;
}
MonitoringModelFrame::~MonitoringModelFrame() {

@ -23,14 +23,12 @@ public:
q = CL_CHECK_ERR(clCreateCommandQueue(context, device_id, 0, &err));
}
virtual ~ModelFrame() {}
virtual uint8_t* prepare(cl_mem yuv_cl, int frame_width, int frame_height, int frame_stride, int frame_uv_offset, const mat3& projection, cl_mem* output) { return NULL; }
/*
virtual cl_mem* prepare(cl_mem yuv_cl, int frame_width, int frame_height, int frame_stride, int frame_uv_offset, const mat3& projection) { return NULL; }
uint8_t* buffer_from_cl(cl_mem *in_frames, int buffer_size) {
CL_CHECK(clEnqueueReadBuffer(q, *in_frames, CL_TRUE, 0, buffer_size, input_frames.get(), 0, nullptr, nullptr));
clFinish(q);
return &input_frames[0];
}
*/
int MODEL_WIDTH;
int MODEL_HEIGHT;
@ -68,7 +66,7 @@ class DrivingModelFrame : public ModelFrame {
public:
DrivingModelFrame(cl_device_id device_id, cl_context context);
~DrivingModelFrame();
uint8_t* prepare(cl_mem yuv_cl, int frame_width, int frame_height, int frame_stride, int frame_uv_offset, const mat3& projection, cl_mem* output);
cl_mem* prepare(cl_mem yuv_cl, int frame_width, int frame_height, int frame_stride, int frame_uv_offset, const mat3& projection);
const int MODEL_WIDTH = 512;
const int MODEL_HEIGHT = 256;
@ -78,7 +76,7 @@ public:
private:
LoadYUVState loadyuv;
cl_mem img_buffer_20hz_cl, last_img_cl;//, input_frames_cl;
cl_mem img_buffer_20hz_cl, last_img_cl, input_frames_cl;
cl_buffer_region region;
};
@ -86,7 +84,7 @@ class MonitoringModelFrame : public ModelFrame {
public:
MonitoringModelFrame(cl_device_id device_id, cl_context context);
~MonitoringModelFrame();
uint8_t* prepare(cl_mem yuv_cl, int frame_width, int frame_height, int frame_stride, int frame_uv_offset, const mat3& projection, cl_mem* output);
cl_mem* prepare(cl_mem yuv_cl, int frame_width, int frame_height, int frame_stride, int frame_uv_offset, const mat3& projection);
const int MODEL_WIDTH = 1440;
const int MODEL_HEIGHT = 960;
@ -94,5 +92,5 @@ public:
const int buf_size = MODEL_FRAME_SIZE;
private:
// cl_mem input_frame_cl;
cl_mem input_frame_cl;
};

@ -14,8 +14,8 @@ cdef extern from "common/clutil.h":
cdef extern from "selfdrive/modeld/models/commonmodel.h":
cppclass ModelFrame:
int buf_size
# unsigned char * buffer_from_cl(cl_mem*, int);
unsigned char * prepare(cl_mem, int, int, int, int, mat3, cl_mem*)
unsigned char * buffer_from_cl(cl_mem*, int);
cl_mem * prepare(cl_mem, int, int, int, int, mat3)
cppclass DrivingModelFrame:
int buf_size

@ -39,24 +39,17 @@ cdef class ModelFrame:
def __dealloc__(self):
del self.frame
def prepare(self, VisionBuf buf, float[:] projection, CLMem output):
def prepare(self, VisionBuf buf, float[:] projection):
cdef mat3 cprojection
memcpy(cprojection.v, &projection[0], 9*sizeof(float))
cdef unsigned char * data
if output is None:
data = self.frame.prepare(buf.buf.buf_cl, buf.width, buf.height, buf.stride, buf.uv_offset, cprojection, NULL)
else:
data = self.frame.prepare(buf.buf.buf_cl, buf.width, buf.height, buf.stride, buf.uv_offset, cprojection, output.mem)
if not data:
return None
cdef cl_mem * data
data = self.frame.prepare(buf.buf.buf_cl, buf.width, buf.height, buf.stride, buf.uv_offset, cprojection)
return CLMem.create(data)
return np.asarray(<cnp.uint8_t[:self.buf_size]> data)
# return CLMem.create(data)
# def buffer_from_cl(self, CLMem in_frames):
# cdef unsigned char * data2
# data2 = self.frame.buffer_from_cl(in_frames.mem, self.buf_size)
# return np.asarray(<cnp.uint8_t[:self.buf_size]> data2)
def buffer_from_cl(self, CLMem in_frames):
cdef unsigned char * data2
data2 = self.frame.buffer_from_cl(in_frames.mem, self.buf_size)
return np.asarray(<cnp.uint8_t[:self.buf_size]> data2)
cdef class DrivingModelFrame(ModelFrame):
@ -74,3 +67,4 @@ cdef class MonitoringModelFrame(ModelFrame):
self._frame = new cppMonitoringModelFrame(context.device_id, context.context)
self.frame = <cppModelFrame*>(self._frame)
self.buf_size = self._frame.buf_size

@ -1,27 +0,0 @@
import os
from openpilot.system.hardware import TICI
from openpilot.selfdrive.modeld.runners.runmodel_pyx import RunModel, Runtime
assert Runtime
USE_THNEED = int(os.getenv('USE_THNEED', str(int(TICI))))
USE_SNPE = int(os.getenv('USE_SNPE', str(int(TICI))))
class ModelRunner(RunModel):
THNEED = 'THNEED'
SNPE = 'SNPE'
ONNX = 'ONNX'
def __new__(cls, paths, *args, **kwargs):
if ModelRunner.THNEED in paths and USE_THNEED:
from openpilot.selfdrive.modeld.runners.thneedmodel_pyx import ThneedModel as Runner
runner_type = ModelRunner.THNEED
elif ModelRunner.SNPE in paths and USE_SNPE:
from openpilot.selfdrive.modeld.runners.snpemodel_pyx import SNPEModel as Runner
runner_type = ModelRunner.SNPE
elif ModelRunner.ONNX in paths:
from openpilot.selfdrive.modeld.runners.onnxmodel import ONNXModel as Runner
runner_type = ModelRunner.ONNX
else:
raise Exception("Couldn't select a model runner, make sure to pass at least one valid model path")
return Runner(str(paths[runner_type]), *args, **kwargs)

@ -1,71 +0,0 @@
import os
import onnx
import sys
import numpy as np
from typing import Any
from openpilot.selfdrive.modeld.runners.runmodel_pyx import RunModel
from openpilot.selfdrive.modeld.runners.ort_helpers import convert_fp16_to_fp32, ORT_TYPES_TO_NP_TYPES
def create_ort_session(path, fp16_to_fp32):
os.environ["OMP_NUM_THREADS"] = "4"
os.environ["OMP_WAIT_POLICY"] = "PASSIVE"
import onnxruntime as ort
print("Onnx available providers: ", ort.get_available_providers(), file=sys.stderr)
options = ort.SessionOptions()
options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL
provider: str | tuple[str, dict[Any, Any]]
if 'OpenVINOExecutionProvider' in ort.get_available_providers() and 'ONNXCPU' not in os.environ:
provider = 'OpenVINOExecutionProvider'
elif 'CUDAExecutionProvider' in ort.get_available_providers() and 'ONNXCPU' not in os.environ:
options.intra_op_num_threads = 2
provider = ('CUDAExecutionProvider', {'cudnn_conv_algo_search': 'EXHAUSTIVE'})
else:
options.intra_op_num_threads = 2
options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
provider = 'CPUExecutionProvider'
model_data = convert_fp16_to_fp32(onnx.load(path)) if fp16_to_fp32 else path
print("Onnx selected provider: ", [provider], file=sys.stderr)
ort_session = ort.InferenceSession(model_data, options, providers=[provider])
print("Onnx using ", ort_session.get_providers(), file=sys.stderr)
return ort_session
class ONNXModel(RunModel):
def __init__(self, path, output, runtime, use_tf8, cl_context):
self.inputs = {}
self.output = output
self.session = create_ort_session(path, fp16_to_fp32=True)
self.input_names = [x.name for x in self.session.get_inputs()]
self.input_shapes = {x.name: [1, *x.shape[1:]] for x in self.session.get_inputs()}
self.input_dtypes = {x.name: ORT_TYPES_TO_NP_TYPES[x.type] for x in self.session.get_inputs()}
# run once to initialize CUDA provider
if "CUDAExecutionProvider" in self.session.get_providers():
self.session.run(None, {k: np.zeros(self.input_shapes[k], dtype=self.input_dtypes[k]) for k in self.input_names})
print("ready to run onnx model", self.input_shapes, file=sys.stderr)
def addInput(self, name, buffer):
assert name in self.input_names
self.inputs[name] = buffer
def setInputBuffer(self, name, buffer):
assert name in self.inputs
self.inputs[name] = buffer
def getCLBuffer(self, name):
return None
def execute(self):
inputs = {k: v.view(self.input_dtypes[k]) for k,v in self.inputs.items()}
inputs = {k: v.reshape(self.input_shapes[k]).astype(self.input_dtypes[k]) for k,v in inputs.items()}
outputs = self.session.run(None, inputs)
assert len(outputs) == 1, "Only single model outputs are supported"
self.output[:] = outputs[0]
return self.output

@ -1,4 +0,0 @@
#pragma once
#include "selfdrive/modeld/runners/runmodel.h"
#include "selfdrive/modeld/runners/snpemodel.h"

@ -1,49 +0,0 @@
#pragma once
#include <string>
#include <vector>
#include <memory>
#include <cassert>
#include "common/clutil.h"
#include "common/swaglog.h"
#define USE_CPU_RUNTIME 0
#define USE_GPU_RUNTIME 1
#define USE_DSP_RUNTIME 2
struct ModelInput {
const std::string name;
float *buffer;
int size;
ModelInput(const std::string _name, float *_buffer, int _size) : name(_name), buffer(_buffer), size(_size) {}
virtual void setBuffer(float *_buffer, int _size) {
assert(size == _size || size == 0);
buffer = _buffer;
size = _size;
}
};
class RunModel {
public:
std::vector<std::unique_ptr<ModelInput>> inputs;
virtual ~RunModel() {}
virtual void execute() {}
virtual void* getCLBuffer(const std::string name) { return nullptr; }
virtual void addInput(const std::string name, float *buffer, int size) {
inputs.push_back(std::unique_ptr<ModelInput>(new ModelInput(name, buffer, size)));
}
virtual void setInputBuffer(const std::string name, float *buffer, int size) {
for (auto &input : inputs) {
if (name == input->name) {
input->setBuffer(buffer, size);
return;
}
}
LOGE("Tried to update input `%s` but no input with this name exists", name.c_str());
assert(false);
}
};

@ -1,14 +0,0 @@
# distutils: language = c++
from libcpp.string cimport string
cdef extern from "selfdrive/modeld/runners/runmodel.h":
cdef int USE_CPU_RUNTIME
cdef int USE_GPU_RUNTIME
cdef int USE_DSP_RUNTIME
cdef cppclass RunModel:
void addInput(string, float*, int)
void setInputBuffer(string, float*, int)
void * getCLBuffer(string)
void execute()

@ -1,6 +0,0 @@
# distutils: language = c++
from .runmodel cimport RunModel as cppRunModel
cdef class RunModel:
cdef cppRunModel * model

@ -1,37 +0,0 @@
# distutils: language = c++
# cython: c_string_encoding=ascii, language_level=3
from libcpp.string cimport string
from .runmodel cimport USE_CPU_RUNTIME, USE_GPU_RUNTIME, USE_DSP_RUNTIME
from selfdrive.modeld.models.commonmodel_pyx cimport CLMem
class Runtime:
CPU = USE_CPU_RUNTIME
GPU = USE_GPU_RUNTIME
DSP = USE_DSP_RUNTIME
cdef class RunModel:
def __dealloc__(self):
del self.model
def addInput(self, string name, float[:] buffer):
if buffer is not None:
self.model.addInput(name, &buffer[0], len(buffer))
else:
self.model.addInput(name, NULL, 0)
def setInputBuffer(self, string name, float[:] buffer):
if buffer is not None:
self.model.setInputBuffer(name, &buffer[0], len(buffer))
else:
self.model.setInputBuffer(name, NULL, 0)
def getCLBuffer(self, string name):
cdef void * cl_buf = self.model.getCLBuffer(name)
if not cl_buf:
return None
return CLMem.create(cl_buf)
def execute(self):
self.model.execute()

@ -1,116 +0,0 @@
#pragma clang diagnostic ignored "-Wexceptions"
#include "selfdrive/modeld/runners/snpemodel.h"
#include <cstring>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "common/util.h"
#include "common/timing.h"
void PrintErrorStringAndExit() {
std::cerr << zdl::DlSystem::getLastErrorString() << std::endl;
std::exit(EXIT_FAILURE);
}
SNPEModel::SNPEModel(const std::string path, float *_output, size_t _output_size, int runtime, bool _use_tf8, cl_context context) {
output = _output;
output_size = _output_size;
use_tf8 = _use_tf8;
#ifdef QCOM2
if (runtime == USE_GPU_RUNTIME) {
snpe_runtime = zdl::DlSystem::Runtime_t::GPU;
} else if (runtime == USE_DSP_RUNTIME) {
snpe_runtime = zdl::DlSystem::Runtime_t::DSP;
} else {
snpe_runtime = zdl::DlSystem::Runtime_t::CPU;
}
assert(zdl::SNPE::SNPEFactory::isRuntimeAvailable(snpe_runtime));
#endif
model_data = util::read_file(path);
assert(model_data.size() > 0);
// load model
std::unique_ptr<zdl::DlContainer::IDlContainer> container = zdl::DlContainer::IDlContainer::open((uint8_t*)model_data.data(), model_data.size());
if (!container) { PrintErrorStringAndExit(); }
LOGW("loaded model with size: %lu", model_data.size());
// create model runner
zdl::SNPE::SNPEBuilder snpe_builder(container.get());
while (!snpe) {
#ifdef QCOM2
snpe = snpe_builder.setOutputLayers({})
.setRuntimeProcessor(snpe_runtime)
.setUseUserSuppliedBuffers(true)
.setPerformanceProfile(zdl::DlSystem::PerformanceProfile_t::HIGH_PERFORMANCE)
.build();
#else
snpe = snpe_builder.setOutputLayers({})
.setUseUserSuppliedBuffers(true)
.setPerformanceProfile(zdl::DlSystem::PerformanceProfile_t::HIGH_PERFORMANCE)
.build();
#endif
if (!snpe) std::cerr << zdl::DlSystem::getLastErrorString() << std::endl;
}
// create output buffer
zdl::DlSystem::UserBufferEncodingFloat ub_encoding_float;
zdl::DlSystem::IUserBufferFactory &ub_factory = zdl::SNPE::SNPEFactory::getUserBufferFactory();
const auto &output_tensor_names_opt = snpe->getOutputTensorNames();
if (!output_tensor_names_opt) throw std::runtime_error("Error obtaining output tensor names");
const auto &output_tensor_names = *output_tensor_names_opt;
assert(output_tensor_names.size() == 1);
const char *output_tensor_name = output_tensor_names.at(0);
const zdl::DlSystem::TensorShape &buffer_shape = snpe->getInputOutputBufferAttributes(output_tensor_name)->getDims();
if (output_size != 0) {
assert(output_size == buffer_shape[1]);
} else {
output_size = buffer_shape[1];
}
std::vector<size_t> output_strides = {output_size * sizeof(float), sizeof(float)};
output_buffer = ub_factory.createUserBuffer(output, output_size * sizeof(float), output_strides, &ub_encoding_float);
output_map.add(output_tensor_name, output_buffer.get());
}
void SNPEModel::addInput(const std::string name, float *buffer, int size) {
const int idx = inputs.size();
const auto &input_tensor_names_opt = snpe->getInputTensorNames();
if (!input_tensor_names_opt) throw std::runtime_error("Error obtaining input tensor names");
const auto &input_tensor_names = *input_tensor_names_opt;
const char *input_tensor_name = input_tensor_names.at(idx);
const bool input_tf8 = use_tf8 && strcmp(input_tensor_name, "input_img") == 0; // TODO: This is a terrible hack, get rid of this name check both here and in onnx_runner.py
LOGW("adding index %d: %s", idx, input_tensor_name);
zdl::DlSystem::UserBufferEncodingFloat ub_encoding_float;
zdl::DlSystem::UserBufferEncodingTf8 ub_encoding_tf8(0, 1./255); // network takes 0-1
zdl::DlSystem::IUserBufferFactory &ub_factory = zdl::SNPE::SNPEFactory::getUserBufferFactory();
zdl::DlSystem::UserBufferEncoding *input_encoding = input_tf8 ? (zdl::DlSystem::UserBufferEncoding*)&ub_encoding_tf8 : (zdl::DlSystem::UserBufferEncoding*)&ub_encoding_float;
const auto &buffer_shape_opt = snpe->getInputDimensions(input_tensor_name);
const zdl::DlSystem::TensorShape &buffer_shape = *buffer_shape_opt;
size_t size_of_input = input_tf8 ? sizeof(uint8_t) : sizeof(float);
std::vector<size_t> strides(buffer_shape.rank());
strides[strides.size() - 1] = size_of_input;
size_t product = 1;
for (size_t i = 0; i < buffer_shape.rank(); i++) product *= buffer_shape[i];
size_t stride = strides[strides.size() - 1];
for (size_t i = buffer_shape.rank() - 1; i > 0; i--) {
stride *= buffer_shape[i];
strides[i-1] = stride;
}
auto input_buffer = ub_factory.createUserBuffer(buffer, product*size_of_input, strides, input_encoding);
input_map.add(input_tensor_name, input_buffer.get());
inputs.push_back(std::unique_ptr<SNPEModelInput>(new SNPEModelInput(name, buffer, size, std::move(input_buffer))));
}
void SNPEModel::execute() {
if (!snpe->execute(input_map, output_map)) {
PrintErrorStringAndExit();
}
}

@ -1,52 +0,0 @@
#pragma once
#pragma clang diagnostic ignored "-Wdeprecated-declarations"
#include <memory>
#include <string>
#include <utility>
#include <DlContainer/IDlContainer.hpp>
#include <DlSystem/DlError.hpp>
#include <DlSystem/ITensor.hpp>
#include <DlSystem/ITensorFactory.hpp>
#include <DlSystem/IUserBuffer.hpp>
#include <DlSystem/IUserBufferFactory.hpp>
#include <SNPE/SNPE.hpp>
#include <SNPE/SNPEBuilder.hpp>
#include <SNPE/SNPEFactory.hpp>
#include "selfdrive/modeld/runners/runmodel.h"
struct SNPEModelInput : public ModelInput {
std::unique_ptr<zdl::DlSystem::IUserBuffer> snpe_buffer;
SNPEModelInput(const std::string _name, float *_buffer, int _size, std::unique_ptr<zdl::DlSystem::IUserBuffer> _snpe_buffer) : ModelInput(_name, _buffer, _size), snpe_buffer(std::move(_snpe_buffer)) {}
void setBuffer(float *_buffer, int _size) {
ModelInput::setBuffer(_buffer, _size);
assert(snpe_buffer->setBufferAddress(_buffer) == true);
}
};
class SNPEModel : public RunModel {
public:
SNPEModel(const std::string path, float *_output, size_t _output_size, int runtime, bool use_tf8 = false, cl_context context = NULL);
void addInput(const std::string name, float *buffer, int size);
void execute();
private:
std::string model_data;
#ifdef QCOM2
zdl::DlSystem::Runtime_t snpe_runtime;
#endif
// snpe model stuff
std::unique_ptr<zdl::SNPE::SNPE> snpe;
zdl::DlSystem::UserBufferMap input_map;
zdl::DlSystem::UserBufferMap output_map;
std::unique_ptr<zdl::DlSystem::IUserBuffer> output_buffer;
bool use_tf8;
float *output;
size_t output_size;
};

@ -1,9 +0,0 @@
# distutils: language = c++
from libcpp.string cimport string
from msgq.visionipc.visionipc cimport cl_context
cdef extern from "selfdrive/modeld/runners/snpemodel.h":
cdef cppclass SNPEModel:
SNPEModel(string, float*, size_t, int, bool, cl_context)

@ -1,17 +0,0 @@
# distutils: language = c++
# cython: c_string_encoding=ascii, language_level=3
import os
from libcpp cimport bool
from libcpp.string cimport string
from .snpemodel cimport SNPEModel as cppSNPEModel
from selfdrive.modeld.models.commonmodel_pyx cimport CLContext
from selfdrive.modeld.runners.runmodel_pyx cimport RunModel
from selfdrive.modeld.runners.runmodel cimport RunModel as cppRunModel
os.environ['ADSP_LIBRARY_PATH'] = "/data/pythonpath/third_party/snpe/dsp/"
cdef class SNPEModel(RunModel):
def __cinit__(self, string path, float[:] output, int runtime, bool use_tf8, CLContext context):
self.model = <cppRunModel *> new cppSNPEModel(path, &output[0], len(output), runtime, use_tf8, context.context)

@ -1,58 +0,0 @@
#include "selfdrive/modeld/runners/thneedmodel.h"
#include <string>
#include "common/swaglog.h"
ThneedModel::ThneedModel(const std::string path, float *_output, size_t _output_size, int runtime, bool luse_tf8, cl_context context) {
thneed = new Thneed(true, context);
thneed->load(path.c_str());
thneed->clexec();
recorded = false;
output = _output;
}
void* ThneedModel::getCLBuffer(const std::string name) {
int index = -1;
for (int i = 0; i < inputs.size(); i++) {
if (name == inputs[i]->name) {
index = i;
break;
}
}
if (index == -1) {
LOGE("Tried to get CL buffer for input `%s` but no input with this name exists", name.c_str());
assert(false);
}
if (thneed->input_clmem.size() >= inputs.size()) {
return &thneed->input_clmem[inputs.size() - index - 1];
} else {
return nullptr;
}
}
void ThneedModel::execute() {
if (!recorded) {
thneed->record = true;
float *input_buffers[inputs.size()];
for (int i = 0; i < inputs.size(); i++) {
input_buffers[inputs.size() - i - 1] = inputs[i]->buffer;
}
thneed->copy_inputs(input_buffers);
thneed->clexec();
thneed->copy_output(output);
thneed->stop();
recorded = true;
} else {
float *input_buffers[inputs.size()];
for (int i = 0; i < inputs.size(); i++) {
input_buffers[inputs.size() - i - 1] = inputs[i]->buffer;
}
thneed->execute(input_buffers, output);
}
}

@ -1,17 +0,0 @@
#pragma once
#include <string>
#include "selfdrive/modeld/runners/runmodel.h"
#include "selfdrive/modeld/thneed/thneed.h"
class ThneedModel : public RunModel {
public:
ThneedModel(const std::string path, float *_output, size_t _output_size, int runtime, bool use_tf8 = false, cl_context context = NULL);
void *getCLBuffer(const std::string name);
void execute();
private:
Thneed *thneed = NULL;
bool recorded;
float *output;
};

@ -1,9 +0,0 @@
# distutils: language = c++
from libcpp.string cimport string
from msgq.visionipc.visionipc cimport cl_context
cdef extern from "selfdrive/modeld/runners/thneedmodel.h":
cdef cppclass ThneedModel:
ThneedModel(string, float*, size_t, int, bool, cl_context)

@ -1,14 +0,0 @@
# distutils: language = c++
# cython: c_string_encoding=ascii, language_level=3
from libcpp cimport bool
from libcpp.string cimport string
from .thneedmodel cimport ThneedModel as cppThneedModel
from selfdrive.modeld.models.commonmodel_pyx cimport CLContext
from selfdrive.modeld.runners.runmodel_pyx cimport RunModel
from selfdrive.modeld.runners.runmodel cimport RunModel as cppRunModel
cdef class ThneedModel(RunModel):
def __cinit__(self, string path, float[:] output, int runtime, bool use_tf8, CLContext context):
self.model = <cppRunModel *> new cppThneedModel(path, &output[0], len(output), runtime, use_tf8, context.context)

@ -0,0 +1,8 @@
from tinygrad.tensor import Tensor
from tinygrad.helpers import to_mv
def qcom_tensor_from_opencl_address(opencl_address, shape, dtype):
cl_buf_desc_ptr = to_mv(opencl_address, 8).cast('Q')[0]
rawbuf_ptr = to_mv(cl_buf_desc_ptr, 0x100).cast('Q')[20] # offset 0xA0 is a raw gpu pointer.
return Tensor.from_blob(rawbuf_ptr, shape, dtype=dtype, device='QCOM')

@ -1,8 +0,0 @@
thneed is an SNPE accelerator. I know SNPE is already an accelerator, but sometimes things need to go even faster..
It runs on the local device, and caches a single model run. Then it replays it, but fast.
thneed slices through abstraction layers like a fish.
You need a thneed.

@ -1,154 +0,0 @@
#include <cassert>
#include <set>
#include "third_party/json11/json11.hpp"
#include "common/util.h"
#include "common/clutil.h"
#include "common/swaglog.h"
#include "selfdrive/modeld/thneed/thneed.h"
using namespace json11;
extern map<cl_program, string> g_program_source;
void Thneed::load(const char *filename) {
LOGD("Thneed::load: loading from %s\n", filename);
string buf = util::read_file(filename);
int jsz = *(int *)buf.data();
string jsonerr;
string jj(buf.data() + sizeof(int), jsz);
Json jdat = Json::parse(jj, jsonerr);
map<cl_mem, cl_mem> real_mem;
real_mem[NULL] = NULL;
int ptr = sizeof(int)+jsz;
for (auto &obj : jdat["objects"].array_items()) {
auto mobj = obj.object_items();
int sz = mobj["size"].int_value();
cl_mem clbuf = NULL;
if (mobj["buffer_id"].string_value().size() > 0) {
// image buffer must already be allocated
clbuf = real_mem[*(cl_mem*)(mobj["buffer_id"].string_value().data())];
assert(mobj["needs_load"].bool_value() == false);
} else {
if (mobj["needs_load"].bool_value()) {
clbuf = clCreateBuffer(context, CL_MEM_COPY_HOST_PTR | CL_MEM_READ_WRITE, sz, &buf[ptr], NULL);
if (debug >= 1) printf("loading %p %d @ 0x%X\n", clbuf, sz, ptr);
ptr += sz;
} else {
// TODO: is there a faster way to init zeroed out buffers?
void *host_zeros = calloc(sz, 1);
clbuf = clCreateBuffer(context, CL_MEM_COPY_HOST_PTR | CL_MEM_READ_WRITE, sz, host_zeros, NULL);
free(host_zeros);
}
}
assert(clbuf != NULL);
if (mobj["arg_type"] == "image2d_t" || mobj["arg_type"] == "image1d_t") {
cl_image_desc desc = {0};
desc.image_type = (mobj["arg_type"] == "image2d_t") ? CL_MEM_OBJECT_IMAGE2D : CL_MEM_OBJECT_IMAGE1D_BUFFER;
desc.image_width = mobj["width"].int_value();
desc.image_height = mobj["height"].int_value();
desc.image_row_pitch = mobj["row_pitch"].int_value();
assert(sz == desc.image_height*desc.image_row_pitch);
#ifdef QCOM2
desc.buffer = clbuf;
#else
// TODO: we are creating unused buffers on PC
clReleaseMemObject(clbuf);
#endif
cl_image_format format = {0};
format.image_channel_order = CL_RGBA;
format.image_channel_data_type = mobj["float32"].bool_value() ? CL_FLOAT : CL_HALF_FLOAT;
cl_int errcode;
#ifndef QCOM2
if (mobj["needs_load"].bool_value()) {
clbuf = clCreateImage(context, CL_MEM_COPY_HOST_PTR | CL_MEM_READ_WRITE, &format, &desc, &buf[ptr-sz], &errcode);
} else {
clbuf = clCreateImage(context, CL_MEM_READ_WRITE, &format, &desc, NULL, &errcode);
}
#else
clbuf = clCreateImage(context, CL_MEM_READ_WRITE, &format, &desc, NULL, &errcode);
#endif
if (clbuf == NULL) {
LOGE("clError: %s create image %zux%zu rp %zu with buffer %p\n", cl_get_error_string(errcode),
desc.image_width, desc.image_height, desc.image_row_pitch, desc.buffer);
}
assert(clbuf != NULL);
}
real_mem[*(cl_mem*)(mobj["id"].string_value().data())] = clbuf;
}
map<string, cl_program> g_programs;
for (const auto &[name, source] : jdat["programs"].object_items()) {
if (debug >= 1) printf("building %s with size %zu\n", name.c_str(), source.string_value().size());
g_programs[name] = cl_program_from_source(context, device_id, source.string_value());
}
for (auto &obj : jdat["inputs"].array_items()) {
auto mobj = obj.object_items();
int sz = mobj["size"].int_value();
cl_mem aa = real_mem[*(cl_mem*)(mobj["buffer_id"].string_value().data())];
input_clmem.push_back(aa);
input_sizes.push_back(sz);
LOGD("Thneed::load: adding input %s with size %d\n", mobj["name"].string_value().data(), sz);
cl_int cl_err;
void *ret = clEnqueueMapBuffer(command_queue, aa, CL_TRUE, CL_MAP_WRITE, 0, sz, 0, NULL, NULL, &cl_err);
if (cl_err != CL_SUCCESS) LOGE("clError: %s map %p %d\n", cl_get_error_string(cl_err), aa, sz);
assert(cl_err == CL_SUCCESS);
inputs.push_back(ret);
}
for (auto &obj : jdat["outputs"].array_items()) {
auto mobj = obj.object_items();
int sz = mobj["size"].int_value();
LOGD("Thneed::save: adding output with size %d\n", sz);
// TODO: support multiple outputs
output = real_mem[*(cl_mem*)(mobj["buffer_id"].string_value().data())];
assert(output != NULL);
}
for (auto &obj : jdat["binaries"].array_items()) {
string name = obj["name"].string_value();
size_t length = obj["length"].int_value();
if (debug >= 1) printf("binary %s with size %zu\n", name.c_str(), length);
g_programs[name] = cl_program_from_binary(context, device_id, (const uint8_t*)&buf[ptr], length);
ptr += length;
}
for (auto &obj : jdat["kernels"].array_items()) {
auto gws = obj["global_work_size"];
auto lws = obj["local_work_size"];
auto kk = shared_ptr<CLQueuedKernel>(new CLQueuedKernel(this));
kk->name = obj["name"].string_value();
kk->program = g_programs[kk->name];
kk->work_dim = obj["work_dim"].int_value();
for (int i = 0; i < kk->work_dim; i++) {
kk->global_work_size[i] = gws[i].int_value();
kk->local_work_size[i] = lws[i].int_value();
}
kk->num_args = obj["num_args"].int_value();
for (int i = 0; i < kk->num_args; i++) {
string arg = obj["args"].array_items()[i].string_value();
int arg_size = obj["args_size"].array_items()[i].int_value();
kk->args_size.push_back(arg_size);
if (arg_size == 8) {
cl_mem val = *(cl_mem*)(arg.data());
val = real_mem[val];
kk->args.push_back(string((char*)&val, sizeof(val)));
} else {
kk->args.push_back(arg);
}
}
kq.push_back(kk);
}
clFinish(command_queue);
}

@ -1,133 +0,0 @@
#pragma once
#ifndef __user
#define __user __attribute__(())
#endif
#include <cstdint>
#include <cstdlib>
#include <memory>
#include <string>
#include <vector>
#include <CL/cl.h>
#include "third_party/linux/include/msm_kgsl.h"
using namespace std;
cl_int thneed_clSetKernelArg(cl_kernel kernel, cl_uint arg_index, size_t arg_size, const void *arg_value);
namespace json11 {
class Json;
}
class Thneed;
class GPUMalloc {
public:
GPUMalloc(int size, int fd);
~GPUMalloc();
void *alloc(int size);
private:
uint64_t base;
int remaining;
};
class CLQueuedKernel {
public:
CLQueuedKernel(Thneed *lthneed) { thneed = lthneed; }
CLQueuedKernel(Thneed *lthneed,
cl_kernel _kernel,
cl_uint _work_dim,
const size_t *_global_work_size,
const size_t *_local_work_size);
cl_int exec();
void debug_print(bool verbose);
int get_arg_num(const char *search_arg_name);
cl_program program;
string name;
cl_uint num_args;
vector<string> arg_names;
vector<string> arg_types;
vector<string> args;
vector<int> args_size;
cl_kernel kernel = NULL;
json11::Json to_json() const;
cl_uint work_dim;
size_t global_work_size[3] = {0};
size_t local_work_size[3] = {0};
private:
Thneed *thneed;
};
class CachedIoctl {
public:
virtual void exec() {}
};
class CachedSync: public CachedIoctl {
public:
CachedSync(Thneed *lthneed, string ldata) { thneed = lthneed; data = ldata; }
void exec();
private:
Thneed *thneed;
string data;
};
class CachedCommand: public CachedIoctl {
public:
CachedCommand(Thneed *lthneed, struct kgsl_gpu_command *cmd);
void exec();
private:
void disassemble(int cmd_index);
struct kgsl_gpu_command cache;
unique_ptr<kgsl_command_object[]> cmds;
unique_ptr<kgsl_command_object[]> objs;
Thneed *thneed;
vector<shared_ptr<CLQueuedKernel> > kq;
};
class Thneed {
public:
Thneed(bool do_clinit=false, cl_context _context = NULL);
void stop();
void execute(float **finputs, float *foutput, bool slow=false);
void wait();
vector<cl_mem> input_clmem;
vector<void *> inputs;
vector<size_t> input_sizes;
cl_mem output = NULL;
cl_context context = NULL;
cl_command_queue command_queue;
cl_device_id device_id;
int context_id;
// protected?
bool record = false;
int debug;
int timestamp;
#ifdef QCOM2
unique_ptr<GPUMalloc> ram;
vector<unique_ptr<CachedIoctl> > cmds;
int fd;
#endif
// all CL kernels
void copy_inputs(float **finputs, bool internal=false);
void copy_output(float *foutput);
cl_int clexec();
vector<shared_ptr<CLQueuedKernel> > kq;
// pending CL kernels
vector<shared_ptr<CLQueuedKernel> > ckq;
// loading
void load(const char *filename);
private:
void clinit();
};

@ -1,216 +0,0 @@
#include "selfdrive/modeld/thneed/thneed.h"
#include <cassert>
#include <cstring>
#include <map>
#include "common/clutil.h"
#include "common/timing.h"
map<pair<cl_kernel, int>, string> g_args;
map<pair<cl_kernel, int>, int> g_args_size;
map<cl_program, string> g_program_source;
void Thneed::stop() {
//printf("Thneed::stop: recorded %lu commands\n", cmds.size());
record = false;
}
void Thneed::clinit() {
device_id = cl_get_device_id(CL_DEVICE_TYPE_DEFAULT);
if (context == NULL) context = CL_CHECK_ERR(clCreateContext(NULL, 1, &device_id, NULL, NULL, &err));
//cl_command_queue_properties props[3] = {CL_QUEUE_PROPERTIES, CL_QUEUE_PROFILING_ENABLE, 0};
cl_command_queue_properties props[3] = {CL_QUEUE_PROPERTIES, 0, 0};
command_queue = CL_CHECK_ERR(clCreateCommandQueueWithProperties(context, device_id, props, &err));
printf("Thneed::clinit done\n");
}
cl_int Thneed::clexec() {
if (debug >= 1) printf("Thneed::clexec: running %lu queued kernels\n", kq.size());
for (auto &k : kq) {
if (record) ckq.push_back(k);
cl_int ret = k->exec();
assert(ret == CL_SUCCESS);
}
return clFinish(command_queue);
}
void Thneed::copy_inputs(float **finputs, bool internal) {
for (int idx = 0; idx < inputs.size(); ++idx) {
if (debug >= 1) printf("copying %lu -- %p -> %p (cl %p)\n", input_sizes[idx], finputs[idx], inputs[idx], input_clmem[idx]);
if (internal) {
// if it's internal, using memcpy is fine since the buffer sync is cached in the ioctl layer
if (finputs[idx] != NULL) memcpy(inputs[idx], finputs[idx], input_sizes[idx]);
} else {
if (finputs[idx] != NULL) CL_CHECK(clEnqueueWriteBuffer(command_queue, input_clmem[idx], CL_TRUE, 0, input_sizes[idx], finputs[idx], 0, NULL, NULL));
}
}
}
void Thneed::copy_output(float *foutput) {
if (output != NULL) {
size_t sz;
clGetMemObjectInfo(output, CL_MEM_SIZE, sizeof(sz), &sz, NULL);
if (debug >= 1) printf("copying %lu for output %p -> %p\n", sz, output, foutput);
CL_CHECK(clEnqueueReadBuffer(command_queue, output, CL_TRUE, 0, sz, foutput, 0, NULL, NULL));
} else {
printf("CAUTION: model output is NULL, does it have no outputs?\n");
}
}
// *********** CLQueuedKernel ***********
CLQueuedKernel::CLQueuedKernel(Thneed *lthneed,
cl_kernel _kernel,
cl_uint _work_dim,
const size_t *_global_work_size,
const size_t *_local_work_size) {
thneed = lthneed;
kernel = _kernel;
work_dim = _work_dim;
assert(work_dim <= 3);
for (int i = 0; i < work_dim; i++) {
global_work_size[i] = _global_work_size[i];
local_work_size[i] = _local_work_size[i];
}
char _name[0x100];
clGetKernelInfo(kernel, CL_KERNEL_FUNCTION_NAME, sizeof(_name), _name, NULL);
name = string(_name);
clGetKernelInfo(kernel, CL_KERNEL_NUM_ARGS, sizeof(num_args), &num_args, NULL);
// get args
for (int i = 0; i < num_args; i++) {
char arg_name[0x100] = {0};
clGetKernelArgInfo(kernel, i, CL_KERNEL_ARG_NAME, sizeof(arg_name), arg_name, NULL);
arg_names.push_back(string(arg_name));
clGetKernelArgInfo(kernel, i, CL_KERNEL_ARG_TYPE_NAME, sizeof(arg_name), arg_name, NULL);
arg_types.push_back(string(arg_name));
args.push_back(g_args[make_pair(kernel, i)]);
args_size.push_back(g_args_size[make_pair(kernel, i)]);
}
// get program
clGetKernelInfo(kernel, CL_KERNEL_PROGRAM, sizeof(program), &program, NULL);
}
int CLQueuedKernel::get_arg_num(const char *search_arg_name) {
for (int i = 0; i < num_args; i++) {
if (arg_names[i] == search_arg_name) return i;
}
printf("failed to find %s in %s\n", search_arg_name, name.c_str());
assert(false);
}
cl_int CLQueuedKernel::exec() {
if (kernel == NULL) {
kernel = clCreateKernel(program, name.c_str(), NULL);
arg_names.clear();
arg_types.clear();
for (int j = 0; j < num_args; j++) {
char arg_name[0x100] = {0};
clGetKernelArgInfo(kernel, j, CL_KERNEL_ARG_NAME, sizeof(arg_name), arg_name, NULL);
arg_names.push_back(string(arg_name));
clGetKernelArgInfo(kernel, j, CL_KERNEL_ARG_TYPE_NAME, sizeof(arg_name), arg_name, NULL);
arg_types.push_back(string(arg_name));
cl_int ret;
if (args[j].size() != 0) {
assert(args[j].size() == args_size[j]);
ret = thneed_clSetKernelArg(kernel, j, args[j].size(), args[j].data());
} else {
ret = thneed_clSetKernelArg(kernel, j, args_size[j], NULL);
}
assert(ret == CL_SUCCESS);
}
}
if (thneed->debug >= 1) {
debug_print(thneed->debug >= 2);
}
return clEnqueueNDRangeKernel(thneed->command_queue,
kernel, work_dim, NULL, global_work_size, local_work_size, 0, NULL, NULL);
}
void CLQueuedKernel::debug_print(bool verbose) {
printf("%p %56s -- ", kernel, name.c_str());
for (int i = 0; i < work_dim; i++) {
printf("%4zu ", global_work_size[i]);
}
printf(" -- ");
for (int i = 0; i < work_dim; i++) {
printf("%4zu ", local_work_size[i]);
}
printf("\n");
if (verbose) {
for (int i = 0; i < num_args; i++) {
string arg = args[i];
printf(" %s %s", arg_types[i].c_str(), arg_names[i].c_str());
void *arg_value = (void*)arg.data();
int arg_size = arg.size();
if (arg_size == 0) {
printf(" (size) %d", args_size[i]);
} else if (arg_size == 1) {
printf(" = %d", *((char*)arg_value));
} else if (arg_size == 2) {
printf(" = %d", *((short*)arg_value));
} else if (arg_size == 4) {
if (arg_types[i] == "float") {
printf(" = %f", *((float*)arg_value));
} else {
printf(" = %d", *((int*)arg_value));
}
} else if (arg_size == 8) {
cl_mem val = (cl_mem)(*((uintptr_t*)arg_value));
printf(" = %p", val);
if (val != NULL) {
cl_mem_object_type obj_type;
clGetMemObjectInfo(val, CL_MEM_TYPE, sizeof(obj_type), &obj_type, NULL);
if (arg_types[i] == "image2d_t" || arg_types[i] == "image1d_t" || obj_type == CL_MEM_OBJECT_IMAGE2D) {
cl_image_format format;
size_t width, height, depth, array_size, row_pitch, slice_pitch;
cl_mem buf;
clGetImageInfo(val, CL_IMAGE_FORMAT, sizeof(format), &format, NULL);
assert(format.image_channel_order == CL_RGBA);
assert(format.image_channel_data_type == CL_HALF_FLOAT || format.image_channel_data_type == CL_FLOAT);
clGetImageInfo(val, CL_IMAGE_WIDTH, sizeof(width), &width, NULL);
clGetImageInfo(val, CL_IMAGE_HEIGHT, sizeof(height), &height, NULL);
clGetImageInfo(val, CL_IMAGE_ROW_PITCH, sizeof(row_pitch), &row_pitch, NULL);
clGetImageInfo(val, CL_IMAGE_DEPTH, sizeof(depth), &depth, NULL);
clGetImageInfo(val, CL_IMAGE_ARRAY_SIZE, sizeof(array_size), &array_size, NULL);
clGetImageInfo(val, CL_IMAGE_SLICE_PITCH, sizeof(slice_pitch), &slice_pitch, NULL);
assert(depth == 0);
assert(array_size == 0);
assert(slice_pitch == 0);
clGetImageInfo(val, CL_IMAGE_BUFFER, sizeof(buf), &buf, NULL);
size_t sz = 0;
if (buf != NULL) clGetMemObjectInfo(buf, CL_MEM_SIZE, sizeof(sz), &sz, NULL);
printf(" image %zu x %zu rp %zu @ %p buffer %zu", width, height, row_pitch, buf, sz);
} else {
size_t sz;
clGetMemObjectInfo(val, CL_MEM_SIZE, sizeof(sz), &sz, NULL);
printf(" buffer %zu", sz);
}
}
}
printf("\n");
}
}
}
cl_int thneed_clSetKernelArg(cl_kernel kernel, cl_uint arg_index, size_t arg_size, const void *arg_value) {
g_args_size[make_pair(kernel, arg_index)] = arg_size;
if (arg_value != NULL) {
g_args[make_pair(kernel, arg_index)] = string((char*)arg_value, arg_size);
} else {
g_args[make_pair(kernel, arg_index)] = string("");
}
cl_int ret = clSetKernelArg(kernel, arg_index, arg_size, arg_value);
return ret;
}

@ -1,32 +0,0 @@
#include "selfdrive/modeld/thneed/thneed.h"
#include <cassert>
#include "common/clutil.h"
#include "common/timing.h"
Thneed::Thneed(bool do_clinit, cl_context _context) {
context = _context;
if (do_clinit) clinit();
char *thneed_debug_env = getenv("THNEED_DEBUG");
debug = (thneed_debug_env != NULL) ? atoi(thneed_debug_env) : 0;
}
void Thneed::execute(float **finputs, float *foutput, bool slow) {
uint64_t tb, te;
if (debug >= 1) tb = nanos_since_boot();
// ****** copy inputs
copy_inputs(finputs);
// ****** run commands
clexec();
// ****** copy outputs
copy_output(foutput);
if (debug >= 1) {
te = nanos_since_boot();
printf("model exec in %lu us\n", (te-tb)/1000);
}
}

@ -1,258 +0,0 @@
#include "selfdrive/modeld/thneed/thneed.h"
#include <dlfcn.h>
#include <sys/mman.h>
#include <cassert>
#include <cerrno>
#include <cstring>
#include <map>
#include <string>
#include "common/clutil.h"
#include "common/timing.h"
Thneed *g_thneed = NULL;
int g_fd = -1;
void hexdump(uint8_t *d, int len) {
assert((len%4) == 0);
printf(" dumping %p len 0x%x\n", d, len);
for (int i = 0; i < len/4; i++) {
if (i != 0 && (i%0x10) == 0) printf("\n");
printf("%8x ", d[i]);
}
printf("\n");
}
// *********** ioctl interceptor ***********
extern "C" {
int (*my_ioctl)(int filedes, unsigned long request, void *argp) = NULL;
#undef ioctl
int ioctl(int filedes, unsigned long request, void *argp) {
request &= 0xFFFFFFFF; // needed on QCOM2
if (my_ioctl == NULL) my_ioctl = reinterpret_cast<decltype(my_ioctl)>(dlsym(RTLD_NEXT, "ioctl"));
Thneed *thneed = g_thneed;
// save the fd
if (request == IOCTL_KGSL_GPUOBJ_ALLOC) g_fd = filedes;
// note that this runs always, even without a thneed object
if (request == IOCTL_KGSL_DRAWCTXT_CREATE) {
struct kgsl_drawctxt_create *create = (struct kgsl_drawctxt_create *)argp;
create->flags &= ~KGSL_CONTEXT_PRIORITY_MASK;
create->flags |= 6 << KGSL_CONTEXT_PRIORITY_SHIFT; // priority from 1-15, 1 is max priority
printf("IOCTL_KGSL_DRAWCTXT_CREATE: creating context with flags 0x%x\n", create->flags);
}
if (thneed != NULL) {
if (request == IOCTL_KGSL_GPU_COMMAND) {
struct kgsl_gpu_command *cmd = (struct kgsl_gpu_command *)argp;
if (thneed->record) {
thneed->timestamp = cmd->timestamp;
thneed->context_id = cmd->context_id;
thneed->cmds.push_back(unique_ptr<CachedCommand>(new CachedCommand(thneed, cmd)));
}
if (thneed->debug >= 1) {
printf("IOCTL_KGSL_GPU_COMMAND(%2zu): flags: 0x%lx context_id: %u timestamp: %u numcmds: %d numobjs: %d\n",
thneed->cmds.size(),
cmd->flags,
cmd->context_id, cmd->timestamp, cmd->numcmds, cmd->numobjs);
}
} else if (request == IOCTL_KGSL_GPUOBJ_SYNC) {
struct kgsl_gpuobj_sync *cmd = (struct kgsl_gpuobj_sync *)argp;
struct kgsl_gpuobj_sync_obj *objs = (struct kgsl_gpuobj_sync_obj *)(cmd->objs);
if (thneed->debug >= 2) {
printf("IOCTL_KGSL_GPUOBJ_SYNC count:%d ", cmd->count);
for (int i = 0; i < cmd->count; i++) {
printf(" -- offset:0x%lx len:0x%lx id:%d op:%d ", objs[i].offset, objs[i].length, objs[i].id, objs[i].op);
}
printf("\n");
}
if (thneed->record) {
thneed->cmds.push_back(unique_ptr<CachedSync>(new
CachedSync(thneed, string((char *)objs, sizeof(struct kgsl_gpuobj_sync_obj)*cmd->count))));
}
} else if (request == IOCTL_KGSL_DEVICE_WAITTIMESTAMP_CTXTID) {
struct kgsl_device_waittimestamp_ctxtid *cmd = (struct kgsl_device_waittimestamp_ctxtid *)argp;
if (thneed->debug >= 1) {
printf("IOCTL_KGSL_DEVICE_WAITTIMESTAMP_CTXTID: context_id: %d timestamp: %d timeout: %d\n",
cmd->context_id, cmd->timestamp, cmd->timeout);
}
} else if (request == IOCTL_KGSL_SETPROPERTY) {
if (thneed->debug >= 1) {
struct kgsl_device_getproperty *prop = (struct kgsl_device_getproperty *)argp;
printf("IOCTL_KGSL_SETPROPERTY: 0x%x sizebytes:%zu\n", prop->type, prop->sizebytes);
if (thneed->debug >= 2) {
hexdump((uint8_t *)prop->value, prop->sizebytes);
if (prop->type == KGSL_PROP_PWR_CONSTRAINT) {
struct kgsl_device_constraint *constraint = (struct kgsl_device_constraint *)prop->value;
hexdump((uint8_t *)constraint->data, constraint->size);
}
}
}
} else if (request == IOCTL_KGSL_DRAWCTXT_CREATE || request == IOCTL_KGSL_DRAWCTXT_DESTROY) {
// this happens
} else if (request == IOCTL_KGSL_GPUOBJ_ALLOC || request == IOCTL_KGSL_GPUOBJ_FREE) {
// this happens
} else {
if (thneed->debug >= 1) {
printf("other ioctl %lx\n", request);
}
}
}
int ret = my_ioctl(filedes, request, argp);
// NOTE: This error message goes into stdout and messes up pyenv
// if (ret != 0) printf("ioctl returned %d with errno %d\n", ret, errno);
return ret;
}
}
// *********** GPUMalloc ***********
GPUMalloc::GPUMalloc(int size, int fd) {
struct kgsl_gpuobj_alloc alloc;
memset(&alloc, 0, sizeof(alloc));
alloc.size = size;
alloc.flags = 0x10000a00;
ioctl(fd, IOCTL_KGSL_GPUOBJ_ALLOC, &alloc);
void *addr = mmap64(NULL, alloc.mmapsize, 0x3, 0x1, fd, alloc.id*0x1000);
assert(addr != MAP_FAILED);
base = (uint64_t)addr;
remaining = size;
}
GPUMalloc::~GPUMalloc() {
// TODO: free the GPU malloced area
}
void *GPUMalloc::alloc(int size) {
void *ret = (void*)base;
size = (size+0xff) & (~0xFF);
assert(size <= remaining);
remaining -= size;
base += size;
return ret;
}
// *********** CachedSync, at the ioctl layer ***********
void CachedSync::exec() {
struct kgsl_gpuobj_sync cmd;
cmd.objs = (uint64_t)data.data();
cmd.obj_len = data.length();
cmd.count = data.length() / sizeof(struct kgsl_gpuobj_sync_obj);
int ret = ioctl(thneed->fd, IOCTL_KGSL_GPUOBJ_SYNC, &cmd);
assert(ret == 0);
}
// *********** CachedCommand, at the ioctl layer ***********
CachedCommand::CachedCommand(Thneed *lthneed, struct kgsl_gpu_command *cmd) {
thneed = lthneed;
assert(cmd->numsyncs == 0);
memcpy(&cache, cmd, sizeof(cache));
if (cmd->numcmds > 0) {
cmds = make_unique<struct kgsl_command_object[]>(cmd->numcmds);
memcpy(cmds.get(), (void *)cmd->cmdlist, sizeof(struct kgsl_command_object)*cmd->numcmds);
cache.cmdlist = (uint64_t)cmds.get();
for (int i = 0; i < cmd->numcmds; i++) {
void *nn = thneed->ram->alloc(cmds[i].size);
memcpy(nn, (void*)cmds[i].gpuaddr, cmds[i].size);
cmds[i].gpuaddr = (uint64_t)nn;
}
}
if (cmd->numobjs > 0) {
objs = make_unique<struct kgsl_command_object[]>(cmd->numobjs);
memcpy(objs.get(), (void *)cmd->objlist, sizeof(struct kgsl_command_object)*cmd->numobjs);
cache.objlist = (uint64_t)objs.get();
for (int i = 0; i < cmd->numobjs; i++) {
void *nn = thneed->ram->alloc(objs[i].size);
memset(nn, 0, objs[i].size);
objs[i].gpuaddr = (uint64_t)nn;
}
}
kq = thneed->ckq;
thneed->ckq.clear();
}
void CachedCommand::exec() {
cache.timestamp = ++thneed->timestamp;
int ret = ioctl(thneed->fd, IOCTL_KGSL_GPU_COMMAND, &cache);
if (thneed->debug >= 1) printf("CachedCommand::exec got %d\n", ret);
if (thneed->debug >= 2) {
for (auto &it : kq) {
it->debug_print(false);
}
}
assert(ret == 0);
}
// *********** Thneed ***********
Thneed::Thneed(bool do_clinit, cl_context _context) {
// TODO: QCOM2 actually requires a different context
//context = _context;
if (do_clinit) clinit();
assert(g_fd != -1);
fd = g_fd;
ram = make_unique<GPUMalloc>(0x80000, fd);
timestamp = -1;
g_thneed = this;
char *thneed_debug_env = getenv("THNEED_DEBUG");
debug = (thneed_debug_env != NULL) ? atoi(thneed_debug_env) : 0;
}
void Thneed::wait() {
struct kgsl_device_waittimestamp_ctxtid wait;
wait.context_id = context_id;
wait.timestamp = timestamp;
wait.timeout = -1;
uint64_t tb = nanos_since_boot();
int wret = ioctl(fd, IOCTL_KGSL_DEVICE_WAITTIMESTAMP_CTXTID, &wait);
uint64_t te = nanos_since_boot();
if (debug >= 1) printf("wait %d after %lu us\n", wret, (te-tb)/1000);
}
void Thneed::execute(float **finputs, float *foutput, bool slow) {
uint64_t tb, te;
if (debug >= 1) tb = nanos_since_boot();
// ****** copy inputs
copy_inputs(finputs, true);
// ****** run commands
int i = 0;
for (auto &it : cmds) {
++i;
if (debug >= 1) printf("run %2d @ %7lu us: ", i, (nanos_since_boot()-tb)/1000);
it->exec();
if ((i == cmds.size()) || slow) wait();
}
// ****** copy outputs
copy_output(foutput);
if (debug >= 1) {
te = nanos_since_boot();
printf("model exec in %lu us\n", (te-tb)/1000);
}
}

@ -36,7 +36,7 @@ CPU usage budget
TEST_DURATION = 25
LOG_OFFSET = 8
MAX_TOTAL_CPU = 265. # total for all 8 cores
MAX_TOTAL_CPU = 275. # total for all 8 cores
PROCS = {
# Baseline CPU usage by process
"selfdrive.controls.controlsd": 16.0,
@ -50,8 +50,8 @@ PROCS = {
"selfdrive.locationd.paramsd": 9.0,
"./sensord": 7.0,
"selfdrive.controls.radard": 2.0,
"selfdrive.modeld.modeld": 17.0,
"selfdrive.modeld.dmonitoringmodeld": 11.0,
"selfdrive.modeld.modeld": 22.0,
"selfdrive.modeld.dmonitoringmodeld": 21.0,
"system.hardware.hardwared": 4.0,
"selfdrive.locationd.calibrationd": 2.0,
"selfdrive.locationd.torqued": 5.0,
@ -371,10 +371,9 @@ class TestOnroad:
result += "------------------------------------------------\n"
result += "----------------- Model Timing -----------------\n"
result += "------------------------------------------------\n"
# TODO: this went up when plannerd cpu usage increased, why?
cfgs = [
("modelV2", 0.050, 0.036),
("driverStateV2", 0.050, 0.026),
("modelV2", 0.045, 0.035),
("driverStateV2", 0.045, 0.035),
]
for (s, instant_max, avg_max) in cfgs:
ts = [getattr(m, s).modelExecutionTime for m in self.msgs[s]]

@ -33,7 +33,7 @@ class Proc:
PROCS = [
Proc(['camerad'], 1.75, msgs=['roadCameraState', 'wideRoadCameraState', 'driverCameraState']),
Proc(['modeld'], 1.12, atol=0.2, msgs=['modelV2']),
Proc(['dmonitoringmodeld'], 0.65, msgs=['driverStateV2']),
Proc(['dmonitoringmodeld'], 0.6, msgs=['driverStateV2']),
Proc(['encoderd'], 0.23, msgs=[]),
]

@ -70,10 +70,12 @@ procs = [
PythonProcess("micd", "system.micd", iscar),
PythonProcess("timed", "system.timed", always_run, enabled=not PC),
# TODO Make python process once TG allows opening QCOM from child proc
NativeProcess("dmonitoringmodeld", "selfdrive/modeld", ["./dmonitoringmodeld"], driverview, enabled=(not PC or WEBCAM)),
NativeProcess("encoderd", "system/loggerd", ["./encoderd"], only_onroad),
NativeProcess("stream_encoderd", "system/loggerd", ["./encoderd", "--stream"], notcar),
NativeProcess("loggerd", "system/loggerd", ["./loggerd"], logging),
# TODO Make python process once TG allows opening QCOM from child proc
NativeProcess("modeld", "selfdrive/modeld", ["./modeld"], only_onroad),
NativeProcess("sensord", "system/sensord", ["./sensord"], only_onroad, enabled=not PC),
NativeProcess("ui", "selfdrive/ui", ["./ui"], always_run, watchdog_max_dt=(5 if not PC else None)),

@ -1 +1 @@
Subproject commit 9dda6d260db0255750bacff61e3cee1e580567e1
Subproject commit 480e5e7a1292bf2f84e18edffd06a985c4b48e65

@ -3,10 +3,10 @@ requires-python = ">=3.11, <=3.12"
resolution-markers = [
"python_full_version < '3.12' and platform_system == 'Darwin'",
"python_full_version < '3.12' and platform_machine == 'aarch64' and platform_system == 'Linux'",
"(python_full_version < '3.12' and platform_machine != 'aarch64' and platform_system == 'Linux') or (python_full_version < '3.12' and platform_system != 'Darwin' and platform_system != 'Linux')",
"(python_full_version < '3.12' and platform_machine != 'aarch64' and platform_system != 'Darwin') or (python_full_version < '3.12' and platform_system != 'Darwin' and platform_system != 'Linux')",
"python_full_version >= '3.12' and platform_system == 'Darwin'",
"python_full_version >= '3.12' and platform_machine == 'aarch64' and platform_system == 'Linux'",
"(python_full_version >= '3.12' and platform_machine != 'aarch64' and platform_system == 'Linux') or (python_full_version >= '3.12' and platform_system != 'Darwin' and platform_system != 'Linux')",
"(python_full_version >= '3.12' and platform_machine != 'aarch64' and platform_system != 'Darwin') or (python_full_version >= '3.12' and platform_system != 'Darwin' and platform_system != 'Linux')",
]
[[package]]
@ -501,7 +501,7 @@ name = "ewmhlib"
version = "0.2"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "python-xlib", marker = "platform_system != 'Darwin' and sys_platform == 'linux'" },
{ name = "python-xlib", marker = "sys_platform == 'linux'" },
{ name = "typing-extensions" },
]
wheels = [
@ -666,6 +666,9 @@ wheels = [
name = "humanfriendly"
version = "10.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "pyreadline3", marker = "sys_platform == 'win32'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/cc/3f/2c29224acb2e2df4d2046e4c73ee2662023c58ff5b113c4c1adac0886c43/humanfriendly-10.0.tar.gz", hash = "sha256:6b0b831ce8f15f7300721aa49829fc4e83921a9a301cc7f606be6686a2288ddc", size = 360702 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/f0/0f/310fb31e39e2d734ccaa2c0fb981ee41f7bd5056ce9bc29b2248bd569169/humanfriendly-10.0-py2.py3-none-any.whl", hash = "sha256:1697e1a8a8f550fd43c2865cd84542fc175a61dcb779b6fee18cf6b6ccba1477", size = 86794 },
@ -1231,27 +1234,16 @@ dependencies = [
{ name = "sympy" },
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/95/8d/2634e2959b34aa8a0037989f4229e9abcfa484e9c228f99633b3241768a6/onnxruntime-1.20.1-cp311-cp311-macosx_13_0_universal2.whl", hash = "sha256:06bfbf02ca9ab5f28946e0f912a562a5f005301d0c419283dc57b3ed7969bb7b", size = 30998725 },
{ url = "https://files.pythonhosted.org/packages/a5/da/c44bf9bd66cd6d9018a921f053f28d819445c4d84b4dd4777271b0fe52a2/onnxruntime-1.20.1-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:f6243e34d74423bdd1edf0ae9596dd61023b260f546ee17d701723915f06a9f7", size = 11955227 },
{ url = "https://files.pythonhosted.org/packages/11/ac/4120dfb74c8e45cce1c664fc7f7ce010edd587ba67ac41489f7432eb9381/onnxruntime-1.20.1-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:5eec64c0269dcdb8d9a9a53dc4d64f87b9e0c19801d9321246a53b7eb5a7d1bc", size = 13331703 },
{ url = "https://files.pythonhosted.org/packages/12/f1/cefacac137f7bb7bfba57c50c478150fcd3c54aca72762ac2c05ce0532c1/onnxruntime-1.20.1-cp311-cp311-win32.whl", hash = "sha256:a19bc6e8c70e2485a1725b3d517a2319603acc14c1f1a017dda0afe6d4665b41", size = 9813977 },
{ url = "https://files.pythonhosted.org/packages/2c/2d/2d4d202c0bcfb3a4cc2b171abb9328672d7f91d7af9ea52572722c6d8d96/onnxruntime-1.20.1-cp311-cp311-win_amd64.whl", hash = "sha256:8508887eb1c5f9537a4071768723ec7c30c28eb2518a00d0adcd32c89dea3221", size = 11329895 },
{ url = "https://files.pythonhosted.org/packages/e5/39/9335e0874f68f7d27103cbffc0e235e32e26759202df6085716375c078bb/onnxruntime-1.20.1-cp312-cp312-macosx_13_0_universal2.whl", hash = "sha256:22b0655e2bf4f2161d52706e31f517a0e54939dc393e92577df51808a7edc8c9", size = 31007580 },
{ url = "https://files.pythonhosted.org/packages/c5/9d/a42a84e10f1744dd27c6f2f9280cc3fb98f869dd19b7cd042e391ee2ab61/onnxruntime-1.20.1-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:f1f56e898815963d6dc4ee1c35fc6c36506466eff6d16f3cb9848cea4e8c8172", size = 11952833 },
{ url = "https://files.pythonhosted.org/packages/47/42/2f71f5680834688a9c81becbe5c5bb996fd33eaed5c66ae0606c3b1d6a02/onnxruntime-1.20.1-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:bb71a814f66517a65628c9e4a2bb530a6edd2cd5d87ffa0af0f6f773a027d99e", size = 13333903 },
]
[[package]]
name = "onnxruntime-gpu"
version = "1.20.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "coloredlogs" },
{ name = "flatbuffers" },
{ name = "numpy" },
{ name = "packaging" },
{ name = "protobuf" },
{ name = "sympy" },
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/e0/a5/5c2287d61f359c7342e9d59d1e3dd728a982dea85f846c7af305a801c3ca/onnxruntime_gpu-1.20.1-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:1795e8bc6f9a1488a4d51d242edc4232a5ae60ec44ab4d4b0a7c65b3d17fcbff", size = 291519550 },
{ url = "https://files.pythonhosted.org/packages/91/a8/6984a2fb070be372a866108e3e85c9eb6e8f0378a8567a66967d80befb75/onnxruntime_gpu-1.20.1-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:1951f96cd534c6151721e552606d0d792ea6a4c3e57e2f10eed17cca8105e953", size = 291510989 },
{ url = "https://files.pythonhosted.org/packages/c8/f1/aabfdf91d013320aa2fc46cf43c88ca0182860ff15df872b4552254a9680/onnxruntime-1.20.1-cp312-cp312-win32.whl", hash = "sha256:bd386cc9ee5f686ee8a75ba74037750aca55183085bf1941da8efcfe12d5b120", size = 9814562 },
{ url = "https://files.pythonhosted.org/packages/dd/80/76979e0b744307d488c79e41051117634b956612cc731f1028eb17ee7294/onnxruntime-1.20.1-cp312-cp312-win_amd64.whl", hash = "sha256:19c2d843eb074f385e8bbb753a40df780511061a63f9def1b216bf53860223fb", size = 11331482 },
]
[[package]]
@ -1288,8 +1280,7 @@ dependencies = [
{ name = "libusb1" },
{ name = "numpy" },
{ name = "onnx" },
{ name = "onnxruntime", marker = "platform_machine == 'aarch64' and platform_system == 'Linux'" },
{ name = "onnxruntime-gpu", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
{ name = "onnxruntime" },
{ name = "psutil" },
{ name = "pyaudio" },
{ name = "pycapnp" },
@ -1387,8 +1378,7 @@ requires-dist = [
{ name = "natsort", marker = "extra == 'docs'" },
{ name = "numpy", specifier = "<2.0.0" },
{ name = "onnx", specifier = ">=1.14.0" },
{ name = "onnxruntime", marker = "platform_machine == 'aarch64' and platform_system == 'Linux'", specifier = ">=1.16.3" },
{ name = "onnxruntime-gpu", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'", specifier = ">=1.16.3" },
{ name = "onnxruntime", specifier = ">=1.16.3" },
{ name = "parameterized", marker = "extra == 'dev'", specifier = ">=0.8,<0.9" },
{ name = "pre-commit-hooks", marker = "extra == 'testing'" },
{ name = "psutil" },
@ -1835,10 +1825,10 @@ name = "pymonctl"
version = "0.92"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "ewmhlib", marker = "platform_system != 'Darwin' and sys_platform == 'linux'" },
{ name = "pyobjc", marker = "(platform_machine != 'aarch64' and sys_platform == 'darwin') or (platform_system != 'Linux' and sys_platform == 'darwin')" },
{ name = "python-xlib", marker = "platform_system != 'Darwin' and sys_platform == 'linux'" },
{ name = "pywin32", marker = "(platform_machine != 'aarch64' and platform_system == 'Linux' and sys_platform == 'win32') or (platform_system != 'Darwin' and platform_system != 'Linux' and sys_platform == 'win32')" },
{ name = "ewmhlib", marker = "sys_platform == 'linux'" },
{ name = "pyobjc", marker = "sys_platform == 'darwin'" },
{ name = "python-xlib", marker = "sys_platform == 'linux'" },
{ name = "pywin32", marker = "sys_platform == 'win32'" },
{ name = "typing-extensions" },
]
wheels = [
@ -4431,6 +4421,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/a7/33/d91e003b85ff7ab227d0fff236d48c18ada2f0cd49d5e35cb514867ba609/PyQt5_sip-12.16.1-cp312-cp312-win_amd64.whl", hash = "sha256:a0f83f554727f43dfe92afbf3a8c51e83bb8b78c5f160b635d4359fad681cebe", size = 57957 },
]
[[package]]
name = "pyreadline3"
version = "3.5.4"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/0f/49/4cea918a08f02817aabae639e3d0ac046fef9f9180518a3ad394e22da148/pyreadline3-3.5.4.tar.gz", hash = "sha256:8d57d53039a1c75adba8e50dd3d992b28143480816187ea5efbd5c78e6c885b7", size = 99839 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/5a/dc/491b7661614ab97483abf2056be1deee4dc2490ecbf7bff9ab5cdbac86e1/pyreadline3-3.5.4-py3-none-any.whl", hash = "sha256:eaf8e6cc3c49bcccf145fc6067ba8643d1df34d604a1ec0eccbf7a18e6d3fae6", size = 83178 },
]
[[package]]
name = "pyrect"
version = "0.2.0"
@ -4460,7 +4459,7 @@ name = "pytest"
version = "8.3.4"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "colorama", marker = "(platform_machine != 'aarch64' and platform_system == 'Linux' and sys_platform == 'win32') or (platform_system != 'Darwin' and platform_system != 'Linux' and sys_platform == 'win32')" },
{ name = "colorama", marker = "sys_platform == 'win32'" },
{ name = "iniconfig" },
{ name = "packaging" },
{ name = "pluggy" },
@ -4649,10 +4648,10 @@ name = "pywinbox"
version = "0.7"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "ewmhlib", marker = "platform_system != 'Darwin' and sys_platform == 'linux'" },
{ name = "pyobjc", marker = "(platform_machine != 'aarch64' and sys_platform == 'darwin') or (platform_system != 'Linux' and sys_platform == 'darwin')" },
{ name = "python-xlib", marker = "platform_system != 'Darwin' and sys_platform == 'linux'" },
{ name = "pywin32", marker = "(platform_machine != 'aarch64' and platform_system == 'Linux' and sys_platform == 'win32') or (platform_system != 'Darwin' and platform_system != 'Linux' and sys_platform == 'win32')" },
{ name = "ewmhlib", marker = "sys_platform == 'linux'" },
{ name = "pyobjc", marker = "sys_platform == 'darwin'" },
{ name = "python-xlib", marker = "sys_platform == 'linux'" },
{ name = "pywin32", marker = "sys_platform == 'win32'" },
{ name = "typing-extensions" },
]
wheels = [
@ -4664,11 +4663,11 @@ name = "pywinctl"
version = "0.4.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "ewmhlib", marker = "platform_system != 'Darwin' and sys_platform == 'linux'" },
{ name = "ewmhlib", marker = "sys_platform == 'linux'" },
{ name = "pymonctl" },
{ name = "pyobjc", marker = "(platform_machine != 'aarch64' and sys_platform == 'darwin') or (platform_system != 'Linux' and sys_platform == 'darwin')" },
{ name = "python-xlib", marker = "platform_system != 'Darwin' and sys_platform == 'linux'" },
{ name = "pywin32", marker = "(platform_machine != 'aarch64' and platform_system == 'Linux' and sys_platform == 'win32') or (platform_system != 'Darwin' and platform_system != 'Linux' and sys_platform == 'win32')" },
{ name = "pyobjc", marker = "sys_platform == 'darwin'" },
{ name = "python-xlib", marker = "sys_platform == 'linux'" },
{ name = "pywin32", marker = "sys_platform == 'win32'" },
{ name = "pywinbox" },
{ name = "typing-extensions" },
]

Loading…
Cancel
Save