@ -56,16 +56,24 @@ class ModelState:
prev_desire : np . ndarray # for tracking the rising edge of the pulse
prev_desire : np . ndarray # for tracking the rising edge of the pulse
def __init__ ( self , context : CLContext ) :
def __init__ ( self , context : CLContext ) :
self . frames = { ' input_imgs ' : DrivingModelFrame ( context ) , ' big_input_imgs ' : DrivingModelFrame ( context ) }
self . frames = {
' input_imgs ' : DrivingModelFrame ( context , ModelConstants . TEMPORAL_SKIP ) ,
' big_input_imgs ' : DrivingModelFrame ( context , ModelConstants . TEMPORAL_SKIP )
}
self . prev_desire = np . zeros ( ModelConstants . DESIRE_LEN , dtype = np . float32 )
self . prev_desire = np . zeros ( ModelConstants . DESIRE_LEN , dtype = np . float32 )
self . full_features_buffer = np . zeros ( ( 1 , ModelConstants . FULL_HISTORY_BUFFER_LEN , ModelConstants . FEATURE_LEN ) , dtype = np . float32 )
self . full_desire = np . zeros ( ( 1 , ModelConstants . FULL_HISTORY_BUFFER_LEN , ModelConstants . DESIRE_LEN ) , dtype = np . float32 )
self . full_prev_desired_curv = np . zeros ( ( 1 , ModelConstants . FULL_HISTORY_BUFFER_LEN , ModelConstants . PREV_DESIRED_CURV_LEN ) , dtype = np . float32 )
self . temporal_idxs = slice ( - 1 - ( ModelConstants . TEMPORAL_SKIP * ( ModelConstants . INPUT_HISTORY_BUFFER_LEN - 1 ) ) , None , ModelConstants . TEMPORAL_SKIP )
# policy inputs
# policy inputs
self . numpy_inputs = {
self . numpy_inputs = {
' desire ' : np . zeros ( ( 1 , ModelConstants . FULL_HISTORY_BUFFER_LEN , ModelConstants . DESIRE_LEN ) , dtype = np . float32 ) ,
' desire ' : np . zeros ( ( 1 , ModelConstants . INPUT _HISTORY_BUFFER_LEN, ModelConstants . DESIRE_LEN ) , dtype = np . float32 ) ,
' traffic_convention ' : np . zeros ( ( 1 , ModelConstants . TRAFFIC_CONVENTION_LEN ) , dtype = np . float32 ) ,
' traffic_convention ' : np . zeros ( ( 1 , ModelConstants . TRAFFIC_CONVENTION_LEN ) , dtype = np . float32 ) ,
' lateral_control_params ' : np . zeros ( ( 1 , ModelConstants . LATERAL_CONTROL_PARAMS_LEN ) , dtype = np . float32 ) ,
' lateral_control_params ' : np . zeros ( ( 1 , ModelConstants . LATERAL_CONTROL_PARAMS_LEN ) , dtype = np . float32 ) ,
' prev_desired_curv ' : np . zeros ( ( 1 , ModelConstants . FULL _HISTORY_BUFFER_LEN, ModelConstants . PREV_DESIRED_CURV_LEN ) , dtype = np . float32 ) ,
' prev_desired_curv ' : np . zeros ( ( 1 , ModelConstants . INPUT _HISTORY_BUFFER_LEN, ModelConstants . PREV_DESIRED_CURV_LEN ) , dtype = np . float32 ) ,
' features_buffer ' : np . zeros ( ( 1 , ModelConstants . FULL _HISTORY_BUFFER_LEN, ModelConstants . FEATURE_LEN ) , dtype = np . float32 ) ,
' features_buffer ' : np . zeros ( ( 1 , ModelConstants . INPUT _HISTORY_BUFFER_LEN, ModelConstants . FEATURE_LEN ) , dtype = np . float32 ) ,
}
}
with open ( VISION_METADATA_PATH , ' rb ' ) as f :
with open ( VISION_METADATA_PATH , ' rb ' ) as f :
@ -104,8 +112,9 @@ class ModelState:
new_desire = np . where ( inputs [ ' desire ' ] - self . prev_desire > .99 , inputs [ ' desire ' ] , 0 )
new_desire = np . where ( inputs [ ' desire ' ] - self . prev_desire > .99 , inputs [ ' desire ' ] , 0 )
self . prev_desire [ : ] = inputs [ ' desire ' ]
self . prev_desire [ : ] = inputs [ ' desire ' ]
self . numpy_inputs [ ' desire ' ] [ 0 , : - 1 ] = self . numpy_inputs [ ' desire ' ] [ 0 , 1 : ]
self . full_desire [ 0 , : - 1 ] = self . full_desire [ 0 , 1 : ]
self . numpy_inputs [ ' desire ' ] [ 0 , - 1 ] = new_desire
self . full_desire [ 0 , - 1 ] = new_desire
self . numpy_inputs [ ' desire ' ] [ : ] = self . full_desire . reshape ( ( 1 , ModelConstants . INPUT_HISTORY_BUFFER_LEN , ModelConstants . TEMPORAL_SKIP , - 1 ) ) . max ( axis = 2 )
self . numpy_inputs [ ' traffic_convention ' ] [ : ] = inputs [ ' traffic_convention ' ]
self . numpy_inputs [ ' traffic_convention ' ] [ : ] = inputs [ ' traffic_convention ' ]
self . numpy_inputs [ ' lateral_control_params ' ] [ : ] = inputs [ ' lateral_control_params ' ]
self . numpy_inputs [ ' lateral_control_params ' ] [ : ] = inputs [ ' lateral_control_params ' ]
@ -128,15 +137,17 @@ class ModelState:
self . vision_output = self . vision_run ( * * self . vision_inputs ) . numpy ( ) . flatten ( )
self . vision_output = self . vision_run ( * * self . vision_inputs ) . numpy ( ) . flatten ( )
vision_outputs_dict = self . parser . parse_vision_outputs ( self . slice_outputs ( self . vision_output , self . vision_output_slices ) )
vision_outputs_dict = self . parser . parse_vision_outputs ( self . slice_outputs ( self . vision_output , self . vision_output_slices ) )
self . numpy_inputs [ ' features_buffer ' ] [ 0 , : - 1 ] = self . numpy_inputs [ ' features_buffer ' ] [ 0 , 1 : ]
self . full_features_buffer [ 0 , : - 1 ] = self . full_features_buffer [ 0 , 1 : ]
self . numpy_inputs [ ' features_buffer ' ] [ 0 , - 1 ] = vision_outputs_dict [ ' hidden_state ' ] [ 0 , : ]
self . full_features_buffer [ 0 , - 1 ] = vision_outputs_dict [ ' hidden_state ' ] [ 0 , : ]
self . numpy_inputs [ ' features_buffer ' ] [ : ] = self . full_features_buffer [ 0 , self . temporal_idxs ]
self . policy_output = self . policy_run ( * * self . policy_inputs ) . numpy ( ) . flatten ( )
self . policy_output = self . policy_run ( * * self . policy_inputs ) . numpy ( ) . flatten ( )
policy_outputs_dict = self . parser . parse_policy_outputs ( self . slice_outputs ( self . policy_output , self . policy_output_slices ) )
policy_outputs_dict = self . parser . parse_policy_outputs ( self . slice_outputs ( self . policy_output , self . policy_output_slices ) )
# TODO model only uses last value now
# TODO model only uses last value now
self . numpy_inputs [ ' prev_desired_curv ' ] [ 0 , : - 1 ] = self . numpy_inputs [ ' prev_desired_curv ' ] [ 0 , 1 : ]
self . full_prev_desired_curv [ 0 , : - 1 ] = self . full_prev_desired_curv [ 0 , 1 : ]
self . numpy_inputs [ ' prev_desired_curv ' ] [ 0 , - 1 , : ] = policy_outputs_dict [ ' desired_curvature ' ] [ 0 , : ]
self . full_prev_desired_curv [ 0 , - 1 , : ] = policy_outputs_dict [ ' desired_curvature ' ] [ 0 , : ]
self . numpy_inputs [ ' prev_desired_curv ' ] [ : ] = self . full_prev_desired_curv [ 0 , self . temporal_idxs ]
combined_outputs_dict = { * * vision_outputs_dict , * * policy_outputs_dict }
combined_outputs_dict = { * * vision_outputs_dict , * * policy_outputs_dict }
if SEND_RAW_PRED :
if SEND_RAW_PRED :