|
|
|
@ -106,7 +106,7 @@ class ModelState: |
|
|
|
|
|
|
|
|
|
# policy inputs |
|
|
|
|
self.numpy_inputs = { |
|
|
|
|
'desire': np.zeros((1, ModelConstants.INPUT_HISTORY_BUFFER_LEN, ModelConstants.DESIRE_LEN), dtype=np.float32), |
|
|
|
|
'desire_pulse': np.zeros((1, ModelConstants.INPUT_HISTORY_BUFFER_LEN, ModelConstants.DESIRE_LEN), dtype=np.float32), |
|
|
|
|
'traffic_convention': np.zeros((1, ModelConstants.TRAFFIC_CONVENTION_LEN), dtype=np.float32), |
|
|
|
|
'features_buffer': np.zeros((1, ModelConstants.INPUT_HISTORY_BUFFER_LEN, ModelConstants.FEATURE_LEN), dtype=np.float32), |
|
|
|
|
} |
|
|
|
@ -131,13 +131,13 @@ class ModelState: |
|
|
|
|
def run(self, bufs: dict[str, VisionBuf], transforms: dict[str, np.ndarray], |
|
|
|
|
inputs: dict[str, np.ndarray], prepare_only: bool) -> dict[str, np.ndarray] | None: |
|
|
|
|
# Model decides when action is completed, so desire input is just a pulse triggered on rising edge |
|
|
|
|
inputs['desire'][0] = 0 |
|
|
|
|
new_desire = np.where(inputs['desire'] - self.prev_desire > .99, inputs['desire'], 0) |
|
|
|
|
self.prev_desire[:] = inputs['desire'] |
|
|
|
|
inputs['desire_pulse'][0] = 0 |
|
|
|
|
new_desire = np.where(inputs['desire_pulse'] - self.prev_desire > .99, inputs['desire_pulse'], 0) |
|
|
|
|
self.prev_desire[:] = inputs['desire_pulse'] |
|
|
|
|
|
|
|
|
|
self.full_desire[0,:-1] = self.full_desire[0,1:] |
|
|
|
|
self.full_desire[0,-1] = new_desire |
|
|
|
|
self.numpy_inputs['desire'][:] = self.full_desire.reshape((1,ModelConstants.INPUT_HISTORY_BUFFER_LEN,ModelConstants.TEMPORAL_SKIP,-1)).max(axis=2) |
|
|
|
|
self.numpy_inputs['desire_pulse'][:] = self.full_desire.reshape((1,ModelConstants.INPUT_HISTORY_BUFFER_LEN,ModelConstants.TEMPORAL_SKIP,-1)).max(axis=2) |
|
|
|
|
|
|
|
|
|
self.numpy_inputs['traffic_convention'][:] = inputs['traffic_convention'] |
|
|
|
|
imgs_cl = {name: self.frames[name].prepare(bufs[name], transforms[name].flatten()) for name in self.vision_input_names} |
|
|
|
@ -313,7 +313,7 @@ def main(demo=False): |
|
|
|
|
bufs = {name: buf_extra if 'big' in name else buf_main for name in model.vision_input_names} |
|
|
|
|
transforms = {name: model_transform_extra if 'big' in name else model_transform_main for name in model.vision_input_names} |
|
|
|
|
inputs:dict[str, np.ndarray] = { |
|
|
|
|
'desire': vec_desire, |
|
|
|
|
'desire_pulse': vec_desire, |
|
|
|
|
'traffic_convention': traffic_convention, |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|