You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
126 lines
9.3 KiB
126 lines
9.3 KiB
import ctypes.util, importlib.metadata, itertools, re, functools, os
|
|
from tinygrad.helpers import flatten, unwrap
|
|
from clang.cindex import Config, Index, CursorKind as CK, TranslationUnit as TU, LinkageKind as LK, TokenKind as ToK, TypeKind as TK
|
|
from clang.cindex import PrintingPolicy as PP, PrintingPolicyProperty as PPP, SourceRange
|
|
|
|
assert importlib.metadata.version('clang')[:2] == "20"
|
|
if not Config.loaded: Config.set_library_file(os.getenv("LIBCLANG_PATH", ctypes.util.find_library("clang-20")))
|
|
|
|
def fst(c): return next(c.get_children())
|
|
def last(c): return list(c.get_children())[-1]
|
|
def readext(f, fst, snd=None):
|
|
with open(f, "r") as f:
|
|
f.seek(start:=(fst.start.offset if isinstance(fst, SourceRange) else fst))
|
|
return f.read((fst.end.offset if isinstance(fst, SourceRange) else snd)-start)
|
|
def attrs(c): return list(filter(lambda k: (v:=k.value) >= 400 and v < 500, map(lambda c: c.kind, c.get_children())))
|
|
|
|
base_rules = [(r'\s*\\\n\s*', ' '), (r'\s*\n\s*', ' '), (r'//.*', ''), (r'/\*.*?\*/', ''), (r'\b(0[xX][0-9a-fA-F]+|\d+)[uUlL]+\b', r'\1'),
|
|
(r'\b0+(?=\d)', ''), (r'\s*&&\s*', r' and '), (r'\s*\|\|\s*', r' or '), (r'\s*!\s*', ' not '),
|
|
(r'(struct|union|enum)\s*([a-zA-Z_][a-zA-Z0-9_]*\b)', r'\1_\2'),
|
|
(r'\((unsigned )?(char|uint64_t)\)', ''), (r'^.*\d+:\d+.*$', ''), (r'^.*\w##\w.*$', '')]
|
|
|
|
ints = (TK.INT, TK.UINT, TK.LONG, TK.ULONG, TK.LONGLONG, TK.ULONGLONG)
|
|
|
|
def gen(dll, files, args=[], prolog=[], rules=[], epilog=[], recsym=False, use_errno=False, anon_names={}, types={}, parse_macros=True):
|
|
macros, lines, anoncnt, types = [], [], itertools.count().__next__, {k:(v,True) for k,v in types.items()}
|
|
def tname(t, suggested_name=None, typedef=None) -> str:
|
|
suggested_name = anon_names.get(f"{(decl:=t.get_declaration()).location.file}:{decl.location.line}", suggested_name)
|
|
nonlocal lines, types, anoncnt
|
|
tmap = {TK.VOID:"None", TK.CHAR_U:"ctypes.c_ubyte", TK.UCHAR:"ctypes.c_ubyte", TK.CHAR_S:"ctypes.c_char", TK.SCHAR:"ctypes.c_char",
|
|
**{getattr(TK, k):f"ctypes.c_{k.lower()}" for k in ["BOOL", "WCHAR", "FLOAT", "DOUBLE", "LONGDOUBLE"]},
|
|
**{getattr(TK, k):f"ctypes.c_{'u' if 'U' in k else ''}int{sz}" for sz,k in
|
|
[(16, "USHORT"), (16, "SHORT"), (32, "UINT"), (32, "INT"), (64, "ULONG"), (64, "LONG"), (64, "ULONGLONG"), (64, "LONGLONG")]}}
|
|
|
|
if t.kind in tmap: return tmap[t.kind]
|
|
if t.spelling in types and types[t.spelling][1]: return types[t.spelling][0]
|
|
if ((f:=t).kind in (fks:=(TK.FUNCTIONPROTO, TK.FUNCTIONNOPROTO))) or (t.kind == TK.POINTER and (f:=t.get_pointee()).kind in fks):
|
|
return f"ctypes.CFUNCTYPE({tname(f.get_result())}{(', '+', '.join(map(tname, f.argument_types()))) if f.kind==TK.FUNCTIONPROTO else ''})"
|
|
match t.kind:
|
|
case TK.POINTER: return "ctypes.c_void_p" if (ptr:=t.get_pointee()).kind == TK.VOID else f"ctypes.POINTER({tname(ptr)})"
|
|
case TK.ELABORATED: return tname(t.get_named_type(), suggested_name)
|
|
case TK.TYPEDEF if t.spelling == t.get_canonical().spelling: return tname(t.get_canonical())
|
|
case TK.TYPEDEF:
|
|
defined, nm = (canon:=t.get_canonical()).spelling in types, tname(canon, typedef=t.spelling.replace('::', '_'))
|
|
types[t.spelling] = nm if t.spelling.startswith("__") else t.spelling.replace('::', '_'), True
|
|
# RECORDs need to handle typedefs specially to allow for self-reference
|
|
if canon.kind != TK.RECORD or defined: lines.append(f"{t.spelling.replace('::', '_')} = {nm}")
|
|
return types[t.spelling][0]
|
|
case TK.RECORD:
|
|
# TODO: packed unions
|
|
# TODO: pragma pack support
|
|
# check for forward declaration
|
|
if t.spelling in types: types[t.spelling] = (nm:=types[t.spelling][0]), len(list(t.get_fields())) != 0
|
|
else:
|
|
if decl.is_anonymous():
|
|
types[t.spelling] = (nm:=(suggested_name or (f"_anon{'struct' if decl.kind == CK.STRUCT_DECL else 'union'}{anoncnt()}")), True)
|
|
else: types[t.spelling] = (nm:=t.spelling.replace(' ', '_').replace('::', '_')), len(list(t.get_fields())) != 0
|
|
lines.append(f"class {nm}({'Struct' if decl.kind==CK.STRUCT_DECL else 'ctypes.Union'}): pass")
|
|
if typedef: lines.append(f"{typedef} = {nm}")
|
|
acnt = itertools.count().__next__
|
|
ll=[" ("+((fn:=f"'_{acnt()}'")+f", {tname(f.type, nm+fn[1:-1])}" if f.is_anonymous_record_decl() else f"'{f.spelling}', "+
|
|
tname(f.type, f'{nm}_{f.spelling}'))+(f',{f.get_bitfield_width()}' if f.is_bitfield() else '')+")," for f in t.get_fields()]
|
|
lines.extend(([f"{nm}._anonymous_ = ["+", ".join(f"'_{i}'" for i in range(n))+"]"] if (n:=acnt()) else [])+
|
|
([f"{nm}._packed_ = True"] * (CK.PACKED_ATTR in attrs(decl)))+([f"{nm}._fields_ = [",*ll,"]"] if ll else []))
|
|
return nm
|
|
case TK.ENUM:
|
|
# TODO: C++ and GNU C have forward declared enums
|
|
if decl.is_anonymous(): types[t.spelling] = suggested_name or f"_anonenum{anoncnt()}", True
|
|
else: types[t.spelling] = t.spelling.replace(' ', '_').replace('::', '_'), True
|
|
lines.append(f"{types[t.spelling][0]} = CEnum({tname(decl.enum_type)})\n" +
|
|
"\n".join(f"{e.spelling} = {types[t.spelling][0]}.define('{e.spelling}', {e.enum_value})" for e in decl.get_children()
|
|
if e.kind == CK.ENUM_CONSTANT_DECL) + "\n")
|
|
return types[t.spelling][0]
|
|
case TK.CONSTANTARRAY:
|
|
return f"({tname(t.get_array_element_type(), suggested_name.rstrip('s') if suggested_name else None)} * {t.get_array_size()})"
|
|
case TK.INCOMPLETEARRAY: return f"({tname(t.get_array_element_type(), suggested_name.rstrip('s') if suggested_name else None)} * 0)"
|
|
case _: raise NotImplementedError(f"unsupported type {t.kind}")
|
|
|
|
for f in files:
|
|
tu = Index.create().parse(f, args, options=TU.PARSE_DETAILED_PROCESSING_RECORD)
|
|
(pp:=PP.create(tu.cursor)).set_property(PPP.TerseOutput, 1)
|
|
for c in tu.cursor.walk_preorder():
|
|
if str(c.location.file) != str(f) and (not recsym or c.kind not in (CK.FUNCTION_DECL,)): continue
|
|
rollback = lines, types
|
|
try:
|
|
match c.kind:
|
|
case CK.FUNCTION_DECL if c.linkage == LK.EXTERNAL and dll:
|
|
# TODO: we could support name-mangling
|
|
lines.append(f"# {c.pretty_printed(pp)}\ntry: ({c.spelling}:=dll.{c.spelling}).restype, {c.spelling}.argtypes = "
|
|
f"{tname(c.result_type)}, [{', '.join(tname(arg.type) for arg in c.get_arguments())}]\nexcept AttributeError: pass\n")
|
|
case CK.STRUCT_DECL | CK.UNION_DECL | CK.TYPEDEF_DECL | CK.ENUM_DECL: tname(c.type)
|
|
case CK.MACRO_DEFINITION if parse_macros and len(toks:=list(c.get_tokens())) > 1:
|
|
if toks[1].spelling == '(' and toks[0].extent.end.column == toks[1].extent.start.column:
|
|
it = iter(toks[1:])
|
|
_args = [t.spelling for t in itertools.takewhile(lambda t:t.spelling!=')', it) if t.kind == ToK.IDENTIFIER]
|
|
if len(body:=list(it)) == 0: continue
|
|
macros += [f"{c.spelling} = lambda {','.join(_args)}: {readext(f, body[0].location.offset, toks[-1].extent.end.offset)}"]
|
|
else: macros += [f"{c.spelling} = {readext(f, toks[1].location.offset, toks[-1].extent.end.offset)}"]
|
|
case CK.VAR_DECL if c.linkage == LK.INTERNAL:
|
|
if (c.type.kind == TK.CONSTANTARRAY and c.type.get_array_element_type().get_canonical().kind in ints and
|
|
(init:=last(c)).kind == CK.INIT_LIST_EXPR and all(re.match(r"\[.*\].*=", readext(f, c.extent)) for c in init.get_children())):
|
|
cs = init.get_children()
|
|
macros += [f"{c.spelling} = {{{','.join(f'{readext(f,next(it:=c.get_children()).extent)}:{readext(f,next(it).extent)}' for c in cs)}}}"]
|
|
elif c.type.get_canonical().kind in ints: macros += [f"{c.spelling} = {readext(f, last(c).extent)}"]
|
|
else: macros += [f"{c.spelling} = {tname(c.type)}({readext(f, last(c).extent)})"]
|
|
case CK.VAR_DECL if c.linkage == LK.EXTERNAL and dll:
|
|
lines.append(f"try: {c.spelling} = {tname(c.type)}.in_dll(dll, '{c.spelling}')\nexcept (ValueError,AttributeError): pass")
|
|
except NotImplementedError as e:
|
|
print(f"skipping {c.spelling}: {e}")
|
|
lines, types = rollback
|
|
main = (f"# mypy: ignore-errors\nimport ctypes{', os' if any('os' in s for s in dll) else ''}\n"
|
|
"from tinygrad.helpers import unwrap\nfrom tinygrad.runtime.support.c import Struct, CEnum, _IO, _IOW, _IOR, _IOWR\n" + '\n'.join([*prolog,
|
|
*(["from ctypes.util import find_library"]*any('find_library' in s for s in dll)),
|
|
*(["def dll():",*flatten([[f" try: return ctypes.CDLL(unwrap({d}){', use_errno=True' if use_errno else ''})",' except: pass'] for d in dll]),
|
|
" return None", "dll = dll()\n"]*bool(dll)), *lines]) + '\n')
|
|
macros = [r for m in macros if (r:=functools.reduce(lambda s,r:re.sub(r[0], r[1], s), rules + base_rules, m))]
|
|
while True:
|
|
try:
|
|
exec(main + '\n'.join(macros), {})
|
|
break
|
|
except (SyntaxError, NameError, TypeError) as e:
|
|
macrono = unwrap(e.lineno if isinstance(e, SyntaxError) else unwrap(unwrap(e.__traceback__).tb_next).tb_lineno) - main.count('\n') - 1
|
|
assert macrono >= 0 and macrono < len(macros), f"error outside macro range: {e}"
|
|
print(f"skipping {macros[macrono]}: {e}")
|
|
del macros[macrono]
|
|
except Exception as e: raise Exception("parsing failed") from e
|
|
return main + '\n'.join(macros + epilog)
|
|
|