diff --git a/selfdrive/car/data_structures.py b/selfdrive/car/data_structures.py index ddc182ef50..f6ac9ed4ea 100644 --- a/selfdrive/car/data_structures.py +++ b/selfdrive/car/data_structures.py @@ -1,17 +1,33 @@ -from dataclasses import dataclass, fields, field, is_dataclass +from dataclasses import dataclass as _dataclass, field, is_dataclass from enum import Enum, StrEnum as _StrEnum, auto -# from typing import Type, TypeVar -from typing import TypeVar, TYPE_CHECKING, Any, get_type_hints, get_origin -from selfdrive.car.data_test_kinda_works_chatgpt import auto_field, apply_auto_fields +from typing import get_origin -if TYPE_CHECKING: - from _typeshed import DataclassInstance -# -# DataclassT = TypeVar("DataclassT", bound="DataclassInstance") -# -# T = TypeVar('T', bound='Struct') -_FIELDS = '__dataclass_fields__' +auto_obj = object() + + +def auto_field(): + return auto_obj + + +def auto_dataclass(cls=None, /, **kwargs): + cls_annotations = cls.__dict__.get('__annotations__', {}) + for name, typ in cls_annotations.items(): + current_value = getattr(cls, name, None) + if current_value is auto_obj: + origin_typ = get_origin(typ) or typ + if isinstance(origin_typ, str): + raise TypeError(f"Forward references are not supported for auto_field: '{origin_typ}'. Use a default_factory with lambda instead.") + elif origin_typ in (int, float, str, bytes, list, tuple, set, dict, bool) or is_dataclass(origin_typ): + setattr(cls, name, field(default_factory=origin_typ)) + elif origin_typ is None: + setattr(cls, name, field(default=origin_typ)) + elif issubclass(origin_typ, Enum): # first enum is the default + setattr(cls, name, field(default=next(iter(origin_typ)))) + else: + raise TypeError(f"Unsupported type for auto_field: {origin_typ}") + + return _dataclass(cls, **kwargs) class StrEnum(_StrEnum): @@ -21,74 +37,34 @@ class StrEnum(_StrEnum): return name -# class Struct: -# @classmethod -# def new_message(cls, **kwargs): -# init_values = {} -# for f in fields(cls): -# init_values[f.name] = kwargs.get(f.name, f.type()) -# -# return cls(**init_values) - -T = TypeVar('T', bound='DataclassInstance') - - -# class Struct: -# @classmethod -# def new_message(cls: type[T], **kwargs: Any) -> T: -# if not is_dataclass(cls): -# raise TypeError(f"{cls.__name__} is not a dataclass") -# -# init_values = {} -# type_hints = get_type_hints(cls) -# print(type_hints) -# for f in fields(cls): -# field_type = type_hints[f.name] -# print(f.name, f.type, field_type) -# print(issubclass(field_type, Enum)) -# if issubclass(field_type, Enum): -# init_values[f.name] = kwargs.get(f.name, list(field_type)[0]) -# # TODO: fix this -# # assert issubclass(init_values[f.name], type(field_type)), f"Expected {field_type} for {f.name}, got {type(init_values[f.name])}" -# else: -# # FIXME: typing check hack since mypy doesn't catch anything -# init_values[f.name] = kwargs.get(f.name, field_type()) -# print('field_type', field_type, f.type) -# # TODO: this is so bad -# assert isinstance(init_values[f.name], get_origin(f.type) or f.type), f"Expected {field_type} for {f.name}, got {type(init_values[f.name])}" -# -# return cls(**init_values) - - -@dataclass +@auto_dataclass class RadarData: - errors: list['Error'] - points: list['RadarPoint'] + errors: list['Error'] = auto_field() + points: list['RadarPoint'] = auto_field() class Error(StrEnum): canError = auto() fault = auto() wrongConfig = auto() - @dataclass + @auto_dataclass class RadarPoint: - trackId: int # no trackId reuse + trackId: int = auto_field() # no trackId reuse # these 3 are the minimum required - dRel: float # m from the front bumper of the car - yRel: float # m - vRel: float # m/s + dRel: float = auto_field() # m from the front bumper of the car + yRel: float = auto_field() # m + vRel: float = auto_field() # m/s # these are optional and valid if they are not NaN - aRel: float # m/s^2 - yvRel: float # m/s + aRel: float = auto_field() # m/s^2 + yvRel: float = auto_field() # m/s # some radars flag measurements VS estimates - measured: bool + measured: bool = auto_field() -@dataclass -@apply_auto_fields +@auto_dataclass class CarParams: carName: str = auto_field() carFingerprint: str = auto_field() @@ -102,8 +78,7 @@ class CarParams: torque = auto() angle = auto() - @dataclass - @apply_auto_fields + @auto_dataclass class CarFw: ecu: 'CarParams.Ecu' = field(default_factory=lambda: CarParams.Ecu.unknown) fwVersion: bytes = auto_field() @@ -149,61 +124,7 @@ class CarParams: debug = auto() -# # CP: CarParams = CarParams.new_message(carName='toyota', fuzzyFingerprint=123) -# # CP: CarParams = CarParams(carName='toyota', fuzzyFingerprint=123) -# -# # import ast -# -# -# # test = ast.literal_eval('CarParams.CarFw') -# -# def mywrapper(cls): -# -# cls_annotations = cls.__dict__.get('__annotations__', {}) -# fields = {} -# for name, _type in cls_annotations.items(): -# f = field(default_factory=_type) -# setattr(cls, name, f) -# fields[name] = f -# -# setattr(cls, _FIELDS, fields) -# -# print('cls_annotations', cls_annotations) -# # cls.hi = 123 -# -# return cls -# -# -# # def mywrapper2(cls): -# # class Test: -# # pass -# # return Test -# -# -# @dataclass -# class CarControl1: -# enabled: bool -# -# @dataclass -# class CarControl2: -# enabled: bool = field(default_factory=bool) -# -# -# # @mywrapper2 -# @dataclass() -# @mywrapper -# class CarControl: -# # enabled: bool = field(default_factory=bool) -# enabled: bool = None -# pts: list[int] = None -# logMonoTime: int = None -# -# -# CC = CarControl() - - -@dataclass -@apply_auto_fields +@auto_dataclass class CarControl: enabled: bool = auto_field() pts: list[int] = auto_field() @@ -212,8 +133,7 @@ class CarControl: # testing: if origin_typ in (int, float, str, bytes, list, tuple, set, dict, bool): -@dataclass -@apply_auto_fields +@auto_dataclass class Test997: a: int = auto_field() b: float = auto_field() @@ -233,4 +153,4 @@ CarControl() CP = CarParams() CP.carFw = [CarParams.CarFw()] -CP.carFw = [CarParams.Ecu.eps] +# CP.carFw = [CarParams.Ecu.eps]