* start again

* need that too

* this actually works

* not needed

* do properly

* still works

* still works

* still good

* all G without ll

* still works

* all still good

* cleanup building

* cleanup sconscript

* new lane planner

* how on earth is this silent too....

* update

* add rotation radius

* update

* pathplanner first pass

* misc fixes

* fix

* need deep_interp

* local again

* fix

* fix test

* very old

* new replay

* interp properly

* correct length

* another horrible silent bug

* like master

* fix that

* do doubles

* different delay compensation

* make robust to empty msg

* make pass with hack for now

* add some extra

* update ref for increased leg

* test cpu usage on this pr

* tiny bit faster

* purge numpy

* update ref

* not needed

* ready for merge

* try again after recompile

Co-authored-by: Adeeb Shihadeh <adeebshihadeh@gmail.com>
old-commit-hash: 158210cde8
commatwo_master
HaraldSchafer 4 years ago committed by GitHub
parent d66a310f3e
commit c6d9b9565a
  1. 22
      common/numpy_helpers.py
  2. 24
      selfdrive/common/modeldata.h
  3. 3
      selfdrive/controls/lib/drive_helpers.py
  4. 88
      selfdrive/controls/lib/lane_planner.py
  5. 1
      selfdrive/controls/lib/lateral_mpc/SConscript
  6. 78
      selfdrive/controls/lib/lateral_mpc/generator.cpp
  7. 64
      selfdrive/controls/lib/lateral_mpc/lateral_mpc.c
  8. 4
      selfdrive/controls/lib/lateral_mpc/lib_mpc_export/acado_common.h
  9. 4
      selfdrive/controls/lib/lateral_mpc/lib_mpc_export/acado_integrator.c
  10. 2
      selfdrive/controls/lib/lateral_mpc/lib_mpc_export/acado_qpoases_interface.cpp
  11. 4
      selfdrive/controls/lib/lateral_mpc/lib_mpc_export/acado_qpoases_interface.hpp
  12. 4
      selfdrive/controls/lib/lateral_mpc/lib_mpc_export/acado_solver.c
  13. 19
      selfdrive/controls/lib/lateral_mpc/libmpc_py.py
  14. 84
      selfdrive/controls/lib/pathplanner.py
  15. 2
      selfdrive/controls/lib/planner.py
  16. 6
      selfdrive/controls/plannerd.py
  17. 41
      selfdrive/controls/tests/test_lateral_mpc.py
  18. 27
      selfdrive/debug/mpc/test_mpc_wobble.py
  19. 11
      selfdrive/modeld/models/driving.cc
  20. 4
      selfdrive/test/longitudinal_maneuvers/plant.py
  21. 2
      selfdrive/test/process_replay/model_replay_ref_commit
  22. 2
      selfdrive/test/process_replay/process_replay.py
  23. 2
      selfdrive/test/process_replay/ref_commit
  24. 3
      selfdrive/test/process_replay/test_processes.py
  25. 11
      tools/replay/lib/ui_helpers.py

@ -0,0 +1,22 @@
import numpy as np
def deep_interp_np(x, xp, fp, axis=None):
if axis is not None:
fp = fp.swapaxes(0,axis)
x = np.atleast_1d(x)
xp = np.array(xp)
if len(xp) < 2:
return np.repeat(fp, len(x), axis=0)
if min(np.diff(xp)) < 0:
raise RuntimeError('Bad x array for interpolation')
j = np.searchsorted(xp, x) - 1
j = np.clip(j, 0, len(xp)-2)
d = np.divide(x - xp[j], xp[j + 1] - xp[j], out=np.ones_like(x, dtype=np.float64), where=xp[j + 1] - xp[j] != 0)
vals_interp = (fp[j].T*(1 - d)).T + (fp[j + 1].T*d).T
if axis is not None:
vals_interp = vals_interp.swapaxes(0,axis)
if len(vals_interp) == 1:
return vals_interp[0]
else:
return vals_interp

@ -1,10 +1,18 @@
#pragma once #pragma once
const int TRAJECTORY_SIZE = 33;
const float MIN_DRAW_DISTANCE = 10.0;
const float MAX_DRAW_DISTANCE = 100.0;
constexpr int MODEL_PATH_DISTANCE = 192; const double T_IDXS[TRAJECTORY_SIZE] = {0. , 0.00976562, 0.0390625 , 0.08789062, 0.15625 ,
constexpr int TRAJECTORY_SIZE = 33; 0.24414062, 0.3515625 , 0.47851562, 0.625 , 0.79101562,
constexpr float MIN_DRAW_DISTANCE = 10.0; 0.9765625 , 1.18164062, 1.40625 , 1.65039062, 1.9140625 ,
constexpr float MAX_DRAW_DISTANCE = 100.0; 2.19726562, 2.5 , 2.82226562, 3.1640625 , 3.52539062,
constexpr int POLYFIT_DEGREE = 4; 3.90625 , 4.30664062, 4.7265625 , 5.16601562, 5.625 ,
constexpr int SPEED_PERCENTILES = 10; 6.10351562, 6.6015625 , 7.11914062, 7.65625 , 8.21289062,
constexpr int DESIRE_PRED_SIZE = 32; 8.7890625 , 9.38476562, 10.};
constexpr int OTHER_META_SIZE = 4; const double X_IDXS[TRAJECTORY_SIZE] = { 0. , 0.1875, 0.75 , 1.6875, 3. , 4.6875,
6.75 , 9.1875, 12. , 15.1875, 18.75 , 22.6875,
27. , 31.6875, 36.75 , 42.1875, 48. , 54.1875,
60.75 , 67.6875, 75. , 82.6875, 90.75 , 99.1875,
108. , 117.1875, 126.75 , 136.6875, 147. , 157.6875,
168.75 , 180.1875, 192.};

@ -7,11 +7,12 @@ V_CRUISE_MAX = 144
V_CRUISE_MIN = 8 V_CRUISE_MIN = 8
V_CRUISE_DELTA = 8 V_CRUISE_DELTA = 8
V_CRUISE_ENABLE_MIN = 40 V_CRUISE_ENABLE_MIN = 40
MPC_N = 16
CAR_ROTATION_RADIUS = 1.5
class MPC_COST_LAT: class MPC_COST_LAT:
PATH = 1.0 PATH = 1.0
LANE = 3.0
HEADING = 1.0 HEADING = 1.0
STEER_RATE = 1.0 STEER_RATE = 1.0

