modeld: parsing and publishing in python (#30273)
* WIP try modeld all in python
* fix plan
* add lane lines stds
* fix lane lines prob
* add lead prob
* add meta
* simplify plan parsing
* add hard brake pred
* add confidence
* fix desire state and desire pred
* check this file for now
* rm prints
* rm debug
* add todos
* add plan_t_idxs
* same as cpp
* removed cython
* add wfd width - rm cpp code
* add new files rm old files
* get metadata at compile time
* forgot this file
* now uses more CPU
* not used
* update readme
* lint
* copy this too
* simplify disengage probs
* update model replay ref commit
* update again
* confidence: remove if statemens
* use publish_state.enqueue
* Revert "use publish_state.enqueue"
This reverts commit d8807c8348
.
* confidence: better shape defs
* use ModelConstants class
* fix confidence
* Parser
* slightly more power too
* no inline ifs :(
* confidence: just use if statements
pull/30033/head^2
parent
09c8866d17
commit
cad17b1255
25 changed files with 478 additions and 768 deletions
@ -1,7 +1,78 @@ |
|||||||
IDX_N = 33 |
import numpy as np |
||||||
|
|
||||||
def index_function(idx, max_val=192, max_idx=32): |
def index_function(idx, max_val=192, max_idx=32): |
||||||
return (max_val) * ((idx/max_idx)**2) |
return (max_val) * ((idx/max_idx)**2) |
||||||
|
|
||||||
|
class ModelConstants: |
||||||
|
# time and distance indices |
||||||
|
IDX_N = 33 |
||||||
|
T_IDXS = [index_function(idx, max_val=10.0) for idx in range(IDX_N)] |
||||||
|
X_IDXS = [index_function(idx, max_val=192.0) for idx in range(IDX_N)] |
||||||
|
LEAD_T_IDXS = [0., 2., 4., 6., 8., 10.] |
||||||
|
LEAD_T_OFFSETS = [0., 2., 4.] |
||||||
|
META_T_IDXS = [2., 4., 6., 8., 10.] |
||||||
|
|
||||||
T_IDXS = [index_function(idx, max_val=10.0) for idx in range(IDX_N)] |
# model inputs constants |
||||||
|
MODEL_FREQ = 20 |
||||||
|
FEATURE_LEN = 512 |
||||||
|
HISTORY_BUFFER_LEN = 99 |
||||||
|
DESIRE_LEN = 8 |
||||||
|
TRAFFIC_CONVENTION_LEN = 2 |
||||||
|
NAV_FEATURE_LEN = 256 |
||||||
|
NAV_INSTRUCTION_LEN = 150 |
||||||
|
DRIVING_STYLE_LEN = 12 |
||||||
|
|
||||||
|
# model outputs constants |
||||||
|
FCW_THRESHOLDS_5MS2 = np.array([.05, .05, .15, .15, .15], dtype=np.float32) |
||||||
|
FCW_THRESHOLDS_3MS2 = np.array([.7, .7], dtype=np.float32) |
||||||
|
|
||||||
|
DISENGAGE_WIDTH = 5 |
||||||
|
POSE_WIDTH = 6 |
||||||
|
WIDE_FROM_DEVICE_WIDTH = 3 |
||||||
|
SIM_POSE_WIDTH = 6 |
||||||
|
LEAD_WIDTH = 4 |
||||||
|
LANE_LINES_WIDTH = 2 |
||||||
|
ROAD_EDGES_WIDTH = 2 |
||||||
|
PLAN_WIDTH = 15 |
||||||
|
DESIRE_PRED_WIDTH = 8 |
||||||
|
|
||||||
|
NUM_LANE_LINES = 4 |
||||||
|
NUM_ROAD_EDGES = 2 |
||||||
|
|
||||||
|
LEAD_TRAJ_LEN = 6 |
||||||
|
DESIRE_PRED_LEN = 4 |
||||||
|
|
||||||
|
PLAN_MHP_N = 5 |
||||||
|
LEAD_MHP_N = 2 |
||||||
|
PLAN_MHP_SELECTION = 1 |
||||||
|
LEAD_MHP_SELECTION = 3 |
||||||
|
|
||||||
|
FCW_THRESHOLD_5MS2_HIGH = 0.15 |
||||||
|
FCW_THRESHOLD_5MS2_LOW = 0.05 |
||||||
|
FCW_THRESHOLD_3MS2 = 0.7 |
||||||
|
|
||||||
|
CONFIDENCE_BUFFER_LEN = 5 |
||||||
|
RYG_GREEN = 0.01165 |
||||||
|
RYG_YELLOW = 0.06157 |
||||||
|
|
||||||
|
# model outputs slices |
||||||
|
class Plan: |
||||||
|
POSITION = slice(0, 3) |
||||||
|
VELOCITY = slice(3, 6) |
||||||
|
ACCELERATION = slice(6, 9) |
||||||
|
T_FROM_CURRENT_EULER = slice(9, 12) |
||||||
|
ORIENTATION_RATE = slice(12, 15) |
||||||
|
|
||||||
|
class Meta: |
||||||
|
ENGAGED = slice(0, 1) |
||||||
|
# next 2, 4, 6, 8, 10 seconds |
||||||
|
GAS_DISENGAGE = slice(1, 36, 7) |
||||||
|
BRAKE_DISENGAGE = slice(2, 36, 7) |
||||||
|
STEER_OVERRIDE = slice(3, 36, 7) |
||||||
|
HARD_BRAKE_3 = slice(4, 36, 7) |
||||||
|
HARD_BRAKE_4 = slice(5, 36, 7) |
||||||
|
HARD_BRAKE_5 = slice(6, 36, 7) |
||||||
|
GAS_PRESS = slice(7, 36, 7) |
||||||
|
# next 0, 2, 4, 6, 8, 10 seconds |
||||||
|
LEFT_BLINKER = slice(36, 48, 2) |
||||||
|
RIGHT_BLINKER = slice(37, 48, 2) |
||||||
|
@ -0,0 +1,181 @@ |
|||||||
|
import capnp |
||||||
|
import numpy as np |
||||||
|
from typing import Dict |
||||||
|
from cereal import log |
||||||
|
from openpilot.selfdrive.modeld.constants import ModelConstants, Plan, Meta |
||||||
|
|
||||||
|
ConfidenceClass = log.ModelDataV2.ConfidenceClass |
||||||
|
|
||||||
|
class PublishState: |
||||||
|
def __init__(self): |
||||||
|
self.disengage_buffer = np.zeros(ModelConstants.CONFIDENCE_BUFFER_LEN*ModelConstants.DISENGAGE_WIDTH, dtype=np.float32) |
||||||
|
self.prev_brake_5ms2_probs = np.zeros(ModelConstants.DISENGAGE_WIDTH, dtype=np.float32) |
||||||
|
self.prev_brake_3ms2_probs = np.zeros(ModelConstants.DISENGAGE_WIDTH, dtype=np.float32) |
||||||
|
|
||||||
|
def fill_xyzt(builder, t, x, y, z, x_std=None, y_std=None, z_std=None): |
||||||
|
builder.t = t |
||||||
|
builder.x = x.tolist() |
||||||
|
builder.y = y.tolist() |
||||||
|
builder.z = z.tolist() |
||||||
|
if x_std is not None: |
||||||
|
builder.xStd = x_std.tolist() |
||||||
|
if y_std is not None: |
||||||
|
builder.yStd = y_std.tolist() |
||||||
|
if z_std is not None: |
||||||
|
builder.zStd = z_std.tolist() |
||||||
|
|
||||||
|
def fill_xyvat(builder, t, x, y, v, a, x_std=None, y_std=None, v_std=None, a_std=None): |
||||||
|
builder.t = t |
||||||
|
builder.x = x.tolist() |
||||||
|
builder.y = y.tolist() |
||||||
|
builder.v = v.tolist() |
||||||
|
builder.a = a.tolist() |
||||||
|
if x_std is not None: |
||||||
|
builder.xStd = x_std.tolist() |
||||||
|
if y_std is not None: |
||||||
|
builder.yStd = y_std.tolist() |
||||||
|
if v_std is not None: |
||||||
|
builder.vStd = v_std.tolist() |
||||||
|
if a_std is not None: |
||||||
|
builder.aStd = a_std.tolist() |
||||||
|
|
||||||
|
def fill_model_msg(msg: capnp._DynamicStructBuilder, net_output_data: Dict[str, np.ndarray], publish_state: PublishState, |
||||||
|
vipc_frame_id: int, vipc_frame_id_extra: int, frame_id: int, frame_drop: float, |
||||||
|
timestamp_eof: int, timestamp_llk: int, model_execution_time: float, |
||||||
|
nav_enabled: bool, valid: bool) -> None: |
||||||
|
frame_age = frame_id - vipc_frame_id if frame_id > vipc_frame_id else 0 |
||||||
|
msg.valid = valid |
||||||
|
|
||||||
|
modelV2 = msg.modelV2 |
||||||
|
modelV2.frameId = vipc_frame_id |
||||||
|
modelV2.frameIdExtra = vipc_frame_id_extra |
||||||
|
modelV2.frameAge = frame_age |
||||||
|
modelV2.frameDropPerc = frame_drop * 100 |
||||||
|
modelV2.timestampEof = timestamp_eof |
||||||
|
modelV2.locationMonoTime = timestamp_llk |
||||||
|
modelV2.modelExecutionTime = model_execution_time |
||||||
|
modelV2.navEnabled = nav_enabled |
||||||
|
|
||||||
|
# plan |
||||||
|
position = modelV2.position |
||||||
|
fill_xyzt(position, ModelConstants.T_IDXS, *net_output_data['plan'][0,:,Plan.POSITION].T, *net_output_data['plan_stds'][0,:,Plan.POSITION].T) |
||||||
|
velocity = modelV2.velocity |
||||||
|
fill_xyzt(velocity, ModelConstants.T_IDXS, *net_output_data['plan'][0,:,Plan.VELOCITY].T) |
||||||
|
acceleration = modelV2.acceleration |
||||||
|
fill_xyzt(acceleration, ModelConstants.T_IDXS, *net_output_data['plan'][0,:,Plan.ACCELERATION].T) |
||||||
|
orientation = modelV2.orientation |
||||||
|
fill_xyzt(orientation, ModelConstants.T_IDXS, *net_output_data['plan'][0,:,Plan.T_FROM_CURRENT_EULER].T) |
||||||
|
orientation_rate = modelV2.orientationRate |
||||||
|
fill_xyzt(orientation_rate, ModelConstants.T_IDXS, *net_output_data['plan'][0,:,Plan.ORIENTATION_RATE].T) |
||||||
|
|
||||||
|
# times at X_IDXS according to model plan |
||||||
|
PLAN_T_IDXS = [np.nan] * ModelConstants.IDX_N |
||||||
|
PLAN_T_IDXS[0] = 0.0 |
||||||
|
plan_x = net_output_data['plan'][0,:,Plan.POSITION][:,0].tolist() |
||||||
|
for xidx in range(1, ModelConstants.IDX_N): |
||||||
|
tidx = 0 |
||||||
|
# increment tidx until we find an element that's further away than the current xidx |
||||||
|
while tidx < ModelConstants.IDX_N - 1 and plan_x[tidx+1] < ModelConstants.X_IDXS[xidx]: |
||||||
|
tidx += 1 |
||||||
|
if tidx == ModelConstants.IDX_N - 1: |
||||||
|
# if the Plan doesn't extend far enough, set plan_t to the max value (10s), then break |
||||||
|
PLAN_T_IDXS[xidx] = ModelConstants.T_IDXS[ModelConstants.IDX_N - 1] |
||||||
|
break |
||||||
|
# interpolate to find `t` for the current xidx |
||||||
|
current_x_val = plan_x[tidx] |
||||||
|
next_x_val = plan_x[tidx+1] |
||||||
|
p = (ModelConstants.X_IDXS[xidx] - current_x_val) / (next_x_val - current_x_val) |
||||||
|
PLAN_T_IDXS[xidx] = p * ModelConstants.T_IDXS[tidx+1] + (1 - p) * ModelConstants.T_IDXS[tidx] |
||||||
|
|
||||||
|
# lane lines |
||||||
|
modelV2.init('laneLines', 4) |
||||||
|
for i in range(4): |
||||||
|
lane_line = modelV2.laneLines[i] |
||||||
|
fill_xyzt(lane_line, PLAN_T_IDXS, np.array(ModelConstants.X_IDXS), net_output_data['lane_lines'][0,i,:,0], net_output_data['lane_lines'][0,i,:,1]) |
||||||
|
modelV2.laneLineStds = net_output_data['lane_lines_stds'][0,:,0,0].tolist() |
||||||
|
modelV2.laneLineProbs = net_output_data['lane_lines_prob'][0,1::2].tolist() |
||||||
|
|
||||||
|
# road edges |
||||||
|
modelV2.init('roadEdges', 2) |
||||||
|
for i in range(2): |
||||||
|
road_edge = modelV2.roadEdges[i] |
||||||
|
fill_xyzt(road_edge, PLAN_T_IDXS, np.array(ModelConstants.X_IDXS), net_output_data['road_edges'][0,i,:,0], net_output_data['road_edges'][0,i,:,1]) |
||||||
|
modelV2.roadEdgeStds = net_output_data['road_edges_stds'][0,:,0,0].tolist() |
||||||
|
|
||||||
|
# leads |
||||||
|
modelV2.init('leadsV3', 3) |
||||||
|
for i in range(3): |
||||||
|
lead = modelV2.leadsV3[i] |
||||||
|
fill_xyvat(lead, ModelConstants.LEAD_T_IDXS, *net_output_data['lead'][0,i].T, *net_output_data['lead_stds'][0,i].T) |
||||||
|
lead.prob = net_output_data['lead_prob'][0,i].tolist() |
||||||
|
lead.probTime = ModelConstants.LEAD_T_OFFSETS[i] |
||||||
|
|
||||||
|
# meta |
||||||
|
meta = modelV2.meta |
||||||
|
meta.desireState = net_output_data['desire_state'][0].reshape(-1).tolist() |
||||||
|
meta.desirePrediction = net_output_data['desire_pred'][0].reshape(-1).tolist() |
||||||
|
meta.engagedProb = net_output_data['meta'][0,Meta.ENGAGED].item() |
||||||
|
meta.init('disengagePredictions') |
||||||
|
disengage_predictions = meta.disengagePredictions |
||||||
|
disengage_predictions.t = ModelConstants.META_T_IDXS |
||||||
|
disengage_predictions.brakeDisengageProbs = net_output_data['meta'][0,Meta.BRAKE_DISENGAGE].tolist() |
||||||
|
disengage_predictions.gasDisengageProbs = net_output_data['meta'][0,Meta.GAS_DISENGAGE].tolist() |
||||||
|
disengage_predictions.steerOverrideProbs = net_output_data['meta'][0,Meta.STEER_OVERRIDE].tolist() |
||||||
|
disengage_predictions.brake3MetersPerSecondSquaredProbs = net_output_data['meta'][0,Meta.HARD_BRAKE_3].tolist() |
||||||
|
disengage_predictions.brake4MetersPerSecondSquaredProbs = net_output_data['meta'][0,Meta.HARD_BRAKE_4].tolist() |
||||||
|
disengage_predictions.brake5MetersPerSecondSquaredProbs = net_output_data['meta'][0,Meta.HARD_BRAKE_5].tolist() |
||||||
|
|
||||||
|
publish_state.prev_brake_5ms2_probs[:-1] = publish_state.prev_brake_5ms2_probs[1:] |
||||||
|
publish_state.prev_brake_5ms2_probs[-1] = net_output_data['meta'][0,Meta.HARD_BRAKE_5][0] |
||||||
|
publish_state.prev_brake_3ms2_probs[:-1] = publish_state.prev_brake_3ms2_probs[1:] |
||||||
|
publish_state.prev_brake_3ms2_probs[-1] = net_output_data['meta'][0,Meta.HARD_BRAKE_3][0] |
||||||
|
hard_brake_predicted = (publish_state.prev_brake_5ms2_probs > ModelConstants.FCW_THRESHOLDS_5MS2).all() and \ |
||||||
|
(publish_state.prev_brake_3ms2_probs > ModelConstants.FCW_THRESHOLDS_3MS2).all() |
||||||
|
meta.hardBrakePredicted = hard_brake_predicted.item() |
||||||
|
|
||||||
|
# temporal pose |
||||||
|
temporal_pose = modelV2.temporalPose |
||||||
|
temporal_pose.trans = net_output_data['sim_pose'][0,:3].tolist() |
||||||
|
temporal_pose.transStd = net_output_data['sim_pose_stds'][0,:3].tolist() |
||||||
|
temporal_pose.rot = net_output_data['sim_pose'][0,3:].tolist() |
||||||
|
temporal_pose.rotStd = net_output_data['sim_pose_stds'][0,3:].tolist() |
||||||
|
|
||||||
|
# confidence |
||||||
|
if vipc_frame_id % (2*ModelConstants.MODEL_FREQ) == 0: |
||||||
|
# any disengage prob |
||||||
|
brake_disengage_probs = net_output_data['meta'][0,Meta.BRAKE_DISENGAGE] |
||||||
|
gas_disengage_probs = net_output_data['meta'][0,Meta.GAS_DISENGAGE] |
||||||
|
steer_override_probs = net_output_data['meta'][0,Meta.STEER_OVERRIDE] |
||||||
|
any_disengage_probs = 1-((1-brake_disengage_probs)*(1-gas_disengage_probs)*(1-steer_override_probs)) |
||||||
|
# independent disengage prob for each 2s slice |
||||||
|
ind_disengage_probs = np.r_[any_disengage_probs[0], np.diff(any_disengage_probs) / (1 - any_disengage_probs[:-1])] |
||||||
|
# rolling buf for 2, 4, 6, 8, 10s |
||||||
|
publish_state.disengage_buffer[:-ModelConstants.DISENGAGE_WIDTH] = publish_state.disengage_buffer[ModelConstants.DISENGAGE_WIDTH:] |
||||||
|
publish_state.disengage_buffer[-ModelConstants.DISENGAGE_WIDTH:] = ind_disengage_probs |
||||||
|
|
||||||
|
score = 0. |
||||||
|
for i in range(ModelConstants.DISENGAGE_WIDTH): |
||||||
|
score += publish_state.disengage_buffer[i*ModelConstants.DISENGAGE_WIDTH+ModelConstants.DISENGAGE_WIDTH-1-i].item() / ModelConstants.DISENGAGE_WIDTH |
||||||
|
if score < ModelConstants.RYG_GREEN: |
||||||
|
modelV2.confidence = ConfidenceClass.green |
||||||
|
elif score < ModelConstants.RYG_YELLOW: |
||||||
|
modelV2.confidence = ConfidenceClass.yellow |
||||||
|
else: |
||||||
|
modelV2.confidence = ConfidenceClass.red |
||||||
|
|
||||||
|
def fill_pose_msg(msg: capnp._DynamicStructBuilder, net_output_data: Dict[str, np.ndarray], |
||||||
|
vipc_frame_id: int, vipc_dropped_frames: int, timestamp_eof: int, live_calib_seen: bool) -> None: |
||||||
|
msg.valid = live_calib_seen & (vipc_dropped_frames < 1) |
||||||
|
cameraOdometry = msg.cameraOdometry |
||||||
|
|
||||||
|
cameraOdometry.frameId = vipc_frame_id |
||||||
|
cameraOdometry.timestampEof = timestamp_eof |
||||||
|
|
||||||
|
cameraOdometry.trans = net_output_data['pose'][0,:3].tolist() |
||||||
|
cameraOdometry.rot = net_output_data['pose'][0,3:].tolist() |
||||||
|
cameraOdometry.wideFromDeviceEuler = net_output_data['wide_from_device_euler'][0,:].tolist() |
||||||
|
cameraOdometry.roadTransformTrans = net_output_data['road_transform'][0,:3].tolist() |
||||||
|
cameraOdometry.transStd = net_output_data['pose_stds'][0,:3].tolist() |
||||||
|
cameraOdometry.rotStd = net_output_data['pose_stds'][0,3:].tolist() |
||||||
|
cameraOdometry.wideFromDeviceEulerStd = net_output_data['wide_from_device_euler_stds'][0,:].tolist() |
||||||
|
cameraOdometry.roadTransformTransStd = net_output_data['road_transform_stds'][0,:3].tolist() |
@ -0,0 +1,29 @@ |
|||||||
|
#!/usr/bin/env python3 |
||||||
|
import sys |
||||||
|
import pathlib |
||||||
|
import onnx |
||||||
|
import codecs |
||||||
|
import pickle |
||||||
|
from typing import Tuple |
||||||
|
|
||||||
|
def get_name_and_shape(value_info:onnx.ValueInfoProto) -> Tuple[str, Tuple[int,...]]: |
||||||
|
shape = tuple([int(dim.dim_value) for dim in value_info.type.tensor_type.shape.dim]) |
||||||
|
name = value_info.name |
||||||
|
return name, shape |
||||||
|
|
||||||
|
if __name__ == "__main__": |
||||||
|
model_path = pathlib.Path(sys.argv[1]) |
||||||
|
model = onnx.load(str(model_path)) |
||||||
|
i = [x.key for x in model.metadata_props].index('output_slices') |
||||||
|
output_slices = model.metadata_props[i].value |
||||||
|
|
||||||
|
metadata = {} |
||||||
|
metadata['output_slices'] = pickle.loads(codecs.decode(output_slices.encode(), "base64")) |
||||||
|
metadata['input_shapes'] = dict([get_name_and_shape(x) for x in model.graph.input]) |
||||||
|
metadata['output_shapes'] = dict([get_name_and_shape(x) for x in model.graph.output]) |
||||||
|
|
||||||
|
metadata_path = model_path.parent / (model_path.stem + '_metadata.pkl') |
||||||
|
with open(metadata_path, 'wb') as f: |
||||||
|
pickle.dump(metadata, f) |
||||||
|
|
||||||
|
print(f'saved metadata to {metadata_path}') |
@ -1,330 +0,0 @@ |
|||||||
#include "selfdrive/modeld/models/driving.h" |
|
||||||
|
|
||||||
#include <cstring> |
|
||||||
|
|
||||||
|
|
||||||
void fill_lead(cereal::ModelDataV2::LeadDataV3::Builder lead, const ModelOutputLeads &leads, int t_idx, float prob_t) { |
|
||||||
std::array<float, LEAD_TRAJ_LEN> lead_t = {0.0, 2.0, 4.0, 6.0, 8.0, 10.0}; |
|
||||||
const auto &best_prediction = leads.get_best_prediction(t_idx); |
|
||||||
lead.setProb(sigmoid(leads.prob[t_idx])); |
|
||||||
lead.setProbTime(prob_t); |
|
||||||
std::array<float, LEAD_TRAJ_LEN> lead_x, lead_y, lead_v, lead_a; |
|
||||||
std::array<float, LEAD_TRAJ_LEN> lead_x_std, lead_y_std, lead_v_std, lead_a_std; |
|
||||||
for (int i=0; i<LEAD_TRAJ_LEN; i++) { |
|
||||||
lead_x[i] = best_prediction.mean[i].x; |
|
||||||
lead_y[i] = best_prediction.mean[i].y; |
|
||||||
lead_v[i] = best_prediction.mean[i].velocity; |
|
||||||
lead_a[i] = best_prediction.mean[i].acceleration; |
|
||||||
lead_x_std[i] = exp(best_prediction.std[i].x); |
|
||||||
lead_y_std[i] = exp(best_prediction.std[i].y); |
|
||||||
lead_v_std[i] = exp(best_prediction.std[i].velocity); |
|
||||||
lead_a_std[i] = exp(best_prediction.std[i].acceleration); |
|
||||||
} |
|
||||||
lead.setT(to_kj_array_ptr(lead_t)); |
|
||||||
lead.setX(to_kj_array_ptr(lead_x)); |
|
||||||
lead.setY(to_kj_array_ptr(lead_y)); |
|
||||||
lead.setV(to_kj_array_ptr(lead_v)); |
|
||||||
lead.setA(to_kj_array_ptr(lead_a)); |
|
||||||
lead.setXStd(to_kj_array_ptr(lead_x_std)); |
|
||||||
lead.setYStd(to_kj_array_ptr(lead_y_std)); |
|
||||||
lead.setVStd(to_kj_array_ptr(lead_v_std)); |
|
||||||
lead.setAStd(to_kj_array_ptr(lead_a_std)); |
|
||||||
} |
|
||||||
|
|
||||||
void fill_meta(cereal::ModelDataV2::MetaData::Builder meta, const ModelOutputMeta &meta_data, PublishState &ps) { |
|
||||||
std::array<float, DESIRE_LEN> desire_state_softmax; |
|
||||||
softmax(meta_data.desire_state_prob.array.data(), desire_state_softmax.data(), DESIRE_LEN); |
|
||||||
|
|
||||||
std::array<float, DESIRE_PRED_LEN * DESIRE_LEN> desire_pred_softmax; |
|
||||||
for (int i=0; i<DESIRE_PRED_LEN; i++) { |
|
||||||
softmax(meta_data.desire_pred_prob[i].array.data(), desire_pred_softmax.data() + (i * DESIRE_LEN), DESIRE_LEN); |
|
||||||
} |
|
||||||
|
|
||||||
std::array<float, DISENGAGE_LEN> lat_long_t = {2, 4, 6, 8, 10}; |
|
||||||
std::array<float, DISENGAGE_LEN> gas_disengage_sigmoid, brake_disengage_sigmoid, steer_override_sigmoid, |
|
||||||
brake_3ms2_sigmoid, brake_4ms2_sigmoid, brake_5ms2_sigmoid; |
|
||||||
for (int i=0; i<DISENGAGE_LEN; i++) { |
|
||||||
gas_disengage_sigmoid[i] = sigmoid(meta_data.disengage_prob[i].gas_disengage); |
|
||||||
brake_disengage_sigmoid[i] = sigmoid(meta_data.disengage_prob[i].brake_disengage); |
|
||||||
steer_override_sigmoid[i] = sigmoid(meta_data.disengage_prob[i].steer_override); |
|
||||||
brake_3ms2_sigmoid[i] = sigmoid(meta_data.disengage_prob[i].brake_3ms2); |
|
||||||
brake_4ms2_sigmoid[i] = sigmoid(meta_data.disengage_prob[i].brake_4ms2); |
|
||||||
brake_5ms2_sigmoid[i] = sigmoid(meta_data.disengage_prob[i].brake_5ms2); |
|
||||||
//gas_pressed_sigmoid[i] = sigmoid(meta_data.disengage_prob[i].gas_pressed);
|
|
||||||
} |
|
||||||
|
|
||||||
std::memmove(ps.prev_brake_5ms2_probs.data(), &ps.prev_brake_5ms2_probs[1], 4*sizeof(float)); |
|
||||||
std::memmove(ps.prev_brake_3ms2_probs.data(), &ps.prev_brake_3ms2_probs[1], 2*sizeof(float)); |
|
||||||
ps.prev_brake_5ms2_probs[4] = brake_5ms2_sigmoid[0]; |
|
||||||
ps.prev_brake_3ms2_probs[2] = brake_3ms2_sigmoid[0]; |
|
||||||
|
|
||||||
bool above_fcw_threshold = true; |
|
||||||
for (int i=0; i<ps.prev_brake_5ms2_probs.size(); i++) { |
|
||||||
float threshold = i < 2 ? FCW_THRESHOLD_5MS2_LOW : FCW_THRESHOLD_5MS2_HIGH; |
|
||||||
above_fcw_threshold = above_fcw_threshold && ps.prev_brake_5ms2_probs[i] > threshold; |
|
||||||
} |
|
||||||
for (int i=0; i<ps.prev_brake_3ms2_probs.size(); i++) { |
|
||||||
above_fcw_threshold = above_fcw_threshold && ps.prev_brake_3ms2_probs[i] > FCW_THRESHOLD_3MS2; |
|
||||||
} |
|
||||||
|
|
||||||
auto disengage = meta.initDisengagePredictions(); |
|
||||||
disengage.setT(to_kj_array_ptr(lat_long_t)); |
|
||||||
disengage.setGasDisengageProbs(to_kj_array_ptr(gas_disengage_sigmoid)); |
|
||||||
disengage.setBrakeDisengageProbs(to_kj_array_ptr(brake_disengage_sigmoid)); |
|
||||||
disengage.setSteerOverrideProbs(to_kj_array_ptr(steer_override_sigmoid)); |
|
||||||
disengage.setBrake3MetersPerSecondSquaredProbs(to_kj_array_ptr(brake_3ms2_sigmoid)); |
|
||||||
disengage.setBrake4MetersPerSecondSquaredProbs(to_kj_array_ptr(brake_4ms2_sigmoid)); |
|
||||||
disengage.setBrake5MetersPerSecondSquaredProbs(to_kj_array_ptr(brake_5ms2_sigmoid)); |
|
||||||
|
|
||||||
meta.setEngagedProb(sigmoid(meta_data.engaged_prob)); |
|
||||||
meta.setDesirePrediction(to_kj_array_ptr(desire_pred_softmax)); |
|
||||||
meta.setDesireState(to_kj_array_ptr(desire_state_softmax)); |
|
||||||
meta.setHardBrakePredicted(above_fcw_threshold); |
|
||||||
} |
|
||||||
|
|
||||||
void fill_confidence(cereal::ModelDataV2::Builder &framed, PublishState &ps) { |
|
||||||
if (framed.getFrameId() % (2*MODEL_FREQ) == 0) { |
|
||||||
// update every 2s to match predictions interval
|
|
||||||
auto dbps = framed.getMeta().getDisengagePredictions().getBrakeDisengageProbs(); |
|
||||||
auto dgps = framed.getMeta().getDisengagePredictions().getGasDisengageProbs(); |
|
||||||
auto dsps = framed.getMeta().getDisengagePredictions().getSteerOverrideProbs(); |
|
||||||
|
|
||||||
float any_dp[DISENGAGE_LEN]; |
|
||||||
float dp_ind[DISENGAGE_LEN]; |
|
||||||
|
|
||||||
for (int i = 0; i < DISENGAGE_LEN; i++) { |
|
||||||
any_dp[i] = 1 - ((1-dbps[i])*(1-dgps[i])*(1-dsps[i])); // any disengage prob
|
|
||||||
} |
|
||||||
|
|
||||||
dp_ind[0] = any_dp[0]; |
|
||||||
for (int i = 0; i < DISENGAGE_LEN-1; i++) { |
|
||||||
dp_ind[i+1] = (any_dp[i+1] - any_dp[i]) / (1 - any_dp[i]); // independent disengage prob for each 2s slice
|
|
||||||
} |
|
||||||
|
|
||||||
// rolling buf for 2, 4, 6, 8, 10s
|
|
||||||
std::memmove(&ps.disengage_buffer[0], &ps.disengage_buffer[DISENGAGE_LEN], sizeof(float) * DISENGAGE_LEN * (DISENGAGE_LEN-1)); |
|
||||||
std::memcpy(&ps.disengage_buffer[DISENGAGE_LEN * (DISENGAGE_LEN-1)], &dp_ind[0], sizeof(float) * DISENGAGE_LEN); |
|
||||||
} |
|
||||||
|
|
||||||
float score = 0; |
|
||||||
for (int i = 0; i < DISENGAGE_LEN; i++) { |
|
||||||
score += ps.disengage_buffer[i*DISENGAGE_LEN+DISENGAGE_LEN-1-i] / DISENGAGE_LEN; |
|
||||||
} |
|
||||||
|
|
||||||
if (score < RYG_GREEN) { |
|
||||||
framed.setConfidence(cereal::ModelDataV2::ConfidenceClass::GREEN); |
|
||||||
} else if (score < RYG_YELLOW) { |
|
||||||
framed.setConfidence(cereal::ModelDataV2::ConfidenceClass::YELLOW); |
|
||||||
} else { |
|
||||||
framed.setConfidence(cereal::ModelDataV2::ConfidenceClass::RED); |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
template<size_t size> |
|
||||||
void fill_xyzt(cereal::XYZTData::Builder xyzt, const std::array<float, size> &t, |
|
||||||
const std::array<float, size> &x, const std::array<float, size> &y, const std::array<float, size> &z) { |
|
||||||
xyzt.setT(to_kj_array_ptr(t)); |
|
||||||
xyzt.setX(to_kj_array_ptr(x)); |
|
||||||
xyzt.setY(to_kj_array_ptr(y)); |
|
||||||
xyzt.setZ(to_kj_array_ptr(z)); |
|
||||||
} |
|
||||||
|
|
||||||
template<size_t size> |
|
||||||
void fill_xyzt(cereal::XYZTData::Builder xyzt, const std::array<float, size> &t, |
|
||||||
const std::array<float, size> &x, const std::array<float, size> &y, const std::array<float, size> &z, |
|
||||||
const std::array<float, size> &x_std, const std::array<float, size> &y_std, const std::array<float, size> &z_std) { |
|
||||||
fill_xyzt(xyzt, t, x, y, z); |
|
||||||
xyzt.setXStd(to_kj_array_ptr(x_std)); |
|
||||||
xyzt.setYStd(to_kj_array_ptr(y_std)); |
|
||||||
xyzt.setZStd(to_kj_array_ptr(z_std)); |
|
||||||
} |
|
||||||
|
|
||||||
void fill_plan(cereal::ModelDataV2::Builder &framed, const ModelOutputPlanPrediction &plan) { |
|
||||||
std::array<float, TRAJECTORY_SIZE> pos_x, pos_y, pos_z; |
|
||||||
std::array<float, TRAJECTORY_SIZE> pos_x_std, pos_y_std, pos_z_std; |
|
||||||
std::array<float, TRAJECTORY_SIZE> vel_x, vel_y, vel_z; |
|
||||||
std::array<float, TRAJECTORY_SIZE> rot_x, rot_y, rot_z; |
|
||||||
std::array<float, TRAJECTORY_SIZE> acc_x, acc_y, acc_z; |
|
||||||
std::array<float, TRAJECTORY_SIZE> rot_rate_x, rot_rate_y, rot_rate_z; |
|
||||||
|
|
||||||
for (int i=0; i<TRAJECTORY_SIZE; i++) { |
|
||||||
pos_x[i] = plan.mean[i].position.x; |
|
||||||
pos_y[i] = plan.mean[i].position.y; |
|
||||||
pos_z[i] = plan.mean[i].position.z; |
|
||||||
pos_x_std[i] = exp(plan.std[i].position.x); |
|
||||||
pos_y_std[i] = exp(plan.std[i].position.y); |
|
||||||
pos_z_std[i] = exp(plan.std[i].position.z); |
|
||||||
vel_x[i] = plan.mean[i].velocity.x; |
|
||||||
vel_y[i] = plan.mean[i].velocity.y; |
|
||||||
vel_z[i] = plan.mean[i].velocity.z; |
|
||||||
acc_x[i] = plan.mean[i].acceleration.x; |
|
||||||
acc_y[i] = plan.mean[i].acceleration.y; |
|
||||||
acc_z[i] = plan.mean[i].acceleration.z; |
|
||||||
rot_x[i] = plan.mean[i].rotation.x; |
|
||||||
rot_y[i] = plan.mean[i].rotation.y; |
|
||||||
rot_z[i] = plan.mean[i].rotation.z; |
|
||||||
rot_rate_x[i] = plan.mean[i].rotation_rate.x; |
|
||||||
rot_rate_y[i] = plan.mean[i].rotation_rate.y; |
|
||||||
rot_rate_z[i] = plan.mean[i].rotation_rate.z; |
|
||||||
} |
|
||||||
|
|
||||||
fill_xyzt(framed.initPosition(), T_IDXS_FLOAT, pos_x, pos_y, pos_z, pos_x_std, pos_y_std, pos_z_std); |
|
||||||
fill_xyzt(framed.initVelocity(), T_IDXS_FLOAT, vel_x, vel_y, vel_z); |
|
||||||
fill_xyzt(framed.initAcceleration(), T_IDXS_FLOAT, acc_x, acc_y, acc_z); |
|
||||||
fill_xyzt(framed.initOrientation(), T_IDXS_FLOAT, rot_x, rot_y, rot_z); |
|
||||||
fill_xyzt(framed.initOrientationRate(), T_IDXS_FLOAT, rot_rate_x, rot_rate_y, rot_rate_z); |
|
||||||
} |
|
||||||
|
|
||||||
void fill_lane_lines(cereal::ModelDataV2::Builder &framed, const std::array<float, TRAJECTORY_SIZE> &plan_t, |
|
||||||
const ModelOutputLaneLines &lanes) { |
|
||||||
std::array<float, TRAJECTORY_SIZE> left_far_y, left_far_z; |
|
||||||
std::array<float, TRAJECTORY_SIZE> left_near_y, left_near_z; |
|
||||||
std::array<float, TRAJECTORY_SIZE> right_near_y, right_near_z; |
|
||||||
std::array<float, TRAJECTORY_SIZE> right_far_y, right_far_z; |
|
||||||
for (int j=0; j<TRAJECTORY_SIZE; j++) { |
|
||||||
left_far_y[j] = lanes.mean.left_far[j].y; |
|
||||||
left_far_z[j] = lanes.mean.left_far[j].z; |
|
||||||
left_near_y[j] = lanes.mean.left_near[j].y; |
|
||||||
left_near_z[j] = lanes.mean.left_near[j].z; |
|
||||||
right_near_y[j] = lanes.mean.right_near[j].y; |
|
||||||
right_near_z[j] = lanes.mean.right_near[j].z; |
|
||||||
right_far_y[j] = lanes.mean.right_far[j].y; |
|
||||||
right_far_z[j] = lanes.mean.right_far[j].z; |
|
||||||
} |
|
||||||
|
|
||||||
auto lane_lines = framed.initLaneLines(4); |
|
||||||
fill_xyzt(lane_lines[0], plan_t, X_IDXS_FLOAT, left_far_y, left_far_z); |
|
||||||
fill_xyzt(lane_lines[1], plan_t, X_IDXS_FLOAT, left_near_y, left_near_z); |
|
||||||
fill_xyzt(lane_lines[2], plan_t, X_IDXS_FLOAT, right_near_y, right_near_z); |
|
||||||
fill_xyzt(lane_lines[3], plan_t, X_IDXS_FLOAT, right_far_y, right_far_z); |
|
||||||
|
|
||||||
framed.setLaneLineStds({ |
|
||||||
exp(lanes.std.left_far[0].y), |
|
||||||
exp(lanes.std.left_near[0].y), |
|
||||||
exp(lanes.std.right_near[0].y), |
|
||||||
exp(lanes.std.right_far[0].y), |
|
||||||
}); |
|
||||||
|
|
||||||
framed.setLaneLineProbs({ |
|
||||||
sigmoid(lanes.prob.left_far.val), |
|
||||||
sigmoid(lanes.prob.left_near.val), |
|
||||||
sigmoid(lanes.prob.right_near.val), |
|
||||||
sigmoid(lanes.prob.right_far.val), |
|
||||||
}); |
|
||||||
} |
|
||||||
|
|
||||||
void fill_road_edges(cereal::ModelDataV2::Builder &framed, const std::array<float, TRAJECTORY_SIZE> &plan_t, |
|
||||||
const ModelOutputRoadEdges &edges) { |
|
||||||
std::array<float, TRAJECTORY_SIZE> left_y, left_z; |
|
||||||
std::array<float, TRAJECTORY_SIZE> right_y, right_z; |
|
||||||
for (int j=0; j<TRAJECTORY_SIZE; j++) { |
|
||||||
left_y[j] = edges.mean.left[j].y; |
|
||||||
left_z[j] = edges.mean.left[j].z; |
|
||||||
right_y[j] = edges.mean.right[j].y; |
|
||||||
right_z[j] = edges.mean.right[j].z; |
|
||||||
} |
|
||||||
|
|
||||||
auto road_edges = framed.initRoadEdges(2); |
|
||||||
fill_xyzt(road_edges[0], plan_t, X_IDXS_FLOAT, left_y, left_z); |
|
||||||
fill_xyzt(road_edges[1], plan_t, X_IDXS_FLOAT, right_y, right_z); |
|
||||||
|
|
||||||
framed.setRoadEdgeStds({ |
|
||||||
exp(edges.std.left[0].y), |
|
||||||
exp(edges.std.right[0].y), |
|
||||||
}); |
|
||||||
} |
|
||||||
|
|
||||||
void fill_model(cereal::ModelDataV2::Builder &framed, const ModelOutput &net_outputs, PublishState &ps) { |
|
||||||
const auto &best_plan = net_outputs.plans.get_best_prediction(); |
|
||||||
std::array<float, TRAJECTORY_SIZE> plan_t; |
|
||||||
std::fill_n(plan_t.data(), plan_t.size(), NAN); |
|
||||||
plan_t[0] = 0.0; |
|
||||||
for (int xidx=1, tidx=0; xidx<TRAJECTORY_SIZE; xidx++) { |
|
||||||
// increment tidx until we find an element that's further away than the current xidx
|
|
||||||
for (int next_tid = tidx + 1; next_tid < TRAJECTORY_SIZE && best_plan.mean[next_tid].position.x < X_IDXS[xidx]; next_tid++) { |
|
||||||
tidx++; |
|
||||||
} |
|
||||||
if (tidx == TRAJECTORY_SIZE - 1) { |
|
||||||
// if the Plan doesn't extend far enough, set plan_t to the max value (10s), then break
|
|
||||||
plan_t[xidx] = T_IDXS[TRAJECTORY_SIZE - 1]; |
|
||||||
break; |
|
||||||
} |
|
||||||
|
|
||||||
// interpolate to find `t` for the current xidx
|
|
||||||
float current_x_val = best_plan.mean[tidx].position.x; |
|
||||||
float next_x_val = best_plan.mean[tidx+1].position.x; |
|
||||||
float p = (X_IDXS[xidx] - current_x_val) / (next_x_val - current_x_val); |
|
||||||
plan_t[xidx] = p * T_IDXS[tidx+1] + (1 - p) * T_IDXS[tidx]; |
|
||||||
} |
|
||||||
|
|
||||||
fill_plan(framed, best_plan); |
|
||||||
fill_lane_lines(framed, plan_t, net_outputs.lane_lines); |
|
||||||
fill_road_edges(framed, plan_t, net_outputs.road_edges); |
|
||||||
|
|
||||||
// meta
|
|
||||||
fill_meta(framed.initMeta(), net_outputs.meta, ps); |
|
||||||
|
|
||||||
// confidence
|
|
||||||
fill_confidence(framed, ps); |
|
||||||
|
|
||||||
// leads
|
|
||||||
auto leads = framed.initLeadsV3(LEAD_MHP_SELECTION); |
|
||||||
std::array<float, LEAD_MHP_SELECTION> t_offsets = {0.0, 2.0, 4.0}; |
|
||||||
for (int i=0; i<LEAD_MHP_SELECTION; i++) { |
|
||||||
fill_lead(leads[i], net_outputs.leads, i, t_offsets[i]); |
|
||||||
} |
|
||||||
|
|
||||||
// temporal pose
|
|
||||||
const auto &v_mean = net_outputs.temporal_pose.velocity_mean; |
|
||||||
const auto &r_mean = net_outputs.temporal_pose.rotation_mean; |
|
||||||
const auto &v_std = net_outputs.temporal_pose.velocity_std; |
|
||||||
const auto &r_std = net_outputs.temporal_pose.rotation_std; |
|
||||||
auto temporal_pose = framed.initTemporalPose(); |
|
||||||
temporal_pose.setTrans({v_mean.x, v_mean.y, v_mean.z}); |
|
||||||
temporal_pose.setRot({r_mean.x, r_mean.y, r_mean.z}); |
|
||||||
temporal_pose.setTransStd({exp(v_std.x), exp(v_std.y), exp(v_std.z)}); |
|
||||||
temporal_pose.setRotStd({exp(r_std.x), exp(r_std.y), exp(r_std.z)}); |
|
||||||
} |
|
||||||
|
|
||||||
void fill_model_msg(MessageBuilder &msg, float *net_output_data, PublishState &ps, uint32_t vipc_frame_id, uint32_t vipc_frame_id_extra, uint32_t frame_id, float frame_drop, |
|
||||||
uint64_t timestamp_eof, uint64_t timestamp_llk, float model_execution_time, const bool nav_enabled, const bool valid) { |
|
||||||
const uint32_t frame_age = (frame_id > vipc_frame_id) ? (frame_id - vipc_frame_id) : 0; |
|
||||||
auto framed = msg.initEvent(valid).initModelV2(); |
|
||||||
framed.setFrameId(vipc_frame_id); |
|
||||||
framed.setFrameIdExtra(vipc_frame_id_extra); |
|
||||||
framed.setFrameAge(frame_age); |
|
||||||
framed.setFrameDropPerc(frame_drop * 100); |
|
||||||
framed.setTimestampEof(timestamp_eof); |
|
||||||
framed.setLocationMonoTime(timestamp_llk); |
|
||||||
framed.setModelExecutionTime(model_execution_time); |
|
||||||
framed.setNavEnabled(nav_enabled); |
|
||||||
if (send_raw_pred) { |
|
||||||
framed.setRawPredictions(kj::ArrayPtr<const float>(net_output_data, NET_OUTPUT_SIZE).asBytes()); |
|
||||||
} |
|
||||||
fill_model(framed, *((ModelOutput*) net_output_data), ps); |
|
||||||
} |
|
||||||
|
|
||||||
void fill_pose_msg(MessageBuilder &msg, float *net_output_data, uint32_t vipc_frame_id, uint32_t vipc_dropped_frames, uint64_t timestamp_eof, const bool valid) { |
|
||||||
const ModelOutput &net_outputs = *((ModelOutput*) net_output_data); |
|
||||||
const auto &v_mean = net_outputs.pose.velocity_mean; |
|
||||||
const auto &r_mean = net_outputs.pose.rotation_mean; |
|
||||||
const auto &t_mean = net_outputs.wide_from_device_euler.mean; |
|
||||||
const auto &v_std = net_outputs.pose.velocity_std; |
|
||||||
const auto &r_std = net_outputs.pose.rotation_std; |
|
||||||
const auto &t_std = net_outputs.wide_from_device_euler.std; |
|
||||||
const auto &road_transform_trans_mean = net_outputs.road_transform.position_mean; |
|
||||||
const auto &road_transform_trans_std = net_outputs.road_transform.position_std; |
|
||||||
|
|
||||||
auto posenetd = msg.initEvent(valid && (vipc_dropped_frames < 1)).initCameraOdometry(); |
|
||||||
posenetd.setTrans({v_mean.x, v_mean.y, v_mean.z}); |
|
||||||
posenetd.setRot({r_mean.x, r_mean.y, r_mean.z}); |
|
||||||
posenetd.setWideFromDeviceEuler({t_mean.x, t_mean.y, t_mean.z}); |
|
||||||
posenetd.setRoadTransformTrans({road_transform_trans_mean.x, road_transform_trans_mean.y, road_transform_trans_mean.z}); |
|
||||||
posenetd.setTransStd({exp(v_std.x), exp(v_std.y), exp(v_std.z)}); |
|
||||||
posenetd.setRotStd({exp(r_std.x), exp(r_std.y), exp(r_std.z)}); |
|
||||||
posenetd.setWideFromDeviceEulerStd({exp(t_std.x), exp(t_std.y), exp(t_std.z)}); |
|
||||||
posenetd.setRoadTransformTransStd({exp(road_transform_trans_std.x), exp(road_transform_trans_std.y), exp(road_transform_trans_std.z)}); |
|
||||||
|
|
||||||
posenetd.setTimestampEof(timestamp_eof); |
|
||||||
posenetd.setFrameId(vipc_frame_id); |
|
||||||
} |
|
@ -1,257 +0,0 @@ |
|||||||
#pragma once |
|
||||||
|
|
||||||
#include <array> |
|
||||||
#include <memory> |
|
||||||
|
|
||||||
#include "cereal/messaging/messaging.h" |
|
||||||
#include "common/modeldata.h" |
|
||||||
#include "common/util.h" |
|
||||||
#include "selfdrive/modeld/models/commonmodel.h" |
|
||||||
#include "selfdrive/modeld/runners/run.h" |
|
||||||
|
|
||||||
constexpr int FEATURE_LEN = 512; |
|
||||||
constexpr int HISTORY_BUFFER_LEN = 99; |
|
||||||
constexpr int DESIRE_LEN = 8; |
|
||||||
constexpr int DESIRE_PRED_LEN = 4; |
|
||||||
constexpr int TRAFFIC_CONVENTION_LEN = 2; |
|
||||||
constexpr int NAV_FEATURE_LEN = 256; |
|
||||||
constexpr int NAV_INSTRUCTION_LEN = 150; |
|
||||||
constexpr int DRIVING_STYLE_LEN = 12; |
|
||||||
constexpr int MODEL_FREQ = 20; |
|
||||||
|
|
||||||
constexpr int DISENGAGE_LEN = 5; |
|
||||||
constexpr int BLINKER_LEN = 6; |
|
||||||
constexpr int META_STRIDE = 7; |
|
||||||
|
|
||||||
constexpr int PLAN_MHP_N = 5; |
|
||||||
constexpr int LEAD_MHP_N = 2; |
|
||||||
constexpr int LEAD_TRAJ_LEN = 6; |
|
||||||
constexpr int LEAD_MHP_SELECTION = 3; |
|
||||||
// Padding to get output shape as multiple of 4
|
|
||||||
constexpr int PAD_SIZE = 2; |
|
||||||
|
|
||||||
constexpr float FCW_THRESHOLD_5MS2_HIGH = 0.15; |
|
||||||
constexpr float FCW_THRESHOLD_5MS2_LOW = 0.05; |
|
||||||
constexpr float FCW_THRESHOLD_3MS2 = 0.7; |
|
||||||
|
|
||||||
struct ModelOutputXYZ { |
|
||||||
float x; |
|
||||||
float y; |
|
||||||
float z; |
|
||||||
}; |
|
||||||
static_assert(sizeof(ModelOutputXYZ) == sizeof(float)*3); |
|
||||||
|
|
||||||
struct ModelOutputYZ { |
|
||||||
float y; |
|
||||||
float z; |
|
||||||
}; |
|
||||||
static_assert(sizeof(ModelOutputYZ) == sizeof(float)*2); |
|
||||||
|
|
||||||
struct ModelOutputPlanElement { |
|
||||||
ModelOutputXYZ position; |
|
||||||
ModelOutputXYZ velocity; |
|
||||||
ModelOutputXYZ acceleration; |
|
||||||
ModelOutputXYZ rotation; |
|
||||||
ModelOutputXYZ rotation_rate; |
|
||||||
}; |
|
||||||
static_assert(sizeof(ModelOutputPlanElement) == sizeof(ModelOutputXYZ)*5); |
|
||||||
|
|
||||||
struct ModelOutputPlanPrediction { |
|
||||||
std::array<ModelOutputPlanElement, TRAJECTORY_SIZE> mean; |
|
||||||
std::array<ModelOutputPlanElement, TRAJECTORY_SIZE> std; |
|
||||||
float prob; |
|
||||||
}; |
|
||||||
static_assert(sizeof(ModelOutputPlanPrediction) == (sizeof(ModelOutputPlanElement)*TRAJECTORY_SIZE*2) + sizeof(float)); |
|
||||||
|
|
||||||
struct ModelOutputPlans { |
|
||||||
std::array<ModelOutputPlanPrediction, PLAN_MHP_N> prediction; |
|
||||||
|
|
||||||
constexpr const ModelOutputPlanPrediction &get_best_prediction() const { |
|
||||||
int max_idx = 0; |
|
||||||
for (int i = 1; i < prediction.size(); i++) { |
|
||||||
if (prediction[i].prob > prediction[max_idx].prob) { |
|
||||||
max_idx = i; |
|
||||||
} |
|
||||||
} |
|
||||||
return prediction[max_idx]; |
|
||||||
} |
|
||||||
}; |
|
||||||
static_assert(sizeof(ModelOutputPlans) == sizeof(ModelOutputPlanPrediction)*PLAN_MHP_N); |
|
||||||
|
|
||||||
struct ModelOutputLinesXY { |
|
||||||
std::array<ModelOutputYZ, TRAJECTORY_SIZE> left_far; |
|
||||||
std::array<ModelOutputYZ, TRAJECTORY_SIZE> left_near; |
|
||||||
std::array<ModelOutputYZ, TRAJECTORY_SIZE> right_near; |
|
||||||
std::array<ModelOutputYZ, TRAJECTORY_SIZE> right_far; |
|
||||||
}; |
|
||||||
static_assert(sizeof(ModelOutputLinesXY) == sizeof(ModelOutputYZ)*TRAJECTORY_SIZE*4); |
|
||||||
|
|
||||||
struct ModelOutputLineProbVal { |
|
||||||
float val_deprecated; |
|
||||||
float val; |
|
||||||
}; |
|
||||||
static_assert(sizeof(ModelOutputLineProbVal) == sizeof(float)*2); |
|
||||||
|
|
||||||
struct ModelOutputLinesProb { |
|
||||||
ModelOutputLineProbVal left_far; |
|
||||||
ModelOutputLineProbVal left_near; |
|
||||||
ModelOutputLineProbVal right_near; |
|
||||||
ModelOutputLineProbVal right_far; |
|
||||||
}; |
|
||||||
static_assert(sizeof(ModelOutputLinesProb) == sizeof(ModelOutputLineProbVal)*4); |
|
||||||
|
|
||||||
struct ModelOutputLaneLines { |
|
||||||
ModelOutputLinesXY mean; |
|
||||||
ModelOutputLinesXY std; |
|
||||||
ModelOutputLinesProb prob; |
|
||||||
}; |
|
||||||
static_assert(sizeof(ModelOutputLaneLines) == (sizeof(ModelOutputLinesXY)*2) + sizeof(ModelOutputLinesProb)); |
|
||||||
|
|
||||||
struct ModelOutputEdgessXY { |
|
||||||
std::array<ModelOutputYZ, TRAJECTORY_SIZE> left; |
|
||||||
std::array<ModelOutputYZ, TRAJECTORY_SIZE> right; |
|
||||||
}; |
|
||||||
static_assert(sizeof(ModelOutputEdgessXY) == sizeof(ModelOutputYZ)*TRAJECTORY_SIZE*2); |
|
||||||
|
|
||||||
struct ModelOutputRoadEdges { |
|
||||||
ModelOutputEdgessXY mean; |
|
||||||
ModelOutputEdgessXY std; |
|
||||||
}; |
|
||||||
static_assert(sizeof(ModelOutputRoadEdges) == (sizeof(ModelOutputEdgessXY)*2)); |
|
||||||
|
|
||||||
struct ModelOutputLeadElement { |
|
||||||
float x; |
|
||||||
float y; |
|
||||||
float velocity; |
|
||||||
float acceleration; |
|
||||||
}; |
|
||||||
static_assert(sizeof(ModelOutputLeadElement) == sizeof(float)*4); |
|
||||||
|
|
||||||
struct ModelOutputLeadPrediction { |
|
||||||
std::array<ModelOutputLeadElement, LEAD_TRAJ_LEN> mean; |
|
||||||
std::array<ModelOutputLeadElement, LEAD_TRAJ_LEN> std; |
|
||||||
std::array<float, LEAD_MHP_SELECTION> prob; |
|
||||||
}; |
|
||||||
static_assert(sizeof(ModelOutputLeadPrediction) == (sizeof(ModelOutputLeadElement)*LEAD_TRAJ_LEN*2) + (sizeof(float)*LEAD_MHP_SELECTION)); |
|
||||||
|
|
||||||
struct ModelOutputLeads { |
|
||||||
std::array<ModelOutputLeadPrediction, LEAD_MHP_N> prediction; |
|
||||||
std::array<float, LEAD_MHP_SELECTION> prob; |
|
||||||
|
|
||||||
constexpr const ModelOutputLeadPrediction &get_best_prediction(int t_idx) const { |
|
||||||
int max_idx = 0; |
|
||||||
for (int i = 1; i < prediction.size(); i++) { |
|
||||||
if (prediction[i].prob[t_idx] > prediction[max_idx].prob[t_idx]) { |
|
||||||
max_idx = i; |
|
||||||
} |
|
||||||
} |
|
||||||
return prediction[max_idx]; |
|
||||||
} |
|
||||||
}; |
|
||||||
static_assert(sizeof(ModelOutputLeads) == (sizeof(ModelOutputLeadPrediction)*LEAD_MHP_N) + (sizeof(float)*LEAD_MHP_SELECTION)); |
|
||||||
|
|
||||||
|
|
||||||
struct ModelOutputPose { |
|
||||||
ModelOutputXYZ velocity_mean; |
|
||||||
ModelOutputXYZ rotation_mean; |
|
||||||
ModelOutputXYZ velocity_std; |
|
||||||
ModelOutputXYZ rotation_std; |
|
||||||
}; |
|
||||||
static_assert(sizeof(ModelOutputPose) == sizeof(ModelOutputXYZ)*4); |
|
||||||
|
|
||||||
struct ModelOutputWideFromDeviceEuler { |
|
||||||
ModelOutputXYZ mean; |
|
||||||
ModelOutputXYZ std; |
|
||||||
}; |
|
||||||
static_assert(sizeof(ModelOutputWideFromDeviceEuler) == sizeof(ModelOutputXYZ)*2); |
|
||||||
|
|
||||||
struct ModelOutputTemporalPose { |
|
||||||
ModelOutputXYZ velocity_mean; |
|
||||||
ModelOutputXYZ rotation_mean; |
|
||||||
ModelOutputXYZ velocity_std; |
|
||||||
ModelOutputXYZ rotation_std; |
|
||||||
}; |
|
||||||
static_assert(sizeof(ModelOutputTemporalPose) == sizeof(ModelOutputXYZ)*4); |
|
||||||
|
|
||||||
struct ModelOutputRoadTransform { |
|
||||||
ModelOutputXYZ position_mean; |
|
||||||
ModelOutputXYZ rotation_mean; |
|
||||||
ModelOutputXYZ position_std; |
|
||||||
ModelOutputXYZ rotation_std; |
|
||||||
}; |
|
||||||
static_assert(sizeof(ModelOutputRoadTransform) == sizeof(ModelOutputXYZ)*4); |
|
||||||
|
|
||||||
struct ModelOutputDisengageProb { |
|
||||||
float gas_disengage; |
|
||||||
float brake_disengage; |
|
||||||
float steer_override; |
|
||||||
float brake_3ms2; |
|
||||||
float brake_4ms2; |
|
||||||
float brake_5ms2; |
|
||||||
float gas_pressed; |
|
||||||
}; |
|
||||||
static_assert(sizeof(ModelOutputDisengageProb) == sizeof(float)*7); |
|
||||||
|
|
||||||
struct ModelOutputBlinkerProb { |
|
||||||
float left; |
|
||||||
float right; |
|
||||||
}; |
|
||||||
static_assert(sizeof(ModelOutputBlinkerProb) == sizeof(float)*2); |
|
||||||
|
|
||||||
struct ModelOutputDesireProb { |
|
||||||
union { |
|
||||||
struct { |
|
||||||
float none; |
|
||||||
float turn_left; |
|
||||||
float turn_right; |
|
||||||
float lane_change_left; |
|
||||||
float lane_change_right; |
|
||||||
float keep_left; |
|
||||||
float keep_right; |
|
||||||
float null; |
|
||||||
}; |
|
||||||
struct { |
|
||||||
std::array<float, DESIRE_LEN> array; |
|
||||||
}; |
|
||||||
}; |
|
||||||
}; |
|
||||||
static_assert(sizeof(ModelOutputDesireProb) == sizeof(float)*DESIRE_LEN); |
|
||||||
|
|
||||||
struct ModelOutputMeta { |
|
||||||
ModelOutputDesireProb desire_state_prob; |
|
||||||
float engaged_prob; |
|
||||||
std::array<ModelOutputDisengageProb, DISENGAGE_LEN> disengage_prob; |
|
||||||
std::array<ModelOutputBlinkerProb, BLINKER_LEN> blinker_prob; |
|
||||||
std::array<ModelOutputDesireProb, DESIRE_PRED_LEN> desire_pred_prob; |
|
||||||
}; |
|
||||||
static_assert(sizeof(ModelOutputMeta) == sizeof(ModelOutputDesireProb) + sizeof(float) + (sizeof(ModelOutputDisengageProb)*DISENGAGE_LEN) + (sizeof(ModelOutputBlinkerProb)*BLINKER_LEN) + (sizeof(ModelOutputDesireProb)*DESIRE_PRED_LEN)); |
|
||||||
|
|
||||||
struct ModelOutputFeatures { |
|
||||||
std::array<float, FEATURE_LEN> feature; |
|
||||||
}; |
|
||||||
static_assert(sizeof(ModelOutputFeatures) == (sizeof(float)*FEATURE_LEN)); |
|
||||||
|
|
||||||
struct ModelOutput { |
|
||||||
const ModelOutputPlans plans; |
|
||||||
const ModelOutputLaneLines lane_lines; |
|
||||||
const ModelOutputRoadEdges road_edges; |
|
||||||
const ModelOutputLeads leads; |
|
||||||
const ModelOutputMeta meta; |
|
||||||
const ModelOutputPose pose; |
|
||||||
const ModelOutputWideFromDeviceEuler wide_from_device_euler; |
|
||||||
const ModelOutputTemporalPose temporal_pose; |
|
||||||
const ModelOutputRoadTransform road_transform; |
|
||||||
}; |
|
||||||
|
|
||||||
constexpr int OUTPUT_SIZE = sizeof(ModelOutput) / sizeof(float); |
|
||||||
constexpr int NET_OUTPUT_SIZE = OUTPUT_SIZE + FEATURE_LEN + PAD_SIZE; |
|
||||||
|
|
||||||
struct PublishState { |
|
||||||
std::array<float, DISENGAGE_LEN * DISENGAGE_LEN> disengage_buffer = {}; |
|
||||||
std::array<float, 5> prev_brake_5ms2_probs = {}; |
|
||||||
std::array<float, 3> prev_brake_3ms2_probs = {}; |
|
||||||
}; |
|
||||||
|
|
||||||
void fill_model_msg(MessageBuilder &msg, float *net_output_data, PublishState &ps, uint32_t vipc_frame_id, uint32_t vipc_frame_id_extra, uint32_t frame_id, float frame_drop, |
|
||||||
uint64_t timestamp_eof, uint64_t timestamp_llk, float model_execution_time, const bool nav_enabled, const bool valid); |
|
||||||
void fill_pose_msg(MessageBuilder &msg, float *net_outputs, uint32_t vipc_frame_id, uint32_t vipc_dropped_frames, uint64_t timestamp_eof, const bool valid); |
|
@ -1,25 +0,0 @@ |
|||||||
# distutils: language = c++ |
|
||||||
|
|
||||||
from libcpp cimport bool |
|
||||||
from libc.stdint cimport uint32_t, uint64_t |
|
||||||
|
|
||||||
cdef extern from "cereal/messaging/messaging.h": |
|
||||||
cdef cppclass MessageBuilder: |
|
||||||
size_t getSerializedSize() |
|
||||||
int serializeToBuffer(unsigned char *, size_t) |
|
||||||
|
|
||||||
cdef extern from "selfdrive/modeld/models/driving.h": |
|
||||||
cdef int FEATURE_LEN |
|
||||||
cdef int HISTORY_BUFFER_LEN |
|
||||||
cdef int DESIRE_LEN |
|
||||||
cdef int TRAFFIC_CONVENTION_LEN |
|
||||||
cdef int DRIVING_STYLE_LEN |
|
||||||
cdef int NAV_FEATURE_LEN |
|
||||||
cdef int NAV_INSTRUCTION_LEN |
|
||||||
cdef int OUTPUT_SIZE |
|
||||||
cdef int NET_OUTPUT_SIZE |
|
||||||
cdef int MODEL_FREQ |
|
||||||
cdef struct PublishState: pass |
|
||||||
|
|
||||||
void fill_model_msg(MessageBuilder, float *, PublishState, uint32_t, uint32_t, uint32_t, float, uint64_t, uint64_t, float, bool, bool) |
|
||||||
void fill_pose_msg(MessageBuilder, float *, uint32_t, uint32_t, uint64_t, bool) |
|
@ -1,52 +0,0 @@ |
|||||||
# distutils: language = c++ |
|
||||||
# cython: c_string_encoding=ascii |
|
||||||
|
|
||||||
import numpy as np |
|
||||||
cimport numpy as cnp |
|
||||||
from libcpp cimport bool |
|
||||||
from libc.string cimport memcpy |
|
||||||
from libc.stdint cimport uint32_t, uint64_t |
|
||||||
|
|
||||||
from .commonmodel cimport mat3 |
|
||||||
from .driving cimport FEATURE_LEN as CPP_FEATURE_LEN, HISTORY_BUFFER_LEN as CPP_HISTORY_BUFFER_LEN, DESIRE_LEN as CPP_DESIRE_LEN, \ |
|
||||||
TRAFFIC_CONVENTION_LEN as CPP_TRAFFIC_CONVENTION_LEN, DRIVING_STYLE_LEN as CPP_DRIVING_STYLE_LEN, \ |
|
||||||
NAV_FEATURE_LEN as CPP_NAV_FEATURE_LEN, NAV_INSTRUCTION_LEN as CPP_NAV_INSTRUCTION_LEN, \ |
|
||||||
OUTPUT_SIZE as CPP_OUTPUT_SIZE, NET_OUTPUT_SIZE as CPP_NET_OUTPUT_SIZE, MODEL_FREQ as CPP_MODEL_FREQ |
|
||||||
from .driving cimport MessageBuilder, PublishState as cppPublishState |
|
||||||
from .driving cimport fill_model_msg, fill_pose_msg |
|
||||||
|
|
||||||
FEATURE_LEN = CPP_FEATURE_LEN |
|
||||||
HISTORY_BUFFER_LEN = CPP_HISTORY_BUFFER_LEN |
|
||||||
DESIRE_LEN = CPP_DESIRE_LEN |
|
||||||
TRAFFIC_CONVENTION_LEN = CPP_TRAFFIC_CONVENTION_LEN |
|
||||||
DRIVING_STYLE_LEN = CPP_DRIVING_STYLE_LEN |
|
||||||
NAV_FEATURE_LEN = CPP_NAV_FEATURE_LEN |
|
||||||
NAV_INSTRUCTION_LEN = CPP_NAV_INSTRUCTION_LEN |
|
||||||
OUTPUT_SIZE = CPP_OUTPUT_SIZE |
|
||||||
NET_OUTPUT_SIZE = CPP_NET_OUTPUT_SIZE |
|
||||||
MODEL_FREQ = CPP_MODEL_FREQ |
|
||||||
|
|
||||||
cdef class PublishState: |
|
||||||
cdef cppPublishState state |
|
||||||
|
|
||||||
def create_model_msg(float[:] model_outputs, PublishState ps, uint32_t vipc_frame_id, uint32_t vipc_frame_id_extra, uint32_t frame_id, float frame_drop, |
|
||||||
uint64_t timestamp_eof, uint64_t timestamp_llk, float model_execution_time, bool nav_enabled, bool valid): |
|
||||||
cdef MessageBuilder msg |
|
||||||
fill_model_msg(msg, &model_outputs[0], ps.state, vipc_frame_id, vipc_frame_id_extra, frame_id, frame_drop, |
|
||||||
timestamp_eof, timestamp_llk, model_execution_time, nav_enabled, valid) |
|
||||||
|
|
||||||
output_size = msg.getSerializedSize() |
|
||||||
output_data = bytearray(output_size) |
|
||||||
cdef unsigned char * output_ptr = output_data |
|
||||||
assert msg.serializeToBuffer(output_ptr, output_size) > 0, "output buffer is too small to serialize" |
|
||||||
return bytes(output_data) |
|
||||||
|
|
||||||
def create_pose_msg(float[:] model_outputs, uint32_t vipc_frame_id, uint32_t vipc_dropped_frames, uint64_t timestamp_eof, bool valid): |
|
||||||
cdef MessageBuilder msg |
|
||||||
fill_pose_msg(msg, &model_outputs[0], vipc_frame_id, vipc_dropped_frames, timestamp_eof, valid) |
|
||||||
|
|
||||||
output_size = msg.getSerializedSize() |
|
||||||
output_data = bytearray(output_size) |
|
||||||
cdef unsigned char * output_ptr = output_data |
|
||||||
assert msg.serializeToBuffer(output_ptr, output_size) > 0, "output buffer is too small to serialize" |
|
||||||
return bytes(output_data) |
|
@ -0,0 +1,100 @@ |
|||||||
|
import numpy as np |
||||||
|
from typing import Dict |
||||||
|
from openpilot.selfdrive.modeld.constants import ModelConstants |
||||||
|
|
||||||
|
def sigmoid(x): |
||||||
|
return 1. / (1. + np.exp(-x)) |
||||||
|
|
||||||
|
def softmax(x, axis=-1): |
||||||
|
x -= np.max(x, axis=axis, keepdims=True) |
||||||
|
if x.dtype == np.float32 or x.dtype == np.float64: |
||||||
|
np.exp(x, out=x) |
||||||
|
else: |
||||||
|
x = np.exp(x) |
||||||
|
x /= np.sum(x, axis=axis, keepdims=True) |
||||||
|
return x |
||||||
|
|
||||||
|
class Parser: |
||||||
|
def __init__(self, ignore_missing=False): |
||||||
|
self.ignore_missing = ignore_missing |
||||||
|
|
||||||
|
def check_missing(self, outs, name): |
||||||
|
if name not in outs and not self.ignore_missing: |
||||||
|
raise ValueError(f"Missing output {name}") |
||||||
|
return name not in outs |
||||||
|
|
||||||
|
def parse_categorical_crossentropy(self, name, outs, out_shape=None): |
||||||
|
if self.check_missing(outs, name): |
||||||
|
return |
||||||
|
raw = outs[name] |
||||||
|
if out_shape is not None: |
||||||
|
raw = raw.reshape((raw.shape[0],) + out_shape) |
||||||
|
outs[name] = softmax(raw, axis=-1) |
||||||
|
|
||||||
|
def parse_binary_crossentropy(self, name, outs): |
||||||
|
if self.check_missing(outs, name): |
||||||
|
return |
||||||
|
raw = outs[name] |
||||||
|
outs[name] = sigmoid(raw) |
||||||
|
|
||||||
|
def parse_mdn(self, name, outs, in_N=0, out_N=1, out_shape=None): |
||||||
|
if self.check_missing(outs, name): |
||||||
|
return |
||||||
|
raw = outs[name] |
||||||
|
raw = raw.reshape((raw.shape[0], max(in_N, 1), -1)) |
||||||
|
|
||||||
|
pred_mu = raw[:,:,:(raw.shape[2] - out_N)//2] |
||||||
|
n_values = (raw.shape[2] - out_N)//2 |
||||||
|
pred_mu = raw[:,:,:n_values] |
||||||
|
pred_std = np.exp(raw[:,:,n_values: 2*n_values]) |
||||||
|
|
||||||
|
if in_N > 1: |
||||||
|
weights = np.zeros((raw.shape[0], in_N, out_N), dtype=raw.dtype) |
||||||
|
for i in range(out_N): |
||||||
|
weights[:,:,i - out_N] = softmax(raw[:,:,i - out_N], axis=-1) |
||||||
|
|
||||||
|
if out_N == 1: |
||||||
|
for fidx in range(weights.shape[0]): |
||||||
|
idxs = np.argsort(weights[fidx][:,0])[::-1] |
||||||
|
weights[fidx] = weights[fidx][idxs] |
||||||
|
pred_mu[fidx] = pred_mu[fidx][idxs] |
||||||
|
pred_std[fidx] = pred_std[fidx][idxs] |
||||||
|
full_shape = tuple([raw.shape[0], in_N] + list(out_shape)) |
||||||
|
outs[name + '_weights'] = weights |
||||||
|
outs[name + '_hypotheses'] = pred_mu.reshape(full_shape) |
||||||
|
outs[name + '_stds_hypotheses'] = pred_std.reshape(full_shape) |
||||||
|
|
||||||
|
pred_mu_final = np.zeros((raw.shape[0], out_N, n_values), dtype=raw.dtype) |
||||||
|
pred_std_final = np.zeros((raw.shape[0], out_N, n_values), dtype=raw.dtype) |
||||||
|
for fidx in range(weights.shape[0]): |
||||||
|
for hidx in range(out_N): |
||||||
|
idxs = np.argsort(weights[fidx,:,hidx])[::-1] |
||||||
|
pred_mu_final[fidx, hidx] = pred_mu[fidx, idxs[0]] |
||||||
|
pred_std_final[fidx, hidx] = pred_std[fidx, idxs[0]] |
||||||
|
else: |
||||||
|
pred_mu_final = pred_mu |
||||||
|
pred_std_final = pred_std |
||||||
|
|
||||||
|
if out_N > 1: |
||||||
|
final_shape = tuple([raw.shape[0], out_N] + list(out_shape)) |
||||||
|
else: |
||||||
|
final_shape = tuple([raw.shape[0],] + list(out_shape)) |
||||||
|
outs[name] = pred_mu_final.reshape(final_shape) |
||||||
|
outs[name + '_stds'] = pred_std_final.reshape(final_shape) |
||||||
|
|
||||||
|
def parse_outputs(self, outs: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: |
||||||
|
self.parse_mdn('plan', outs, in_N=ModelConstants.PLAN_MHP_N, out_N=ModelConstants.PLAN_MHP_SELECTION, |
||||||
|
out_shape=(ModelConstants.IDX_N,ModelConstants.PLAN_WIDTH)) |
||||||
|
self.parse_mdn('lane_lines', outs, in_N=0, out_N=0, out_shape=(ModelConstants.NUM_LANE_LINES,ModelConstants.IDX_N,ModelConstants.LANE_LINES_WIDTH)) |
||||||
|
self.parse_mdn('road_edges', outs, in_N=0, out_N=0, out_shape=(ModelConstants.NUM_ROAD_EDGES,ModelConstants.IDX_N,ModelConstants.LANE_LINES_WIDTH)) |
||||||
|
self.parse_mdn('pose', outs, in_N=0, out_N=0, out_shape=(ModelConstants.POSE_WIDTH,)) |
||||||
|
self.parse_mdn('road_transform', outs, in_N=0, out_N=0, out_shape=(ModelConstants.POSE_WIDTH,)) |
||||||
|
self.parse_mdn('sim_pose', outs, in_N=0, out_N=0, out_shape=(ModelConstants.POSE_WIDTH,)) |
||||||
|
self.parse_mdn('wide_from_device_euler', outs, in_N=0, out_N=0, out_shape=(ModelConstants.WIDE_FROM_DEVICE_WIDTH,)) |
||||||
|
self.parse_mdn('lead', outs, in_N=ModelConstants.LEAD_MHP_N, out_N=ModelConstants.LEAD_MHP_SELECTION, |
||||||
|
out_shape=(ModelConstants.LEAD_TRAJ_LEN,ModelConstants.LEAD_WIDTH)) |
||||||
|
for k in ['lead_prob', 'lane_lines_prob', 'meta']: |
||||||
|
self.parse_binary_crossentropy(k, outs) |
||||||
|
self.parse_categorical_crossentropy('desire_state', outs, out_shape=(ModelConstants.DESIRE_PRED_WIDTH,)) |
||||||
|
self.parse_categorical_crossentropy('desire_pred', outs, out_shape=(ModelConstants.DESIRE_PRED_LEN,ModelConstants.DESIRE_PRED_WIDTH)) |
||||||
|
return outs |
@ -1,32 +0,0 @@ |
|||||||
import struct |
|
||||||
import json |
|
||||||
|
|
||||||
def load_thneed(fn): |
|
||||||
with open(fn, "rb") as f: |
|
||||||
json_len = struct.unpack("I", f.read(4))[0] |
|
||||||
jdat = json.loads(f.read(json_len).decode('latin_1')) |
|
||||||
weights = f.read() |
|
||||||
ptr = 0 |
|
||||||
for o in jdat['objects']: |
|
||||||
if o['needs_load']: |
|
||||||
nptr = ptr + o['size'] |
|
||||||
o['data'] = weights[ptr:nptr] |
|
||||||
ptr = nptr |
|
||||||
for o in jdat['binaries']: |
|
||||||
nptr = ptr + o['length'] |
|
||||||
o['data'] = weights[ptr:nptr] |
|
||||||
ptr = nptr |
|
||||||
return jdat |
|
||||||
|
|
||||||
def save_thneed(jdat, fn): |
|
||||||
new_weights = [] |
|
||||||
for o in jdat['objects'] + jdat['binaries']: |
|
||||||
if 'data' in o: |
|
||||||
new_weights.append(o['data']) |
|
||||||
del o['data'] |
|
||||||
new_weights_bytes = b''.join(new_weights) |
|
||||||
with open(fn, "wb") as f: |
|
||||||
j = json.dumps(jdat, ensure_ascii=False).encode('latin_1') |
|
||||||
f.write(struct.pack("I", len(j))) |
|
||||||
f.write(j) |
|
||||||
f.write(new_weights_bytes) |
|
@ -1 +1 @@ |
|||||||
f851c7e7f90eff828a59444d20fac5df8cd7ae0c |
0e0f55cf3bb2cf79b44adf190e6387a83deb6646 |
||||||
|
Loading…
Reference in new issue