use dataclass_transform!

pull/33208/head
Shane Smiskol 1 year ago
parent 9b72521ef2
commit 418d1ce230
  1. 40
      selfdrive/car/data_structures.py

@ -1,6 +1,6 @@
from dataclasses import dataclass, field, is_dataclass
from dataclasses import dataclass as _dataclass, field, is_dataclass
from enum import Enum, StrEnum as _StrEnum, auto
from typing import get_origin
from typing import dataclass_transform, get_origin
auto_obj = object()
@ -9,7 +9,8 @@ def auto_field():
return auto_obj
def apply_auto_fields(cls=None, /, **kwargs):
@dataclass_transform()
def auto_dataclass(cls=None, /, **kwargs):
cls_annotations = cls.__dict__.get('__annotations__', {})
for name, typ in cls_annotations.items():
current_value = getattr(cls, name, None)
@ -25,7 +26,8 @@ def apply_auto_fields(cls=None, /, **kwargs):
setattr(cls, name, field(default=next(iter(origin_typ))))
else:
raise TypeError(f"Unsupported type for auto_field: {origin_typ}")
return cls
return _dataclass(cls, **kwargs)
class StrEnum(_StrEnum):
@ -35,8 +37,7 @@ class StrEnum(_StrEnum):
return name
@dataclass
@apply_auto_fields
@auto_dataclass
class RadarData:
errors: list['Error'] = auto_field()
points: list['RadarPoint'] = auto_field()
@ -46,8 +47,7 @@ class RadarData:
fault = auto()
wrongConfig = auto()
@dataclass
@apply_auto_fields
@auto_dataclass
class RadarPoint:
trackId: int = auto_field() # no trackId reuse
@ -64,8 +64,7 @@ class RadarData:
measured: bool = auto_field()
@dataclass
@apply_auto_fields
@auto_dataclass
class CarParams:
carName: str = auto_field()
carFingerprint: str = auto_field()
@ -102,8 +101,7 @@ class CarParams:
lateralParams: 'CarParams.LateralParams' = field(default_factory=lambda: CarParams.LateralParams())
lateralTuning: 'CarParams.LateralTuning' = field(default_factory=lambda: CarParams.LateralTuning())
@dataclass
@apply_auto_fields
@auto_dataclass
class LateralTuning:
def init(self, which: str):
assert which in ('pid', 'torque'), 'Invalid union type'
@ -114,20 +112,17 @@ class CarParams:
pid: 'CarParams.LateralPIDTuning' = field(default_factory=lambda: CarParams.LateralPIDTuning())
torque: 'CarParams.LateralTorqueTuning' = field(default_factory=lambda: CarParams.LateralTorqueTuning())
@dataclass
@apply_auto_fields
@auto_dataclass
class SafetyConfig:
safetyModel: 'CarParams.SafetyModel' = field(default_factory=lambda: CarParams.SafetyModel.silent)
safetyParam: int = auto_field()
@dataclass
@apply_auto_fields
@auto_dataclass
class LateralParams:
torqueBP: list[int] = auto_field()
torqueV: list[int] = auto_field()
@dataclass
@apply_auto_fields
@auto_dataclass
class LateralPIDTuning:
kpBP: list[float] = auto_field()
kpV: list[float] = auto_field()
@ -135,8 +130,7 @@ class CarParams:
kiV: list[float] = auto_field()
kf: float = auto_field()
@dataclass
@apply_auto_fields
@auto_dataclass
class LateralTorqueTuning:
useSteeringAngle: bool = auto_field()
kp: float = auto_field()
@ -176,8 +170,7 @@ class CarParams:
wheelSpeedFactor: float = auto_field() # Multiplier on wheels speeds to computer actual speeds
@dataclass
@apply_auto_fields
@auto_dataclass
class LongitudinalPIDTuning:
kpBP: list[float] = auto_field()
kpV: list[float] = auto_field()
@ -230,8 +223,7 @@ class CarParams:
direct = auto() # Electric vehicle or other direct drive
cvt = auto()
@dataclass
@apply_auto_fields
@auto_dataclass
class CarFw:
ecu: 'CarParams.Ecu' = field(default_factory=lambda: CarParams.Ecu.unknown)
fwVersion: bytes = auto_field()

Loading…
Cancel
Save