modeld: more readable fill_model (#22540)

* more readable fill plan_t_attr

* remove parameter column_offset from fill_xyzt
pull/22551/head
Dean Lee 4 years ago committed by GitHub
parent a437fda2fe
commit ac9bef1f64
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 51
      selfdrive/modeld/models/driving.cc

@ -239,8 +239,8 @@ void fill_meta(cereal::ModelDataV2::MetaData::Builder meta, const float *meta_da
meta.setHardBrakePredicted(above_fcw_threshold);
}
void fill_xyzt(cereal::ModelDataV2::XYZTData::Builder xyzt, const float * data,
int columns, int column_offset, float * plan_t_arr, bool fill_std) {
void fill_xyzt(cereal::ModelDataV2::XYZTData::Builder xyzt, const float *data, int columns,
float *plan_t_arr = nullptr, bool fill_std = false) {
float x_arr[TRAJECTORY_SIZE] = {};
float y_arr[TRAJECTORY_SIZE] = {};
float z_arr[TRAJECTORY_SIZE] = {};
@ -249,20 +249,20 @@ void fill_xyzt(cereal::ModelDataV2::XYZTData::Builder xyzt, const float * data,
float z_std_arr[TRAJECTORY_SIZE];
float t_arr[TRAJECTORY_SIZE];
for (int i=0; i<TRAJECTORY_SIZE; i++) {
// column_offset == -1 means this data is X indexed not T indexed
if (column_offset >= 0) {
// plan_t_arr != nullptr means this data is X indexed not T indexed
if (plan_t_arr == nullptr) {
t_arr[i] = T_IDXS[i];
x_arr[i] = data[i*columns + 0 + column_offset];
x_std_arr[i] = data[columns*(TRAJECTORY_SIZE + i) + 0 + column_offset];
x_arr[i] = data[i*columns + 0];
x_std_arr[i] = data[columns*(TRAJECTORY_SIZE + i) + 0];
} else {
t_arr[i] = plan_t_arr[i];
x_arr[i] = X_IDXS[i];
x_std_arr[i] = NAN;
}
y_arr[i] = data[i*columns + 1 + column_offset];
y_std_arr[i] = data[columns*(TRAJECTORY_SIZE + i) + 1 + column_offset];
z_arr[i] = data[i*columns + 2 + column_offset];
z_std_arr[i] = data[columns*(TRAJECTORY_SIZE + i) + 2 + column_offset];
y_arr[i] = data[i*columns + 1];
y_std_arr[i] = data[columns*(TRAJECTORY_SIZE + i) + 1];
z_arr[i] = data[i*columns + 2];
z_std_arr[i] = data[columns*(TRAJECTORY_SIZE + i) + 2];
}
xyzt.setX(x_arr);
xyzt.setY(y_arr);
@ -276,40 +276,39 @@ void fill_xyzt(cereal::ModelDataV2::XYZTData::Builder xyzt, const float * data,
}
void fill_model(cereal::ModelDataV2::Builder &framed, const ModelDataRaw &net_outputs) {
// plan
const float *best_plan = get_plan_data(net_outputs.plan);
float plan_t_arr[TRAJECTORY_SIZE];
std::fill_n(plan_t_arr, TRAJECTORY_SIZE, NAN);
plan_t_arr[0] = 0.0;
for (int xidx=1, tidx=0; xidx<TRAJECTORY_SIZE; xidx++) {
// increment tidx until we find an element that's further away than the current xidx
while (tidx < TRAJECTORY_SIZE-1 && best_plan[(tidx+1)*PLAN_MHP_COLUMNS] < X_IDXS[xidx]) {
for (int next_tid = tidx + 1; next_tid < TRAJECTORY_SIZE && best_plan[next_tid*PLAN_MHP_COLUMNS] < X_IDXS[xidx]; next_tid++) {
tidx++;
}
float current_x_val = best_plan[tidx*PLAN_MHP_COLUMNS];
float next_x_val = best_plan[(tidx+1)*PLAN_MHP_COLUMNS];
if (next_x_val < X_IDXS[xidx]) {
if (tidx == TRAJECTORY_SIZE - 1) {
// if the plan doesn't extend far enough, set plan_t to the max value (10s), then break
plan_t_arr[xidx] = T_IDXS[TRAJECTORY_SIZE-1];
plan_t_arr[xidx] = T_IDXS[TRAJECTORY_SIZE - 1];
break;
} else {
// otherwise, interpolate to find `t` for the current xidx
float p = (X_IDXS[xidx] - current_x_val) / (next_x_val - current_x_val);
plan_t_arr[xidx] = p * T_IDXS[tidx+1] + (1 - p) * T_IDXS[tidx];
}
// interpolate to find `t` for the current xidx
float current_x_val = best_plan[tidx * PLAN_MHP_COLUMNS];
float next_x_val = best_plan[(tidx+1) * PLAN_MHP_COLUMNS];
float p = (X_IDXS[xidx] - current_x_val) / (next_x_val - current_x_val);
plan_t_arr[xidx] = p * T_IDXS[tidx+1] + (1 - p) * T_IDXS[tidx];
}
fill_xyzt(framed.initPosition(), best_plan, PLAN_MHP_COLUMNS, 0, plan_t_arr, true);
fill_xyzt(framed.initVelocity(), best_plan, PLAN_MHP_COLUMNS, 3, plan_t_arr, false);
fill_xyzt(framed.initOrientation(), best_plan, PLAN_MHP_COLUMNS, 9, plan_t_arr, false);
fill_xyzt(framed.initOrientationRate(), best_plan, PLAN_MHP_COLUMNS, 12, plan_t_arr, false);
fill_xyzt(framed.initPosition(), &best_plan[0], PLAN_MHP_COLUMNS, nullptr, true);
fill_xyzt(framed.initVelocity(), &best_plan[3], PLAN_MHP_COLUMNS);
fill_xyzt(framed.initOrientation(), &best_plan[9], PLAN_MHP_COLUMNS);
fill_xyzt(framed.initOrientationRate(), &best_plan[12], PLAN_MHP_COLUMNS);
// lane lines
auto lane_lines = framed.initLaneLines(4);
float lane_line_probs_arr[4];
float lane_line_stds_arr[4];
for (int i = 0; i < 4; i++) {
fill_xyzt(lane_lines[i], &net_outputs.lane_lines[i*TRAJECTORY_SIZE*2], 2, -1, plan_t_arr, false);
fill_xyzt(lane_lines[i], &net_outputs.lane_lines[i*TRAJECTORY_SIZE*2-1], 2, plan_t_arr);
lane_line_probs_arr[i] = sigmoid(net_outputs.lane_lines_prob[i*2+1]);
lane_line_stds_arr[i] = exp(net_outputs.lane_lines[2*TRAJECTORY_SIZE*(4 + i)]);
}
@ -320,7 +319,7 @@ void fill_model(cereal::ModelDataV2::Builder &framed, const ModelDataRaw &net_ou
auto road_edges = framed.initRoadEdges(2);
float road_edge_stds_arr[2];
for (int i = 0; i < 2; i++) {
fill_xyzt(road_edges[i], &net_outputs.road_edges[i*TRAJECTORY_SIZE*2], 2, -1, plan_t_arr, false);
fill_xyzt(road_edges[i], &net_outputs.road_edges[i*TRAJECTORY_SIZE*2-1], 2, plan_t_arr);
road_edge_stds_arr[i] = exp(net_outputs.road_edges[2*TRAJECTORY_SIZE*(2 + i)]);
}
framed.setRoadEdgeStds(road_edge_stds_arr);

Loading…
Cancel
Save