|
|
@ -56,7 +56,10 @@ class ModelState: |
|
|
|
prev_desire: np.ndarray # for tracking the rising edge of the pulse |
|
|
|
prev_desire: np.ndarray # for tracking the rising edge of the pulse |
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, context: CLContext): |
|
|
|
def __init__(self, context: CLContext): |
|
|
|
self.frames = {'input_imgs': DrivingModelFrame(context, ModelConstants.TEMPORAL_SKIP), 'big_input_imgs': DrivingModelFrame(context, ModelConstants.TEMPORAL_SKIP)} |
|
|
|
self.frames = { |
|
|
|
|
|
|
|
'input_imgs': DrivingModelFrame(context, ModelConstants.TEMPORAL_SKIP), |
|
|
|
|
|
|
|
'big_input_imgs': DrivingModelFrame(context, ModelConstants.TEMPORAL_SKIP) |
|
|
|
|
|
|
|
} |
|
|
|
self.prev_desire = np.zeros(ModelConstants.DESIRE_LEN, dtype=np.float32) |
|
|
|
self.prev_desire = np.zeros(ModelConstants.DESIRE_LEN, dtype=np.float32) |
|
|
|
|
|
|
|
|
|
|
|
self.full_features_buffer = np.zeros((1, ModelConstants.FULL_HISTORY_BUFFER_LEN, ModelConstants.FEATURE_LEN), dtype=np.float32) |
|
|
|
self.full_features_buffer = np.zeros((1, ModelConstants.FULL_HISTORY_BUFFER_LEN, ModelConstants.FEATURE_LEN), dtype=np.float32) |
|
|
|