openpilot is an open source driver assistance system. openpilot performs the functions of Automated Lane Centering and Adaptive Cruise Control for over 200 supported car makes and models.
 
 
 
 
 
 

137 lines
6.3 KiB

import onnx, yaml, tempfile, time, collections, pprint, argparse, json
from pathlib import Path
from tinygrad.frontend.onnx import OnnxRunner
from extra.onnx import get_onnx_ops
from extra.onnx_helpers import validate, get_example_inputs
def get_config(root_path: Path):
ret = {}
for path in root_path.rglob("*config.json"):
config = json.load(path.open())
if isinstance(config, dict):
ret.update(config)
return ret
def run_huggingface_validate(onnx_model_path, config, rtol, atol):
onnx_model = onnx.load(onnx_model_path)
onnx_runner = OnnxRunner(onnx_model)
inputs = get_example_inputs(onnx_runner.graph_inputs, config)
validate(onnx_model_path, inputs, rtol=rtol, atol=atol)
def get_tolerances(file_name): # -> rtol, atol
# TODO very high rtol atol
if "fp16" in file_name: return 9e-2, 9e-2
if any(q in file_name for q in ["int8", "uint8", "quantized"]): return 4, 4
return 4e-3, 3e-2
def validate_repos(models:dict[str, tuple[Path, Path]]):
print(f"** Validating {len(model_paths)} models **")
for model_id, (root_path, relative_path) in models.items():
print(f"validating model {model_id}")
model_path = root_path / relative_path
onnx_file_name = model_path.stem
config = get_config(root_path)
rtol, atol = get_tolerances(onnx_file_name)
st = time.time()
run_huggingface_validate(model_path, config, rtol, atol)
et = time.time() - st
print(f"passed, took {et:.2f}s")
def retrieve_op_stats(models:dict[str, tuple[Path, Path]]) -> dict:
ret = {}
op_counter = collections.Counter()
unsupported_ops = collections.defaultdict(set)
supported_ops = get_onnx_ops()
print(f"** Retrieving stats from {len(model_paths)} models **")
for model_id, (root_path, relative_path) in models.items():
print(f"examining {model_id}")
model_path = root_path / relative_path
onnx_runner = OnnxRunner(onnx.load(model_path))
for node in onnx_runner.graph_nodes:
op_counter[node.op] += 1
if node.op not in supported_ops:
unsupported_ops[node.op].add(model_id)
del onnx_runner
ret["unsupported_ops"] = {k:list(v) for k, v in unsupported_ops.items()}
ret["op_counter"] = op_counter.most_common()
return ret
def debug_run(model_path, truncate, config, rtol, atol):
if truncate != -1:
model = onnx.load(model_path)
nodes_up_to_limit = list(model.graph.node)[:truncate + 1]
new_output_values = [onnx.helper.make_empty_tensor_value_info(output_name) for output_name in nodes_up_to_limit[-1].output]
model.graph.ClearField("node")
model.graph.node.extend(nodes_up_to_limit)
model.graph.ClearField("output")
model.graph.output.extend(new_output_values)
with tempfile.NamedTemporaryFile(suffix=model_path.suffix) as tmp:
onnx.save(model, tmp.name)
run_huggingface_validate(tmp.name, config, rtol, atol)
else:
run_huggingface_validate(model_path, config, rtol, atol)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Huggingface ONNX Model Validator and Ops Checker")
parser.add_argument("input", type=str, help="Path to the input YAML configuration file containing model information.")
parser.add_argument("--check_ops", action="store_true", default=False,
help="Check support for ONNX operations in models from the YAML file")
parser.add_argument("--validate", action="store_true", default=False,
help="Validate correctness of models from the YAML file")
parser.add_argument("--debug", type=str, default="",
help="""Validates without explicitly needing a YAML or models pre-installed.
provide repo id (e.g. "minishlab/potion-base-8M") to validate all onnx models inside the repo
provide onnx model path (e.g. "minishlab/potion-base-8M/onnx/model.onnx") to validate only that one model
""")
parser.add_argument("--truncate", type=int, default=-1, help="Truncate the ONNX model so intermediate results can be validated")
args = parser.parse_args()
if not (args.check_ops or args.validate or args.debug):
parser.error("Please provide either --validate, --check_ops, or --debug.")
if args.truncate != -1 and not args.debug:
parser.error("--truncate and --debug should be used together for debugging")
if args.check_ops or args.validate:
with open(args.input, 'r') as f:
data = yaml.safe_load(f)
assert all(repo["download_path"] is not None for repo in data["repositories"].values()), "please run `download_models.py` for this yaml"
model_paths = {
model_id + "/" + model["file"]: (Path(repo["download_path"]), Path(model["file"]))
for model_id, repo in data["repositories"].items()
for model in repo["files"]
if model["file"].endswith(".onnx")
}
if args.check_ops:
pprint.pprint(retrieve_op_stats(model_paths))
if args.validate:
validate_repos(model_paths)
if args.debug:
from huggingface_hub import snapshot_download
download_dir = Path(__file__).parent / "models"
path:list[str] = args.debug.split("/")
if len(path) == 2:
# repo id
# validates all onnx models inside repo
repo_id = "/".join(path)
root_path = Path(snapshot_download(repo_id=repo_id, allow_patterns=["*.onnx", ".onnx_data"], cache_dir=download_dir))
snapshot_download(repo_id=repo_id, allow_patterns=["*config.json"], cache_dir=download_dir)
config = get_config(root_path)
for onnx_model in root_path.rglob("*.onnx"):
rtol, atol = get_tolerances(onnx_model.name)
print(f"validating {onnx_model.relative_to(root_path)} with truncate={args.truncate}, {rtol=}, {atol=}")
debug_run(onnx_model, -1, config, rtol, atol)
else:
# model id
# only validate the specified onnx model
onnx_model = path[-1]
assert path[-1].endswith(".onnx")
repo_id, relative_path = "/".join(path[:2]), "/".join(path[2:])
root_path = Path(snapshot_download(repo_id=repo_id, allow_patterns=[relative_path], cache_dir=download_dir))
snapshot_download(repo_id=repo_id, allow_patterns=["*config.json"], cache_dir=download_dir)
config = get_config(root_path)
rtol, atol = get_tolerances(onnx_model)
print(f"validating {relative_path} with truncate={args.truncate}, {rtol=}, {atol=}")
debug_run(root_path / relative_path, args.truncate, config, rtol, atol)