From 5732002296c0592fce64fe191c6c983cbe30c5ff Mon Sep 17 00:00:00 2001 From: Mitchell Goff Date: Thu, 7 Sep 2023 19:46:43 -0700 Subject: [PATCH] Rewrite dmonitoringmodeld in python (#29740) * Added dmonitoringmodeld.py * Removed dmonitoringmodeld.cc * Use ModelRunner helpers from runners/__init__.py * Fixed DriverStateResult field ordering * Some bug fixes * Set calib input * Look ma, no loop! * Bump dmonitoringmodeld cpu usage in test_onroad * Fixed memory leak caused by np.ctypes.data_as * Formatting fixes * chmod +x * remove USE_ONNX_MODEL * Realtime priority 1, formatting fixes old-commit-hash: 503fa121ee83469763e9386b9e5df7a733693e2c --- release/files_common | 5 +- selfdrive/manager/process_config.py | 2 +- selfdrive/modeld/SConscript | 15 --- selfdrive/modeld/dmonitoringmodeld | 12 -- selfdrive/modeld/dmonitoringmodeld.cc | 69 ----------- selfdrive/modeld/dmonitoringmodeld.py | 157 +++++++++++++++++++++++++ selfdrive/modeld/models/dmonitoring.cc | 133 --------------------- selfdrive/modeld/models/dmonitoring.h | 50 -------- selfdrive/modeld/runners/run.h | 4 - selfdrive/test/test_onroad.py | 2 +- 10 files changed, 160 insertions(+), 289 deletions(-) delete mode 100755 selfdrive/modeld/dmonitoringmodeld delete mode 100644 selfdrive/modeld/dmonitoringmodeld.cc create mode 100755 selfdrive/modeld/dmonitoringmodeld.py delete mode 100644 selfdrive/modeld/models/dmonitoring.cc delete mode 100644 selfdrive/modeld/models/dmonitoring.h diff --git a/release/files_common b/release/files_common index 028b233d8b..2d0b5f0514 100644 --- a/release/files_common +++ b/release/files_common @@ -362,10 +362,9 @@ selfdrive/modeld/__init__.py selfdrive/modeld/SConscript selfdrive/modeld/modeld.py selfdrive/modeld/navmodeld.py -selfdrive/modeld/dmonitoringmodeld.cc +selfdrive/modeld/dmonitoringmodeld.py selfdrive/modeld/constants.py selfdrive/modeld/modeld -selfdrive/modeld/dmonitoringmodeld selfdrive/modeld/models/__init__.py selfdrive/modeld/models/*.pxd @@ -378,8 +377,6 @@ selfdrive/modeld/models/driving.cc selfdrive/modeld/models/driving.h selfdrive/modeld/models/supercombo.onnx -selfdrive/modeld/models/dmonitoring.cc -selfdrive/modeld/models/dmonitoring.h selfdrive/modeld/models/dmonitoring_model_q.dlc selfdrive/modeld/models/navmodel_q.dlc diff --git a/selfdrive/manager/process_config.py b/selfdrive/manager/process_config.py index fd57cb9a8d..aee4748a26 100644 --- a/selfdrive/manager/process_config.py +++ b/selfdrive/manager/process_config.py @@ -52,7 +52,7 @@ procs = [ PythonProcess("micd", "system.micd", iscar), PythonProcess("timezoned", "system.timezoned", always_run, enabled=not PC), - NativeProcess("dmonitoringmodeld", "selfdrive/modeld", ["./dmonitoringmodeld"], driverview, enabled=(not PC or WEBCAM)), + PythonProcess("dmonitoringmodeld", "selfdrive.modeld.dmonitoringmodeld", driverview, enabled=(not PC or WEBCAM)), NativeProcess("encoderd", "system/loggerd", ["./encoderd"], only_onroad), NativeProcess("stream_encoderd", "system/loggerd", ["./encoderd", "--stream"], notcar), NativeProcess("loggerd", "system/loggerd", ["./loggerd"], logging), diff --git a/selfdrive/modeld/SConscript b/selfdrive/modeld/SConscript index 1c4d46795e..428fe7d0bf 100644 --- a/selfdrive/modeld/SConscript +++ b/selfdrive/modeld/SConscript @@ -32,12 +32,6 @@ if arch == "Darwin": else: libs += ['OpenCL'] -# Use onnx on PC -if arch != "larch64" and not GetOption('snpe'): - common_src += ['runners/onnxmodel.cc'] - lenv['CFLAGS'].append("-DUSE_ONNX_MODEL") - lenv['CXXFLAGS'].append("-DUSE_ONNX_MODEL") - # Set path definitions for pathdef, fn in {'TRANSFORM': 'transforms/transform.cl', 'LOADYUV': 'transforms/loadyuv.cl', 'ONNXRUNNER': 'runners/onnx_runner.py'}.items(): for xenv in (lenv, lenvCython): @@ -60,15 +54,6 @@ lenvCython.Program('runners/snpemodel_pyx.so', 'runners/snpemodel_pyx.pyx', LIBS lenvCython.Program('models/commonmodel_pyx.so', 'models/commonmodel_pyx.pyx', LIBS=[commonmodel_lib, *cython_libs], FRAMEWORKS=frameworks) lenvCython.Program('models/driving_pyx.so', 'models/driving_pyx.pyx', LIBS=[driving_lib, commonmodel_lib, *cython_libs], FRAMEWORKS=frameworks) -# Compile binaries -lenv['FRAMEWORKS'] = frameworks -common_model = lenv.Object(common_src) - -lenv.Program('_dmonitoringmodeld', [ - "dmonitoringmodeld.cc", - "models/dmonitoring.cc", - ]+common_model, LIBS=libs + snpe_lib) - # Build thneed model if arch == "larch64" or GetOption('pc_thneed'): fn = File("models/supercombo").abspath diff --git a/selfdrive/modeld/dmonitoringmodeld b/selfdrive/modeld/dmonitoringmodeld deleted file mode 100755 index fc007470b2..0000000000 --- a/selfdrive/modeld/dmonitoringmodeld +++ /dev/null @@ -1,12 +0,0 @@ -#!/usr/bin/env bash - -DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" >/dev/null && pwd)" -cd $DIR - -if [ -f /TICI ]; then - export LD_LIBRARY_PATH="/usr/lib/aarch64-linux-gnu:/data/pythonpath/third_party/snpe/larch64:$LD_LIBRARY_PATH" - export ADSP_LIBRARY_PATH="/data/pythonpath/third_party/snpe/dsp/" -else - export LD_LIBRARY_PATH="$DIR/../../third_party/snpe/x86_64-linux-clang:$DIR/../../openpilot/third_party/snpe/x86_64:$LD_LIBRARY_PATH" -fi -exec ./_dmonitoringmodeld diff --git a/selfdrive/modeld/dmonitoringmodeld.cc b/selfdrive/modeld/dmonitoringmodeld.cc deleted file mode 100644 index 2890b44fe0..0000000000 --- a/selfdrive/modeld/dmonitoringmodeld.cc +++ /dev/null @@ -1,69 +0,0 @@ -#include -#include - -#include -#include - -#include "cereal/visionipc/visionipc_client.h" -#include "common/params.h" -#include "common/swaglog.h" -#include "common/util.h" -#include "selfdrive/modeld/models/dmonitoring.h" - -ExitHandler do_exit; - -void run_model(DMonitoringModelState &model, VisionIpcClient &vipc_client) { - PubMaster pm({"driverStateV2"}); - SubMaster sm({"liveCalibration"}); - float calib[CALIB_LEN] = {0}; - // double last = 0; - - while (!do_exit) { - VisionIpcBufExtra extra = {}; - VisionBuf *buf = vipc_client.recv(&extra); - if (buf == nullptr) continue; - - sm.update(0); - if (sm.updated("liveCalibration")) { - auto calib_msg = sm["liveCalibration"].getLiveCalibration().getRpyCalib(); - for (int i = 0; i < CALIB_LEN; i++) { - calib[i] = calib_msg[i]; - } - } - - double t1 = millis_since_boot(); - DMonitoringModelResult model_res = dmonitoring_eval_frame(&model, buf->addr, buf->width, buf->height, buf->stride, buf->uv_offset, calib); - double t2 = millis_since_boot(); - - // send dm packet - dmonitoring_publish(pm, extra.frame_id, model_res, (t2 - t1) / 1000.0, model.output); - - // printf("dmonitoring process: %.2fms, from last %.2fms\n", t2 - t1, t1 - last); - // last = t1; - } -} - -int main(int argc, char **argv) { - setpriority(PRIO_PROCESS, 0, -15); - - // init the models - DMonitoringModelState model; - dmonitoring_init(&model); - - Params().putBool("DmModelInitialized", true); - - LOGW("connecting to driver stream"); - VisionIpcClient vipc_client = VisionIpcClient("camerad", VISION_STREAM_DRIVER, true); - while (!do_exit && !vipc_client.connect(false)) { - util::sleep_for(100); - } - - // run the models - if (vipc_client.connected) { - LOGW("connected with buffer size: %zu", vipc_client.buffers[0].len); - run_model(model, vipc_client); - } - - dmonitoring_free(&model); - return 0; -} diff --git a/selfdrive/modeld/dmonitoringmodeld.py b/selfdrive/modeld/dmonitoringmodeld.py new file mode 100755 index 0000000000..0f9669258c --- /dev/null +++ b/selfdrive/modeld/dmonitoringmodeld.py @@ -0,0 +1,157 @@ +#!/usr/bin/env python3 +import os +import gc +import math +import time +import ctypes +import numpy as np +from pathlib import Path +from typing import Tuple, Dict + +from cereal import messaging +from cereal.messaging import PubMaster, SubMaster +from cereal.visionipc import VisionIpcClient, VisionStreamType, VisionBuf +from openpilot.system.swaglog import cloudlog +from openpilot.common.params import Params +from openpilot.common.realtime import set_realtime_priority +from openpilot.selfdrive.modeld.runners import ModelRunner, Runtime +from openpilot.selfdrive.modeld.models.commonmodel_pyx import sigmoid + +CALIB_LEN = 3 +REG_SCALE = 0.25 +MODEL_WIDTH = 1440 +MODEL_HEIGHT = 960 +OUTPUT_SIZE = 84 +SEND_RAW_PRED = os.getenv('SEND_RAW_PRED') +MODEL_PATHS = { + ModelRunner.SNPE: Path(__file__).parent / 'models/dmonitoring_model_q.dlc', + ModelRunner.ONNX: Path(__file__).parent / 'models/dmonitoring_model.onnx'} + +class DriverStateResult(ctypes.Structure): + _fields_ = [ + ("face_orientation", ctypes.c_float*3), + ("face_position", ctypes.c_float*3), + ("face_orientation_std", ctypes.c_float*3), + ("face_position_std", ctypes.c_float*3), + ("face_prob", ctypes.c_float), + ("_unused_a", ctypes.c_float*8), + ("left_eye_prob", ctypes.c_float), + ("_unused_b", ctypes.c_float*8), + ("right_eye_prob", ctypes.c_float), + ("left_blink_prob", ctypes.c_float), + ("right_blink_prob", ctypes.c_float), + ("sunglasses_prob", ctypes.c_float), + ("occluded_prob", ctypes.c_float), + ("ready_prob", ctypes.c_float*4), + ("not_ready_prob", ctypes.c_float*2)] + +class DMonitoringModelResult(ctypes.Structure): + _fields_ = [ + ("driver_state_lhd", DriverStateResult), + ("driver_state_rhd", DriverStateResult), + ("poor_vision_prob", ctypes.c_float), + ("wheel_on_right_prob", ctypes.c_float)] + +class ModelState: + inputs: Dict[str, np.ndarray] + output: np.ndarray + model: ModelRunner + + def __init__(self): + assert ctypes.sizeof(DMonitoringModelResult) == OUTPUT_SIZE * ctypes.sizeof(ctypes.c_float) + self.output = np.zeros(OUTPUT_SIZE, dtype=np.float32) + self.inputs = { + 'input_imgs': np.zeros(MODEL_HEIGHT * MODEL_WIDTH, dtype=np.uint8), + 'calib': np.zeros(CALIB_LEN, dtype=np.float32)} + + self.model = ModelRunner(MODEL_PATHS, self.output, Runtime.DSP, True, None) + self.model.addInput("input_imgs", None) + self.model.addInput("calib", self.inputs['calib']) + + def run(self, buf:VisionBuf, calib:np.ndarray) -> Tuple[np.ndarray, float]: + self.inputs['calib'][:] = calib + + v_offset = buf.height - MODEL_HEIGHT + h_offset = (buf.width - MODEL_WIDTH) // 2 + buf_data = buf.data.reshape(-1, buf.stride) + input_data = self.inputs['input_imgs'].reshape(MODEL_HEIGHT, MODEL_WIDTH) + input_data[:] = buf_data[v_offset:v_offset+MODEL_HEIGHT, h_offset:h_offset+MODEL_WIDTH] + + t1 = time.perf_counter() + self.model.setInputBuffer("input_imgs", self.inputs['input_imgs'].view(np.float32)) + self.model.execute() + t2 = time.perf_counter() + return self.output, t2 - t1 + + +def fill_driver_state(msg, ds_result: DriverStateResult): + msg.faceOrientation = [x * REG_SCALE for x in ds_result.face_orientation] + msg.faceOrientationStd = [math.exp(x) for x in ds_result.face_orientation_std] + msg.facePosition = [x * REG_SCALE for x in ds_result.face_position[:2]] + msg.facePositionStd = [math.exp(x) for x in ds_result.face_position_std[:2]] + msg.faceProb = sigmoid(ds_result.face_prob) + msg.leftEyeProb = sigmoid(ds_result.left_eye_prob) + msg.rightEyeProb = sigmoid(ds_result.right_eye_prob) + msg.leftBlinkProb = sigmoid(ds_result.left_blink_prob) + msg.rightBlinkProb = sigmoid(ds_result.right_blink_prob) + msg.sunglassesProb = sigmoid(ds_result.sunglasses_prob) + msg.occludedProb = sigmoid(ds_result.occluded_prob) + msg.readyProb = [sigmoid(x) for x in ds_result.ready_prob] + msg.notReadyProb = [sigmoid(x) for x in ds_result.not_ready_prob] + +def get_driverstate_packet(model_output: np.ndarray, frame_id: int, location_ts: int, execution_time: float, dsp_execution_time: float): + model_result = ctypes.cast(model_output.ctypes.data, ctypes.POINTER(DMonitoringModelResult)).contents + msg = messaging.new_message('driverStateV2') + ds = msg.driverStateV2 + ds.frameId = frame_id + ds.modelExecutionTime = execution_time + ds.dspExecutionTime = dsp_execution_time + ds.poorVisionProb = sigmoid(model_result.poor_vision_prob) + ds.wheelOnRightProb = sigmoid(model_result.wheel_on_right_prob) + ds.rawPredictions = model_output.tobytes() if SEND_RAW_PRED else b'' + fill_driver_state(ds.leftDriverData, model_result.driver_state_lhd) + fill_driver_state(ds.rightDriverData, model_result.driver_state_rhd) + return msg + + +def main(): + gc.disable() + set_realtime_priority(1) + + model = ModelState() + cloudlog.warning("models loaded, dmonitoringmodeld starting") + Params().put_bool("DmModelInitialized", True) + + cloudlog.warning("connecting to driver stream") + vipc_client = VisionIpcClient("camerad", VisionStreamType.VISION_STREAM_DRIVER, True) + while not vipc_client.connect(False): + time.sleep(0.1) + assert vipc_client.is_connected() + cloudlog.warning(f"connected with buffer size: {vipc_client.buffer_len}") + + sm = SubMaster(["liveCalibration"]) + pm = PubMaster(["driverStateV2"]) + + calib = np.zeros(CALIB_LEN, dtype=np.float32) + # last = 0 + + while True: + buf = vipc_client.recv() + if buf is None: + continue + + sm.update(0) + if sm.updated["liveCalibration"]: + calib[:] = np.array(sm["liveCalibration"].rpyCalib) + + t1 = time.perf_counter() + model_output, dsp_execution_time = model.run(buf, calib) + t2 = time.perf_counter() + + pm.send("driverStateV2", get_driverstate_packet(model_output, vipc_client.frame_id, vipc_client.timestamp_sof, t2 - t1, dsp_execution_time)) + # print("dmonitoring process: %.2fms, from last %.2fms\n" % (t2 - t1, t1 - last)) + # last = t1 + + +if __name__ == "__main__": + main() diff --git a/selfdrive/modeld/models/dmonitoring.cc b/selfdrive/modeld/models/dmonitoring.cc deleted file mode 100644 index eb5239bef0..0000000000 --- a/selfdrive/modeld/models/dmonitoring.cc +++ /dev/null @@ -1,133 +0,0 @@ -#include - -#include "common/mat.h" -#include "common/modeldata.h" -#include "common/params.h" -#include "common/timing.h" -#include "system/hardware/hw.h" - -#include "selfdrive/modeld/models/dmonitoring.h" - -constexpr int MODEL_WIDTH = 1440; -constexpr int MODEL_HEIGHT = 960; - -template -static inline T *get_buffer(std::vector &buf, const size_t size) { - if (buf.size() < size) buf.resize(size); - return buf.data(); -} - -void dmonitoring_init(DMonitoringModelState* s) { - -#ifdef USE_ONNX_MODEL - s->m = new ONNXModel("models/dmonitoring_model.onnx", &s->output[0], OUTPUT_SIZE, USE_DSP_RUNTIME, true); -#else - s->m = new SNPEModel("models/dmonitoring_model_q.dlc", &s->output[0], OUTPUT_SIZE, USE_DSP_RUNTIME, true); -#endif - - s->m->addInput("input_imgs", NULL, 0); - s->m->addInput("calib", s->calib, CALIB_LEN); -} - -void parse_driver_data(DriverStateResult &ds_res, const DMonitoringModelState* s, int out_idx_offset) { - for (int i = 0; i < 3; ++i) { - ds_res.face_orientation[i] = s->output[out_idx_offset+i] * REG_SCALE; - ds_res.face_orientation_std[i] = exp(s->output[out_idx_offset+6+i]); - } - for (int i = 0; i < 2; ++i) { - ds_res.face_position[i] = s->output[out_idx_offset+3+i] * REG_SCALE; - ds_res.face_position_std[i] = exp(s->output[out_idx_offset+9+i]); - } - for (int i = 0; i < 4; ++i) { - ds_res.ready_prob[i] = sigmoid(s->output[out_idx_offset+35+i]); - } - for (int i = 0; i < 2; ++i) { - ds_res.not_ready_prob[i] = sigmoid(s->output[out_idx_offset+39+i]); - } - ds_res.face_prob = sigmoid(s->output[out_idx_offset+12]); - ds_res.left_eye_prob = sigmoid(s->output[out_idx_offset+21]); - ds_res.right_eye_prob = sigmoid(s->output[out_idx_offset+30]); - ds_res.left_blink_prob = sigmoid(s->output[out_idx_offset+31]); - ds_res.right_blink_prob = sigmoid(s->output[out_idx_offset+32]); - ds_res.sunglasses_prob = sigmoid(s->output[out_idx_offset+33]); - ds_res.occluded_prob = sigmoid(s->output[out_idx_offset+34]); -} - -void fill_driver_data(cereal::DriverStateV2::DriverData::Builder ddata, const DriverStateResult &ds_res) { - ddata.setFaceOrientation(ds_res.face_orientation); - ddata.setFaceOrientationStd(ds_res.face_orientation_std); - ddata.setFacePosition(ds_res.face_position); - ddata.setFacePositionStd(ds_res.face_position_std); - ddata.setFaceProb(ds_res.face_prob); - ddata.setLeftEyeProb(ds_res.left_eye_prob); - ddata.setRightEyeProb(ds_res.right_eye_prob); - ddata.setLeftBlinkProb(ds_res.left_blink_prob); - ddata.setRightBlinkProb(ds_res.right_blink_prob); - ddata.setSunglassesProb(ds_res.sunglasses_prob); - ddata.setOccludedProb(ds_res.occluded_prob); - ddata.setReadyProb(ds_res.ready_prob); - ddata.setNotReadyProb(ds_res.not_ready_prob); -} - -DMonitoringModelResult dmonitoring_eval_frame(DMonitoringModelState* s, void* stream_buf, int width, int height, int stride, int uv_offset, float *calib) { - int v_off = height - MODEL_HEIGHT; - int h_off = (width - MODEL_WIDTH) / 2; - int yuv_buf_len = MODEL_WIDTH * MODEL_HEIGHT; - - uint8_t *raw_buf = (uint8_t *) stream_buf; - // vertical crop free - uint8_t *raw_y_start = raw_buf + stride * v_off; - - uint8_t *net_input_buf = get_buffer(s->net_input_buf, yuv_buf_len); - - // here makes a uint8 copy - for (int r = 0; r < MODEL_HEIGHT; ++r) { - memcpy(net_input_buf + r * MODEL_WIDTH, raw_y_start + r * stride + h_off, MODEL_WIDTH); - } - - // printf("preprocess completed. %d \n", yuv_buf_len); - // FILE *dump_yuv_file = fopen("/tmp/rawdump.yuv", "wb"); - // fwrite(net_input_buf, yuv_buf_len, sizeof(uint8_t), dump_yuv_file); - // fclose(dump_yuv_file); - - double t1 = millis_since_boot(); - s->m->setInputBuffer("input_imgs", (float*)net_input_buf, yuv_buf_len / sizeof(float)); - for (int i = 0; i < CALIB_LEN; i++) { - s->calib[i] = calib[i]; - } - s->m->execute(); - double t2 = millis_since_boot(); - - DMonitoringModelResult model_res = {0}; - parse_driver_data(model_res.driver_state_lhd, s, 0); - parse_driver_data(model_res.driver_state_rhd, s, 41); - model_res.poor_vision_prob = sigmoid(s->output[82]); - model_res.wheel_on_right_prob = sigmoid(s->output[83]); - model_res.dsp_execution_time = (t2 - t1) / 1000.; - - return model_res; -} - -void dmonitoring_publish(PubMaster &pm, uint32_t frame_id, const DMonitoringModelResult &model_res, float execution_time, kj::ArrayPtr raw_pred) { - // make msg - MessageBuilder msg; - auto framed = msg.initEvent().initDriverStateV2(); - framed.setFrameId(frame_id); - framed.setModelExecutionTime(execution_time); - framed.setDspExecutionTime(model_res.dsp_execution_time); - - framed.setPoorVisionProb(model_res.poor_vision_prob); - framed.setWheelOnRightProb(model_res.wheel_on_right_prob); - fill_driver_data(framed.initLeftDriverData(), model_res.driver_state_lhd); - fill_driver_data(framed.initRightDriverData(), model_res.driver_state_rhd); - - if (send_raw_pred) { - framed.setRawPredictions(raw_pred.asBytes()); - } - - pm.send("driverStateV2", msg); -} - -void dmonitoring_free(DMonitoringModelState* s) { - delete s->m; -} diff --git a/selfdrive/modeld/models/dmonitoring.h b/selfdrive/modeld/models/dmonitoring.h deleted file mode 100644 index ae2bf05394..0000000000 --- a/selfdrive/modeld/models/dmonitoring.h +++ /dev/null @@ -1,50 +0,0 @@ -#pragma once - -#include - -#include "cereal/messaging/messaging.h" -#include "common/util.h" -#include "selfdrive/modeld/models/commonmodel.h" -#include "selfdrive/modeld/runners/run.h" - -#define CALIB_LEN 3 - -#define OUTPUT_SIZE 84 -#define REG_SCALE 0.25f - -typedef struct DriverStateResult { - float face_orientation[3]; - float face_orientation_std[3]; - float face_position[2]; - float face_position_std[2]; - float face_prob; - float left_eye_prob; - float right_eye_prob; - float left_blink_prob; - float right_blink_prob; - float sunglasses_prob; - float occluded_prob; - float ready_prob[4]; - float not_ready_prob[2]; -} DriverStateResult; - -typedef struct DMonitoringModelResult { - DriverStateResult driver_state_lhd; - DriverStateResult driver_state_rhd; - float poor_vision_prob; - float wheel_on_right_prob; - float dsp_execution_time; -} DMonitoringModelResult; - -typedef struct DMonitoringModelState { - RunModel *m; - float output[OUTPUT_SIZE]; - std::vector net_input_buf; - float calib[CALIB_LEN]; -} DMonitoringModelState; - -void dmonitoring_init(DMonitoringModelState* s); -DMonitoringModelResult dmonitoring_eval_frame(DMonitoringModelState* s, void* stream_buf, int width, int height, int stride, int uv_offset, float *calib); -void dmonitoring_publish(PubMaster &pm, uint32_t frame_id, const DMonitoringModelResult &model_res, float execution_time, kj::ArrayPtr raw_pred); -void dmonitoring_free(DMonitoringModelState* s); - diff --git a/selfdrive/modeld/runners/run.h b/selfdrive/modeld/runners/run.h index cae2f5b27a..36ad262a5b 100644 --- a/selfdrive/modeld/runners/run.h +++ b/selfdrive/modeld/runners/run.h @@ -2,7 +2,3 @@ #include "selfdrive/modeld/runners/runmodel.h" #include "selfdrive/modeld/runners/snpemodel.h" - -#if defined(USE_ONNX_MODEL) -#include "selfdrive/modeld/runners/onnxmodel.h" -#endif diff --git a/selfdrive/test/test_onroad.py b/selfdrive/test/test_onroad.py index 9fa48b43d8..553ef1afa6 100755 --- a/selfdrive/test/test_onroad.py +++ b/selfdrive/test/test_onroad.py @@ -37,7 +37,7 @@ PROCS = { "./_sensord": 7.0, "selfdrive.controls.radard": 4.5, "selfdrive.modeld.modeld": 8.0, - "./_dmonitoringmodeld": 5.0, + "selfdrive.modeld.dmonitoringmodeld": 8.0, "selfdrive.modeld.navmodeld": 1.0, "selfdrive.thermald.thermald": 3.87, "selfdrive.locationd.calibrationd": 2.0,