pull/33208/head
Shane Smiskol 11 months ago
parent 95351b133c
commit a06264a8a4
  1. 166
      selfdrive/car/data_structures.py

@ -1,17 +1,33 @@
from dataclasses import dataclass, fields, field, is_dataclass from dataclasses import dataclass as _dataclass, field, is_dataclass
from enum import Enum, StrEnum as _StrEnum, auto from enum import Enum, StrEnum as _StrEnum, auto
# from typing import Type, TypeVar from typing import get_origin
from typing import TypeVar, TYPE_CHECKING, Any, get_type_hints, get_origin
from selfdrive.car.data_test_kinda_works_chatgpt import auto_field, apply_auto_fields
if TYPE_CHECKING:
from _typeshed import DataclassInstance
#
# DataclassT = TypeVar("DataclassT", bound="DataclassInstance")
#
# T = TypeVar('T', bound='Struct')
_FIELDS = '__dataclass_fields__' auto_obj = object()
def auto_field():
return auto_obj
def auto_dataclass(cls=None, /, **kwargs):
cls_annotations = cls.__dict__.get('__annotations__', {})
for name, typ in cls_annotations.items():
current_value = getattr(cls, name, None)
if current_value is auto_obj:
origin_typ = get_origin(typ) or typ
if isinstance(origin_typ, str):
raise TypeError(f"Forward references are not supported for auto_field: '{origin_typ}'. Use a default_factory with lambda instead.")
elif origin_typ in (int, float, str, bytes, list, tuple, set, dict, bool) or is_dataclass(origin_typ):
setattr(cls, name, field(default_factory=origin_typ))
elif origin_typ is None:
setattr(cls, name, field(default=origin_typ))
elif issubclass(origin_typ, Enum): # first enum is the default
setattr(cls, name, field(default=next(iter(origin_typ))))
else:
raise TypeError(f"Unsupported type for auto_field: {origin_typ}")
return _dataclass(cls, **kwargs)
class StrEnum(_StrEnum): class StrEnum(_StrEnum):
@ -21,74 +37,34 @@ class StrEnum(_StrEnum):
return name return name
# class Struct: @auto_dataclass
# @classmethod
# def new_message(cls, **kwargs):
# init_values = {}
# for f in fields(cls):
# init_values[f.name] = kwargs.get(f.name, f.type())
#
# return cls(**init_values)
T = TypeVar('T', bound='DataclassInstance')
# class Struct:
# @classmethod
# def new_message(cls: type[T], **kwargs: Any) -> T:
# if not is_dataclass(cls):
# raise TypeError(f"{cls.__name__} is not a dataclass")
#
# init_values = {}
# type_hints = get_type_hints(cls)
# print(type_hints)
# for f in fields(cls):
# field_type = type_hints[f.name]
# print(f.name, f.type, field_type)
# print(issubclass(field_type, Enum))
# if issubclass(field_type, Enum):
# init_values[f.name] = kwargs.get(f.name, list(field_type)[0])
# # TODO: fix this
# # assert issubclass(init_values[f.name], type(field_type)), f"Expected {field_type} for {f.name}, got {type(init_values[f.name])}"
# else:
# # FIXME: typing check hack since mypy doesn't catch anything
# init_values[f.name] = kwargs.get(f.name, field_type())
# print('field_type', field_type, f.type)
# # TODO: this is so bad
# assert isinstance(init_values[f.name], get_origin(f.type) or f.type), f"Expected {field_type} for {f.name}, got {type(init_values[f.name])}"
#
# return cls(**init_values)
@dataclass
class RadarData: class RadarData:
errors: list['Error'] errors: list['Error'] = auto_field()
points: list['RadarPoint'] points: list['RadarPoint'] = auto_field()
class Error(StrEnum): class Error(StrEnum):
canError = auto() canError = auto()
fault = auto() fault = auto()
wrongConfig = auto() wrongConfig = auto()
@dataclass @auto_dataclass
class RadarPoint: class RadarPoint:
trackId: int # no trackId reuse trackId: int = auto_field() # no trackId reuse
# these 3 are the minimum required # these 3 are the minimum required
dRel: float # m from the front bumper of the car dRel: float = auto_field() # m from the front bumper of the car
yRel: float # m yRel: float = auto_field() # m
vRel: float # m/s vRel: float = auto_field() # m/s
# these are optional and valid if they are not NaN # these are optional and valid if they are not NaN
aRel: float # m/s^2 aRel: float = auto_field() # m/s^2
yvRel: float # m/s yvRel: float = auto_field() # m/s
# some radars flag measurements VS estimates # some radars flag measurements VS estimates
measured: bool measured: bool = auto_field()
@dataclass @auto_dataclass
@apply_auto_fields
class CarParams: class CarParams:
carName: str = auto_field() carName: str = auto_field()
carFingerprint: str = auto_field() carFingerprint: str = auto_field()
@ -102,8 +78,7 @@ class CarParams:
torque = auto() torque = auto()
angle = auto() angle = auto()
@dataclass @auto_dataclass
@apply_auto_fields
class CarFw: class CarFw:
ecu: 'CarParams.Ecu' = field(default_factory=lambda: CarParams.Ecu.unknown) ecu: 'CarParams.Ecu' = field(default_factory=lambda: CarParams.Ecu.unknown)
fwVersion: bytes = auto_field() fwVersion: bytes = auto_field()
@ -149,61 +124,7 @@ class CarParams:
debug = auto() debug = auto()
# # CP: CarParams = CarParams.new_message(carName='toyota', fuzzyFingerprint=123) @auto_dataclass
# # CP: CarParams = CarParams(carName='toyota', fuzzyFingerprint=123)
#
# # import ast
#
#
# # test = ast.literal_eval('CarParams.CarFw')
#
# def mywrapper(cls):
#
# cls_annotations = cls.__dict__.get('__annotations__', {})
# fields = {}
# for name, _type in cls_annotations.items():
# f = field(default_factory=_type)
# setattr(cls, name, f)
# fields[name] = f
#
# setattr(cls, _FIELDS, fields)
#
# print('cls_annotations', cls_annotations)
# # cls.hi = 123
#
# return cls
#
#
# # def mywrapper2(cls):
# # class Test:
# # pass
# # return Test
#
#
# @dataclass
# class CarControl1:
# enabled: bool
#
# @dataclass
# class CarControl2:
# enabled: bool = field(default_factory=bool)
#
#
# # @mywrapper2
# @dataclass()
# @mywrapper
# class CarControl:
# # enabled: bool = field(default_factory=bool)
# enabled: bool = None
# pts: list[int] = None
# logMonoTime: int = None
#
#
# CC = CarControl()
@dataclass
@apply_auto_fields
class CarControl: class CarControl:
enabled: bool = auto_field() enabled: bool = auto_field()
pts: list[int] = auto_field() pts: list[int] = auto_field()
@ -212,8 +133,7 @@ class CarControl:
# testing: if origin_typ in (int, float, str, bytes, list, tuple, set, dict, bool): # testing: if origin_typ in (int, float, str, bytes, list, tuple, set, dict, bool):
@dataclass @auto_dataclass
@apply_auto_fields
class Test997: class Test997:
a: int = auto_field() a: int = auto_field()
b: float = auto_field() b: float = auto_field()
@ -233,4 +153,4 @@ CarControl()
CP = CarParams() CP = CarParams()
CP.carFw = [CarParams.CarFw()] CP.carFw = [CarParams.CarFw()]
CP.carFw = [CarParams.Ecu.eps] # CP.carFw = [CarParams.Ecu.eps]

Loading…
Cancel
Save