You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
267 lines
14 KiB
267 lines
14 KiB
import ctypes, struct, dataclasses, array, itertools
|
|
from typing import Sequence
|
|
from tinygrad.runtime.autogen import libusb
|
|
from tinygrad.helpers import DEBUG, to_mv, round_up
|
|
from tinygrad.runtime.support.hcq import MMIOInterface
|
|
|
|
class USB3:
|
|
def __init__(self, vendor:int, dev:int, ep_data_in:int, ep_stat_in:int, ep_data_out:int, ep_cmd_out:int, max_streams:int=31, max_read_len:int=4096):
|
|
self.vendor, self.dev = vendor, dev
|
|
self.ep_data_in, self.ep_stat_in, self.ep_data_out, self.ep_cmd_out = ep_data_in, ep_stat_in, ep_data_out, ep_cmd_out
|
|
self.max_streams, self.max_read_len = max_streams, max_read_len
|
|
self.ctx = ctypes.POINTER(libusb.struct_libusb_context)()
|
|
|
|
if libusb.libusb_init(ctypes.byref(self.ctx)): raise RuntimeError("libusb_init failed")
|
|
if DEBUG >= 6: libusb.libusb_set_option(self.ctx, libusb.LIBUSB_OPTION_LOG_LEVEL, 4)
|
|
|
|
self.handle = libusb.libusb_open_device_with_vid_pid(self.ctx, self.vendor, self.dev)
|
|
if not self.handle: raise RuntimeError(f"device {self.vendor:04x}:{self.dev:04x} not found. sudo required?")
|
|
|
|
# Detach kernel driver if needed
|
|
if libusb.libusb_kernel_driver_active(self.handle, 0):
|
|
libusb.libusb_detach_kernel_driver(self.handle, 0)
|
|
libusb.libusb_reset_device(self.handle)
|
|
|
|
# Set configuration and claim interface
|
|
if libusb.libusb_set_configuration(self.handle, 1): raise RuntimeError("set_configuration failed")
|
|
if libusb.libusb_claim_interface(self.handle, 0): raise RuntimeError("claim_interface failed. sudo required?")
|
|
if libusb.libusb_set_interface_alt_setting(self.handle, 0, 1): raise RuntimeError("alt_setting failed")
|
|
|
|
# Clear any stalled endpoints
|
|
all_eps = (self.ep_data_out, self.ep_data_in, self.ep_stat_in, self.ep_cmd_out)
|
|
for ep in all_eps: libusb.libusb_clear_halt(self.handle, ep)
|
|
|
|
# Allocate streams
|
|
stream_eps = (ctypes.c_uint8 * 3)(self.ep_data_out, self.ep_data_in, self.ep_stat_in)
|
|
if (rc:=libusb.libusb_alloc_streams(self.handle, self.max_streams * len(stream_eps), stream_eps, len(stream_eps))) < 0:
|
|
raise RuntimeError(f"alloc_streams failed: {rc}")
|
|
|
|
# Base cmd
|
|
cmd_template = bytes([0x01, 0x00, 0x00, 0x01, *([0] * 12), 0xE4, 0x24, 0x00, 0xB2, 0x1A, 0x00, 0x00, 0x00, *([0] * 8)])
|
|
|
|
# Init pools
|
|
self.tr = {ep: [libusb.libusb_alloc_transfer(0) for _ in range(self.max_streams)] for ep in all_eps}
|
|
|
|
self.buf_cmd = [(ctypes.c_uint8 * len(cmd_template))(*cmd_template) for _ in range(self.max_streams)]
|
|
self.buf_stat = [(ctypes.c_uint8 * 64)() for _ in range(self.max_streams)]
|
|
self.buf_data_in = [(ctypes.c_uint8 * 0x1000)() for _ in range(self.max_streams)]
|
|
self.buf_data_out = [(ctypes.c_uint8 * 0x1000)() for _ in range(self.max_streams)]
|
|
self.buf_data_out_mvs = [to_mv(ctypes.addressof(self.buf_data_out[i]), 0x1000) for i in range(self.max_streams)]
|
|
|
|
def _prep_transfer(self, tr, ep, stream_id, buf, length):
|
|
tr.contents.dev_handle, tr.contents.endpoint, tr.contents.length, tr.contents.buffer = self.handle, ep, length, buf
|
|
tr.contents.status, tr.contents.flags, tr.contents.timeout, tr.contents.num_iso_packets = 0xff, 0, 1000, 0
|
|
tr.contents.type = (libusb.LIBUSB_TRANSFER_TYPE_BULK_STREAM if stream_id is not None else libusb.LIBUSB_TRANSFER_TYPE_BULK)
|
|
if stream_id is not None: libusb.libusb_transfer_set_stream_id(tr, stream_id)
|
|
return tr
|
|
|
|
def _submit_and_wait(self, cmds):
|
|
for tr in cmds: libusb.libusb_submit_transfer(tr)
|
|
|
|
running = len(cmds)
|
|
while running:
|
|
libusb.libusb_handle_events(self.ctx)
|
|
running = len(cmds)
|
|
for tr in cmds:
|
|
if tr.contents.status == libusb.LIBUSB_TRANSFER_COMPLETED: running -= 1
|
|
elif tr.contents.status != 0xFF: raise RuntimeError(f"EP 0x{tr.contents.endpoint:02X} error: {tr.contents.status}")
|
|
|
|
def send_batch(self, cdbs:list[bytes], idata:list[int]|None=None, odata:list[bytes|None]|None=None) -> list[bytes|None]:
|
|
idata, odata = idata or [0] * len(cdbs), odata or [None] * len(cdbs)
|
|
results, tr_window, op_window = [], [], []
|
|
|
|
for idx, (cdb, rlen, send_data) in enumerate(zip(cdbs, idata, odata)):
|
|
# allocate slot and stream. stream is 1-based
|
|
slot, stream = idx % self.max_streams, (idx % self.max_streams) + 1
|
|
|
|
# build cmd packet
|
|
struct.pack_into(">B", self.buf_cmd[slot], 3, stream)
|
|
self.buf_cmd[slot][16:16+len(cdb)] = list(cdb)
|
|
|
|
# cmd + stat transfers
|
|
tr_window.append(self._prep_transfer(self.tr[self.ep_cmd_out][slot], self.ep_cmd_out, None, self.buf_cmd[slot], len(self.buf_cmd[slot])))
|
|
tr_window.append(self._prep_transfer(self.tr[self.ep_stat_in][slot], self.ep_stat_in, stream, self.buf_stat[slot], 64))
|
|
|
|
if rlen:
|
|
if rlen > self.max_read_len: raise ValueError("read length > max_read_len per CDB")
|
|
tr_window.append(self._prep_transfer(self.tr[self.ep_data_in][slot], self.ep_data_in, stream, self.buf_data_in[slot], rlen))
|
|
|
|
if send_data is not None:
|
|
if len(send_data) > len(self.buf_data_out[slot]):
|
|
self.buf_data_out[slot] = (ctypes.c_uint8 * len(send_data))()
|
|
self.buf_data_out_mvs[slot] = to_mv(ctypes.addressof(self.buf_data_out[slot]), len(send_data))
|
|
|
|
self.buf_data_out_mvs[slot][:len(send_data)] = bytes(send_data)
|
|
tr_window.append(self._prep_transfer(self.tr[self.ep_data_out][slot], self.ep_data_out, stream, self.buf_data_out[slot], len(send_data)))
|
|
|
|
op_window.append((idx, slot, rlen))
|
|
if (idx + 1 == len(cdbs)) or len(op_window) >= self.max_streams:
|
|
self._submit_and_wait(tr_window)
|
|
for idx, slot, rlen in op_window: results.append(bytes(self.buf_data_in[slot][:rlen]) if rlen else None)
|
|
tr_window = []
|
|
|
|
return results
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class WriteOp: addr:int; data:bytes; ignore_cache:bool=True # noqa: E702
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class ReadOp: addr:int; size:int # noqa: E702
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class ScsiWriteOp: data:bytes; lba:int=0 # noqa: E702
|
|
|
|
class ASM24Controller:
|
|
def __init__(self):
|
|
self.usb = USB3(0xADD1, 0x0001, 0x81, 0x83, 0x02, 0x04)
|
|
self._cache: dict[int, int|None] = {}
|
|
self._pci_cacheable: list[tuple[int, int]] = []
|
|
self._pci_cache: dict[int, int|None] = {}
|
|
|
|
# Init controller.
|
|
self.exec_ops([WriteOp(0x54b, b' '), WriteOp(0x54e, b'\x04'), WriteOp(0x5a8, b'\x02'), WriteOp(0x5f8, b'\x04'),
|
|
WriteOp(0x7ec, b'\x01\x00\x00\x00'), WriteOp(0xc422, b'\x02'), WriteOp(0x0, b'\x33')])
|
|
|
|
def exec_ops(self, ops:Sequence[WriteOp|ReadOp|ScsiWriteOp]):
|
|
cdbs:list[bytes] = []
|
|
idata:list[int] = []
|
|
odata:list[bytes|None] = []
|
|
|
|
def _add_req(cdb:bytes, i:int, o:bytes|None):
|
|
nonlocal cdbs, idata, odata
|
|
cdbs, idata, odata = cdbs + [cdb], idata + [i], odata + [o]
|
|
|
|
for op in ops:
|
|
if isinstance(op, WriteOp):
|
|
for off, value in enumerate(op.data):
|
|
addr = ((op.addr + off) & 0x1FFFF) | 0x500000
|
|
if not op.ignore_cache and self._cache.get(addr) == value: continue
|
|
_add_req(struct.pack('>BBBHB', 0xE5, value, addr >> 16, addr & 0xFFFF, 0), 0, None)
|
|
self._cache[addr] = value
|
|
elif isinstance(op, ReadOp):
|
|
assert op.size <= 0xff
|
|
addr = (op.addr & 0x1FFFF) | 0x500000
|
|
_add_req(struct.pack('>BBBHB', 0xE4, op.size, addr >> 16, addr & 0xFFFF, 0), op.size, None)
|
|
for i in range(op.size): self._cache[addr + i] = None
|
|
elif isinstance(op, ScsiWriteOp):
|
|
sectors = round_up(len(op.data), 512) // 512
|
|
_add_req(struct.pack('>BBQIBB', 0x8A, 0, op.lba, sectors, 0, 0), 0, op.data+b'\x00'*((sectors*512)-len(op.data)))
|
|
|
|
return self.usb.send_batch(cdbs, idata, odata)
|
|
|
|
def write(self, base_addr:int, data:bytes, ignore_cache:bool=True): return self.exec_ops([WriteOp(base_addr, data, ignore_cache)])
|
|
|
|
def scsi_write(self, buf:bytes, lba:int=0):
|
|
if len(buf) > 0x4000: buf += b'\x00' * (round_up(len(buf), 0x10000) - len(buf))
|
|
|
|
for i in range(0, len(buf), 0x10000):
|
|
self.exec_ops([ScsiWriteOp(buf[i:i+0x10000], lba), WriteOp(0x171, b'\xff\xff\xff', ignore_cache=True)])
|
|
self.exec_ops([WriteOp(0xce6e, b'\x00\x00', ignore_cache=True)])
|
|
|
|
if len(buf) > 0x4000:
|
|
for i in range(4): self.exec_ops([WriteOp(0xce40 + i, b'\x00', ignore_cache=True)])
|
|
|
|
def read(self, base_addr:int, length:int, stride:int=0xff) -> bytes:
|
|
parts = self.exec_ops([ReadOp(base_addr + off, min(stride, length - off)) for off in range(0, length, stride)])
|
|
return b''.join(p or b'' for p in parts)[:length]
|
|
|
|
def _is_pci_cacheable(self, addr:int) -> bool: return any(x <= addr <= x + sz for x, sz in self._pci_cacheable)
|
|
def pcie_prep_request(self, fmt_type:int, address:int, value:int|None=None, size:int=4) -> list[WriteOp]:
|
|
if fmt_type == 0x60 and size == 4 and self._is_pci_cacheable(address) and self._pci_cache.get(address) == value: return []
|
|
|
|
assert fmt_type >> 8 == 0 and size > 0 and size <= 4, f"Invalid fmt_type {fmt_type} or size {size}"
|
|
if DEBUG >= 3: print("pcie_request", hex(fmt_type), hex(address), value, size)
|
|
|
|
masked_address, offset = address & 0xFFFFFFFC, address & 0x3
|
|
assert size + offset <= 4 and (value is None or value >> (8 * size) == 0)
|
|
self._pci_cache[masked_address] = value if size == 4 and fmt_type == 0x60 else None
|
|
|
|
return ([WriteOp(0xB220, struct.pack('>I', value << (8 * offset)), ignore_cache=False)] if value is not None else []) + \
|
|
[WriteOp(0xB218, struct.pack('>I', masked_address), ignore_cache=False), WriteOp(0xB21c, struct.pack('>I', address>>32), ignore_cache=False),
|
|
WriteOp(0xB217, bytes([((1 << size) - 1) << offset]), ignore_cache=False), WriteOp(0xB210, bytes([fmt_type]), ignore_cache=False),
|
|
WriteOp(0xB254, b"\x0f", ignore_cache=True), WriteOp(0xB296, b"\x04", ignore_cache=True)]
|
|
|
|
def pcie_request(self, fmt_type, address, value=None, size=4, cnt=10):
|
|
self.exec_ops(self.pcie_prep_request(fmt_type, address, value, size))
|
|
|
|
# Fast path for write requests
|
|
if ((fmt_type & 0b11011111) == 0b01000000) or ((fmt_type & 0b10111000) == 0b00110000): return
|
|
|
|
while (stat:=self.read(0xB296, 1)[0]) & 2 == 0:
|
|
if stat & 1:
|
|
self.write(0xB296, bytes([0x01]))
|
|
if cnt > 0: return self.pcie_request(fmt_type, address, value, size, cnt=cnt-1)
|
|
assert stat == 2, f"stat read 2 was {stat}"
|
|
|
|
# Retrieve completion data from Link Status (0xB22A, 0xB22B)
|
|
b284 = self.read(0xB284, 1)[0]
|
|
completion = struct.unpack('>H', self.read(0xB22A, 2))
|
|
|
|
# Validate completion status based on PCIe request typ
|
|
# Completion TLPs for configuration requests always have a byte count of 4.
|
|
assert completion[0] & 0xfff == (4 if (fmt_type & 0xbe == 0x04) else size)
|
|
|
|
# Extract completion status field
|
|
status = (completion[0] >> 13) & 0x7
|
|
|
|
# Handle completion errors or inconsistencies
|
|
if status or ((fmt_type & 0xbe == 0x04) and (((value is None) and (not (b284 & 0x01))) or ((value is not None) and (b284 & 0x01)))):
|
|
status_map = {0b001: f"Unsupported Request: invalid address/function (target might not be reachable): {address:#x}",
|
|
0b100: "Completer Abort: abort due to internal error", 0b010: "Configuration Request Retry Status: configuration space busy"}
|
|
raise RuntimeError(f"TLP status: {status_map.get(status, 'Reserved (0b{:03b})'.format(status))}")
|
|
|
|
if value is None: return (struct.unpack('>I', self.read(0xB220, 4))[0] >> (8 * (address & 0x3))) & ((1 << (8 * size)) - 1)
|
|
|
|
def pcie_cfg_req(self, byte_addr, bus=1, dev=0, fn=0, value=None, size=4):
|
|
assert byte_addr >> 12 == 0 and bus >> 8 == 0 and dev >> 5 == 0 and fn >> 3 == 0, f"Invalid byte_addr {byte_addr}, bus {bus}, dev {dev}, fn {fn}"
|
|
|
|
fmt_type = (0x44 if value is not None else 0x4) | int(bus > 0)
|
|
address = (bus << 24) | (dev << 19) | (fn << 16) | (byte_addr & 0xfff)
|
|
return self.pcie_request(fmt_type, address, value, size)
|
|
|
|
def pcie_mem_req(self, address, value=None, size=4): return self.pcie_request(0x60 if value is not None else 0x20, address, value, size)
|
|
|
|
def pcie_mem_write(self, address, values, size):
|
|
ops = [self.pcie_prep_request(0x60, address + i * size, value, size) for i, value in enumerate(values)]
|
|
|
|
# Send in batches of 4
|
|
for i in range(0, len(ops), 4): self.exec_ops(list(itertools.chain.from_iterable(ops[i:i+4])))
|
|
|
|
class USBMMIOInterface(MMIOInterface):
|
|
def __init__(self, usb, addr, size, fmt, pcimem=True):
|
|
self.usb, self.addr, self.nbytes, self.fmt, self.pcimem, self.el_sz = usb, addr, size, fmt, pcimem, struct.calcsize(fmt)
|
|
|
|
def __getitem__(self, index): return self._access_items(index)
|
|
def __setitem__(self, index, val): self._access_items(index, val)
|
|
|
|
def _access_items(self, index, val=None):
|
|
if isinstance(index, slice): return self._acc((index.start or 0) * self.el_sz, ((index.stop or len(self))-(index.start or 0)) * self.el_sz, val)
|
|
return self._acc_one(index * self.el_sz, self.el_sz, val) if self.pcimem else self._acc(index * self.el_sz, self.el_sz, val)
|
|
|
|
def view(self, offset:int=0, size:int|None=None, fmt=None):
|
|
return USBMMIOInterface(self.usb, self.addr+offset, size or (self.nbytes - offset), fmt=fmt or self.fmt, pcimem=self.pcimem)
|
|
|
|
def _acc_size(self, sz): return next(x for x in [('I', 4), ('H', 2), ('B', 1)] if sz % x[1] == 0)
|
|
|
|
def _acc_one(self, off, sz, val=None):
|
|
upper = 0 if sz < 8 else self.usb.pcie_mem_req(self.addr + off + 4, val if val is None else (val >> 32), 4)
|
|
lower = self.usb.pcie_mem_req(self.addr + off, val if val is None else val & 0xffffffff, min(sz, 4))
|
|
if val is None: return lower | (upper << 32)
|
|
|
|
def _acc(self, off, sz, data=None):
|
|
if data is None: # read op
|
|
if not self.pcimem:
|
|
return int.from_bytes(self.usb.read(self.addr + off, sz), "little") if sz == self.el_sz else self.usb.read(self.addr + off, sz)
|
|
|
|
acc, acc_size = self._acc_size(sz)
|
|
return bytes(array.array(acc, [self._acc_one(off + i * acc_size, acc_size) for i in range(sz // acc_size)]))
|
|
else: # write op
|
|
data = struct.pack(self.fmt, data) if isinstance(data, int) else bytes(data)
|
|
|
|
if not self.pcimem:
|
|
# Fast path for writing into buffer 0xf000
|
|
use_cache = 0xa800 <= self.addr <= 0xb000
|
|
return self.usb.scsi_write(bytes(data)) if self.addr == 0xf000 else self.usb.write(self.addr + off, bytes(data), ignore_cache=not use_cache)
|
|
|
|
_, acc_sz = self._acc_size(len(data) * struct.calcsize(self.fmt))
|
|
self.usb.pcie_mem_write(self.addr+off, [int.from_bytes(data[i:i+acc_sz], "little") for i in range(0, len(data), acc_sz)], acc_sz)
|
|
|