stash data_structures.py

pull/33208/head
Shane Smiskol 11 months ago
parent 965f692f2e
commit bc242022d0
  1. 183
      selfdrive/car/data_structures.py

@ -1,7 +1,8 @@
from dataclasses import dataclass, fields, is_dataclass from dataclasses import dataclass, fields, field, is_dataclass
from enum import Enum, StrEnum as _StrEnum, auto from enum import Enum, StrEnum as _StrEnum, auto
# from typing import Type, TypeVar # from typing import Type, TypeVar
from typing import Type, TypeVar, TYPE_CHECKING, Any, get_type_hints, get_origin from typing import TypeVar, TYPE_CHECKING, Any, get_type_hints, get_origin
from selfdrive.car.data_test_kinda_works_chatgpt import auto_field, apply_auto_fields
if TYPE_CHECKING: if TYPE_CHECKING:
from _typeshed import DataclassInstance from _typeshed import DataclassInstance
@ -10,6 +11,8 @@ if TYPE_CHECKING:
# #
# T = TypeVar('T', bound='Struct') # T = TypeVar('T', bound='Struct')
_FIELDS = '__dataclass_fields__'
class StrEnum(_StrEnum): class StrEnum(_StrEnum):
@staticmethod @staticmethod
@ -30,35 +33,35 @@ class StrEnum(_StrEnum):
T = TypeVar('T', bound='DataclassInstance') T = TypeVar('T', bound='DataclassInstance')
class Struct: # class Struct:
@classmethod # @classmethod
def new_message(cls: Type[T], **kwargs: Any) -> T: # def new_message(cls: type[T], **kwargs: Any) -> T:
if not is_dataclass(cls): # if not is_dataclass(cls):
raise TypeError(f"{cls.__name__} is not a dataclass") # raise TypeError(f"{cls.__name__} is not a dataclass")
#
init_values = {} # init_values = {}
type_hints = get_type_hints(cls) # type_hints = get_type_hints(cls)
print(type_hints) # print(type_hints)
for f in fields(cls): # for f in fields(cls):
field_type = type_hints[f.name] # field_type = type_hints[f.name]
print(f.name, f.type, field_type) # print(f.name, f.type, field_type)
print(issubclass(field_type, Enum)) # print(issubclass(field_type, Enum))
if issubclass(field_type, Enum): # if issubclass(field_type, Enum):
init_values[f.name] = kwargs.get(f.name, list(field_type)[0]) # init_values[f.name] = kwargs.get(f.name, list(field_type)[0])
# TODO: fix this # # TODO: fix this
# assert issubclass(init_values[f.name], type(field_type)), f"Expected {field_type} for {f.name}, got {type(init_values[f.name])}" # # assert issubclass(init_values[f.name], type(field_type)), f"Expected {field_type} for {f.name}, got {type(init_values[f.name])}"
else: # else:
# FIXME: typing check hack since mypy doesn't catch anything # # FIXME: typing check hack since mypy doesn't catch anything
init_values[f.name] = kwargs.get(f.name, field_type()) # init_values[f.name] = kwargs.get(f.name, field_type())
print('field_type', field_type, f.type) # print('field_type', field_type, f.type)
# TODO: this is so bad # # TODO: this is so bad
assert isinstance(init_values[f.name], get_origin(f.type) or f.type), f"Expected {field_type} for {f.name}, got {type(init_values[f.name])}" # assert isinstance(init_values[f.name], get_origin(f.type) or f.type), f"Expected {field_type} for {f.name}, got {type(init_values[f.name])}"
#
return cls(**init_values) # return cls(**init_values)
@dataclass @dataclass
class RadarData(Struct): class RadarData:
errors: list['Error'] errors: list['Error']
points: list['RadarPoint'] points: list['RadarPoint']
@ -68,7 +71,7 @@ class RadarData(Struct):
wrongConfig = auto() wrongConfig = auto()
@dataclass @dataclass
class RadarPoint(Struct): class RadarPoint:
trackId: int # no trackId reuse trackId: int # no trackId reuse
# these 3 are the minimum required # these 3 are the minimum required
@ -85,31 +88,33 @@ class RadarData(Struct):
@dataclass @dataclass
class CarParams(Struct): @apply_auto_fields
carName: str class CarParams:
carFingerprint: str carName: str = auto_field()
fuzzyFingerprint: bool carFingerprint: str = auto_field()
fuzzyFingerprint: bool = auto_field()
notCar: bool # flag for non-car robotics platforms notCar: bool = auto_field() # flag for non-car robotics platforms
carFw: list['CarFw'] carFw: list['CarParams.CarFw'] = auto_field()
class SteerControlType(StrEnum): class SteerControlType(StrEnum):
torque = auto() torque = auto()
angle = auto() angle = auto()
@dataclass @dataclass
class CarFw(Struct): @apply_auto_fields
ecu: 'CarParams.Ecu' class CarFw:
fwVersion: bytes ecu: 'CarParams.Ecu' = field(default_factory=lambda: CarParams.Ecu.unknown)
address: int fwVersion: bytes = auto_field()
subAddress: int address: int = auto_field()
responseAddress: int subAddress: int = auto_field()
request: list[bytes] responseAddress: int = auto_field()
brand: str request: list[bytes] = auto_field()
bus: int brand: str = auto_field()
logging: bool bus: int = auto_field()
obdMultiplexing: bool logging: bool = auto_field()
obdMultiplexing: bool = auto_field()
class Ecu(StrEnum): class Ecu(StrEnum):
eps = auto() eps = auto()
@ -144,10 +149,88 @@ class CarParams(Struct):
debug = auto() debug = auto()
# CP: CarParams = CarParams.new_message(carName='toyota', fuzzyFingerprint=123) # # CP: CarParams = CarParams.new_message(carName='toyota', fuzzyFingerprint=123)
# CP: CarParams = CarParams(carName='toyota', fuzzyFingerprint=123) # # CP: CarParams = CarParams(carName='toyota', fuzzyFingerprint=123)
#
# # import ast
#
#
# # test = ast.literal_eval('CarParams.CarFw')
#
# def mywrapper(cls):
#
# cls_annotations = cls.__dict__.get('__annotations__', {})
# fields = {}
# for name, _type in cls_annotations.items():
# f = field(default_factory=_type)
# setattr(cls, name, f)
# fields[name] = f
#
# setattr(cls, _FIELDS, fields)
#
# print('cls_annotations', cls_annotations)
# # cls.hi = 123
#
# return cls
#
#
# # def mywrapper2(cls):
# # class Test:
# # pass
# # return Test
#
#
# @dataclass
# class CarControl1:
# enabled: bool
#
# @dataclass
# class CarControl2:
# enabled: bool = field(default_factory=bool)
#
#
# # @mywrapper2
# @dataclass()
# @mywrapper
# class CarControl:
# # enabled: bool = field(default_factory=bool)
# enabled: bool = None
# pts: list[int] = None
# logMonoTime: int = None
#
#
# CC = CarControl()
import ast @dataclass
@apply_auto_fields
class CarControl:
enabled: bool = auto_field()
pts: list[int] = auto_field()
logMonoTime: int = auto_field()
test: None = auto_field()
# test = ast.literal_eval('CarParams.CarFw') # testing: if origin_typ in (int, float, str, bytes, list, tuple, set, dict, bool):
@dataclass
@apply_auto_fields
class Test997:
a: int = auto_field()
b: float = auto_field()
c: str = auto_field()
d: bytes = auto_field()
e: list[int] = auto_field()
f: tuple[int] = auto_field()
g: set[int] = auto_field()
h: dict[str, int] = auto_field()
i: bool = auto_field()
ecu: CarParams.Ecu = auto_field()
carFw: CarParams.CarFw = auto_field()
# Out[4]: Test997(a=0, b=0.0, c='', d=b'', e=[], f=(), g=set(), h={}, i=False)
CarControl()
CP = CarParams()
CP.carFw = [CarParams.CarFw()]
CP.carFw = [CarParams.Ecu.eps]

Loading…
Cancel
Save