@ -3,103 +3,79 @@ import numpy as np
from cereal import log from cereal import log
CAMERA_OFFSET = 0.06 # m from center car to camera CAMERA_OFFSET = 0.06 # m from center car to camera
TRAJECTORY_SIZE = 33
def compute_path_pinv(length=50):
deg = 3
x = np.arange(length*1.0)
X = np.vstack(tuple(x**n for n in range(deg, -1, -1))).T
pinv = np.linalg.pinv(X)
return pinv
def model_polyfit(points, path_pinv):
return np.dot(path_pinv, [float(x) for x in points])
def eval_poly(poly, x):
return poly[3] + poly[2]*x + poly[1]*x**2 + poly[0]*x**3
class LanePlanner: class LanePlanner:
def __init__(self): def __init__(self):
self.l_poly = [0., 0., 0., 0.] self.lane_t = np.zeros((TRAJECTORY_SIZE,))
self.r_poly = [0., 0., 0., 0.] self.lll_y = np.zeros((TRAJECTORY_SIZE,))
self.p_poly = [0., 0., 0., 0.] self.rll_y = np.zeros((TRAJECTORY_SIZE,))
self.d_poly = [0., 0., 0., 0.]
self.lane_width_estimate = 3.7 self.lane_width_estimate = 3.7
self.lane_width_certainty = 1.0 self.lane_width_certainty = 1.0
self.lane_width = 3.7 self.lane_width = 3.7
self.l_prob = 0. self.lll_prob = 0.
self.r_prob = 0. self.rll_prob = 0.
self.l_std = 0. self.lll_std = 0.
self.r_std = 0. self.rll_std = 0.
self.l_lane_change_prob = 0. self.l_lane_change_prob = 0.
self.r_lane_change_prob = 0. self.r_lane_change_prob = 0.
self._path_pinv = compute_path_pinv()
self.x_points = np.arange(50)
def parse_model(self, md): def parse_model(self, md):
if len(md.leftLane.poly): if len(md.laneLines) == 4 and len(md.laneLines[0].t) == TRAJECTORY_SIZE:
self.l_poly = np.array(md.leftLane.poly) self.ll_t = (np.array(md.laneLines[1].t) + np.array(md.laneLines[2].t))/2
self.l_std = float(md.leftLane.std) # left and right ll x is the same
self.r_poly = np.array(md.rightLane.poly) self.ll_x = md.laneLines[1].x
self.r_std = float(md.rightLane.std) # only offset left and right lane lines; offsetting path does not make sense
self.p_poly = np.array(md.path.poly) self.lll_y = np.array(md.laneLines[1].y) - CAMERA_OFFSET
else: self.rll_y = np.array(md.laneLines[2].y) - CAMERA_OFFSET
self.l_poly = model_polyfit(md.leftLane.points, self._path_pinv) # left line self.lll_prob = md.laneLineProbs[1]
self.r_poly = model_polyfit(md.rightLane.points, self._path_pinv) # right line self.rll_prob = md.laneLineProbs[2]
self.p_poly = model_polyfit(md.path.points, self._path_pinv) # predicted path self.lll_std = md.laneLineStds[1]
self.l_prob = md.leftLane.prob # left line prob self.rll_std = md.laneLineStds[2]
self.r_prob = md.rightLane.prob # right line prob
if len(md.meta.desireState): if len(md.meta.desireState):
self.l_lane_change_prob = md.meta.desireState[log.PathPlan.Desire.laneChangeLeft] self.l_lane_change_prob = md.meta.desireState[log.PathPlan.Desire.laneChangeLeft]
self.r_lane_change_prob = md.meta.desireState[log.PathPlan.Desire.laneChangeRight] self.r_lane_change_prob = md.meta.desireState[log.PathPlan.Desire.laneChangeRight]
def update_d_poly(self, v_ego): def get_d_path(self, v_ego, path_t, path_xyz):
# only offset left and right lane lines; offsetting p_poly does not make sense
self.l_poly[3] += CAMERA_OFFSET
self.r_poly[3] += CAMERA_OFFSET
# Reduce reliance on lanelines that are too far apart or # Reduce reliance on lanelines that are too far apart or
# will be in a few seconds # will be in a few seconds
l_prob, r_prob = self.l_prob, self.r_prob l_prob, r_prob = self.lll_prob, self.rll_prob
width_poly = self.l_poly - self.r_poly width_pts = self.rll_y - self.lll_y
prob_mods = [] prob_mods = []
for t_check in [0.0, 1.5, 3.0]: for t_check in [0.0, 1.5, 3.0]:
width_at_t = eval_poly(width_poly, t_check * (v_ego + 7)) width_at_t = interp(t_check * (v_ego + 7), self.ll_x, width_pts)
prob_mods.append(interp(width_at_t, [4.0, 5.0], [1.0, 0.0])) prob_mods.append(interp(width_at_t, [4.0, 5.0], [1.0, 0.0]))
mod = min(prob_mods) mod = min(prob_mods)
l_prob *= mod l_prob *= mod
r_prob *= mod r_prob *= mod
# Reduce reliance on uncertain lanelines # Reduce reliance on uncertain lanelines
l_std_mod = interp(self.l_std, [.15, .3], [1.0, 0.0]) l_std_mod = interp(self.lll_std, [.15, .3], [1.0, 0.0])
r_std_mod = interp(self.r_std, [.15, .3], [1.0, 0.0]) r_std_mod = interp(self.rll_std, [.15, .3], [1.0, 0.0])
l_prob *= l_std_mod l_prob *= l_std_mod
r_prob *= r_std_mod r_prob *= r_std_mod
# Find current lanewidth # Find current lanewidth
self.lane_width_certainty += 0.05 * (l_prob * r_prob - self.lane_width_certainty) self.lane_width_certainty += 0.05 * (l_prob * r_prob - self.lane_width_certainty)
current_lane_width = abs(self.l_poly[3] - self.r_poly[3]) current_lane_width = abs(self.rll_y[0] - self.lll_y[0])
self.lane_width_estimate += 0.005 * (current_lane_width - self.lane_width_estimate) self.lane_width_estimate += 0.005 * (current_lane_width - self.lane_width_estimate)
speed_lane_width = interp(v_ego, [0., 31.], [2.8, 3.5]) speed_lane_width = interp(v_ego, [0., 31.], [2.8, 3.5])
self.lane_width = self.lane_width_certainty * self.lane_width_estimate + \ self.lane_width = self.lane_width_certainty * self.lane_width_estimate + \
(1 - self.lane_width_certainty) * speed_lane_width (1 - self.lane_width_certainty) * speed_lane_width
clipped_lane_width = min(4.0, self.lane_width) clipped_lane_width = min(4.0, self.lane_width)
path_from_left_lane = self.l_poly.copy() path_from_left_lane = self.lll_y + clipped_lane_width / 2.0
path_from_left_lane[3] -= clipped_lane_width / 2.0 path_from_right_lane = self.rll_y - clipped_lane_width / 2.0
path_from_right_lane = self.r_poly.copy()
path_from_right_lane[3] += clipped_lane_width / 2.0
lr_prob = l_prob + r_prob - l_prob * r_prob lr_prob = l_prob + r_prob - l_prob * r_prob
lane_path_y = (l_prob * path_from_left_lane + r_prob * path_from_right_lane) / (l_prob + r_prob + 0.0001)
d_poly_lane = (l_prob * path_from_left_lane + r_prob * path_from_right_lane) / (l_prob + r_prob + 0.0001) lane_path_y_interp = np.interp(path_t, self.ll_t, lane_path_y)
self.d_poly = lr_prob * d_poly_lane + (1.0 - lr_prob) * self.p_poly.copy() path_xyz[:,1] = lr_prob * lane_path_y_interp + (1.0 - lr_prob) * path_xyz[:,1]
return path_xyz

