use dataclass_transform!

pull/33208/head
Shane Smiskol 1 year ago
parent 9b72521ef2
commit 418d1ce230
  1. 40
      selfdrive/car/data_structures.py

@ -1,6 +1,6 @@
from dataclasses import dataclass, field, is_dataclass from dataclasses import dataclass as _dataclass, field, is_dataclass
from enum import Enum, StrEnum as _StrEnum, auto from enum import Enum, StrEnum as _StrEnum, auto
from typing import get_origin from typing import dataclass_transform, get_origin
auto_obj = object() auto_obj = object()
@ -9,7 +9,8 @@ def auto_field():
return auto_obj return auto_obj
def apply_auto_fields(cls=None, /, **kwargs): @dataclass_transform()
def auto_dataclass(cls=None, /, **kwargs):
cls_annotations = cls.__dict__.get('__annotations__', {}) cls_annotations = cls.__dict__.get('__annotations__', {})
for name, typ in cls_annotations.items(): for name, typ in cls_annotations.items():
current_value = getattr(cls, name, None) current_value = getattr(cls, name, None)
@ -25,7 +26,8 @@ def apply_auto_fields(cls=None, /, **kwargs):
setattr(cls, name, field(default=next(iter(origin_typ)))) setattr(cls, name, field(default=next(iter(origin_typ))))
else: else:
raise TypeError(f"Unsupported type for auto_field: {origin_typ}") raise TypeError(f"Unsupported type for auto_field: {origin_typ}")
return cls
return _dataclass(cls, **kwargs)
class StrEnum(_StrEnum): class StrEnum(_StrEnum):
@ -35,8 +37,7 @@ class StrEnum(_StrEnum):
return name return name
@dataclass @auto_dataclass
@apply_auto_fields
class RadarData: class RadarData:
errors: list['Error'] = auto_field() errors: list['Error'] = auto_field()
points: list['RadarPoint'] = auto_field() points: list['RadarPoint'] = auto_field()
@ -46,8 +47,7 @@ class RadarData:
fault = auto() fault = auto()
wrongConfig = auto() wrongConfig = auto()
@dataclass @auto_dataclass
@apply_auto_fields
class RadarPoint: class RadarPoint:
trackId: int = auto_field() # no trackId reuse trackId: int = auto_field() # no trackId reuse
@ -64,8 +64,7 @@ class RadarData:
measured: bool = auto_field() measured: bool = auto_field()
@dataclass @auto_dataclass
@apply_auto_fields
class CarParams: class CarParams:
carName: str = auto_field() carName: str = auto_field()
carFingerprint: str = auto_field() carFingerprint: str = auto_field()
@ -102,8 +101,7 @@ class CarParams:
lateralParams: 'CarParams.LateralParams' = field(default_factory=lambda: CarParams.LateralParams()) lateralParams: 'CarParams.LateralParams' = field(default_factory=lambda: CarParams.LateralParams())
lateralTuning: 'CarParams.LateralTuning' = field(default_factory=lambda: CarParams.LateralTuning()) lateralTuning: 'CarParams.LateralTuning' = field(default_factory=lambda: CarParams.LateralTuning())
@dataclass @auto_dataclass
@apply_auto_fields
class LateralTuning: class LateralTuning:
def init(self, which: str): def init(self, which: str):
assert which in ('pid', 'torque'), 'Invalid union type' assert which in ('pid', 'torque'), 'Invalid union type'
@ -114,20 +112,17 @@ class CarParams:
pid: 'CarParams.LateralPIDTuning' = field(default_factory=lambda: CarParams.LateralPIDTuning()) pid: 'CarParams.LateralPIDTuning' = field(default_factory=lambda: CarParams.LateralPIDTuning())
torque: 'CarParams.LateralTorqueTuning' = field(default_factory=lambda: CarParams.LateralTorqueTuning()) torque: 'CarParams.LateralTorqueTuning' = field(default_factory=lambda: CarParams.LateralTorqueTuning())
@dataclass @auto_dataclass
@apply_auto_fields
class SafetyConfig: class SafetyConfig:
safetyModel: 'CarParams.SafetyModel' = field(default_factory=lambda: CarParams.SafetyModel.silent) safetyModel: 'CarParams.SafetyModel' = field(default_factory=lambda: CarParams.SafetyModel.silent)
safetyParam: int = auto_field() safetyParam: int = auto_field()
@dataclass @auto_dataclass
@apply_auto_fields
class LateralParams: class LateralParams:
torqueBP: list[int] = auto_field() torqueBP: list[int] = auto_field()
torqueV: list[int] = auto_field() torqueV: list[int] = auto_field()
@dataclass @auto_dataclass
@apply_auto_fields
class LateralPIDTuning: class LateralPIDTuning:
kpBP: list[float] = auto_field() kpBP: list[float] = auto_field()
kpV: list[float] = auto_field() kpV: list[float] = auto_field()
@ -135,8 +130,7 @@ class CarParams:
kiV: list[float] = auto_field() kiV: list[float] = auto_field()
kf: float = auto_field() kf: float = auto_field()
@dataclass @auto_dataclass
@apply_auto_fields
class LateralTorqueTuning: class LateralTorqueTuning:
useSteeringAngle: bool = auto_field() useSteeringAngle: bool = auto_field()
kp: float = auto_field() kp: float = auto_field()
@ -176,8 +170,7 @@ class CarParams:
wheelSpeedFactor: float = auto_field() # Multiplier on wheels speeds to computer actual speeds wheelSpeedFactor: float = auto_field() # Multiplier on wheels speeds to computer actual speeds
@dataclass @auto_dataclass
@apply_auto_fields
class LongitudinalPIDTuning: class LongitudinalPIDTuning:
kpBP: list[float] = auto_field() kpBP: list[float] = auto_field()
kpV: list[float] = auto_field() kpV: list[float] = auto_field()
@ -230,8 +223,7 @@ class CarParams:
direct = auto() # Electric vehicle or other direct drive direct = auto() # Electric vehicle or other direct drive
cvt = auto() cvt = auto()
@dataclass @auto_dataclass
@apply_auto_fields
class CarFw: class CarFw:
ecu: 'CarParams.Ecu' = field(default_factory=lambda: CarParams.Ecu.unknown) ecu: 'CarParams.Ecu' = field(default_factory=lambda: CarParams.Ecu.unknown)
fwVersion: bytes = auto_field() fwVersion: bytes = auto_field()

Loading…
Cancel
Save