diff --git a/selfdrive/car/data_structures.py b/selfdrive/car/data_structures.py index 4417466227..ef443907a7 100644 --- a/selfdrive/car/data_structures.py +++ b/selfdrive/car/data_structures.py @@ -100,7 +100,24 @@ class CarParams: longitudinalTuning: 'CarParams.LongitudinalPIDTuning' = field(default_factory=lambda: CarParams.LongitudinalPIDTuning()) lateralParams: 'CarParams.LateralParams' = field(default_factory=lambda: CarParams.LateralParams()) - lateralTuning: 'CarParams.LateralPIDTuning | CarParams.LateralTorqueTuning' = field(default_factory=lambda: CarParams.LateralPIDTuning()) + # lateralTuning: 'CarParams.LateralPIDTuning | CarParams.LateralTorqueTuning' = field(default_factory=lambda: CarParams.LateralPIDTuning()) + # lateralTuningWhich: type['CarParams.LateralPIDTuning | CarParams.LateralTorqueTuning'] = field(default_factory=lambda: CarParams.LateralPIDTuning) + # lateralPIDTuning: 'CarParams.LateralPIDTuning' = field(default_factory=lambda: CarParams.LateralPIDTuning()) + # lateralTorqueTuning: 'CarParams.LateralTorqueTuning' = field(default_factory=lambda: CarParams.LateralTorqueTuning()) + + lateralTuning: 'CarParams.LateralTuning' = field(default_factory=lambda: CarParams.LateralTuning()) + + @dataclass + @apply_auto_fields + class LateralTuning: + def init(self, which: str): + assert which in ('pid', 'torque'), 'Invalid union type' + self.which = which + + which: str = 'pid' + + pid: 'CarParams.LateralPIDTuning' = field(default_factory=lambda: CarParams.LateralPIDTuning()) + torque: 'CarParams.LateralTorqueTuning' = field(default_factory=lambda: CarParams.LateralTorqueTuning()) @dataclass @apply_auto_fields @@ -215,3 +232,60 @@ class CarParams: programmedFuelInjection = auto() debug = auto() + + +import typing + +# CP = CarParams() +# +# print(CP.lateralTuningWhich is CarParams.LateralPIDTuning) +# print(CP.lateralTuningWhich is CarParams.LateralTorqueTuning) +# print() +# CP.lateralTuningWhich = CarParams.LateralTorqueTuning +# print(CP.lateralTuningWhich is CarParams.LateralPIDTuning) +# print(CP.lateralTuningWhich is CarParams.LateralTorqueTuning) + +# +# +# T = typing.TypeVar('T') +# +# @dataclass +# class UnionField(typing.Generic[T]): +# value: T +# +# def init(self, value: T): +# self.value = value +# +# def which(self): +# return type(self.value).__name__ +# +# +# lateral_tuning: UnionField[CarParams.LateralTorqueTuning | CarParams.LateralPIDTuning] = UnionField(CarParams.LateralPIDTuning) +# +# lateral_tuning.init(CarParams().LateralPIDTuning()) +# if isinstance(lateral_tuning.value, CarParams.LateralTorqueTuning): +# print(lateral_tuning.value.useSteeringAngle) # This will work without mypy errors + +T = typing.TypeVar('T') + + +# def which(value: object, typ: typing.Type[T]) -> bool: +# return isinstance(value, typ) + +which = isinstance + + +lateralTuning: CarParams.LateralPIDTuning | CarParams.LateralTorqueTuning = field(default_factory=lambda: CarParams.LateralPIDTuning()) + +if isinstance(lateralTuning, CarParams.LateralTorqueTuning): + lateralTuning.useSteeringAngle = True + + + +# @dataclass +# @apply_auto_fields +# class UnionDataclass: +# which: str = 'pid' +# +# pid: CarParams.LateralPIDTuning = auto_field() +# torque: CarParams.LateralTorqueTuning = auto_field() diff --git a/selfdrive/car/interfaces.py b/selfdrive/car/interfaces.py index 4b56b26469..53d5100a4f 100644 --- a/selfdrive/car/interfaces.py +++ b/selfdrive/car/interfaces.py @@ -189,7 +189,6 @@ class CarInterfaceBase(ABC): # standard ALC params ret.tireStiffnessFactor = 1.0 ret.steerControlType = CarParams.SteerControlType.torque - ret.lateralParams = CarParams.LateralPIDTuning() ret.minSteerSpeed = 0. ret.wheelSpeedFactor = 1.0