import onnx , yaml , tempfile , time , collections , pprint , argparse , json
from pathlib import Path
from tinygrad . frontend . onnx import OnnxRunner , onnx_load
from extra . onnx import get_onnx_ops
from extra . onnx_helpers import validate , get_example_inputs
def get_config ( root_path : Path ) :
ret = { }
for path in root_path . rglob ( " *config.json " ) :
config = json . load ( path . open ( ) )
if isinstance ( config , dict ) :
ret . update ( config )
return ret
def run_huggingface_validate ( onnx_model_path , config , rtol , atol ) :
onnx_model = onnx_load ( onnx_model_path )
onnx_runner = OnnxRunner ( onnx_model )
inputs = get_example_inputs ( onnx_runner . graph_inputs , config )
validate ( onnx_model_path , inputs , rtol = rtol , atol = atol )
def get_tolerances ( file_name ) : # -> rtol, atol
# TODO very high rtol atol
if " fp16 " in file_name : return 9e-2 , 9e-2
if any ( q in file_name for q in [ " int8 " , " uint8 " , " quantized " ] ) : return 4 , 4
return 4e-3 , 3e-2
def validate_repos ( models : dict [ str , tuple [ Path , Path ] ] ) :
print ( f " ** Validating { len ( model_paths ) } models ** " )
for model_id , ( root_path , relative_path ) in models . items ( ) :
print ( f " validating model { model_id } " )
model_path = root_path / relative_path
onnx_file_name = model_path . stem
config = get_config ( root_path )
rtol , atol = get_tolerances ( onnx_file_name )
st = time . time ( )
run_huggingface_validate ( model_path , config , rtol , atol )
et = time . time ( ) - st
print ( f " passed, took { et : .2f } s " )
def retrieve_op_stats ( models : dict [ str , tuple [ Path , Path ] ] ) - > dict :
ret = { }
op_counter = collections . Counter ( )
unsupported_ops = collections . defaultdict ( set )
supported_ops = get_onnx_ops ( )
print ( f " ** Retrieving stats from { len ( model_paths ) } models ** " )
for model_id , ( root_path , relative_path ) in models . items ( ) :
print ( f " examining { model_id } " )
model_path = root_path / relative_path
onnx_runner = OnnxRunner ( onnx . load ( model_path ) )
for node in onnx_runner . graph_nodes :
op_counter [ node . op ] + = 1
if node . op not in supported_ops :
unsupported_ops [ node . op ] . add ( model_id )
del onnx_runner
ret [ " unsupported_ops " ] = { k : list ( v ) for k , v in unsupported_ops . items ( ) }
ret [ " op_counter " ] = op_counter . most_common ( )
return ret
def debug_run ( model_path , truncate , config , rtol , atol ) :
if truncate != - 1 :
model = onnx . load ( model_path )
nodes_up_to_limit = list ( model . graph . node ) [ : truncate + 1 ]
new_output_values = [ onnx . helper . make_empty_tensor_value_info ( output_name ) for output_name in nodes_up_to_limit [ - 1 ] . output ]
model . graph . ClearField ( " node " )
model . graph . node . extend ( nodes_up_to_limit )
model . graph . ClearField ( " output " )
model . graph . output . extend ( new_output_values )
with tempfile . NamedTemporaryFile ( suffix = model_path . suffix ) as tmp :
onnx . save ( model , tmp . name )
run_huggingface_validate ( tmp . name , config , rtol , atol )
else :
run_huggingface_validate ( model_path , config , rtol , atol )
if __name__ == " __main__ " :
parser = argparse . ArgumentParser ( description = " Huggingface ONNX Model Validator and Ops Checker " )
parser . add_argument ( " input " , type = str , help = " Path to the input YAML configuration file containing model information. " )
parser . add_argument ( " --check_ops " , action = " store_true " , default = False ,
help = " Check support for ONNX operations in models from the YAML file " )
parser . add_argument ( " --validate " , action = " store_true " , default = False ,
help = " Validate correctness of models from the YAML file " )
parser . add_argument ( " --debug " , type = str , default = " " ,
help = """ Validates without explicitly needing a YAML or models pre-installed.
provide repo id ( e . g . " minishlab/potion-base-8M " ) to validate all onnx models inside the repo
provide onnx model path ( e . g . " minishlab/potion-base-8M/onnx/model.onnx " ) to validate only that one model
""" )
parser . add_argument ( " --truncate " , type = int , default = - 1 , help = " Truncate the ONNX model so intermediate results can be validated " )
args = parser . parse_args ( )
if not ( args . check_ops or args . validate or args . debug ) :
parser . error ( " Please provide either --validate, --check_ops, or --debug. " )
if args . truncate != - 1 and not args . debug :
parser . error ( " --truncate and --debug should be used together for debugging " )
if args . check_ops or args . validate :
with open ( args . input , ' r ' ) as f :
data = yaml . safe_load ( f )
assert all ( repo [ " download_path " ] is not None for repo in data [ " repositories " ] . values ( ) ) , " please run `download_models.py` for this yaml "
model_paths = {
model_id + " / " + model [ " file " ] : ( Path ( repo [ " download_path " ] ) , Path ( model [ " file " ] ) )
for model_id , repo in data [ " repositories " ] . items ( )
for model in repo [ " files " ]
if model [ " file " ] . endswith ( " .onnx " )
}
if args . check_ops :
pprint . pprint ( retrieve_op_stats ( model_paths ) )
if args . validate :
validate_repos ( model_paths )
if args . debug :
from huggingface_hub import snapshot_download
download_dir = Path ( __file__ ) . parent / " models "
path : list [ str ] = args . debug . split ( " / " )
if len ( path ) == 2 :
# repo id
# validates all onnx models inside repo
repo_id = " / " . join ( path )
root_path = Path ( snapshot_download ( repo_id = repo_id , allow_patterns = [ " *.onnx " , " *.onnx_data " ] , cache_dir = download_dir ) )
snapshot_download ( repo_id = repo_id , allow_patterns = [ " *config.json " ] , cache_dir = download_dir )
config = get_config ( root_path )
for onnx_model in root_path . rglob ( " *.onnx " ) :
rtol , atol = get_tolerances ( onnx_model . name )
print ( f " validating { onnx_model . relative_to ( root_path ) } with truncate= { args . truncate } , { rtol =} , { atol =} " )
debug_run ( onnx_model , - 1 , config , rtol , atol )
else :
# model id
# only validate the specified onnx model
onnx_model = path [ - 1 ]
assert path [ - 1 ] . endswith ( " .onnx " )
repo_id , relative_path = " / " . join ( path [ : 2 ] ) , " / " . join ( path [ 2 : ] )
root_path = Path ( snapshot_download ( repo_id = repo_id , allow_patterns = [ relative_path ] , cache_dir = download_dir ) )
snapshot_download ( repo_id = repo_id , allow_patterns = [ " *config.json " ] , cache_dir = download_dir )
config = get_config ( root_path )
rtol , atol = get_tolerances ( onnx_model )
print ( f " validating { relative_path } with truncate= { args . truncate } , { rtol =} , { atol =} " )
debug_run ( root_path / relative_path , args . truncate , config , rtol , atol )