From 965f692f2e12e322972f1748a25c6bb54a85499d Mon Sep 17 00:00:00 2001 From: Shane Smiskol Date: Thu, 8 Aug 2024 21:48:17 -0700 Subject: [PATCH] epic! --- .../car/data_test_kinda_works_chatgpt.py | 31 ++++++++++++++----- 1 file changed, 24 insertions(+), 7 deletions(-) diff --git a/selfdrive/car/data_test_kinda_works_chatgpt.py b/selfdrive/car/data_test_kinda_works_chatgpt.py index 9b5d6d8d1a..016cec80c9 100644 --- a/selfdrive/car/data_test_kinda_works_chatgpt.py +++ b/selfdrive/car/data_test_kinda_works_chatgpt.py @@ -1,27 +1,44 @@ # import attr -from dataclasses import dataclass, field +from enum import Enum +from typing import get_origin, get_args, get_type_hints +from dataclasses import dataclass, field, is_dataclass auto_obj = object() -def auto_factory(): +def auto_field(): return auto_obj -def apply_auto_factory(cls): +def apply_auto_fields(cls): cls_annotations = cls.__dict__.get('__annotations__', {}) + # type_hints = get_type_hints(cls) for name, typ in cls_annotations.items(): + # typ = type_hints.get(name, typ) current_value = getattr(cls, name, None) if current_value is auto_obj: - setattr(cls, name, field(default_factory=typ)) + origin_typ = get_origin(typ) or typ + print('name123', name, typ, origin_typ) + if isinstance(origin_typ, str): + raise TypeError(f"Forward references are not supported for auto_field: '{origin_typ}'. Use a default_factory with lambda instead.") + elif origin_typ in (int, float, str, bytes, list, tuple, set, dict, bool): + setattr(cls, name, field(default_factory=origin_typ)) + elif origin_typ is None: + setattr(cls, name, field(default=origin_typ)) + elif is_dataclass(origin_typ): + setattr(cls, name, field(default_factory=origin_typ)) + elif issubclass(origin_typ, Enum): # first enum is the default + setattr(cls, name, field(default=next(iter(origin_typ)))) + else: + raise TypeError(f"Unsupported type for auto_field: {origin_typ}") return cls @dataclass -@apply_auto_factory +@apply_auto_fields class CarControl: - enabled: bool = auto_factory() - pts: list[int] = auto_factory() + enabled: bool = auto_field() + pts: list[int] = auto_field() logMonoTime: list[int] = field(default_factory=lambda: [1, 2, 3])