modeld: desire->desire_pulse (#36076)

consistent naming
pull/36072/head
ZwX1616 4 days ago committed by GitHub
parent 375dfe16a8
commit f8ff156869
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 12
      selfdrive/modeld/modeld.py
  2. 4
      selfdrive/modeld/models/driving_policy.onnx

@ -106,7 +106,7 @@ class ModelState:
# policy inputs
self.numpy_inputs = {
'desire': np.zeros((1, ModelConstants.INPUT_HISTORY_BUFFER_LEN, ModelConstants.DESIRE_LEN), dtype=np.float32),
'desire_pulse': np.zeros((1, ModelConstants.INPUT_HISTORY_BUFFER_LEN, ModelConstants.DESIRE_LEN), dtype=np.float32),
'traffic_convention': np.zeros((1, ModelConstants.TRAFFIC_CONVENTION_LEN), dtype=np.float32),
'features_buffer': np.zeros((1, ModelConstants.INPUT_HISTORY_BUFFER_LEN, ModelConstants.FEATURE_LEN), dtype=np.float32),
}
@ -131,13 +131,13 @@ class ModelState:
def run(self, bufs: dict[str, VisionBuf], transforms: dict[str, np.ndarray],
inputs: dict[str, np.ndarray], prepare_only: bool) -> dict[str, np.ndarray] | None:
# Model decides when action is completed, so desire input is just a pulse triggered on rising edge
inputs['desire'][0] = 0
new_desire = np.where(inputs['desire'] - self.prev_desire > .99, inputs['desire'], 0)
self.prev_desire[:] = inputs['desire']
inputs['desire_pulse'][0] = 0
new_desire = np.where(inputs['desire_pulse'] - self.prev_desire > .99, inputs['desire_pulse'], 0)
self.prev_desire[:] = inputs['desire_pulse']
self.full_desire[0,:-1] = self.full_desire[0,1:]
self.full_desire[0,-1] = new_desire
self.numpy_inputs['desire'][:] = self.full_desire.reshape((1,ModelConstants.INPUT_HISTORY_BUFFER_LEN,ModelConstants.TEMPORAL_SKIP,-1)).max(axis=2)
self.numpy_inputs['desire_pulse'][:] = self.full_desire.reshape((1,ModelConstants.INPUT_HISTORY_BUFFER_LEN,ModelConstants.TEMPORAL_SKIP,-1)).max(axis=2)
self.numpy_inputs['traffic_convention'][:] = inputs['traffic_convention']
imgs_cl = {name: self.frames[name].prepare(bufs[name], transforms[name].flatten()) for name in self.vision_input_names}
@ -313,7 +313,7 @@ def main(demo=False):
bufs = {name: buf_extra if 'big' in name else buf_main for name in model.vision_input_names}
transforms = {name: model_transform_extra if 'big' in name else model_transform_main for name in model.vision_input_names}
inputs:dict[str, np.ndarray] = {
'desire': vec_desire,
'desire_pulse': vec_desire,
'traffic_convention': traffic_convention,
}

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:04b763fb71efe57a8a4c4168a8043ecd58939015026ded0dc755ded6905ac251
size 12343523
oid sha256:72e98a95541f200bd2faeae8d718997483696fd4801fc7d718c167b05854707d
size 12343535

Loading…
Cancel
Save