diff --git a/selfdrive/car/structs.py b/selfdrive/car/structs.py index dce60e99b2..f2f1e1be4f 100644 --- a/selfdrive/car/structs.py +++ b/selfdrive/car/structs.py @@ -1,4 +1,5 @@ -from dataclasses import dataclass as _dataclass, field, is_dataclass +import attr +from dataclasses import field from enum import Enum, StrEnum as _StrEnum, auto from typing import dataclass_transform, get_origin @@ -9,6 +10,27 @@ def auto_field(): return auto_obj +# @dataclass_transform() +# def auto_dataclass(cls=None, /, **kwargs): +# cls_annotations = cls.__dict__.get('__annotations__', {}) +# for name, typ in cls_annotations.items(): +# current_value = getattr(cls, name, None) +# if current_value is auto_obj: +# origin_typ = get_origin(typ) or typ +# if isinstance(origin_typ, str): +# raise TypeError(f"Forward references are not supported for auto_field: '{origin_typ}'. Use a default_factory with lambda instead.") +# elif origin_typ in (int, float, str, bytes, list, tuple, set, dict, bool) or is_dataclass(origin_typ): +# setattr(cls, name, field(default_factory=origin_typ)) +# elif origin_typ is None: +# setattr(cls, name, field(default=origin_typ)) +# elif issubclass(origin_typ, Enum): # first enum is the default +# setattr(cls, name, field(default=next(iter(origin_typ)))) +# else: +# raise TypeError(f"Unsupported type for auto_field: {origin_typ}") +# +# return _dataclass(cls, **kwargs) + + @dataclass_transform() def auto_dataclass(cls=None, /, **kwargs): cls_annotations = cls.__dict__.get('__annotations__', {}) @@ -18,16 +40,24 @@ def auto_dataclass(cls=None, /, **kwargs): origin_typ = get_origin(typ) or typ if isinstance(origin_typ, str): raise TypeError(f"Forward references are not supported for auto_field: '{origin_typ}'. Use a default_factory with lambda instead.") - elif origin_typ in (int, float, str, bytes, list, tuple, set, dict, bool) or is_dataclass(origin_typ): - setattr(cls, name, field(default_factory=origin_typ)) + elif origin_typ in (int, float, str, bytes, list, tuple, set, dict, bool): + setattr(cls, name, attr.attr(factory=origin_typ)) + elif attr.has(origin_typ): + def convert(data, _origin_typ=origin_typ): + print('got data', data) + if attr.has(data):# or (not any(isinstance(k, dict) for k in data) and len(data)): + return data + # print('ret cls **data', cls) + return _origin_typ(**data) + setattr(cls, name, attr.attr(factory=origin_typ, converter=convert)) elif origin_typ is None: - setattr(cls, name, field(default=origin_typ)) + setattr(cls, name, attr.attr(default=origin_typ)) elif issubclass(origin_typ, Enum): # first enum is the default - setattr(cls, name, field(default=next(iter(origin_typ)))) + setattr(cls, name, attr.attr(default=next(iter(origin_typ)))) else: raise TypeError(f"Unsupported type for auto_field: {origin_typ}") - return _dataclass(cls, **kwargs) + return attr.dataclass(cls, slots=True, **kwargs) class StrEnum(_StrEnum): @@ -497,3 +527,15 @@ class CarParams: class NetworkLocation(StrEnum): fwdCamera = auto() # Standard/default integration at LKAS camera gateway = auto() # Integration at vehicle's CAN gateway + + +# @attr.dataclass(slots=True) +@auto_dataclass +class Test: + # actuators: CarControl.Actuators = attr.attr(factory=lambda: CarControl.Actuators(), converter=lambda arg: CarControl.Actuators(**arg)) + actuators: CarControl.Actuators = auto_field() + hudControl: CarControl.HUDControl = auto_field() # attr.attr(factory=lambda: CarControl.HUDControl(), converter=lambda arg: CarControl.HUDControl(**arg)) + + +# Test(**{'actuators': {'gas': 1.0}}) +