You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
65 lines
2.3 KiB
65 lines
2.3 KiB
#!/usr/bin/env python3
|
|
|
|
from tinygrad.runtime.support.system import System
|
|
import argparse, glob, os, re, time, subprocess, sys
|
|
|
|
def scan_devs_based_on_lock(prefix:str) -> list[str]:
|
|
devs = []
|
|
for dev in glob.glob(f'/tmp/{prefix}_*.lock'):
|
|
dev_id = dev[8:-5]
|
|
if os.path.exists(f"/sys/bus/pci/devices/{dev_id}"): devs.append(dev_id)
|
|
return devs
|
|
|
|
def _do_reset_device(pci_bus): System.pci_reset(pci_bus)
|
|
def _is_module_loaded(name: str) -> bool: return os.path.isdir(f"/sys/module/{name}")
|
|
|
|
def cmd_remove_module(args):
|
|
to_unload = [m for m in ["nvidia_drm", "nvidia_modeset", "nvidia_uvm", "nvidia"] if _is_module_loaded(m)]
|
|
if not to_unload:
|
|
print("NVIDIA kernel modules are not loaded")
|
|
else:
|
|
print("Removing NVIDIA kernel modules:", ", ".join(to_unload))
|
|
try: subprocess.run(["sudo", "modprobe", "-r", *to_unload], check=True)
|
|
except subprocess.CalledProcessError as e:
|
|
print("Failed to unload all modules — they may be in use.", file=sys.stderr)
|
|
sys.exit(e.returncode)
|
|
|
|
def cmd_insert_module(args):
|
|
cmd_remove_module(args)
|
|
cmd_reset_devices(args)
|
|
|
|
if not os.path.exists("/sys/module/nvidia"):
|
|
print("Inserting nvidia kernel module")
|
|
subprocess.run(["nvidia-smi"], check=True)
|
|
else: print("Nvidia kernel module already loaded")
|
|
|
|
def cmd_reset_devices(args):
|
|
devs = scan_devs_based_on_lock("nv")
|
|
dev_to_reset = args.pci_bus if 'pci_bus' in args.__dir__() else ""
|
|
|
|
for dev in devs:
|
|
if dev.startswith(dev_to_reset):
|
|
print(f"Resetting device {dev}")
|
|
_do_reset_device(dev)
|
|
time.sleep(0.2)
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
subparsers = parser.add_subparsers(required=True, dest="cmd")
|
|
|
|
parser_insmod = subparsers.add_parser('insmod', help='Insert a nvidia kernel module')
|
|
parser_insmod.set_defaults(func=cmd_insert_module)
|
|
|
|
parser_rmmod = subparsers.add_parser('rmmod', help='Remove a nvidia kernel module')
|
|
parser_rmmod.set_defaults(func=cmd_remove_module)
|
|
|
|
parser_reset = subparsers.add_parser('reset', help='Reset a nvidia device')
|
|
parser_reset.add_argument('--pci_bus', type=str, default="", help='PCI bus ID of the device to reset')
|
|
parser_reset.set_defaults(func=cmd_reset_devices)
|
|
|
|
args = parser.parse_args()
|
|
if args.cmd is None:
|
|
parser.print_help(sys.stderr)
|
|
sys.exit(1)
|
|
|
|
args.func(args)
|
|
|