use base class in car_kf

pull/1518/head
Willem Melching 5 years ago
parent 53c0214a65
commit 896bd1b5c7
  1. 2
      rednose_repo
  2. 63
      selfdrive/locationd/models/car_kf.py

@ -1 +1 @@
Subproject commit a6c02b647b288f4f7e50326996b38f7e21dc483c Subproject commit e50846d845f7fe542e2538476258b926b2643398

@ -1,16 +1,16 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import sys import sys
import math import math
import numpy as np import numpy as np
import sympy as sp import sympy as sp
from selfdrive.locationd.models.constants import ObservationKind from rednose import KalmanFilter
from rednose.helpers.ekf_sym import EKF_sym, gen_code from rednose.helpers.ekf_sym import EKF_sym, gen_code
from selfdrive.locationd.models.constants import ObservationKind
i = 0 i = 0
def _slice(n): def _slice(n):
global i global i
s = slice(i, i + n) s = slice(i, i + n)
@ -31,10 +31,10 @@ class States():
STEER_ANGLE = _slice(1) # [rad] STEER_ANGLE = _slice(1) # [rad]
class CarKalman(): class CarKalman(KalmanFilter):
name = 'car' name = 'car'
x_initial = np.array([ initial_x = np.array([
1.0, 1.0,
15.0, 15.0,
0.0, 0.0,
@ -66,7 +66,6 @@ class CarKalman():
ObservationKind.ROAD_FRAME_X_SPEED: np.atleast_2d(0.1**2), ObservationKind.ROAD_FRAME_X_SPEED: np.atleast_2d(0.1**2),
} }
maha_test_kinds = [] # [ObservationKind.ROAD_FRAME_YAW_RATE, ObservationKind.ROAD_FRAME_XY_SPEED]
global_vars = [ global_vars = [
sp.Symbol('mass'), sp.Symbol('mass'),
sp.Symbol('rotational_inertia'), sp.Symbol('rotational_inertia'),
@ -78,9 +77,8 @@ class CarKalman():
@staticmethod @staticmethod
def generate_code(generated_dir): def generate_code(generated_dir):
dim_state = CarKalman.x_initial.shape[0] dim_state = CarKalman.initial_x.shape[0]
name = CarKalman.name name = CarKalman.name
maha_test_kinds = CarKalman.maha_test_kinds
# globals # globals
m, j, aF, aR, cF_orig, cR_orig = CarKalman.global_vars m, j, aF, aR, cF_orig, cR_orig = CarKalman.global_vars
@ -137,57 +135,18 @@ class CarKalman():
[sp.Matrix([x]), ObservationKind.STIFFNESS, None], [sp.Matrix([x]), ObservationKind.STIFFNESS, None],
] ]
gen_code(generated_dir, name, f_sym, dt, state_sym, obs_eqs, dim_state, dim_state, maha_test_kinds=maha_test_kinds, global_vars=CarKalman.global_vars) gen_code(generated_dir, name, f_sym, dt, state_sym, obs_eqs, dim_state, dim_state, global_vars=CarKalman.global_vars)
def __init__(self, generated_dir, steer_ratio=15, stiffness_factor=1, angle_offset=0): def __init__(self, generated_dir, steer_ratio=15, stiffness_factor=1, angle_offset=0):
self.dim_state = self.x_initial.shape[0] dim_state = self.initial_x.shape[0]
x_init = self.x_initial dim_state_err = self.initial_P_diag.shape[0]
x_init = self.initial_x
x_init[States.STEER_RATIO] = steer_ratio x_init[States.STEER_RATIO] = steer_ratio
x_init[States.STIFFNESS] = stiffness_factor x_init[States.STIFFNESS] = stiffness_factor
x_init[States.ANGLE_OFFSET] = angle_offset x_init[States.ANGLE_OFFSET] = angle_offset
# init filter # init filter
self.filter = EKF_sym(generated_dir, self.name, self.Q, self.x_initial, self.P_initial, self.dim_state, self.dim_state, maha_test_kinds=self.maha_test_kinds, global_vars=self.global_vars) self.filter = EKF_sym(generated_dir, self.name, self.Q, self.initial_x, self.P_initial, dim_state, dim_state_err, global_vars=self.global_vars)
@property
def x(self):
return self.filter.state()
@property
def P(self):
return self.filter.covs()
def predict(self, t):
return self.filter.predict(t)
def rts_smooth(self, estimates):
return self.filter.rts_smooth(estimates, norm_quats=False)
def get_R(self, kind, n):
obs_noise = self.obs_noise[kind]
dim = obs_noise.shape[0]
R = np.zeros((n, dim, dim))
for i in range(n):
R[i, :, :] = obs_noise
return R
def init_state(self, state, covs_diag=None, covs=None, filter_time=None):
if covs_diag is not None:
P = np.diag(covs_diag)
elif covs is not None:
P = covs
else:
P = self.filter.covs()
self.filter.init_state(state, P, filter_time)
def predict_and_observe(self, t, kind, data, R=None):
if len(data) > 0:
data = np.atleast_2d(data)
if R is None:
R = self.get_R(kind, len(data))
self.filter.predict_and_update_batch(t, kind, data, R)
if __name__ == "__main__": if __name__ == "__main__":

Loading…
Cancel
Save