diff --git a/selfdrive/modeld/models/driving.cc b/selfdrive/modeld/models/driving.cc index b7d1c1ced9..259aafff4e 100644 --- a/selfdrive/modeld/models/driving.cc +++ b/selfdrive/modeld/models/driving.cc @@ -167,11 +167,11 @@ void fill_path(cereal::ModelData::PathData::Builder path, const float * data, bo path.setStd(std); } -void fill_lead(cereal::ModelData::LeadData::Builder lead, const float * data, int mdn_max_idx) { +void fill_lead(cereal::ModelData::LeadData::Builder lead, const float * data, int mdn_max_idx, int t_offset) { const double x_scale = 10.0; const double y_scale = 10.0; - lead.setProb(sigmoid(data[LEAD_MDN_N*MDN_GROUP_SIZE])); + lead.setProb(sigmoid(data[LEAD_MDN_N*MDN_GROUP_SIZE + t_offset])); lead.setDist(x_scale * data[mdn_max_idx*MDN_GROUP_SIZE]); lead.setStd(x_scale * softplus(data[mdn_max_idx*MDN_GROUP_SIZE + MDN_VALS])); lead.setRelY(y_scale * data[mdn_max_idx*MDN_GROUP_SIZE + 1]); @@ -228,22 +228,24 @@ void model_publish(PubSocket *sock, uint32_t frame_id, // Find the distribution that corresponds to the current lead int mdn_max_idx = 0; + int t_offset = 0; for (int i=1; i net_outputs.lead[mdn_max_idx*MDN_GROUP_SIZE + 8]) { + if (net_outputs.lead[i*MDN_GROUP_SIZE + 8 + t_offset] > net_outputs.lead[mdn_max_idx*MDN_GROUP_SIZE + 8 + t_offset]) { mdn_max_idx = i; } } auto lead = framed.initLead(); - fill_lead(lead, net_outputs.lead, mdn_max_idx); + fill_lead(lead, net_outputs.lead, mdn_max_idx, t_offset); // Find the distribution that corresponds to the lead in 2s mdn_max_idx = 0; + t_offset = 1; for (int i=1; i net_outputs.lead[mdn_max_idx*MDN_GROUP_SIZE + 9]) { + if (net_outputs.lead[i*MDN_GROUP_SIZE + 8 + t_offset] > net_outputs.lead[mdn_max_idx*MDN_GROUP_SIZE + 8 + t_offset]) { mdn_max_idx = i; } } auto lead_future = framed.initLeadFuture(); - fill_lead(lead_future, net_outputs.lead, mdn_max_idx); + fill_lead(lead_future, net_outputs.lead, mdn_max_idx, t_offset); auto meta = framed.initMeta();