dragonpilot - 基於 openpilot 的開源駕駛輔助系統
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

230 lines
8.1 KiB

import yaml
import time
import requests
import argparse
from pathlib import Path
from huggingface_hub import list_models, HfApi, snapshot_download
from tinygrad.helpers import _ensure_downloads_dir
DOWNLOADS_DIR = _ensure_downloads_dir() / "models"
from tinygrad.helpers import tqdm
def snapshot_download_with_retry(*, repo_id: str, allow_patterns: list[str]|tuple[str, ...]|None=None, cache_dir: str|Path|None=None,
tries: int=2, **kwargs) -> Path:
for attempt in range(tries):
try:
return Path(snapshot_download(
repo_id=repo_id,
allow_patterns=allow_patterns,
cache_dir=str(cache_dir) if cache_dir is not None else None,
**kwargs
))
except Exception as e:
if attempt == tries-1: raise
time.sleep(1)
# Constants for filtering models
HUGGINGFACE_URL = "https://huggingface.co"
SKIPPED_FILES = [
"fp16", "int8", "uint8", "quantized", # numerical accuracy issues
"avx2", "arm64", "avx512", "avx512_vnni", # numerical accuracy issues
"q4", "q4f16", "bnb4", # unimplemented quantization
"model_O4", # requires non cpu ort runner and MemcpyFromHost op
"merged", # TODO implement attribute with graph type and Loop op
]
SKIPPED_REPO_PATHS = [
# Invalid model-index
"AdamCodd/vit-base-nsfw-detector",
# TODO: implement attribute with graph type and Loop op
"minishlab/potion-base-8M", "minishlab/M2V_base_output", "minishlab/potion-retrieval-32M",
# TODO: implement SimplifiedLayerNormalization, SkipSimplifiedLayerNormalization, GroupQueryAttention
"HuggingFaceTB/SmolLM2-360M-Instruct",
# TODO: implement SimplifiedLayerNormalization, SkipSimplifiedLayerNormalization, RotaryEmbedding, MultiHeadAttention
"HuggingFaceTB/SmolLM2-1.7B-Instruct",
# TODO: implement RandomNormalLike
"stabilityai/stable-diffusion-xl-base-1.0", "stabilityai/sdxl-turbo", 'SimianLuo/LCM_Dreamshaper_v7',
# TODO: implement NonZero
"mangoapps/fb_zeroshot_mnli_onnx",
# TODO huge Concat in here with 1024 (1, 3, 32, 32) Tensors, and maybe a MOD bug with const folding
"briaai/RMBG-2.0",
]
class HuggingFaceONNXManager:
def __init__(self):
self.base_dir = Path(__file__).parent
self.models_dir = DOWNLOADS_DIR
self.api = HfApi()
def discover_models(self, limit: int, sort: str = "downloads") -> list[str]:
print(f"Discovering top {limit} ONNX models sorted by {sort}...")
repos = []
i = 0
for model in list_models(filter="onnx", sort=sort):
if model.id in SKIPPED_REPO_PATHS:
continue
print(f" {i+1}/{limit}: {model.id} ({getattr(model, sort)})")
repos.append(model.id)
i += 1
if i == limit:
break
print(f"Found {len(repos)} suitable ONNX models")
return repos
def collect_metadata(self, repos: list[str]) -> dict:
print(f"Collecting metadata for {len(repos)} repositories...")
metadata = {"repositories": {}}
total_size = 0
for repo in tqdm(repos, desc="Collecting metadata"):
try:
files_metadata = []
model_info = self.api.model_info(repo)
for file in model_info.siblings:
filename = file.rfilename
if not (filename.endswith('.onnx') or filename.endswith('.onnx_data')):
continue
if any(skip_str in filename for skip_str in SKIPPED_FILES):
continue
# Get file size from API or HEAD request
try:
head = requests.head(
f"{HUGGINGFACE_URL}/{repo}/resolve/main/{filename}",
allow_redirects=True,
timeout=10
)
file_size = file.size or int(head.headers.get('Content-Length', 0))
except requests.RequestException:
file_size = file.size or 0
files_metadata.append({
"file": filename,
"size": f"{file_size/1e6:.2f}MB"
})
total_size += file_size
if files_metadata: # Only add repos with valid ONNX files
metadata["repositories"][repo] = {
"url": f"{HUGGINGFACE_URL}/{repo}",
"download_path": None,
"files": files_metadata,
}
except Exception as e:
print(f"WARNING: Failed to collect metadata for {repo}: {e}")
continue
metadata['total_size'] = f"{total_size/1e9:.2f}GB"
metadata['created_at'] = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
print(f"Collected metadata for {len(metadata['repositories'])} repositories")
print(f"Total estimated download size: {metadata['total_size']}")
return metadata
def download_models(self, metadata: dict) -> dict:
self.models_dir.mkdir(parents=True, exist_ok=True)
repos = metadata["repositories"]
n = len(repos)
print(f"Downloading {n} repositories to {self.models_dir}...")
for i, (model_id, model_data) in enumerate(repos.items()):
print(f" Downloading {i+1}/{n}: {model_id}...")
try:
# Download ONNX model files
allow_patterns = [file_info["file"] for file_info in model_data["files"]]
root_path = snapshot_download_with_retry(
repo_id=model_id,
allow_patterns=allow_patterns,
cache_dir=str(self.models_dir)
)
# Download config files (usually small)
snapshot_download_with_retry(
repo_id=model_id,
allow_patterns=["*config.json"],
cache_dir=str(self.models_dir)
)
model_data["download_path"] = str(root_path)
print(f" Downloaded to: {root_path}")
except Exception as e:
print(f" ERROR: Failed to download {model_id}: {e}")
model_data["download_path"] = None
continue
successful_downloads = sum(1 for repo in repos.values() if repo["download_path"] is not None)
print(f"Successfully downloaded {successful_downloads}/{n} repositories")
print(f"All models saved to: {self.models_dir}")
return metadata
def save_metadata(self, metadata: dict, output_file: str):
yaml_path = self.base_dir / output_file
with open(yaml_path, 'w') as f:
yaml.dump(metadata, f, sort_keys=False)
print(f"Metadata saved to: {yaml_path}")
def discover_and_download(self, limit: int, output_file: str = "huggingface_repos.yaml",
sort: str = "downloads", download: bool = True):
print(f"Starting HuggingFace ONNX workflow...")
print(f" Limit: {limit} models")
print(f" Sort by: {sort}")
print(f" Download: {'Yes' if download else 'No'}")
print(f" Output: {output_file}")
print("-" * 50)
repos = self.discover_models(limit, sort)
metadata = self.collect_metadata(repos)
if download:
metadata = self.download_models(metadata)
self.save_metadata(metadata, output_file)
print("-" * 50)
print("Workflow completed successfully!")
if download:
successful = sum(1 for repo in metadata["repositories"].values()
if repo["download_path"] is not None)
print(f"{successful}/{len(metadata['repositories'])} models downloaded")
return metadata
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="HuggingFace ONNX Model Manager - Discover, collect metadata, and download ONNX models",
)
parser.add_argument("--limit", type=int, help="Number of top repositories to process")
parser.add_argument("--output", type=str, default="huggingface_repos.yaml",
help="Output YAML file name (default: huggingface_repos.yaml)")
parser.add_argument("--sort", type=str, default="downloads",
choices=["downloads", "likes", "created", "modified"],
help="Sort criteria for model discovery (default: downloads)")
parser.add_argument("--download", action="store_true", default=False,
help="Download models after collecting metadata")
args = parser.parse_args()
if not args.limit: parser.error("--limit is required")
manager = HuggingFaceONNXManager()
manager.discover_and_download(
limit=args.limit,
output_file=args.output,
sort=args.sort,
download=args.download
)