Rewrite modeld in python (#29230)

* Added modeld.py (WIP)

* No more VisionIpcBufExtra

* Started work on cython bindings for runmodel

* Got ONNXModel cython bindings mostly working, added ModelFrame bindings

* Got modeld main loop running without model eval

* Move everything into ModelState

* Doesn't crash!

* Moved ModelState into modeld.py

* Added driving_pyx

* Added cython bindings for message generation

* Moved CLContext definition to visionipc.pxd

* *facepalm*

* Move cl_pyx into commonmodel_pyx

* Split out ONNXModel into a subclass of RunModel

* Added snpemodel/thneedmodel bindings

* Removed modeld.cc

* Fixed scons for macOS

* Fixed sconscript

* Added flag for thneedmodel

* paths are now relative to openpilot root dir

* Set cl kernel paths in SConscript

* Set LD_PRELOAD=libthneed.so to fix ioctl interception

* Run from root dir

* A few more fixes

* A few more minor fixes

* Use C update_calibration for now to exactly match refs

* Add nav_instructions input

* Link driving_pyx.pyx with transformations

* Checked python FirstOrderFilter against C++ FirstOrderFilter

* Set process name to fix test_onroad

* Revert changes to onnxmodel.cc

* Fixed bad onnx_runner.py path in onnxmodel.cc

* Import all constants from driving.h

* logging -> cloudlog

* pylint import-error suppressions no longer needed?

* Loop in SConscript

* Added parens

* Bump modeld cpu usage in test_onroad

* Get rid of use_nav

* use config_realtime_process

* error message from ioctl sniffer was messing up pyenv

* cast distance_idx to int

* Removed cloudlog.infos in model.run

* Fixed rebase conflicts

* Clean up driving.pxd/pyx

* Fixed linter error
old-commit-hash: 72a3c987c0
beeps
Mitchell Goff 2 years ago committed by GitHub
parent 4db56c1247
commit a3fbbb26ac
  1. 2
      release/files_common
  2. 21
      selfdrive/modeld/SConscript
  3. 10
      selfdrive/modeld/modeld
  4. 271
      selfdrive/modeld/modeld.cc
  5. 279
      selfdrive/modeld/modeld.py
  6. 201
      selfdrive/modeld/models/driving.cc
  7. 57
      selfdrive/modeld/models/driving.h
  8. 29
      selfdrive/modeld/models/driving.pxd
  9. 60
      selfdrive/modeld/models/driving_pyx.pyx
  10. 2
      selfdrive/test/test_onroad.py

@ -357,7 +357,7 @@ selfdrive/manager/test/test_manager.py
selfdrive/modeld/.gitignore selfdrive/modeld/.gitignore
selfdrive/modeld/__init__.py selfdrive/modeld/__init__.py
selfdrive/modeld/SConscript selfdrive/modeld/SConscript
selfdrive/modeld/modeld.cc selfdrive/modeld/modeld.py
selfdrive/modeld/navmodeld.cc selfdrive/modeld/navmodeld.cc
selfdrive/modeld/dmonitoringmodeld.cc selfdrive/modeld/dmonitoringmodeld.cc
selfdrive/modeld/constants.py selfdrive/modeld/constants.py

@ -17,7 +17,6 @@ common_src = [
thneed_src_common = [ thneed_src_common = [
"thneed/thneed_common.cc", "thneed/thneed_common.cc",
"thneed/serialize.cc", "thneed/serialize.cc",
"runners/thneedmodel.cc",
] ]
thneed_src_qcom = thneed_src_common + ["thneed/thneed_qcom2.cc"] thneed_src_qcom = thneed_src_common + ["thneed/thneed_qcom2.cc"]
@ -28,10 +27,6 @@ use_thneed = not GetOption('no_thneed')
if arch == "larch64": if arch == "larch64":
libs += ['gsl', 'CB', 'pthread', 'dl'] libs += ['gsl', 'CB', 'pthread', 'dl']
if use_thneed:
common_src += thneed_src_qcom
lenv['CXXFLAGS'].append("-DUSE_THNEED")
else: else:
libs += ['pthread'] libs += ['pthread']
@ -71,11 +66,13 @@ else:
onnxmodel_lib = lenv.Library('onnxmodel', ['runners/onnxmodel.cc']) onnxmodel_lib = lenv.Library('onnxmodel', ['runners/onnxmodel.cc'])
snpemodel_lib = lenv.Library('snpemodel', ['runners/snpemodel.cc']) snpemodel_lib = lenv.Library('snpemodel', ['runners/snpemodel.cc'])
commonmodel_lib = lenv.Library('commonmodel', common_src) commonmodel_lib = lenv.Library('commonmodel', common_src)
driving_lib = lenv.Library('driving', ['models/driving.cc'])
lenvCython.Program('runners/runmodel_pyx.so', 'runners/runmodel_pyx.pyx', LIBS=common_libs, FRAMEWORKS=common_frameworks) lenvCython.Program('runners/runmodel_pyx.so', 'runners/runmodel_pyx.pyx', LIBS=common_libs, FRAMEWORKS=common_frameworks)
lenvCython.Program('runners/onnxmodel_pyx.so', 'runners/onnxmodel_pyx.pyx', LIBS=[onnxmodel_lib, *common_libs], FRAMEWORKS=common_frameworks) lenvCython.Program('runners/onnxmodel_pyx.so', 'runners/onnxmodel_pyx.pyx', LIBS=[onnxmodel_lib, *common_libs], FRAMEWORKS=common_frameworks)
lenvCython.Program('runners/snpemodel_pyx.so', 'runners/snpemodel_pyx.pyx', LIBS=[snpemodel_lib, *common_libs], FRAMEWORKS=common_frameworks) lenvCython.Program('runners/snpemodel_pyx.so', 'runners/snpemodel_pyx.pyx', LIBS=[snpemodel_lib, *common_libs], FRAMEWORKS=common_frameworks)
lenvCython.Program('models/commonmodel_pyx.so', 'models/commonmodel_pyx.pyx', LIBS=[commonmodel_lib, *common_libs], FRAMEWORKS=common_frameworks) lenvCython.Program('models/commonmodel_pyx.so', 'models/commonmodel_pyx.pyx', LIBS=[commonmodel_lib, *common_libs], FRAMEWORKS=common_frameworks)
lenvCython.Program('models/driving_pyx.so', 'models/driving_pyx.pyx', LIBS=[driving_lib, commonmodel_lib, cereal, messaging, *common_libs, 'capnp', 'kj'] + transformations, FRAMEWORKS=common_frameworks)
common_model = lenv.Object(common_src) common_model = lenv.Object(common_src)
@ -102,14 +99,6 @@ if (use_thneed and arch == "larch64") or GetOption('pc_thneed'):
tinygrad_files = sum([lenv.Glob("#"+x) for x in open(File("#release/files_common").abspath).read().split("\n") if x.startswith("tinygrad_repo/")], []) tinygrad_files = sum([lenv.Glob("#"+x) for x in open(File("#release/files_common").abspath).read().split("\n") if x.startswith("tinygrad_repo/")], [])
lenv.Command(fn + ".thneed", [fn + ".onnx"] + tinygrad_files, cmd) lenv.Command(fn + ".thneed", [fn + ".onnx"] + tinygrad_files, cmd)
llenv = lenv.Clone() thneed_lib = env.SharedLibrary('thneed', thneed_src, LIBS=[gpucommon, common, 'zmq', 'OpenCL', 'dl'])
if GetOption('pc_thneed'): thneedmodel_lib = env.Library('thneedmodel', ['runners/thneedmodel.cc'])
llenv['CFLAGS'].append("-DUSE_THNEED") lenvCython.Program('runners/thneedmodel_pyx.so', 'runners/thneedmodel_pyx.pyx', LIBS=envCython["LIBS"]+[thneedmodel_lib, thneed_lib, gpucommon, common, 'dl', 'zmq', 'OpenCL'])
llenv['CXXFLAGS'].append("-DUSE_THNEED")
common_model += llenv.Object(thneed_src_pc)
libs += ['dl']
llenv.Program('_modeld', [
"modeld.cc",
"models/driving.cc",
]+common_model, LIBS=libs + transformations)

@ -1,11 +1,7 @@
#!/bin/sh #!/bin/sh
DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" >/dev/null && pwd)" DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" >/dev/null && pwd)"
cd $DIR cd "$DIR/../../"
if [ -f /TICI ]; then export LD_PRELOAD="$DIR/libthneed.so"
export LD_LIBRARY_PATH="/usr/lib/aarch64-linux-gnu:/data/pythonpath/third_party/snpe/larch64:$LD_LIBRARY_PATH" exec "$DIR/modeld.py"
else
export LD_LIBRARY_PATH="$DIR/../../third_party/snpe/x86_64-linux-clang:$DIR/../../openpilot/third_party/snpe/x86_64:$LD_LIBRARY_PATH"
fi
exec ./_modeld

@ -1,271 +0,0 @@
#include <cstdio>
#include <cstdlib>
#include <mutex>
#include <cmath>
#include <eigen3/Eigen/Dense>
#include "cereal/messaging/messaging.h"
#include "common/transformations/orientation.hpp"
#include "cereal/visionipc/visionipc_client.h"
#include "common/clutil.h"
#include "common/params.h"
#include "common/swaglog.h"
#include "common/util.h"
#include "system/hardware/hw.h"
#include "selfdrive/modeld/models/driving.h"
#include "selfdrive/modeld/models/nav.h"
ExitHandler do_exit;
mat3 update_calibration(Eigen::Vector3d device_from_calib_euler, bool wide_camera, bool bigmodel_frame) {
/*
import numpy as np
from openpilot.common.transformations.model import medmodel_frame_from_calib_frame
medmodel_frame_from_calib_frame = medmodel_frame_from_calib_frame[:, :3]
calib_from_smedmodel_frame = np.linalg.inv(medmodel_frame_from_calib_frame)
*/
static const auto calib_from_medmodel = (Eigen::Matrix<float, 3, 3>() <<
0.00000000e+00, 0.00000000e+00, 1.00000000e+00,
1.09890110e-03, 0.00000000e+00, -2.81318681e-01,
-2.25466395e-20, 1.09890110e-03, -5.23076923e-02).finished();
static const auto calib_from_sbigmodel = (Eigen::Matrix<float, 3, 3>() <<
0.00000000e+00, 7.31372216e-19, 1.00000000e+00,
2.19780220e-03, 4.11497335e-19, -5.62637363e-01,
-6.66298828e-20, 2.19780220e-03, -3.33626374e-01).finished();
static const auto view_from_device = (Eigen::Matrix<float, 3, 3>() <<
0.0, 1.0, 0.0,
0.0, 0.0, 1.0,
1.0, 0.0, 0.0).finished();
const auto cam_intrinsics = Eigen::Matrix<float, 3, 3, Eigen::RowMajor>(wide_camera ? ECAM_INTRINSIC_MATRIX.v : FCAM_INTRINSIC_MATRIX.v);
Eigen::Matrix<float, 3, 3, Eigen::RowMajor> device_from_calib = euler2rot(device_from_calib_euler).cast <float> ();
auto calib_from_model = bigmodel_frame ? calib_from_sbigmodel : calib_from_medmodel;
auto camera_from_calib = cam_intrinsics * view_from_device * device_from_calib;
auto warp_matrix = camera_from_calib * calib_from_model;
mat3 transform = {};
for (int i=0; i<3*3; i++) {
transform.v[i] = warp_matrix(i / 3, i % 3);
}
return transform;
}
void run_model(ModelState &model, VisionIpcClient &vipc_client_main, VisionIpcClient &vipc_client_extra, bool main_wide_camera, bool use_extra_client) {
// messaging
PubMaster pm({"modelV2", "cameraOdometry"});
SubMaster sm({"lateralPlan", "roadCameraState", "liveCalibration", "driverMonitoringState", "navModel", "navInstruction"});
Params params;
PublishState ps = {};
// setup filter to track dropped frames
FirstOrderFilter frame_dropped_filter(0., 10., 1. / MODEL_FREQ);
uint32_t frame_id = 0, last_vipc_frame_id = 0;
// double last = 0;
uint32_t run_count = 0;
mat3 model_transform_main = {};
mat3 model_transform_extra = {};
bool nav_enabled = false;
bool live_calib_seen = false;
float driving_style[DRIVING_STYLE_LEN] = {1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0};
float nav_features[NAV_FEATURE_LEN] = {0};
float nav_instructions[NAV_INSTRUCTION_LEN] = {0};
VisionBuf *buf_main = nullptr;
VisionBuf *buf_extra = nullptr;
VisionIpcBufExtra meta_main = {0};
VisionIpcBufExtra meta_extra = {0};
while (!do_exit) {
// Keep receiving frames until we are at least 1 frame ahead of previous extra frame
while (meta_main.timestamp_sof < meta_extra.timestamp_sof + 25000000ULL) {
buf_main = vipc_client_main.recv(&meta_main);
if (buf_main == nullptr) break;
}
if (buf_main == nullptr) {
LOGE("vipc_client_main no frame");
continue;
}
if (use_extra_client) {
// Keep receiving extra frames until frame id matches main camera
do {
buf_extra = vipc_client_extra.recv(&meta_extra);
} while (buf_extra != nullptr && meta_main.timestamp_sof > meta_extra.timestamp_sof + 25000000ULL);
if (buf_extra == nullptr) {
LOGE("vipc_client_extra no frame");
continue;
}
if (std::abs((int64_t)meta_main.timestamp_sof - (int64_t)meta_extra.timestamp_sof) > 10000000ULL) {
LOGE("frames out of sync! main: %d (%.5f), extra: %d (%.5f)",
meta_main.frame_id, double(meta_main.timestamp_sof) / 1e9,
meta_extra.frame_id, double(meta_extra.timestamp_sof) / 1e9);
}
} else {
// Use single camera
buf_extra = buf_main;
meta_extra = meta_main;
}
// TODO: path planner timeout?
sm.update(0);
int desire = ((int)sm["lateralPlan"].getLateralPlan().getDesire());
bool is_rhd = ((bool)sm["driverMonitoringState"].getDriverMonitoringState().getIsRHD());
frame_id = sm["roadCameraState"].getRoadCameraState().getFrameId();
if (sm.updated("liveCalibration")) {
auto rpy_calib = sm["liveCalibration"].getLiveCalibration().getRpyCalib();
Eigen::Vector3d device_from_calib_euler;
for (int i=0; i<3; i++) {
device_from_calib_euler(i) = rpy_calib[i];
}
model_transform_main = update_calibration(device_from_calib_euler, main_wide_camera, false);
model_transform_extra = update_calibration(device_from_calib_euler, true, true);
live_calib_seen = true;
}
float vec_desire[DESIRE_LEN] = {0};
if (desire >= 0 && desire < DESIRE_LEN) {
vec_desire[desire] = 1.0;
}
// Enable/disable nav features
uint64_t timestamp_llk = sm["navModel"].getNavModel().getLocationMonoTime();
bool nav_valid = sm["navModel"].getValid() && (nanos_since_boot() - timestamp_llk < 1e9);
bool use_nav = nav_valid && params.getBool("ExperimentalMode");
if (!nav_enabled && use_nav) {
nav_enabled = true;
} else if (nav_enabled && !use_nav) {
memset(nav_features, 0, sizeof(float)*NAV_FEATURE_LEN);
memset(nav_instructions, 0, sizeof(float)*NAV_INSTRUCTION_LEN);
nav_enabled = false;
}
if (nav_enabled && sm.updated("navModel")) {
auto nav_model_features = sm["navModel"].getNavModel().getFeatures();
for (int i=0; i<NAV_FEATURE_LEN; i++) {
nav_features[i] = nav_model_features[i];
}
}
if (nav_enabled && sm.updated("navInstruction")) {
memset(nav_instructions, 0, sizeof(float)*NAV_INSTRUCTION_LEN);
auto maneuvers = sm["navInstruction"].getNavInstruction().getAllManeuvers();
for (int i=0; i<maneuvers.size(); i++) {
int distance_idx = 25 + (int)(maneuvers[i].getDistance() / 20);
std::string direction = maneuvers[i].getModifier();
int direction_idx = 0;
if (direction == "left" || direction == "slight left" || direction == "sharp left") direction_idx = 1;
if (direction == "right" || direction == "slight right" || direction == "sharp right") direction_idx = 2;
if (distance_idx >= 0 && distance_idx < 50) {
nav_instructions[distance_idx*3 + direction_idx] = 1;
}
}
}
// tracked dropped frames
uint32_t vipc_dropped_frames = meta_main.frame_id - last_vipc_frame_id - 1;
float frames_dropped = frame_dropped_filter.update((float)std::min(vipc_dropped_frames, 10U));
if (run_count < 10) { // let frame drops warm up
frame_dropped_filter.reset(0);
frames_dropped = 0.;
}
run_count++;
float frame_drop_ratio = frames_dropped / (1 + frames_dropped);
bool prepare_only = vipc_dropped_frames > 0;
if (prepare_only) {
LOGE("skipping model eval. Dropped %d frames", vipc_dropped_frames);
}
double mt1 = millis_since_boot();
ModelOutput *model_output = model_eval_frame(&model, buf_main, buf_extra, model_transform_main, model_transform_extra, vec_desire, is_rhd, driving_style, nav_features, nav_instructions, prepare_only);
double mt2 = millis_since_boot();
float model_execution_time = (mt2 - mt1) / 1000.0;
if (model_output != nullptr) {
model_publish(pm, meta_main.frame_id, meta_extra.frame_id, frame_id, frame_drop_ratio, *model_output, model, ps, meta_main.timestamp_eof, timestamp_llk, model_execution_time,
nav_enabled, live_calib_seen);
posenet_publish(pm, meta_main.frame_id, vipc_dropped_frames, *model_output, meta_main.timestamp_eof, live_calib_seen);
}
// printf("model process: %.2fms, from last %.2fms, vipc_frame_id %u, frame_id, %u, frame_drop %.3f\n", mt2 - mt1, mt1 - last, extra.frame_id, frame_id, frame_drop_ratio);
// last = mt1;
last_vipc_frame_id = meta_main.frame_id;
}
}
int main(int argc, char **argv) {
if (!Hardware::PC()) {
int ret;
ret = util::set_realtime_priority(54);
assert(ret == 0);
util::set_core_affinity({7});
assert(ret == 0);
}
// cl init
cl_device_id device_id = cl_get_device_id(CL_DEVICE_TYPE_DEFAULT);
cl_context context = CL_CHECK_ERR(clCreateContext(NULL, 1, &device_id, NULL, NULL, &err));
// init the models
ModelState model;
model_init(&model, device_id, context);
LOGW("models loaded, modeld starting");
bool main_wide_camera = false;
bool use_extra_client = true; // set to false to use single camera
while (!do_exit) {
auto streams = VisionIpcClient::getAvailableStreams("camerad", false);
if (!streams.empty()) {
use_extra_client = streams.count(VISION_STREAM_WIDE_ROAD) > 0 && streams.count(VISION_STREAM_ROAD) > 0;
main_wide_camera = streams.count(VISION_STREAM_ROAD) == 0;
break;
}
util::sleep_for(100);
}
VisionIpcClient vipc_client_main = VisionIpcClient("camerad", main_wide_camera ? VISION_STREAM_WIDE_ROAD : VISION_STREAM_ROAD, true, device_id, context);
VisionIpcClient vipc_client_extra = VisionIpcClient("camerad", VISION_STREAM_WIDE_ROAD, false, device_id, context);
LOGW("vision stream set up, main_wide_camera: %d, use_extra_client: %d", main_wide_camera, use_extra_client);
while (!do_exit && !vipc_client_main.connect(false)) {
util::sleep_for(100);
}
while (!do_exit && use_extra_client && !vipc_client_extra.connect(false)) {
util::sleep_for(100);
}
// run the models
// vipc_client.connected is false only when do_exit is true
if (!do_exit) {
const VisionBuf *b = &vipc_client_main.buffers[0];
LOGW("connected main cam with buffer size: %zu (%zu x %zu)", b->len, b->width, b->height);
if (use_extra_client) {
const VisionBuf *wb = &vipc_client_extra.buffers[0];
LOGW("connected extra cam with buffer size: %zu (%zu x %zu)", wb->len, wb->width, wb->height);
}
run_model(model, vipc_client_main, vipc_client_extra, main_wide_camera, use_extra_client);
}
model_free(&model);
CL_CHECK(clReleaseContext(context));
return 0;
}

@ -0,0 +1,279 @@
#!/usr/bin/env python3
import sys
import time
import numpy as np
from pathlib import Path
from typing import Dict, Optional
from setproctitle import setproctitle
from cereal.messaging import PubMaster, SubMaster
from cereal.visionipc import VisionIpcClient, VisionStreamType, VisionBuf
from openpilot.system.hardware import PC
from openpilot.system.swaglog import cloudlog
from openpilot.common.params import Params
from openpilot.common.filter_simple import FirstOrderFilter
from openpilot.common.realtime import config_realtime_process
from openpilot.selfdrive.modeld.models.commonmodel_pyx import ModelFrame, CLContext, Runtime
from openpilot.selfdrive.modeld.models.driving_pyx import (
PublishState, create_model_msg, create_pose_msg, update_calibration,
FEATURE_LEN, HISTORY_BUFFER_LEN, DESIRE_LEN, TRAFFIC_CONVENTION_LEN, NAV_FEATURE_LEN, NAV_INSTRUCTION_LEN,
OUTPUT_SIZE, NET_OUTPUT_SIZE, MODEL_FREQ, USE_THNEED)
if USE_THNEED:
from selfdrive.modeld.runners.thneedmodel_pyx import ThneedModel as ModelRunner
else:
from selfdrive.modeld.runners.onnxmodel_pyx import ONNXModel as ModelRunner
MODEL_PATH = str(Path(__file__).parent / f"models/supercombo.{'thneed' if USE_THNEED else 'onnx'}")
# NOTE: numpy matmuls don't seem to perfectly match eigen matmuls so the ref test fails, but we should switch to the np version after checking compare_runtime
# from common.transformations.orientation import rot_from_euler
# from common.transformations.model import medmodel_frame_from_calib_frame, sbigmodel_frame_from_calib_frame
# from common.transformations.camera import view_frame_from_device_frame, tici_fcam_intrinsics, tici_ecam_intrinsics
# calib_from_medmodel = np.linalg.inv(medmodel_frame_from_calib_frame[:, :3])
# calib_from_sbigmodel = np.linalg.inv(sbigmodel_frame_from_calib_frame[:, :3])
#
# def update_calibration(device_from_calib_euler: np.ndarray, wide_camera: bool, bigmodel_frame: bool) -> np.ndarray:
# cam_intrinsics = tici_ecam_intrinsics if wide_camera else tici_fcam_intrinsics
# calib_from_model = calib_from_sbigmodel if bigmodel_frame else calib_from_medmodel
# device_from_calib = rot_from_euler(device_from_calib_euler)
# camera_from_calib = cam_intrinsics @ view_frame_from_device_frame @ device_from_calib
# warp_matrix: np.ndarray = camera_from_calib @ calib_from_model
# return warp_matrix
class FrameMeta:
frame_id: int = 0
timestamp_sof: int = 0
timestamp_eof: int = 0
def __init__(self, vipc=None):
if vipc is not None:
self.frame_id, self.timestamp_sof, self.timestamp_eof = vipc.frame_id, vipc.timestamp_sof, vipc.timestamp_eof
class ModelState:
frame: ModelFrame
wide_frame: ModelFrame
inputs: Dict[str, np.ndarray]
output: np.ndarray
prev_desire: np.ndarray # for tracking the rising edge of the pulse
model: ModelRunner
def __init__(self, context: CLContext):
self.frame = ModelFrame(context)
self.wide_frame = ModelFrame(context)
self.prev_desire = np.zeros(DESIRE_LEN, dtype=np.float32)
self.output = np.zeros(NET_OUTPUT_SIZE, dtype=np.float32)
self.inputs = {
'desire_pulse': np.zeros(DESIRE_LEN * (HISTORY_BUFFER_LEN+1), dtype=np.float32),
'traffic_convention': np.zeros(TRAFFIC_CONVENTION_LEN, dtype=np.float32),
'nav_features': np.zeros(NAV_FEATURE_LEN, dtype=np.float32),
'nav_instructions': np.zeros(NAV_INSTRUCTION_LEN, dtype=np.float32),
'feature_buffer': np.zeros(HISTORY_BUFFER_LEN * FEATURE_LEN, dtype=np.float32),
}
self.model = ModelRunner(MODEL_PATH, self.output, Runtime.GPU, False, context)
self.model.addInput("input_imgs", None)
self.model.addInput("big_input_imgs", None)
for k,v in self.inputs.items():
self.model.addInput(k, v)
def run(self, buf: VisionBuf, wbuf: VisionBuf, transform: np.ndarray, transform_wide: np.ndarray,
inputs: Dict[str, np.ndarray], prepare_only: bool) -> Optional[np.ndarray]:
# Model decides when action is completed, so desire input is just a pulse triggered on rising edge
inputs['desire_pulse'][0] = 0
self.inputs['desire_pulse'][:-DESIRE_LEN] = self.inputs['desire_pulse'][DESIRE_LEN:]
self.inputs['desire_pulse'][-DESIRE_LEN:] = np.where(inputs['desire_pulse'] - self.prev_desire > .99, inputs['desire_pulse'], 0)
self.prev_desire[:] = inputs['desire_pulse']
self.inputs['traffic_convention'][:] = inputs['traffic_convention']
self.inputs['nav_features'][:] = inputs['nav_features']
self.inputs['nav_instructions'][:] = inputs['nav_instructions']
# self.inputs['driving_style'][:] = inputs['driving_style']
# if getCLBuffer is not None, frame will be None
self.model.setInputBuffer("input_imgs", self.frame.prepare(buf, transform.flatten(), self.model.getCLBuffer("input_imgs")))
if wbuf is not None:
self.model.setInputBuffer("big_input_imgs", self.wide_frame.prepare(wbuf, transform_wide.flatten(), self.model.getCLBuffer("big_input_imgs")))
if prepare_only:
return None
self.model.execute()
self.inputs['feature_buffer'][:-FEATURE_LEN] = self.inputs['feature_buffer'][FEATURE_LEN:]
self.inputs['feature_buffer'][-FEATURE_LEN:] = self.output[OUTPUT_SIZE:OUTPUT_SIZE+FEATURE_LEN]
return self.output
def main():
cloudlog.bind(daemon="selfdrive.modeld.modeld")
setproctitle("selfdrive.modeld.modeld")
if not PC:
config_realtime_process(7, 54)
cl_context = CLContext()
model = ModelState(cl_context)
cloudlog.warning("models loaded, modeld starting")
# visionipc clients
while True:
available_streams = VisionIpcClient.available_streams("camerad", block=False)
if available_streams:
use_extra_client = VisionStreamType.VISION_STREAM_WIDE_ROAD in available_streams and VisionStreamType.VISION_STREAM_ROAD in available_streams
main_wide_camera = VisionStreamType.VISION_STREAM_ROAD not in available_streams
break
time.sleep(.1)
vipc_client_main_stream = VisionStreamType.VISION_STREAM_WIDE_ROAD if main_wide_camera else VisionStreamType.VISION_STREAM_ROAD
vipc_client_main = VisionIpcClient("camerad", vipc_client_main_stream, True, cl_context)
vipc_client_extra = VisionIpcClient("camerad", VisionStreamType.VISION_STREAM_WIDE_ROAD, False, cl_context)
cloudlog.warning(f"vision stream set up, main_wide_camera: {main_wide_camera}, use_extra_client: {use_extra_client}")
while not vipc_client_main.connect(False):
time.sleep(0.1)
while not vipc_client_extra.connect(False):
time.sleep(0.1)
cloudlog.warning(f"connected main cam with buffer size: {vipc_client_main.buffer_len} ({vipc_client_main.width} x {vipc_client_main.height})")
if use_extra_client:
cloudlog.warning(f"connected extra cam with buffer size: {vipc_client_extra.buffer_len} ({vipc_client_extra.width} x {vipc_client_extra.height})")
# messaging
pm = PubMaster(["modelV2", "cameraOdometry"])
sm = SubMaster(["lateralPlan", "roadCameraState", "liveCalibration", "driverMonitoringState", "navModel", "navInstruction"])
state = PublishState()
params = Params()
# setup filter to track dropped frames
frame_dropped_filter = FirstOrderFilter(0., 10., 1. / MODEL_FREQ)
frame_id = 0
last_vipc_frame_id = 0
run_count = 0
# last = 0.0
model_transform_main = np.zeros((3, 3), dtype=np.float32)
model_transform_extra = np.zeros((3, 3), dtype=np.float32)
live_calib_seen = False
driving_style = np.array([1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0], dtype=np.float32)
nav_features = np.zeros(NAV_FEATURE_LEN, dtype=np.float32)
nav_instructions = np.zeros(NAV_INSTRUCTION_LEN, dtype=np.float32)
buf_main, buf_extra = None, None
meta_main = FrameMeta()
meta_extra = FrameMeta()
while True:
# Keep receiving frames until we are at least 1 frame ahead of previous extra frame
while meta_main.timestamp_sof < meta_extra.timestamp_sof + 25000000:
buf_main = vipc_client_main.recv()
meta_main = FrameMeta(vipc_client_main)
if buf_main is None:
break
if buf_main is None:
cloudlog.error("vipc_client_main no frame")
continue
if use_extra_client:
# Keep receiving extra frames until frame id matches main camera
while True:
buf_extra = vipc_client_extra.recv()
meta_extra = FrameMeta(vipc_client_extra)
if buf_extra is None or meta_main.timestamp_sof < meta_extra.timestamp_sof + 25000000:
break
if buf_extra is None:
cloudlog.error("vipc_client_extra no frame")
continue
if abs(meta_main.timestamp_sof - meta_extra.timestamp_sof) > 10000000:
cloudlog.error("frames out of sync! main: {} ({:.5f}), extra: {} ({:.5f})".format(
meta_main.frame_id, meta_main.timestamp_sof / 1e9,
meta_extra.frame_id, meta_extra.timestamp_sof / 1e9))
else:
# Use single camera
buf_extra = buf_main
meta_extra = meta_main
# TODO: path planner timeout?
sm.update(0)
desire = sm["lateralPlan"].desire.raw
is_rhd = sm["driverMonitoringState"].isRHD
frame_id = sm["roadCameraState"].frameId
if sm.updated["liveCalibration"]:
device_from_calib_euler = np.array(sm["liveCalibration"].rpyCalib, dtype=np.float32)
model_transform_main = update_calibration(device_from_calib_euler, main_wide_camera, False)
model_transform_extra = update_calibration(device_from_calib_euler, True, True)
live_calib_seen = True
traffic_convention = np.zeros(2)
traffic_convention[int(is_rhd)] = 1
vec_desire = np.zeros(DESIRE_LEN, dtype=np.float32)
if desire >= 0 and desire < DESIRE_LEN:
vec_desire[desire] = 1
# Enable/disable nav features
timestamp_llk = sm["navModel"].locationMonoTime
nav_valid = sm.valid["navModel"] # and (nanos_since_boot() - timestamp_llk < 1e9)
nav_enabled = nav_valid and params.get_bool("ExperimentalMode")
if not nav_enabled:
nav_features[:] = 0
nav_instructions[:] = 0
if nav_enabled and sm.updated["navModel"]:
nav_features = np.array(sm["navModel"].features)
if nav_enabled and sm.updated["navInstruction"]:
nav_instructions[:] = 0
for maneuver in sm["navInstruction"].allManeuvers:
distance_idx = 25 + int(maneuver.distance / 20)
direction_idx = 0
if maneuver.modifier in ("left", "slight left", "sharp left"):
direction_idx = 1
if maneuver.modifier in ("right", "slight right", "sharp right"):
direction_idx = 2
if 0 <= distance_idx < 50:
nav_instructions[distance_idx*3 + direction_idx] = 1
# tracked dropped frames
vipc_dropped_frames = max(0, meta_main.frame_id - last_vipc_frame_id - 1)
frames_dropped = frame_dropped_filter.update(min(vipc_dropped_frames, 10))
if run_count < 10: # let frame drops warm up
frame_dropped_filter.x = 0.
frames_dropped = 0.
run_count = run_count + 1
frame_drop_ratio = frames_dropped / (1 + frames_dropped)
prepare_only = vipc_dropped_frames > 0
if prepare_only:
cloudlog.error(f"skipping model eval. Dropped {vipc_dropped_frames} frames")
inputs:Dict[str, np.ndarray] = {
'desire_pulse': vec_desire,
'traffic_convention': traffic_convention,
'driving_style': driving_style,
'nav_features': nav_features,
'nav_instructions': nav_instructions}
mt1 = time.perf_counter()
model_output = model.run(buf_main, buf_extra, model_transform_main, model_transform_extra, inputs, prepare_only)
mt2 = time.perf_counter()
model_execution_time = mt2 - mt1
if model_output is not None:
pm.send("modelV2", create_model_msg(model_output, state, meta_main.frame_id, meta_extra.frame_id, frame_id, frame_drop_ratio,
meta_main.timestamp_eof, timestamp_llk, model_execution_time, nav_enabled, live_calib_seen))
pm.send("cameraOdometry", create_pose_msg(model_output, meta_main.frame_id, vipc_dropped_frames, meta_main.timestamp_eof, live_calib_seen))
# print("model process: %.2fms, from last %.2fms, vipc_frame_id %u, frame_id, %u, frame_drop %.3f" %
# ((mt2 - mt1)*1000, (mt1 - last)*1000, meta_extra.frame_id, frame_id, frame_drop_ratio))
# last = mt1
last_vipc_frame_id = meta_main.frame_id
if __name__ == "__main__":
try:
main()
except KeyboardInterrupt:
sys.exit()

@ -12,112 +12,47 @@
#include "common/params.h" #include "common/params.h"
#include "common/timing.h" #include "common/timing.h"
#include "common/swaglog.h" #include "common/swaglog.h"
#include "common/transformations/orientation.hpp"
// #define DUMP_YUV
mat3 update_calibration(float *device_from_calib_euler, bool wide_camera, bool bigmodel_frame) {
void model_init(ModelState* s, cl_device_id device_id, cl_context context) { /*
s->frame = new ModelFrame(device_id, context); import numpy as np
s->wide_frame = new ModelFrame(device_id, context); from common.transformations.model import medmodel_frame_from_calib_frame
medmodel_frame_from_calib_frame = medmodel_frame_from_calib_frame[:, :3]
#ifdef USE_THNEED calib_from_smedmodel_frame = np.linalg.inv(medmodel_frame_from_calib_frame)
s->m = std::make_unique<ThneedModel>("models/supercombo.thneed", */
#elif USE_ONNX_MODEL static const auto calib_from_medmodel = (Eigen::Matrix<float, 3, 3>() <<
s->m = std::make_unique<ONNXModel>("models/supercombo.onnx", 0.00000000e+00, 0.00000000e+00, 1.00000000e+00,
#else 1.09890110e-03, 0.00000000e+00, -2.81318681e-01,
s->m = std::make_unique<SNPEModel>("models/supercombo.dlc", -2.25466395e-20, 1.09890110e-03, -5.23076923e-02).finished();
#endif
&s->output[0], NET_OUTPUT_SIZE, USE_GPU_RUNTIME, false, context); static const auto calib_from_sbigmodel = (Eigen::Matrix<float, 3, 3>() <<
0.00000000e+00, 7.31372216e-19, 1.00000000e+00,
s->m->addInput("input_imgs", NULL, 0); 2.19780220e-03, 4.11497335e-19, -5.62637363e-01,
s->m->addInput("big_input_imgs", NULL, 0); -6.66298828e-20, 2.19780220e-03, -3.33626374e-01).finished();
// TODO: the input is important here, still need to fix this static const auto view_from_device = (Eigen::Matrix<float, 3, 3>() <<
#ifdef DESIRE 0.0, 1.0, 0.0,
s->m->addInput("desire_pulse", s->pulse_desire, DESIRE_LEN*(HISTORY_BUFFER_LEN+1)); 0.0, 0.0, 1.0,
#endif 1.0, 0.0, 0.0).finished();
#ifdef TRAFFIC_CONVENTION Eigen::Vector3d device_from_calib_euler_vec;
s->m->addInput("traffic_convention", s->traffic_convention, TRAFFIC_CONVENTION_LEN); for (int i=0; i<3; i++) {
#endif device_from_calib_euler_vec(i) = device_from_calib_euler[i];
#ifdef DRIVING_STYLE
s->m->addInput("driving_style", s->driving_style, DRIVING_STYLE_LEN);
#endif
#ifdef NAV
s->m->addInput("nav_features", s->nav_features, NAV_FEATURE_LEN);
s->m->addInput("nav_instructions", s->nav_instructions, NAV_INSTRUCTION_LEN);
#endif
#ifdef TEMPORAL
s->m->addInput("feature_buffer", &s->feature_buffer[0], TEMPORAL_SIZE);
#endif
}
ModelOutput* model_eval_frame(ModelState* s, VisionBuf* buf, VisionBuf* wbuf, const mat3 &transform, const mat3 &transform_wide,
float *desire_in, bool is_rhd, float *driving_style, float *nav_features, float *nav_instructions, bool prepare_only) {
#ifdef DESIRE
std::memmove(&s->pulse_desire[0], &s->pulse_desire[DESIRE_LEN], sizeof(float) * DESIRE_LEN*HISTORY_BUFFER_LEN);
if (desire_in != NULL) {
for (int i = 1; i < DESIRE_LEN; i++) {
// Model decides when action is completed
// so desire input is just a pulse triggered on rising edge
if (desire_in[i] - s->prev_desire[i] > .99) {
s->pulse_desire[DESIRE_LEN*HISTORY_BUFFER_LEN+i] = desire_in[i];
} else {
s->pulse_desire[DESIRE_LEN*HISTORY_BUFFER_LEN+i] = 0.0;
}
s->prev_desire[i] = desire_in[i];
}
}
LOGT("Desire enqueued");
#endif
#ifdef NAV
std::memcpy(s->nav_features, nav_features, sizeof(float)*NAV_FEATURE_LEN);
std::memcpy(s->nav_instructions, nav_instructions, sizeof(float)*NAV_INSTRUCTION_LEN);
#endif
#ifdef DRIVING_STYLE
std::memcpy(s->driving_style, driving_style, sizeof(float)*DRIVING_STYLE_LEN);
#endif
int rhd_idx = is_rhd;
s->traffic_convention[rhd_idx] = 1.0;
s->traffic_convention[1-rhd_idx] = 0.0;
// if getInputBuf is not NULL, net_input_buf will be
auto net_input_buf = s->frame->prepare(buf->buf_cl, buf->width, buf->height, buf->stride, buf->uv_offset, transform, static_cast<cl_mem*>(s->m->getCLBuffer("input_imgs")));
s->m->setInputBuffer("input_imgs", net_input_buf, s->frame->buf_size);
LOGT("Image added");
if (wbuf != nullptr) {
auto net_extra_buf = s->wide_frame->prepare(wbuf->buf_cl, wbuf->width, wbuf->height, wbuf->stride, wbuf->uv_offset, transform_wide, static_cast<cl_mem*>(s->m->getCLBuffer("big_input_imgs")));
s->m->setInputBuffer("big_input_imgs", net_extra_buf, s->wide_frame->buf_size);
LOGT("Extra image added");
}
if (prepare_only) {
return nullptr;
} }
s->m->execute(); const auto cam_intrinsics = Eigen::Matrix<float, 3, 3, Eigen::RowMajor>(wide_camera ? ECAM_INTRINSIC_MATRIX.v : FCAM_INTRINSIC_MATRIX.v);
LOGT("Execution finished"); Eigen::Matrix<float, 3, 3, Eigen::RowMajor> device_from_calib = euler2rot(device_from_calib_euler_vec).cast <float> ();
auto calib_from_model = bigmodel_frame ? calib_from_sbigmodel : calib_from_medmodel;
#ifdef TEMPORAL auto camera_from_calib = cam_intrinsics * view_from_device * device_from_calib;
std::memmove(&s->feature_buffer[0], &s->feature_buffer[FEATURE_LEN], sizeof(float) * FEATURE_LEN*(HISTORY_BUFFER_LEN-1)); auto warp_matrix = camera_from_calib * calib_from_model;
std::memcpy(&s->feature_buffer[FEATURE_LEN*(HISTORY_BUFFER_LEN-1)], &s->output[OUTPUT_SIZE], sizeof(float) * FEATURE_LEN);
LOGT("Features enqueued");
#endif
return (ModelOutput*)&s->output; mat3 transform = {};
} for (int i=0; i<3*3; i++) {
transform.v[i] = warp_matrix(i / 3, i % 3);
void model_free(ModelState* s) { }
delete s->frame; return transform;
delete s->wide_frame;
} }
void fill_lead(cereal::ModelDataV2::LeadDataV3::Builder lead, const ModelOutputLeads &leads, int t_idx, float prob_t) { void fill_lead(cereal::ModelDataV2::LeadDataV3::Builder lead, const ModelOutputLeads &leads, int t_idx, float prob_t) {
@ -403,11 +338,9 @@ void fill_model(cereal::ModelDataV2::Builder &framed, const ModelOutput &net_out
temporal_pose.setRotStd({exp(r_std.x), exp(r_std.y), exp(r_std.z)}); temporal_pose.setRotStd({exp(r_std.x), exp(r_std.y), exp(r_std.z)});
} }
void model_publish(PubMaster &pm, uint32_t vipc_frame_id, uint32_t vipc_frame_id_extra, uint32_t frame_id, float frame_drop, void fill_model_msg(MessageBuilder &msg, float *net_output_data, PublishState &ps, uint32_t vipc_frame_id, uint32_t vipc_frame_id_extra, uint32_t frame_id, float frame_drop,
const ModelOutput &net_outputs, ModelState &s, PublishState &ps, uint64_t timestamp_eof, uint64_t timestamp_llk, uint64_t timestamp_eof, uint64_t timestamp_llk, float model_execution_time, const bool nav_enabled, const bool valid) {
float model_execution_time, const bool nav_enabled, const bool valid) {
const uint32_t frame_age = (frame_id > vipc_frame_id) ? (frame_id - vipc_frame_id) : 0; const uint32_t frame_age = (frame_id > vipc_frame_id) ? (frame_id - vipc_frame_id) : 0;
MessageBuilder msg;
auto framed = msg.initEvent(valid).initModelV2(); auto framed = msg.initEvent(valid).initModelV2();
framed.setFrameId(vipc_frame_id); framed.setFrameId(vipc_frame_id);
framed.setFrameIdExtra(vipc_frame_id_extra); framed.setFrameIdExtra(vipc_frame_id_extra);
@ -418,36 +351,32 @@ void model_publish(PubMaster &pm, uint32_t vipc_frame_id, uint32_t vipc_frame_id
framed.setModelExecutionTime(model_execution_time); framed.setModelExecutionTime(model_execution_time);
framed.setNavEnabled(nav_enabled); framed.setNavEnabled(nav_enabled);
if (send_raw_pred) { if (send_raw_pred) {
framed.setRawPredictions((kj::ArrayPtr<const float>(s.output.data(), s.output.size())).asBytes()); framed.setRawPredictions(kj::ArrayPtr<const float>(net_output_data, NET_OUTPUT_SIZE).asBytes());
} }
fill_model(framed, net_outputs, ps); fill_model(framed, *((ModelOutput*) net_output_data), ps);
pm.send("modelV2", msg);
} }
void posenet_publish(PubMaster &pm, uint32_t vipc_frame_id, uint32_t vipc_dropped_frames, void fill_pose_msg(MessageBuilder &msg, float *net_output_data, uint32_t vipc_frame_id, uint32_t vipc_dropped_frames, uint64_t timestamp_eof, const bool valid) {
const ModelOutput &net_outputs, uint64_t timestamp_eof, const bool valid) { const ModelOutput &net_outputs = *((ModelOutput*) net_output_data);
MessageBuilder msg; const auto &v_mean = net_outputs.pose.velocity_mean;
const auto &v_mean = net_outputs.pose.velocity_mean; const auto &r_mean = net_outputs.pose.rotation_mean;
const auto &r_mean = net_outputs.pose.rotation_mean; const auto &t_mean = net_outputs.wide_from_device_euler.mean;
const auto &t_mean = net_outputs.wide_from_device_euler.mean; const auto &v_std = net_outputs.pose.velocity_std;
const auto &v_std = net_outputs.pose.velocity_std; const auto &r_std = net_outputs.pose.rotation_std;
const auto &r_std = net_outputs.pose.rotation_std; const auto &t_std = net_outputs.wide_from_device_euler.std;
const auto &t_std = net_outputs.wide_from_device_euler.std; const auto &road_transform_trans_mean = net_outputs.road_transform.position_mean;
const auto &road_transform_trans_mean = net_outputs.road_transform.position_mean; const auto &road_transform_trans_std = net_outputs.road_transform.position_std;
const auto &road_transform_trans_std = net_outputs.road_transform.position_std;
auto posenetd = msg.initEvent(valid && (vipc_dropped_frames < 1)).initCameraOdometry();
auto posenetd = msg.initEvent(valid && (vipc_dropped_frames < 1)).initCameraOdometry(); posenetd.setTrans({v_mean.x, v_mean.y, v_mean.z});
posenetd.setTrans({v_mean.x, v_mean.y, v_mean.z}); posenetd.setRot({r_mean.x, r_mean.y, r_mean.z});
posenetd.setRot({r_mean.x, r_mean.y, r_mean.z}); posenetd.setWideFromDeviceEuler({t_mean.x, t_mean.y, t_mean.z});
posenetd.setWideFromDeviceEuler({t_mean.x, t_mean.y, t_mean.z}); posenetd.setRoadTransformTrans({road_transform_trans_mean.x, road_transform_trans_mean.y, road_transform_trans_mean.z});
posenetd.setRoadTransformTrans({road_transform_trans_mean.x, road_transform_trans_mean.y, road_transform_trans_mean.z}); posenetd.setTransStd({exp(v_std.x), exp(v_std.y), exp(v_std.z)});
posenetd.setTransStd({exp(v_std.x), exp(v_std.y), exp(v_std.z)}); posenetd.setRotStd({exp(r_std.x), exp(r_std.y), exp(r_std.z)});
posenetd.setRotStd({exp(r_std.x), exp(r_std.y), exp(r_std.z)}); posenetd.setWideFromDeviceEulerStd({exp(t_std.x), exp(t_std.y), exp(t_std.z)});
posenetd.setWideFromDeviceEulerStd({exp(t_std.x), exp(t_std.y), exp(t_std.z)}); posenetd.setRoadTransformTransStd({exp(road_transform_trans_std.x), exp(road_transform_trans_std.y), exp(road_transform_trans_std.z)});
posenetd.setRoadTransformTransStd({exp(road_transform_trans_std.x), exp(road_transform_trans_std.y), exp(road_transform_trans_std.z)});
posenetd.setTimestampEof(timestamp_eof);
posenetd.setTimestampEof(timestamp_eof); posenetd.setFrameId(vipc_frame_id);
posenetd.setFrameId(vipc_frame_id);
pm.send("cameraOdometry", msg);
} }

@ -4,19 +4,16 @@
#include <memory> #include <memory>
#include "cereal/messaging/messaging.h" #include "cereal/messaging/messaging.h"
#include "cereal/visionipc/visionipc_client.h"
#include "common/mat.h" #include "common/mat.h"
#include "common/modeldata.h" #include "common/modeldata.h"
#include "common/util.h" #include "common/util.h"
#include "selfdrive/modeld/models/commonmodel.h"
#include "selfdrive/modeld/models/nav.h" #include "selfdrive/modeld/models/nav.h"
#include "selfdrive/modeld/runners/run.h"
// gate this here #ifdef USE_THNEED
#define TEMPORAL constexpr bool CPP_USE_THNEED = true;
#define DESIRE #else
#define TRAFFIC_CONVENTION constexpr bool CPP_USE_THNEED = false;
#define NAV #endif
constexpr int FEATURE_LEN = 128; constexpr int FEATURE_LEN = 128;
constexpr int HISTORY_BUFFER_LEN = 99; constexpr int HISTORY_BUFFER_LEN = 99;
@ -31,10 +28,8 @@ constexpr int BLINKER_LEN = 6;
constexpr int META_STRIDE = 7; constexpr int META_STRIDE = 7;
constexpr int PLAN_MHP_N = 5; constexpr int PLAN_MHP_N = 5;
constexpr int LEAD_MHP_N = 2; constexpr int LEAD_MHP_N = 2;
constexpr int LEAD_TRAJ_LEN = 6; constexpr int LEAD_TRAJ_LEN = 6;
constexpr int LEAD_PRED_DIM = 4;
constexpr int LEAD_MHP_SELECTION = 3; constexpr int LEAD_MHP_SELECTION = 3;
// Padding to get output shape as multiple of 4 // Padding to get output shape as multiple of 4
constexpr int PAD_SIZE = 2; constexpr int PAD_SIZE = 2;
@ -253,49 +248,15 @@ struct ModelOutput {
}; };
constexpr int OUTPUT_SIZE = sizeof(ModelOutput) / sizeof(float); constexpr int OUTPUT_SIZE = sizeof(ModelOutput) / sizeof(float);
#ifdef TEMPORAL
constexpr int TEMPORAL_SIZE = HISTORY_BUFFER_LEN * FEATURE_LEN;
#else
constexpr int TEMPORAL_SIZE = 0;
#endif
constexpr int NET_OUTPUT_SIZE = OUTPUT_SIZE + FEATURE_LEN + PAD_SIZE; constexpr int NET_OUTPUT_SIZE = OUTPUT_SIZE + FEATURE_LEN + PAD_SIZE;
// TODO: convert remaining arrays to std::array and update model runners
struct ModelState {
ModelFrame *frame = nullptr;
ModelFrame *wide_frame = nullptr;
std::array<float, HISTORY_BUFFER_LEN * FEATURE_LEN> feature_buffer = {};
std::array<float, NET_OUTPUT_SIZE> output = {};
std::unique_ptr<RunModel> m;
#ifdef DESIRE
float prev_desire[DESIRE_LEN] = {};
float pulse_desire[DESIRE_LEN*(HISTORY_BUFFER_LEN+1)] = {};
#endif
#ifdef TRAFFIC_CONVENTION
float traffic_convention[TRAFFIC_CONVENTION_LEN] = {};
#endif
#ifdef DRIVING_STYLE
float driving_style[DRIVING_STYLE_LEN] = {};
#endif
#ifdef NAV
float nav_features[NAV_FEATURE_LEN] = {};
float nav_instructions[NAV_INSTRUCTION_LEN] = {};
#endif
};
struct PublishState { struct PublishState {
std::array<float, DISENGAGE_LEN * DISENGAGE_LEN> disengage_buffer = {}; std::array<float, DISENGAGE_LEN * DISENGAGE_LEN> disengage_buffer = {};
std::array<float, 5> prev_brake_5ms2_probs = {}; std::array<float, 5> prev_brake_5ms2_probs = {};
std::array<float, 3> prev_brake_3ms2_probs = {}; std::array<float, 3> prev_brake_3ms2_probs = {};
}; };
void model_init(ModelState* s, cl_device_id device_id, cl_context context); mat3 update_calibration(float *device_from_calib_euler, bool wide_camera, bool bigmodel_frame);
ModelOutput *model_eval_frame(ModelState* s, VisionBuf* buf, VisionBuf* buf_wide, const mat3 &transform, const mat3 &transform_wide, void fill_model_msg(MessageBuilder &msg, float *net_output_data, PublishState &ps, uint32_t vipc_frame_id, uint32_t vipc_frame_id_extra, uint32_t frame_id, float frame_drop,
float *desire_in, bool is_rhd, float *driving_style, float *nav_features, float *nav_instructions, bool prepare_only); uint64_t timestamp_eof, uint64_t timestamp_llk, float model_execution_time, const bool nav_enabled, const bool valid);
void model_free(ModelState* s); void fill_pose_msg(MessageBuilder &msg, float *net_outputs, uint32_t vipc_frame_id, uint32_t vipc_dropped_frames, uint64_t timestamp_eof, const bool valid);
void model_publish(PubMaster &pm, uint32_t vipc_frame_id, uint32_t vipc_frame_id_extra, uint32_t frame_id, float frame_drop,
const ModelOutput &net_outputs, ModelState &s, PublishState &ps, uint64_t timestamp_eof, uint64_t timestamp_llk,
float model_execution_time, const bool nav_enabled, const bool valid);
void posenet_publish(PubMaster &pm, uint32_t vipc_frame_id, uint32_t vipc_dropped_frames,
const ModelOutput &net_outputs, uint64_t timestamp_eof, const bool valid);

@ -0,0 +1,29 @@
# distutils: language = c++
from libcpp cimport bool
from libc.stdint cimport uint32_t, uint64_t
from .commonmodel cimport mat3
cdef extern from "cereal/messaging/messaging.h":
cdef cppclass MessageBuilder:
size_t getSerializedSize()
int serializeToBuffer(unsigned char *, size_t)
cdef extern from "selfdrive/modeld/models/driving.h":
cdef int FEATURE_LEN
cdef int HISTORY_BUFFER_LEN
cdef int DESIRE_LEN
cdef int TRAFFIC_CONVENTION_LEN
cdef int DRIVING_STYLE_LEN
cdef int NAV_FEATURE_LEN
cdef int NAV_INSTRUCTION_LEN
cdef int OUTPUT_SIZE
cdef int NET_OUTPUT_SIZE
cdef int MODEL_FREQ
cdef bool CPP_USE_THNEED
cdef struct PublishState: pass
mat3 update_calibration(float *, bool, bool)
void fill_model_msg(MessageBuilder, float *, PublishState, uint32_t, uint32_t, uint32_t, float, uint64_t, uint64_t, float, bool, bool)
void fill_pose_msg(MessageBuilder, float *, uint32_t, uint32_t, uint64_t, bool)

@ -0,0 +1,60 @@
# distutils: language = c++
# cython: c_string_encoding=ascii
import numpy as np
cimport numpy as cnp
from libcpp cimport bool
from libc.string cimport memcpy
from libc.stdint cimport uint32_t, uint64_t
from .commonmodel cimport mat3
from .driving cimport FEATURE_LEN as CPP_FEATURE_LEN, HISTORY_BUFFER_LEN as CPP_HISTORY_BUFFER_LEN, DESIRE_LEN as CPP_DESIRE_LEN, \
TRAFFIC_CONVENTION_LEN as CPP_TRAFFIC_CONVENTION_LEN, DRIVING_STYLE_LEN as CPP_DRIVING_STYLE_LEN, \
NAV_FEATURE_LEN as CPP_NAV_FEATURE_LEN, NAV_INSTRUCTION_LEN as CPP_NAV_INSTRUCTION_LEN, \
OUTPUT_SIZE as CPP_OUTPUT_SIZE, NET_OUTPUT_SIZE as CPP_NET_OUTPUT_SIZE, MODEL_FREQ as CPP_MODEL_FREQ, CPP_USE_THNEED
from .driving cimport MessageBuilder, PublishState as cppPublishState
from .driving cimport fill_model_msg, fill_pose_msg, update_calibration as cpp_update_calibration
FEATURE_LEN = CPP_FEATURE_LEN
HISTORY_BUFFER_LEN = CPP_HISTORY_BUFFER_LEN
DESIRE_LEN = CPP_DESIRE_LEN
TRAFFIC_CONVENTION_LEN = CPP_TRAFFIC_CONVENTION_LEN
DRIVING_STYLE_LEN = CPP_DRIVING_STYLE_LEN
NAV_FEATURE_LEN = CPP_NAV_FEATURE_LEN
NAV_INSTRUCTION_LEN = CPP_NAV_INSTRUCTION_LEN
OUTPUT_SIZE = CPP_OUTPUT_SIZE
NET_OUTPUT_SIZE = CPP_NET_OUTPUT_SIZE
MODEL_FREQ = CPP_MODEL_FREQ
USE_THNEED = CPP_USE_THNEED
cdef class PublishState:
cdef cppPublishState state
def update_calibration(float[:] device_from_calib_euler, bool wide_camera, bool bigmodel_frame):
cdef mat3 result = cpp_update_calibration(&device_from_calib_euler[0], wide_camera, bigmodel_frame)
np_result = np.empty(9, dtype=np.float32)
cdef float[:] np_result_view = np_result
memcpy(&np_result_view[0], &result.v[0], 9*sizeof(float))
return np_result.reshape(3, 3)
def create_model_msg(float[:] model_outputs, PublishState ps, uint32_t vipc_frame_id, uint32_t vipc_frame_id_extra, uint32_t frame_id, float frame_drop,
uint64_t timestamp_eof, uint64_t timestamp_llk, float model_execution_time, bool nav_enabled, bool valid):
cdef MessageBuilder msg
fill_model_msg(msg, &model_outputs[0], ps.state, vipc_frame_id, vipc_frame_id_extra, frame_id, frame_drop,
timestamp_eof, timestamp_llk, model_execution_time, nav_enabled, valid)
output_size = msg.getSerializedSize()
output_data = bytearray(output_size)
cdef unsigned char * output_ptr = output_data
assert msg.serializeToBuffer(output_ptr, output_size) > 0, "output buffer is too small to serialize"
return bytes(output_data)
def create_pose_msg(float[:] model_outputs, uint32_t vipc_frame_id, uint32_t vipc_dropped_frames, uint64_t timestamp_eof, bool valid):
cdef MessageBuilder msg
fill_pose_msg(msg, &model_outputs[0], vipc_frame_id, vipc_dropped_frames, timestamp_eof, valid)
output_size = msg.getSerializedSize()
output_data = bytearray(output_size)
cdef unsigned char * output_ptr = output_data
assert msg.serializeToBuffer(output_ptr, output_size) > 0, "output buffer is too small to serialize"
return bytes(output_data)

@ -36,7 +36,7 @@ PROCS = {
"selfdrive.locationd.paramsd": 9.0, "selfdrive.locationd.paramsd": 9.0,
"./_sensord": 7.0, "./_sensord": 7.0,
"selfdrive.controls.radard": 4.5, "selfdrive.controls.radard": 4.5,
"./_modeld": 4.48, "selfdrive.modeld.modeld": 8.0,
"./_dmonitoringmodeld": 5.0, "./_dmonitoringmodeld": 5.0,
"./_navmodeld": 1.0, "./_navmodeld": 1.0,
"selfdrive.thermald.thermald": 3.87, "selfdrive.thermald.thermald": 3.87,

Loading…
Cancel
Save