From 418d1ce230686b623015498eb26ad61b9e0e42e7 Mon Sep 17 00:00:00 2001 From: Shane Smiskol Date: Fri, 9 Aug 2024 16:21:18 -0700 Subject: [PATCH] use dataclass_transform! --- selfdrive/car/data_structures.py | 40 +++++++++++++------------------- 1 file changed, 16 insertions(+), 24 deletions(-) diff --git a/selfdrive/car/data_structures.py b/selfdrive/car/data_structures.py index 58ac750c7d..61895fab46 100644 --- a/selfdrive/car/data_structures.py +++ b/selfdrive/car/data_structures.py @@ -1,6 +1,6 @@ -from dataclasses import dataclass, field, is_dataclass +from dataclasses import dataclass as _dataclass, field, is_dataclass from enum import Enum, StrEnum as _StrEnum, auto -from typing import get_origin +from typing import dataclass_transform, get_origin auto_obj = object() @@ -9,7 +9,8 @@ def auto_field(): return auto_obj -def apply_auto_fields(cls=None, /, **kwargs): +@dataclass_transform() +def auto_dataclass(cls=None, /, **kwargs): cls_annotations = cls.__dict__.get('__annotations__', {}) for name, typ in cls_annotations.items(): current_value = getattr(cls, name, None) @@ -25,7 +26,8 @@ def apply_auto_fields(cls=None, /, **kwargs): setattr(cls, name, field(default=next(iter(origin_typ)))) else: raise TypeError(f"Unsupported type for auto_field: {origin_typ}") - return cls + + return _dataclass(cls, **kwargs) class StrEnum(_StrEnum): @@ -35,8 +37,7 @@ class StrEnum(_StrEnum): return name -@dataclass -@apply_auto_fields +@auto_dataclass class RadarData: errors: list['Error'] = auto_field() points: list['RadarPoint'] = auto_field() @@ -46,8 +47,7 @@ class RadarData: fault = auto() wrongConfig = auto() - @dataclass - @apply_auto_fields + @auto_dataclass class RadarPoint: trackId: int = auto_field() # no trackId reuse @@ -64,8 +64,7 @@ class RadarData: measured: bool = auto_field() -@dataclass -@apply_auto_fields +@auto_dataclass class CarParams: carName: str = auto_field() carFingerprint: str = auto_field() @@ -102,8 +101,7 @@ class CarParams: lateralParams: 'CarParams.LateralParams' = field(default_factory=lambda: CarParams.LateralParams()) lateralTuning: 'CarParams.LateralTuning' = field(default_factory=lambda: CarParams.LateralTuning()) - @dataclass - @apply_auto_fields + @auto_dataclass class LateralTuning: def init(self, which: str): assert which in ('pid', 'torque'), 'Invalid union type' @@ -114,20 +112,17 @@ class CarParams: pid: 'CarParams.LateralPIDTuning' = field(default_factory=lambda: CarParams.LateralPIDTuning()) torque: 'CarParams.LateralTorqueTuning' = field(default_factory=lambda: CarParams.LateralTorqueTuning()) - @dataclass - @apply_auto_fields + @auto_dataclass class SafetyConfig: safetyModel: 'CarParams.SafetyModel' = field(default_factory=lambda: CarParams.SafetyModel.silent) safetyParam: int = auto_field() - @dataclass - @apply_auto_fields + @auto_dataclass class LateralParams: torqueBP: list[int] = auto_field() torqueV: list[int] = auto_field() - @dataclass - @apply_auto_fields + @auto_dataclass class LateralPIDTuning: kpBP: list[float] = auto_field() kpV: list[float] = auto_field() @@ -135,8 +130,7 @@ class CarParams: kiV: list[float] = auto_field() kf: float = auto_field() - @dataclass - @apply_auto_fields + @auto_dataclass class LateralTorqueTuning: useSteeringAngle: bool = auto_field() kp: float = auto_field() @@ -176,8 +170,7 @@ class CarParams: wheelSpeedFactor: float = auto_field() # Multiplier on wheels speeds to computer actual speeds - @dataclass - @apply_auto_fields + @auto_dataclass class LongitudinalPIDTuning: kpBP: list[float] = auto_field() kpV: list[float] = auto_field() @@ -230,8 +223,7 @@ class CarParams: direct = auto() # Electric vehicle or other direct drive cvt = auto() - @dataclass - @apply_auto_fields + @auto_dataclass class CarFw: ecu: 'CarParams.Ecu' = field(default_factory=lambda: CarParams.Ecu.unknown) fwVersion: bytes = auto_field()