137 lines
6.3 KiB
137 lines
6.3 KiB
import onnx, yaml, tempfile, time, collections, pprint, argparse, json
|
|
from pathlib import Path
|
|
from tinygrad.frontend.onnx import OnnxRunner
|
|
from extra.onnx import get_onnx_ops
|
|
from extra.onnx_helpers import validate, get_example_inputs
|
|
|
|
def get_config(root_path: Path):
|
|
ret = {}
|
|
for path in root_path.rglob("*config.json"):
|
|
config = json.load(path.open())
|
|
if isinstance(config, dict):
|
|
ret.update(config)
|
|
return ret
|
|
|
|
def run_huggingface_validate(onnx_model_path, config, rtol, atol):
|
|
onnx_model = onnx.load(onnx_model_path)
|
|
onnx_runner = OnnxRunner(onnx_model)
|
|
inputs = get_example_inputs(onnx_runner.graph_inputs, config)
|
|
validate(onnx_model_path, inputs, rtol=rtol, atol=atol)
|
|
|
|
def get_tolerances(file_name): # -> rtol, atol
|
|
# TODO very high rtol atol
|
|
if "fp16" in file_name: return 9e-2, 9e-2
|
|
if any(q in file_name for q in ["int8", "uint8", "quantized"]): return 4, 4
|
|
return 4e-3, 3e-2
|
|
|
|
def validate_repos(models:dict[str, tuple[Path, Path]]):
|
|
print(f"** Validating {len(model_paths)} models **")
|
|
for model_id, (root_path, relative_path) in models.items():
|
|
print(f"validating model {model_id}")
|
|
model_path = root_path / relative_path
|
|
onnx_file_name = model_path.stem
|
|
config = get_config(root_path)
|
|
rtol, atol = get_tolerances(onnx_file_name)
|
|
st = time.time()
|
|
run_huggingface_validate(model_path, config, rtol, atol)
|
|
et = time.time() - st
|
|
print(f"passed, took {et:.2f}s")
|
|
|
|
def retrieve_op_stats(models:dict[str, tuple[Path, Path]]) -> dict:
|
|
ret = {}
|
|
op_counter = collections.Counter()
|
|
unsupported_ops = collections.defaultdict(set)
|
|
supported_ops = get_onnx_ops()
|
|
print(f"** Retrieving stats from {len(model_paths)} models **")
|
|
for model_id, (root_path, relative_path) in models.items():
|
|
print(f"examining {model_id}")
|
|
model_path = root_path / relative_path
|
|
onnx_runner = OnnxRunner(onnx.load(model_path))
|
|
for node in onnx_runner.graph_nodes:
|
|
op_counter[node.op] += 1
|
|
if node.op not in supported_ops:
|
|
unsupported_ops[node.op].add(model_id)
|
|
del onnx_runner
|
|
ret["unsupported_ops"] = {k:list(v) for k, v in unsupported_ops.items()}
|
|
ret["op_counter"] = op_counter.most_common()
|
|
return ret
|
|
|
|
def debug_run(model_path, truncate, config, rtol, atol):
|
|
if truncate != -1:
|
|
model = onnx.load(model_path)
|
|
nodes_up_to_limit = list(model.graph.node)[:truncate + 1]
|
|
new_output_values = [onnx.helper.make_empty_tensor_value_info(output_name) for output_name in nodes_up_to_limit[-1].output]
|
|
model.graph.ClearField("node")
|
|
model.graph.node.extend(nodes_up_to_limit)
|
|
model.graph.ClearField("output")
|
|
model.graph.output.extend(new_output_values)
|
|
with tempfile.NamedTemporaryFile(suffix=model_path.suffix) as tmp:
|
|
onnx.save(model, tmp.name)
|
|
run_huggingface_validate(tmp.name, config, rtol, atol)
|
|
else:
|
|
run_huggingface_validate(model_path, config, rtol, atol)
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(description="Huggingface ONNX Model Validator and Ops Checker")
|
|
parser.add_argument("input", type=str, help="Path to the input YAML configuration file containing model information.")
|
|
parser.add_argument("--check_ops", action="store_true", default=False,
|
|
help="Check support for ONNX operations in models from the YAML file")
|
|
parser.add_argument("--validate", action="store_true", default=False,
|
|
help="Validate correctness of models from the YAML file")
|
|
parser.add_argument("--debug", type=str, default="",
|
|
help="""Validates without explicitly needing a YAML or models pre-installed.
|
|
provide repo id (e.g. "minishlab/potion-base-8M") to validate all onnx models inside the repo
|
|
provide onnx model path (e.g. "minishlab/potion-base-8M/onnx/model.onnx") to validate only that one model
|
|
""")
|
|
parser.add_argument("--truncate", type=int, default=-1, help="Truncate the ONNX model so intermediate results can be validated")
|
|
args = parser.parse_args()
|
|
|
|
if not (args.check_ops or args.validate or args.debug):
|
|
parser.error("Please provide either --validate, --check_ops, or --debug.")
|
|
if args.truncate != -1 and not args.debug:
|
|
parser.error("--truncate and --debug should be used together for debugging")
|
|
|
|
if args.check_ops or args.validate:
|
|
with open(args.input, 'r') as f:
|
|
data = yaml.safe_load(f)
|
|
assert all(repo["download_path"] is not None for repo in data["repositories"].values()), "please run `download_models.py` for this yaml"
|
|
model_paths = {
|
|
model_id + "/" + model["file"]: (Path(repo["download_path"]), Path(model["file"]))
|
|
for model_id, repo in data["repositories"].items()
|
|
for model in repo["files"]
|
|
if model["file"].endswith(".onnx")
|
|
}
|
|
|
|
if args.check_ops:
|
|
pprint.pprint(retrieve_op_stats(model_paths))
|
|
|
|
if args.validate:
|
|
validate_repos(model_paths)
|
|
|
|
if args.debug:
|
|
from huggingface_hub import snapshot_download
|
|
download_dir = Path(__file__).parent / "models"
|
|
path:list[str] = args.debug.split("/")
|
|
if len(path) == 2:
|
|
# repo id
|
|
# validates all onnx models inside repo
|
|
repo_id = "/".join(path)
|
|
root_path = Path(snapshot_download(repo_id=repo_id, allow_patterns=["*.onnx", ".onnx_data"], cache_dir=download_dir))
|
|
snapshot_download(repo_id=repo_id, allow_patterns=["*config.json"], cache_dir=download_dir)
|
|
config = get_config(root_path)
|
|
for onnx_model in root_path.rglob("*.onnx"):
|
|
rtol, atol = get_tolerances(onnx_model.name)
|
|
print(f"validating {onnx_model.relative_to(root_path)} with truncate={args.truncate}, {rtol=}, {atol=}")
|
|
debug_run(onnx_model, -1, config, rtol, atol)
|
|
else:
|
|
# model id
|
|
# only validate the specified onnx model
|
|
onnx_model = path[-1]
|
|
assert path[-1].endswith(".onnx")
|
|
repo_id, relative_path = "/".join(path[:2]), "/".join(path[2:])
|
|
root_path = Path(snapshot_download(repo_id=repo_id, allow_patterns=[relative_path], cache_dir=download_dir))
|
|
snapshot_download(repo_id=repo_id, allow_patterns=["*config.json"], cache_dir=download_dir)
|
|
config = get_config(root_path)
|
|
rtol, atol = get_tolerances(onnx_model)
|
|
print(f"validating {relative_path} with truncate={args.truncate}, {rtol=}, {atol=}")
|
|
debug_run(root_path / relative_path, args.truncate, config, rtol, atol) |