openpilot is an open source driver assistance system. openpilot performs the functions of Automated Lane Centering and Adaptive Cruise Control for over 200 supported car makes and models.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

224 lines
6.0 KiB

#include "ekf_sym.h"
#include "logger/logger.h"
using namespace EKFS;
using namespace Eigen;
EKFSym::EKFSym(std::string name, Map<MatrixXdr> Q, Map<VectorXd> x_initial, Map<MatrixXdr> P_initial, int dim_main,
int dim_main_err, int N, int dim_augment, int dim_augment_err, std::vector<int> maha_test_kinds,
std::vector<int> quaternion_idxs, std::vector<std::string> global_vars, double max_rewind_age)
{
// TODO: add logger
this->ekf = ekf_lookup(name);
assert(this->ekf);
this->msckf = N > 0;
this->N = N;
this->dim_augment = dim_augment;
this->dim_augment_err = dim_augment_err;
this->dim_main = dim_main;
this->dim_main_err = dim_main_err;
this->dim_x = x_initial.rows();
this->dim_err = P_initial.rows();
assert(dim_main + dim_augment * N == dim_x);
assert(dim_main_err + dim_augment_err * N == this->dim_err);
assert(Q.rows() == P_initial.rows() && Q.cols() == P_initial.cols());
// kinds that should get mahalanobis distance
// tested for outlier rejection
this->maha_test_kinds = maha_test_kinds;
// quaternions need normalization
this->quaternion_idxs = quaternion_idxs;
this->global_vars = global_vars;
// Process noise
this->Q = Q;
this->max_rewind_age = max_rewind_age;
this->init_state(x_initial, P_initial, NAN);
}
void EKFSym::init_state(Map<VectorXd> state, Map<MatrixXdr> covs, double init_filter_time) {
this->x = state;
this->P = covs;
this->filter_time = init_filter_time;
this->augment_times = VectorXd::Zero(this->N);
this->reset_rewind();
}
VectorXd EKFSym::state() {
return this->x;
}
MatrixXdr EKFSym::covs() {
return this->P;
}
void EKFSym::set_filter_time(double t) {
this->filter_time = t;
}
double EKFSym::get_filter_time() {
return this->filter_time;
}
void EKFSym::normalize_quaternions() {
for(std::size_t i = 0; i < this->quaternion_idxs.size(); ++i) {
this->normalize_slice(this->quaternion_idxs[i], this->quaternion_idxs[i] + 4);
}
}
void EKFSym::normalize_slice(int slice_start, int slice_end_ex) {
this->x.block(slice_start, 0, slice_end_ex - slice_start, this->x.cols()).normalize();
}
void EKFSym::set_global(std::string global_var, double val) {
this->ekf->sets.at(global_var)(val);
}
std::optional<Estimate> EKFSym::predict_and_update_batch(double t, int kind, std::vector<Map<VectorXd>> z_map,
std::vector<Map<MatrixXdr>> R_map, std::vector<std::vector<double>> extra_args, bool augment)
{
// TODO handle rewinding at this level
std::deque<Observation> rewound;
if (!std::isnan(this->filter_time) && t < this->filter_time) {
if (this->rewind_t.empty() || t < this->rewind_t.front() || t < this->rewind_t.back() - this->max_rewind_age) {
LOGD("observation too old at %f with filter at %f, ignoring!", t, this->filter_time);
return std::nullopt;
}
rewound = this->rewind(t);
}
Observation obs;
obs.t = t;
obs.kind = kind;
obs.extra_args = extra_args;
for (Map<VectorXd> zi : z_map) {
obs.z.push_back(zi);
}
for (Map<MatrixXdr> Ri : R_map) {
obs.R.push_back(Ri);
}
std::optional<Estimate> res = std::make_optional(this->predict_and_update_batch(obs, augment));
// optional fast forward
while (!rewound.empty()) {
this->predict_and_update_batch(rewound.front(), false);
rewound.pop_front();
}
return res;
}
void EKFSym::reset_rewind() {
this->rewind_obscache.clear();
this->rewind_t.clear();
this->rewind_states.clear();
}
std::deque<Observation> EKFSym::rewind(double t) {
std::deque<Observation> rewound;
// rewind observations until t is after previous observation
while (this->rewind_t.back() > t) {
rewound.push_front(this->rewind_obscache.back());
this->rewind_t.pop_back();
this->rewind_states.pop_back();
this->rewind_obscache.pop_back();
}
// set the state to the time right before that
this->filter_time = this->rewind_t.back();
this->x = this->rewind_states.back().first;
this->P = this->rewind_states.back().second;
return rewound;
}
void EKFSym::checkpoint(Observation& obs) {
// push to rewinder
this->rewind_t.push_back(this->filter_time);
this->rewind_states.push_back(std::make_pair(this->x, this->P));
this->rewind_obscache.push_back(obs);
// only keep a certain number around
if (this->rewind_t.size() > REWIND_TO_KEEP) {
this->rewind_t.pop_front();
this->rewind_states.pop_front();
this->rewind_obscache.pop_front();
}
}
Estimate EKFSym::predict_and_update_batch(Observation& obs, bool augment) {
assert(obs.z.size() == obs.R.size());
assert(obs.z.size() == obs.extra_args.size());
this->predict(obs.t);
Estimate res;
res.t = obs.t;
res.kind = obs.kind;
res.z = obs.z;
res.extra_args = obs.extra_args;
res.xk1 = this->x;
res.Pk1 = this->P;
// update batch
std::vector<VectorXd> y;
for (int i = 0; i < obs.z.size(); i++) {
assert(obs.z[i].rows() == obs.R[i].rows());
assert(obs.z[i].rows() == obs.R[i].cols());
// update state
y.push_back(this->update(obs.kind, obs.z[i], obs.R[i], obs.extra_args[i]));
}
res.xk = this->x;
res.Pk = this->P;
res.y = y;
assert(!augment); // TODO
// if (augment) {
// this->augment();
// }
this->checkpoint(obs);
return res;
}
void EKFSym::predict(double t) {
// initialize time
if (std::isnan(this->filter_time)) {
this->filter_time = t;
}
// predict
double dt = t - this->filter_time;
assert(dt >= 0.0);
this->ekf->predict(this->x.data(), this->P.data(), this->Q.data(), dt);
this->normalize_quaternions();
this->filter_time = t;
}
VectorXd EKFSym::update(int kind, VectorXd z, MatrixXdr R, std::vector<double> extra_args) {
this->ekf->updates.at(kind)(this->x.data(), this->P.data(), z.data(), R.data(), extra_args.data());
this->normalize_quaternions();
if (this->msckf && std::find(this->feature_track_kinds.begin(), this->feature_track_kinds.end(), kind) != this->feature_track_kinds.end()) {
return z.head(z.rows() - extra_args.size());
}
return z;
}
extra_routine_t EKFSym::get_extra_routine(const std::string& routine) {
return this->ekf->extra_routines.at(routine);
}