@ -1,6 +1,7 @@
Import('env', 'arch') Import('env', 'arch')
cpp_path = [ cpp_path = [
"#selfdrive",
"#phonelibs/acado/include", "#phonelibs/acado/include",
"#phonelibs/acado/include/acado", "#phonelibs/acado/include/acado",
"#phonelibs/qpoases/INCLUDE", "#phonelibs/qpoases/INCLUDE",

@ -1,10 +1,10 @@
#include <acado_code_generation.hpp> #include <acado_code_generation.hpp>
#include "common/modeldata.h"
#define PI 3.1415926536 #define PI 3.1415926536
#define deg2rad(d) (d/180.0*PI) #define deg2rad(d) (d/180.0*PI)
const int controlHorizon = 50; const int N_steps = 16;
using namespace std; using namespace std;
int main( ) int main( )
@ -20,51 +20,32 @@ int main( )
DifferentialState delta; DifferentialState delta;
OnlineData curvature_factor; OnlineData curvature_factor;
OnlineData v_ref; // m/s OnlineData v_poly_r0, v_poly_r1, v_poly_r2, v_poly_r3;
OnlineData l_poly_r0, l_poly_r1, l_poly_r2, l_poly_r3; OnlineData rotation_radius;
OnlineData r_poly_r0, r_poly_r1, r_poly_r2, r_poly_r3;
OnlineData d_poly_r0, d_poly_r1, d_poly_r2, d_poly_r3;
OnlineData l_prob, r_prob;
OnlineData lane_width;
Control t; Control t;
auto poly_v = v_poly_r0*(xx*xx*xx) + v_poly_r1*(xx*xx) + v_poly_r2*xx + v_poly_r3;
// Equations of motion // Equations of motion
f << dot(xx) == v_ref * cos(psi); f << dot(xx) == poly_v * cos(psi) - rotation_radius * sin(psi) * (poly_v * delta *curvature_factor);
f << dot(yy) == v_ref * sin(psi); f << dot(yy) == poly_v * sin(psi) + rotation_radius * cos(psi) * (poly_v * delta *curvature_factor);
f << dot(psi) == v_ref * delta * curvature_factor; f << dot(psi) == poly_v * delta * curvature_factor;
f << dot(delta) == t; f << dot(delta) == t;
auto lr_prob = l_prob + r_prob - l_prob * r_prob;
auto poly_l = l_poly_r0*(xx*xx*xx) + l_poly_r1*(xx*xx) + l_poly_r2*xx + l_poly_r3;
auto poly_r = r_poly_r0*(xx*xx*xx) + r_poly_r1*(xx*xx) + r_poly_r2*xx + r_poly_r3;
auto poly_d = d_poly_r0*(xx*xx*xx) + d_poly_r1*(xx*xx) + d_poly_r2*xx + d_poly_r3;
auto angle_d = atan(3*d_poly_r0*xx*xx + 2*d_poly_r1*xx + d_poly_r2);
// When the lane is not visible, use an estimate of its position
auto weighted_left_lane = l_prob * poly_l + (1 - l_prob) * (poly_d + lane_width/2.0);
auto weighted_right_lane = r_prob * poly_r + (1 - r_prob) * (poly_d - lane_width/2.0);
auto c_left_lane = exp(-(weighted_left_lane - yy));
auto c_right_lane = exp(weighted_right_lane - yy);
// Running cost // Running cost
Function h; Function h;
// Distance errors // Distance errors
h << poly_d - yy; h << yy;
h << lr_prob * c_left_lane;
h << lr_prob * c_right_lane;
// Heading error // Heading error
h << (v_ref + 1.0 ) * (angle_d - psi); h << (v_poly_r3 + 1.0 ) * psi;
// Angular rate error // Angular rate error
h << (v_ref + 1.0 ) * t; h << (v_poly_r3 + 1.0 ) * t;
BMatrix Q(5,5); Q.setAll(true); BMatrix Q(3,3); Q.setAll(true);
// Q(0,0) = 1.0; // Q(0,0) = 1.0;
// Q(1,1) = 1.0; // Q(1,1) = 1.0;
// Q(2,2) = 1.0; // Q(2,2) = 1.0;
@ -75,34 +56,21 @@ int main( )
Function hN; Function hN;
// Distance errors // Distance errors
hN << poly_d - yy; hN << yy;
hN << l_prob * c_left_lane;
hN << r_prob * c_right_lane;
// Heading errors // Heading errors
hN << (2.0 * v_ref + 1.0 ) * (angle_d - psi); hN << (2.0 * v_poly_r3 + 1.0 ) * psi;
BMatrix QN(4,4); QN.setAll(true); BMatrix QN(2,2); QN.setAll(true);
// QN(0,0) = 1.0; // QN(0,0) = 1.0;
// QN(1,1) = 1.0; // QN(1,1) = 1.0;
// QN(2,2) = 1.0; // QN(2,2) = 1.0;
// QN(3,3) = 1.0; // QN(3,3) = 1.0;
// Non uniform time grid double T_IDXS_ARR[N_steps + 1];
// First 5 timesteps are 0.05, after that it's 0.15 memcpy(T_IDXS_ARR, T_IDXS, (N_steps + 1) * sizeof(double));
DMatrix numSteps(20, 1); Grid times(N_steps + 1, T_IDXS_ARR);
for (int i = 0; i < 5; i++){ OCP ocp(times);
numSteps(i) = 1;
}
for (int i = 5; i < 20; i++){
numSteps(i) = 3;
}
// Setup Optimal Control Problem
const double tStart = 0.0;
const double tEnd = 2.5;
OCP ocp( tStart, tEnd, numSteps);
ocp.subjectTo(f); ocp.subjectTo(f);
ocp.minimizeLSQ(Q, h); ocp.minimizeLSQ(Q, h);
@ -112,14 +80,14 @@ int main( )
ocp.subjectTo( deg2rad(-90) <= psi <= deg2rad(90)); ocp.subjectTo( deg2rad(-90) <= psi <= deg2rad(90));
// more than absolute max steer angle // more than absolute max steer angle
ocp.subjectTo( deg2rad(-50) <= delta <= deg2rad(50)); ocp.subjectTo( deg2rad(-50) <= delta <= deg2rad(50));
ocp.setNOD(17); ocp.setNOD(6);
OCPexport mpc(ocp); OCPexport mpc(ocp);
mpc.set( HESSIAN_APPROXIMATION, GAUSS_NEWTON ); mpc.set( HESSIAN_APPROXIMATION, GAUSS_NEWTON );
mpc.set( DISCRETIZATION_TYPE, MULTIPLE_SHOOTING ); mpc.set( DISCRETIZATION_TYPE, MULTIPLE_SHOOTING );
mpc.set( INTEGRATOR_TYPE, INT_RK4 ); mpc.set( INTEGRATOR_TYPE, INT_RK4 );
mpc.set( NUM_INTEGRATOR_STEPS, 1 * controlHorizon); mpc.set( NUM_INTEGRATOR_STEPS, 2500);
mpc.set( MAX_NUM_QP_ITERATIONS, 500); mpc.set( MAX_NUM_QP_ITERATIONS, 1000);
mpc.set( CG_USE_VARIABLE_WEIGHTING_MATRIX, YES); mpc.set( CG_USE_VARIABLE_WEIGHTING_MATRIX, YES);
mpc.set( SPARSE_QP_SOLUTION, CONDENSING ); mpc.set( SPARSE_QP_SOLUTION, CONDENSING );

@ -1,6 +1,6 @@
#include "acado_common.h" #include "acado_common.h"
#include "acado_auxiliary_functions.h" #include "acado_auxiliary_functions.h"
#include "common/modeldata.h"
#include <stdio.h> #include <stdio.h>
#define NX ACADO_NX /* Number of differential state variables. */ #define NX ACADO_NX /* Number of differential state variables. */
@ -20,7 +20,6 @@ typedef struct {
double x, y, psi, delta, t; double x, y, psi, delta, t;
} state_t; } state_t;
typedef struct { typedef struct {
double x[N+1]; double x[N+1];
double y[N+1]; double y[N+1];
@ -30,35 +29,28 @@ typedef struct {
double cost; double cost;
} log_t; } log_t;
void init_weights(double pathCost, double laneCost, double headingCost, double steerRateCost){ void init_weights(double pathCost, double headingCost, double steerRateCost){
int i; int i;
const int STEP_MULTIPLIER = 3; const int STEP_MULTIPLIER = 3.0;
for (i = 0; i < N; i++) { for (i = 0; i < N; i++) {
int f = 1; double f = 20 * (T_IDXS[i+1] - T_IDXS[i]);
if (i > 4){
f = STEP_MULTIPLIER;
}
// Setup diagonal entries // Setup diagonal entries
acadoVariables.W[NY*NY*i + (NY+1)*0] = pathCost * f; acadoVariables.W[NY*NY*i + (NY+1)*0] = pathCost * f;
acadoVariables.W[NY*NY*i + (NY+1)*1] = laneCost * f; acadoVariables.W[NY*NY*i + (NY+1)*1] = headingCost * f;
acadoVariables.W[NY*NY*i + (NY+1)*2] = laneCost * f; acadoVariables.W[NY*NY*i + (NY+1)*2] = steerRateCost * f;
acadoVariables.W[NY*NY*i + (NY+1)*3] = headingCost * f;
acadoVariables.W[NY*NY*i + (NY+1)*4] = steerRateCost * f;
} }
acadoVariables.WN[(NYN+1)*0] = pathCost * STEP_MULTIPLIER; acadoVariables.WN[(NYN+1)*0] = pathCost * STEP_MULTIPLIER;
acadoVariables.WN[(NYN+1)*1] = laneCost * STEP_MULTIPLIER; acadoVariables.WN[(NYN+1)*1] = headingCost * STEP_MULTIPLIER;
acadoVariables.WN[(NYN+1)*2] = laneCost * STEP_MULTIPLIER;
acadoVariables.WN[(NYN+1)*3] = headingCost * STEP_MULTIPLIER;
} }
void init(double pathCost, double laneCost, double headingCost, double steerRateCost){ void init(double pathCost, double headingCost, double steerRateCost){
acado_initializeSolver(); acado_initializeSolver();
int i; int i;
/* Initialize the states and controls. */ /* Initialize the states and controls. */
for (i = 0; i < NX * (N + 1); ++i) acadoVariables.x[ i ] = 0.0; for (i = 0; i < NX * (N + 1); ++i) acadoVariables.x[ i ] = 0.0;
for (i = 0; i < NU * N; ++i) acadoVariables.u[ i ] = 0.1; for (i = 0; i < NU * N; ++i) acadoVariables.u[ i ] = 0.0;
/* Initialize the measurements/reference. */ /* Initialize the measurements/reference. */
for (i = 0; i < NY * N; ++i) acadoVariables.y[ i ] = 0.0; for (i = 0; i < NY * N; ++i) acadoVariables.y[ i ] = 0.0;
@ -67,40 +59,32 @@ void init(double pathCost, double laneCost, double headingCost, double steerRate
/* MPC: initialize the current state feedback. */ /* MPC: initialize the current state feedback. */
for (i = 0; i < NX; ++i) acadoVariables.x0[ i ] = 0.0; for (i = 0; i < NX; ++i) acadoVariables.x0[ i ] = 0.0;
init_weights(pathCost, laneCost, headingCost, steerRateCost); init_weights(pathCost, headingCost, steerRateCost);
} }
int run_mpc(state_t * x0, log_t * solution, int run_mpc(state_t * x0, log_t * solution, double v_poly[4],
double l_poly[4], double r_poly[4], double d_poly[4], double curvature_factor, double rotation_radius, double target_y[N+1], double target_psi[N+1]){
double l_prob, double r_prob, double curvature_factor, double v_ref, double lane_width){
int i; int i;
for (i = 0; i <= NOD * N; i+= NOD){ for (i = 0; i <= NOD * N; i+= NOD){
acadoVariables.od[i] = curvature_factor; acadoVariables.od[i] = curvature_factor;
acadoVariables.od[i+1] = v_ref;
acadoVariables.od[i+2] = l_poly[0];
acadoVariables.od[i+3] = l_poly[1];
acadoVariables.od[i+4] = l_poly[2];
acadoVariables.od[i+5] = l_poly[3];
acadoVariables.od[i+6] = r_poly[0]; acadoVariables.od[i+1] = v_poly[0];
acadoVariables.od[i+7] = r_poly[1]; acadoVariables.od[i+2] = v_poly[1];
acadoVariables.od[i+8] = r_poly[2]; acadoVariables.od[i+3] = v_poly[2];
acadoVariables.od[i+9] = r_poly[3]; acadoVariables.od[i+4] = v_poly[3];
acadoVariables.od[i+10] = d_poly[0]; acadoVariables.od[i+5] = rotation_radius;
acadoVariables.od[i+11] = d_poly[1];
acadoVariables.od[i+12] = d_poly[2];
acadoVariables.od[i+13] = d_poly[3];
acadoVariables.od[i+14] = l_prob;
acadoVariables.od[i+15] = r_prob;
acadoVariables.od[i+16] = lane_width;
} }
for (i = 0; i < N; i+= 1){
acadoVariables.y[NY*i + 0] = target_y[i];
acadoVariables.y[NY*i + 1] = (v_poly[3] + 1.0) * target_psi[i];
acadoVariables.y[NY*i + 2] = 0.0;
}
acadoVariables.yN[0] = target_y[N];
acadoVariables.yN[1] = (2.0 * v_poly[3] + 1.0) * target_psi[N];
acadoVariables.x0[0] = x0->x; acadoVariables.x0[0] = x0->x;
acadoVariables.x0[1] = x0->y; acadoVariables.x0[1] = x0->y;

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1 version https://git-lfs.github.com/spec/v1
oid sha256:b175a66de26ad7bd788086a2d6a7ef6243eb2a0aac1ddcff39b00554a8960d97 oid sha256:e15604230fe8c48c3448ec978b3b5a0c80b21cade787931acce50602190fca7b
size 8823 size 8755

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1 version https://git-lfs.github.com/spec/v1
oid sha256:5848ec6e7975d6fee93187e0f41d6cba57cc0ebee6edf63ebddf3c7ad6f8f52c oid sha256:2bd358ab623df9fbf4182ff955f04505df4abd83c2a0afd21a66d71f34aeda08
size 18622 size 25742

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1 version https://git-lfs.github.com/spec/v1
oid sha256:77977740e5e95a7a0e86ec4cc903a09fa528934d1221f7100499176006b6b8fd oid sha256:415810c92f48f825f81fb1c9fee16ed2997edf66ad51859e31ebcb9c1c034d7e
size 1948 size 1948

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1 version https://git-lfs.github.com/spec/v1
oid sha256:a5f24abe53c09556bfd27179c9255ce4674d88c38e6554d10e99b60ddd10d0c5 oid sha256:030e60a7796b3730a96d7157800ecc2d2390b8dbe2ebcd81a849b490cce3942a
size 1821 size 1822

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1 version https://git-lfs.github.com/spec/v1
oid sha256:a2c030dd09379475b0247609d8a02f161f3e468e85480740d4abcf9c80868de0 oid sha256:ee16cb2f641439c28e352ac0fe967a5cea95e7807074e40523d2e1f259fe84f5
size 390405 size 245177

@ -11,21 +11,22 @@ ffi.cdef("""
typedef struct { typedef struct {
double x, y, psi, delta, t; double x, y, psi, delta, t;
} state_t; } state_t;
int N = 16;
typedef struct { typedef struct {
double x[21]; double x[N+1];
double y[21]; double y[N+1];
double psi[21]; double psi[N+1];
double delta[21]; double delta[N+1];
double rate[20]; double rate[N];
double cost; double cost;
} log_t; } log_t;
void init(double pathCost, double laneCost, double headingCost, double steerRateCost); void init(double pathCost, double headingCost, double steerRateCost);
void init_weights(double pathCost, double laneCost, double headingCost, double steerRateCost); void init_weights(double pathCost, double headingCost, double steerRateCost);
int run_mpc(state_t * x0, log_t * solution, int run_mpc(state_t * x0, log_t * solution,
double l_poly[4], double r_poly[4], double d_poly[4], double v_poly[4], double curvature_factor, double rotation_radius,
double l_prob, double r_prob, double curvature_factor, double v_ref, double lane_width); double target_y[N+1], double target_psi[N+1]);
""") """)
libmpc = ffi.dlopen(libmpc_fn) libmpc = ffi.dlopen(libmpc_fn)

@ -1,10 +1,12 @@
import os import os
import math import math
import numpy as np
from common.realtime import sec_since_boot, DT_MDL from common.realtime import sec_since_boot, DT_MDL
from common.numpy_fast import interp
from selfdrive.swaglog import cloudlog from selfdrive.swaglog import cloudlog
from selfdrive.controls.lib.lateral_mpc import libmpc_py from selfdrive.controls.lib.lateral_mpc import libmpc_py
from selfdrive.controls.lib.drive_helpers import MPC_COST_LAT from selfdrive.controls.lib.drive_helpers import MPC_COST_LAT, MPC_N, CAR_ROTATION_RADIUS
from selfdrive.controls.lib.lane_planner import LanePlanner from selfdrive.controls.lib.lane_planner import LanePlanner, TRAJECTORY_SIZE
from selfdrive.config import Conversions as CV from selfdrive.config import Conversions as CV
from common.params import Params from common.params import Params
import cereal.messaging as messaging import cereal.messaging as messaging
@ -40,13 +42,6 @@ DESIRES = {
} }
def calc_states_after_delay(states, v_ego, steer_angle, curvature_factor, steer_ratio, delay):
states[0].x = v_ego * delay
states[0].psi = v_ego * curvature_factor * math.radians(steer_angle) / steer_ratio * delay
states[0].y = states[0].x * math.sin(states[0].psi / 2)
return states
class PathPlanner(): class PathPlanner():
def __init__(self, CP): def __init__(self, CP):
self.LP = LanePlanner() self.LP = LanePlanner()
@ -63,9 +58,13 @@ class PathPlanner():
self.lane_change_ll_prob = 1.0 self.lane_change_ll_prob = 1.0
self.prev_one_blinker = False self.prev_one_blinker = False
self.path_xyz = np.zeros((TRAJECTORY_SIZE,3))
self.plan_yaw = np.zeros((TRAJECTORY_SIZE,))
self.t_idxs = np.zeros((TRAJECTORY_SIZE,))
def setup_mpc(self): def setup_mpc(self):
self.libmpc = libmpc_py.libmpc self.libmpc = libmpc_py.libmpc
self.libmpc.init(MPC_COST_LAT.PATH, MPC_COST_LAT.LANE, MPC_COST_LAT.HEADING, self.steer_rate_cost) self.libmpc.init(MPC_COST_LAT.PATH, MPC_COST_LAT.HEADING, self.steer_rate_cost)
self.mpc_solution = libmpc_py.ffi.new("log_t *") self.mpc_solution = libmpc_py.ffi.new("log_t *")
self.cur_state = libmpc_py.ffi.new("state_t *") self.cur_state = libmpc_py.ffi.new("state_t *")
@ -96,7 +95,12 @@ class PathPlanner():
curvature_factor = VM.curvature_factor(v_ego) curvature_factor = VM.curvature_factor(v_ego)
self.LP.parse_model(sm['model']) md = sm['modelV2']
self.LP.parse_model(sm['modelV2'])
if len(md.position.x) == TRAJECTORY_SIZE and len(md.orientation.x) == TRAJECTORY_SIZE:
self.path_xyz = np.column_stack([md.position.x, md.position.y, md.position.z])
self.t_idxs = list(md.position.t)
self.plan_yaw = list(md.orientation.z)
# Lane change logic # Lane change logic
one_blinker = sm['carState'].leftBlinker != sm['carState'].rightBlinker one_blinker = sm['carState'].leftBlinker != sm['carState'].rightBlinker
@ -161,35 +165,52 @@ class PathPlanner():
# Turn off lanes during lane change # Turn off lanes during lane change
if desire == log.PathPlan.Desire.laneChangeRight or desire == log.PathPlan.Desire.laneChangeLeft: if desire == log.PathPlan.Desire.laneChangeRight or desire == log.PathPlan.Desire.laneChangeLeft:
self.LP.l_prob *= self.lane_change_ll_prob self.LP.lll_prob *= self.lane_change_ll_prob
self.LP.r_prob *= self.lane_change_ll_prob self.LP.rll_prob *= self.lane_change_ll_prob
self.LP.update_d_poly(v_ego) d_path_xyz = self.LP.get_d_path(v_ego, self.t_idxs, self.path_xyz)
y_pts = np.interp(self.t_idxs[:MPC_N+1], np.linalg.norm(d_path_xyz, axis=1)/v_ego, d_path_xyz[:,1])
# account for actuation delay heading_pts = np.interp(self.t_idxs[:MPC_N+1], np.linalg.norm(self.path_xyz, axis=1)/v_ego, self.plan_yaw)
self.cur_state = calc_states_after_delay(self.cur_state, v_ego, angle_steers - angle_offset, curvature_factor, VM.sR, CP.steerActuatorDelay)
# init state
self.cur_state.x = 0.0
self.cur_state.y = 0.0
self.cur_state.psi = 0.0
# TODO negative sign, still run mpc in ENU, make NED
self.cur_state.delta = -math.radians(angle_steers - angle_offset) / VM.sR
v_ego_mpc = max(v_ego, 5.0) # avoid mpc roughness due to low speed v_ego_mpc = max(v_ego, 5.0) # avoid mpc roughness due to low speed
v_poly = [0.0, 0.0, 0.0, v_ego_mpc]
assert len(v_poly) == 4
assert len(y_pts) == MPC_N + 1
assert len(heading_pts) == MPC_N + 1
self.libmpc.run_mpc(self.cur_state, self.mpc_solution, self.libmpc.run_mpc(self.cur_state, self.mpc_solution,
list(self.LP.l_poly), list(self.LP.r_poly), list(self.LP.d_poly), v_poly,
self.LP.l_prob, self.LP.r_prob, curvature_factor, v_ego_mpc, self.LP.lane_width) curvature_factor,
CAR_ROTATION_RADIUS,
list(y_pts),
list(heading_pts))
# TODO this needs more thought, use .2s extra for now to estimate other delays
delay = CP.steerActuatorDelay + .2
# TODO negative sign, still run mpc in ENU, make NED
next_delta = -interp(DT_MDL + delay, self.t_idxs[:MPC_N+1], self.mpc_solution.delta)
next_rate = -interp(delay, self.t_idxs[:MPC_N], self.mpc_solution.rate)
# reset to current steer angle if not active or overriding # reset to current steer angle if not active or overriding
if active: if active:
delta_desired = self.mpc_solution[0].delta[1] delta_desired = next_delta
rate_desired = math.degrees(self.mpc_solution[0].rate[0] * VM.sR) rate_desired = math.degrees(next_rate * VM.sR)
else: else:
delta_desired = math.radians(angle_steers - angle_offset) / VM.sR delta_desired = math.radians(angle_steers - angle_offset) / VM.sR
rate_desired = 0.0 rate_desired = 0.0
self.cur_state[0].delta = delta_desired
self.angle_steers_des_mpc = float(math.degrees(delta_desired * VM.sR) + angle_offset) self.angle_steers_des_mpc = float(math.degrees(delta_desired * VM.sR) + angle_offset)
# Check for infeasable MPC solution # Check for infeasable MPC solution
mpc_nans = any(math.isnan(x) for x in self.mpc_solution[0].delta) mpc_nans = any(math.isnan(x) for x in self.mpc_solution.delta)
t = sec_since_boot() t = sec_since_boot()
if mpc_nans: if mpc_nans:
self.libmpc.init(MPC_COST_LAT.PATH, MPC_COST_LAT.LANE, MPC_COST_LAT.HEADING, CP.steerRateCost) self.libmpc.init(MPC_COST_LAT.PATH, MPC_COST_LAT.HEADING, CP.steerRateCost)
self.cur_state[0].delta = math.radians(angle_steers - angle_offset) / VM.sR self.cur_state[0].delta = math.radians(angle_steers - angle_offset) / VM.sR
if t > self.last_cloudlog_t + 5.0: if t > self.last_cloudlog_t + 5.0:
@ -201,15 +222,14 @@ class PathPlanner():
else: else:
self.solution_invalid_cnt = 0 self.solution_invalid_cnt = 0
plan_solution_valid = self.solution_invalid_cnt < 2 plan_solution_valid = self.solution_invalid_cnt < 2
plan_send = messaging.new_message('pathPlan') plan_send = messaging.new_message('pathPlan')
plan_send.valid = sm.all_alive_and_valid(service_list=['carState', 'controlsState', 'liveParameters', 'model']) plan_send.valid = sm.all_alive_and_valid(service_list=['carState', 'controlsState', 'liveParameters', 'modelV2'])
plan_send.pathPlan.laneWidth = float(self.LP.lane_width) plan_send.pathPlan.laneWidth = float(self.LP.lane_width)
plan_send.pathPlan.dPoly = [float(x) for x in self.LP.d_poly] plan_send.pathPlan.dPoly = [0,0,0,0]
plan_send.pathPlan.lPoly = [float(x) for x in self.LP.l_poly] plan_send.pathPlan.lPoly = [0,0,0,0]
plan_send.pathPlan.lProb = float(self.LP.l_prob) plan_send.pathPlan.rPoly = [0,0,0,0]
plan_send.pathPlan.rPoly = [float(x) for x in self.LP.r_poly] plan_send.pathPlan.lProb = float(self.LP.lll_prob)
plan_send.pathPlan.rProb = float(self.LP.r_prob) plan_send.pathPlan.rProb = float(self.LP.rll_prob)
plan_send.pathPlan.angleSteers = float(self.angle_steers_des_mpc) plan_send.pathPlan.angleSteers = float(self.angle_steers_des_mpc)
plan_send.pathPlan.rateSteers = float(rate_desired) plan_send.pathPlan.rateSteers = float(rate_desired)

@ -186,7 +186,7 @@ class Planner():
plan_send.valid = sm.all_alive_and_valid(service_list=['carState', 'controlsState', 'radarState']) plan_send.valid = sm.all_alive_and_valid(service_list=['carState', 'controlsState', 'radarState'])
plan_send.plan.mdMonoTime = sm.logMonoTime['model'] plan_send.plan.mdMonoTime = sm.logMonoTime['modelV2']
plan_send.plan.radarStateMonoTime = sm.logMonoTime['radarState'] plan_send.plan.radarStateMonoTime = sm.logMonoTime['radarState']
# longitudal plan # longitudal plan

@ -23,8 +23,8 @@ def plannerd_thread(sm=None, pm=None):
VM = VehicleModel(CP) VM = VehicleModel(CP)
if sm is None: if sm is None:
sm = messaging.SubMaster(['carState', 'controlsState', 'radarState', 'model', 'liveParameters'], sm = messaging.SubMaster(['carState', 'controlsState', 'radarState', 'modelV2', 'liveParameters'],
poll=['radarState', 'model']) poll=['radarState', 'modelV2'])
if pm is None: if pm is None:
pm = messaging.PubMaster(['plan', 'liveLongitudinalMpc', 'pathPlan', 'liveMpc']) pm = messaging.PubMaster(['plan', 'liveLongitudinalMpc', 'pathPlan', 'liveMpc'])
@ -37,7 +37,7 @@ def plannerd_thread(sm=None, pm=None):
while True: while True:
sm.update() sm.update()
if sm.updated['model']: if sm.updated['modelV2']:
PP.update(sm, pm, CP, VM) PP.update(sm, pm, CP, VM)
if sm.updated['radarState']: if sm.updated['radarState']:
PL.update(sm, pm, CP, VM, PP) PL.update(sm, pm, CP, VM, PP)

@ -3,48 +3,36 @@ import numpy as np
from selfdrive.car.honda.interface import CarInterface from selfdrive.car.honda.interface import CarInterface
from selfdrive.controls.lib.lateral_mpc import libmpc_py from selfdrive.controls.lib.lateral_mpc import libmpc_py
from selfdrive.controls.lib.vehicle_model import VehicleModel from selfdrive.controls.lib.vehicle_model import VehicleModel
from selfdrive.controls.lib.drive_helpers import MPC_N, CAR_ROTATION_RADIUS
def run_mpc(v_ref=30., x_init=0., y_init=0., psi_init=0., delta_init=0., def run_mpc(v_ref=30., x_init=0., y_init=0., psi_init=0., delta_init=0.,
l_prob=1., r_prob=1., p_prob=1.,
poly_l=np.array([0., 0., 0., 1.8]), poly_r=np.array([0., 0., 0., -1.8]), poly_p=np.array([0., 0., 0., 0.]),
lane_width=3.6, poly_shift=0.): lane_width=3.6, poly_shift=0.):
libmpc = libmpc_py.libmpc libmpc = libmpc_py.libmpc
libmpc.init(1.0, 3.0, 1.0, 1.0) libmpc.init(1.0, 1.0, 1.0)
mpc_solution = libmpc_py.ffi.new("log_t *") mpc_solution = libmpc_py.ffi.new("log_t *")
p_l = poly_l.copy() y_pts = poly_shift * np.ones(MPC_N + 1)
p_l[3] += poly_shift heading_pts = np.zeros(MPC_N + 1)
p_r = poly_r.copy()
p_r[3] += poly_shift
p_p = poly_p.copy()
p_p[3] += poly_shift
d_poly = p_p
CP = CarInterface.get_params("HONDA CIVIC 2016 TOURING") CP = CarInterface.get_params("HONDA CIVIC 2016 TOURING")
VM = VehicleModel(CP) VM = VehicleModel(CP)
curvature_factor = VM.curvature_factor(v_ref) curvature_factor = VM.curvature_factor(v_ref)
l_poly = libmpc_py.ffi.new("double[4]", list(map(float, p_l)))
r_poly = libmpc_py.ffi.new("double[4]", list(map(float, p_r)))
d_poly = libmpc_py.ffi.new("double[4]", list(map(float, d_poly)))
cur_state = libmpc_py.ffi.new("state_t *") cur_state = libmpc_py.ffi.new("state_t *")
cur_state[0].x = x_init cur_state.x = x_init
cur_state[0].y = y_init cur_state.y = y_init
cur_state[0].psi = psi_init cur_state.psi = psi_init
cur_state[0].delta = delta_init cur_state.delta = delta_init
# converge in no more than 20 iterations # converge in no more than 20 iterations
for _ in range(20): for _ in range(20):
libmpc.run_mpc(cur_state, mpc_solution, l_poly, r_poly, d_poly, l_prob, r_prob, libmpc.run_mpc(cur_state, mpc_solution, [0,0,0,v_ref],
curvature_factor, v_ref, lane_width) curvature_factor, CAR_ROTATION_RADIUS,
list(y_pts), list(heading_pts))
return mpc_solution return mpc_solution
@ -100,13 +88,6 @@ class TestLateralMpc(unittest.TestCase):
sol.append(run_mpc(psi_init=psi_init)) sol.append(run_mpc(psi_init=psi_init))
self._assert_simmetry(sol) self._assert_simmetry(sol)
def test_prob_symmetry(self):
sol = []
lane_width = 3.
for r_prob in [0., 1.]:
sol.append(run_mpc(r_prob=r_prob, l_prob=1.-r_prob, lane_width=lane_width))
self._assert_simmetry(sol)
def test_y_shift_vs_poly_shift(self): def test_y_shift_vs_poly_shift(self):
shift = 1. shift = 1.
sol = [] sol = []

@ -2,11 +2,11 @@
# type: ignore # type: ignore
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from selfdrive.controls.lib.lateral_mpc import libmpc_py from selfdrive.controls.lib.lateral_mpc import libmpc_py
from selfdrive.controls.lib.drive_helpers import MPC_COST_LAT from selfdrive.controls.lib.drive_helpers import MPC_COST_LAT, MPC_N, CAR_ROTATION_RADIUS
import math import math
libmpc = libmpc_py.libmpc libmpc = libmpc_py.libmpc
libmpc.init(MPC_COST_LAT.PATH, MPC_COST_LAT.LANE, MPC_COST_LAT.HEADING, 1.) libmpc.init(MPC_COST_LAT.PATH, MPC_COST_LAT.HEADING, 1.)
cur_state = libmpc_py.ffi.new("state_t *") cur_state = libmpc_py.ffi.new("state_t *")
cur_state[0].x = 0.0 cur_state[0].x = 0.0
@ -24,30 +24,15 @@ times = []
curvature_factor = 0.3 curvature_factor = 0.3
v_ref = 1.0 * 20.12 # 45 mph v_ref = 1.0 * 20.12 # 45 mph
LANE_WIDTH = 3.7
p = [0.0, 0.0, 0.0, 0.0]
p_l = p[:]
p_l[3] += LANE_WIDTH / 2.0
p_r = p[:]
p_r[3] -= LANE_WIDTH / 2.0
l_poly = libmpc_py.ffi.new("double[4]", p_l)
r_poly = libmpc_py.ffi.new("double[4]", p_r)
p_poly = libmpc_py.ffi.new("double[4]", p)
l_prob = 1.0
r_prob = 1.0
p_prob = 1.0
for i in range(1): for i in range(1):
cur_state[0].delta = math.radians(510. / 13.) cur_state[0].delta = math.radians(510. / 13.)
libmpc.run_mpc(cur_state, mpc_solution, l_poly, r_poly, p_poly, l_prob, r_prob, libmpc.run_mpc(cur_state, mpc_solution, [0,0,0,v_ref],
curvature_factor, v_ref, LANE_WIDTH) curvature_factor, CAR_ROTATION_RADIUS,
[0.0]*MPC_N, [0.0]*MPC_N)
timesi = [] timesi = []
ct = 0 ct = 0
for i in range(21): for i in range(MPC_N + 1):
timesi.append(ct) timesi.append(ct)
if i <= 4: if i <= 4:
ct += 0.05 ct += 0.05

@ -10,6 +10,11 @@
#include "driving.h" #include "driving.h"
#include "clutil.h" #include "clutil.h"
constexpr int MODEL_PATH_DISTANCE = 192;
constexpr int POLYFIT_DEGREE = 4;
constexpr int DESIRE_PRED_SIZE = 32;
constexpr int OTHER_META_SIZE = 4;
constexpr int MODEL_WIDTH = 512; constexpr int MODEL_WIDTH = 512;
constexpr int MODEL_HEIGHT = 256; constexpr int MODEL_HEIGHT = 256;
constexpr int MODEL_FRAME_SIZE = MODEL_WIDTH * MODEL_HEIGHT * 3 / 2; constexpr int MODEL_FRAME_SIZE = MODEL_WIDTH * MODEL_HEIGHT * 3 / 2;
@ -28,8 +33,6 @@ constexpr int LEAD_MHP_GROUP_SIZE = (2*LEAD_MHP_VALS + LEAD_MHP_SELECTION);
constexpr int POSE_SIZE = 12; constexpr int POSE_SIZE = 12;
constexpr int MIN_VALID_LEN = 10; constexpr int MIN_VALID_LEN = 10;
constexpr int TRAJECTORY_TIME = 10;
constexpr float TRAJECTORY_DISTANCE = 192.0;
constexpr int PLAN_IDX = 0; constexpr int PLAN_IDX = 0;
constexpr int LL_IDX = PLAN_IDX + PLAN_MHP_N*PLAN_MHP_GROUP_SIZE; constexpr int LL_IDX = PLAN_IDX + PLAN_MHP_N*PLAN_MHP_GROUP_SIZE;
constexpr int LL_PROB_IDX = LL_IDX + 4*2*2*33; constexpr int LL_PROB_IDX = LL_IDX + 4*2*2*33;
@ -49,8 +52,6 @@ constexpr int OUTPUT_SIZE = POSE_IDX + POSE_SIZE;
// #define DUMP_YUV // #define DUMP_YUV
Eigen::Matrix<float, MODEL_PATH_DISTANCE, POLYFIT_DEGREE - 1> vander; Eigen::Matrix<float, MODEL_PATH_DISTANCE, POLYFIT_DEGREE - 1> vander;
float X_IDXS[TRAJECTORY_SIZE];
float T_IDXS[TRAJECTORY_SIZE];
void model_init(ModelState* s, cl_device_id device_id, cl_context context) { void model_init(ModelState* s, cl_device_id device_id, cl_context context) {
frame_init(&s->frame, MODEL_WIDTH, MODEL_HEIGHT, device_id, context); frame_init(&s->frame, MODEL_WIDTH, MODEL_HEIGHT, device_id, context);
@ -77,8 +78,6 @@ void model_init(ModelState* s, cl_device_id device_id, cl_context context) {
// Build Vandermonde matrix // Build Vandermonde matrix
for(int i = 0; i < TRAJECTORY_SIZE; i++) { for(int i = 0; i < TRAJECTORY_SIZE; i++) {
for(int j = 0; j < POLYFIT_DEGREE - 1; j++) { for(int j = 0; j < POLYFIT_DEGREE - 1; j++) {
X_IDXS[i] = (TRAJECTORY_DISTANCE/1024.0) * (pow(i,2));
T_IDXS[i] = (TRAJECTORY_TIME/1024.0) * (pow(i,2));
vander(i, j) = pow(X_IDXS[i], POLYFIT_DEGREE-j-1); vander(i, j) = pow(X_IDXS[i], POLYFIT_DEGREE-j-1);
} }
} }

@ -110,10 +110,12 @@ class Plant():
self.rate = rate self.rate = rate
if not Plant.messaging_initialized: if not Plant.messaging_initialized:
Plant.pm = messaging.PubMaster(['frame', 'frontFrame', 'ubloxRaw'])
Plant.pm = messaging.PubMaster(['frame', 'frontFrame', 'ubloxRaw', 'modelV2'])
Plant.logcan = messaging.pub_sock('can') Plant.logcan = messaging.pub_sock('can')
Plant.sendcan = messaging.sub_sock('sendcan') Plant.sendcan = messaging.sub_sock('sendcan')
Plant.model = messaging.pub_sock('model') Plant.model = messaging.pub_sock('model')
Plant.front_frame = messaging.pub_sock('frontFrame')
Plant.live_params = messaging.pub_sock('liveParameters') Plant.live_params = messaging.pub_sock('liveParameters')
Plant.live_location_kalman = messaging.pub_sock('liveLocationKalman') Plant.live_location_kalman = messaging.pub_sock('liveLocationKalman')
Plant.health = messaging.pub_sock('health') Plant.health = messaging.pub_sock('health')

@ -1 +1 @@
852c79998828975cce184114537b0067b80bc608 4d71a89ccbfd351cbe58fcf217ee2eefa48eee2d

@ -243,7 +243,7 @@ CONFIGS = [
ProcessConfig( ProcessConfig(
proc_name="plannerd", proc_name="plannerd",
pub_sub={ pub_sub={
"model": ["pathPlan"], "radarState": ["plan"], "modelV2": ["pathPlan"], "radarState": ["plan"],
"carState": [], "controlsState": [], "liveParameters": [], "carState": [], "controlsState": [], "liveParameters": [],
}, },
ignore=["logMonoTime", "valid", "plan.processingDelay"], ignore=["logMonoTime", "valid", "plan.processingDelay"],

@ -1 +1 @@
3964f847c722e6e6a4b3876bbe9e9c8a354fb7f8 859c964a01f994fb5873d5383af725af3263b4fd

@ -160,6 +160,9 @@ if __name__ == "__main__":
if (procs_whitelisted and cfg.proc_name not in args.whitelist_procs) or \ if (procs_whitelisted and cfg.proc_name not in args.whitelist_procs) or \
(not procs_whitelisted and cfg.proc_name in args.blacklist_procs): (not procs_whitelisted and cfg.proc_name in args.blacklist_procs):
continue continue
# TODO remove this hack
if cfg.proc_name == 'plannerd' and car_brand in ["GM", "SUBARU", "VOLKSWAGEN", "NISSAN"]:
continue
cmp_log_fn = os.path.join(process_replay_dir, "%s_%s_%s.bz2" % (segment, cfg.proc_name, ref_commit)) cmp_log_fn = os.path.join(process_replay_dir, "%s_%s_%s.bz2" % (segment, cfg.proc_name, ref_commit))
results[segment][cfg.proc_name] = test_process(cfg, lr, cmp_log_fn, args.ignore_fields, args.ignore_msgs) results[segment][cfg.proc_name] = test_process(cfg, lr, cmp_log_fn, args.ignore_fields, args.ignore_msgs)

@ -10,8 +10,6 @@ from common.transformations.camera import (eon_f_frame_size, eon_f_focal_length,
tici_f_frame_size, tici_f_focal_length) tici_f_frame_size, tici_f_focal_length)
from selfdrive.config import RADAR_TO_CAMERA from selfdrive.config import RADAR_TO_CAMERA
from selfdrive.config import UIParams as UP from selfdrive.config import UIParams as UP
from selfdrive.controls.lib.lane_planner import (compute_path_pinv,
model_polyfit)
from tools.lib.lazy_property import lazy_property from tools.lib.lazy_property import lazy_property
RED = (255, 0, 0) RED = (255, 0, 0)
@ -23,8 +21,6 @@ WHITE = (255, 255, 255)
_PATH_X = np.arange(192.) _PATH_X = np.arange(192.)
_PATH_XD = np.arange(192.) _PATH_XD = np.arange(192.)
_PATH_PINV = compute_path_pinv(50)
_FULL_FRAME_SIZE = { _FULL_FRAME_SIZE = {
} }
@ -247,14 +243,11 @@ def draw_var(y, x, var, color, img, calibration, top_down):
class ModelPoly(object): class ModelPoly(object):
def __init__(self, model_path): def __init__(self, model_path):
if len(model_path.points) == 0 and len(model_path.poly) == 0: if len(model_path.poly) == 0:
self.valid = False self.valid = False
return return
if len(model_path.poly): self.poly = np.array(model_path.poly)
self.poly = np.array(model_path.poly)
else:
self.poly = model_polyfit(model_path.points, _PATH_PINV)
self.prob = model_path.prob self.prob = model_path.prob
self.std = model_path.std self.std = model_path.std

Loading…
Cancel
Save