fuzzy_generation.py: improve performance (#32591)

* improve performance

* remove DEPRECATED

* formatting

* catch kjException

---------

Co-authored-by: Shane Smiskol <shane@smiskol.com>
pull/31315/head
Dean Lee 11 months ago committed by GitHub
parent ae375091db
commit 3a43f5d784
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 84
      selfdrive/test/fuzzy_generation.py

@ -2,6 +2,7 @@ import capnp
import hypothesis.strategies as st
from typing import Any
from collections.abc import Callable
from functools import cache
from cereal import log
@ -11,67 +12,62 @@ DrawType = Callable[[st.SearchStrategy], Any]
class FuzzyGenerator:
def __init__(self, draw: DrawType, real_floats: bool):
self.draw = draw
self.real_floats = real_floats
self.native_type_map = FuzzyGenerator._get_native_type_map(real_floats)
def generate_native_type(self, field: str) -> st.SearchStrategy[bool | int | float | str | bytes]:
def floats(**kwargs) -> st.SearchStrategy[float]:
allow_nan = not self.real_floats
allow_infinity = not self.real_floats
return st.floats(**kwargs, allow_nan=allow_nan, allow_infinity=allow_infinity)
if field == 'bool':
return st.booleans()
elif field == 'int8':
return st.integers(min_value=-2**7, max_value=2**7-1)
elif field == 'int16':
return st.integers(min_value=-2**15, max_value=2**15-1)
elif field == 'int32':
return st.integers(min_value=-2**31, max_value=2**31-1)
elif field == 'int64':
return st.integers(min_value=-2**63, max_value=2**63-1)
elif field == 'uint8':
return st.integers(min_value=0, max_value=2**8-1)
elif field == 'uint16':
return st.integers(min_value=0, max_value=2**16-1)
elif field == 'uint32':
return st.integers(min_value=0, max_value=2**32-1)
elif field == 'uint64':
return st.integers(min_value=0, max_value=2**64-1)
elif field == 'float32':
return floats(width=32)
elif field == 'float64':
return floats(width=64)
elif field == 'text':
return st.text(max_size=1000)
elif field == 'data':
return st.binary(max_size=1000)
elif field == 'anyPointer':
return st.text()
value_func = self.native_type_map.get(field)
if value_func:
return value_func
else:
raise NotImplementedError(f'Invalid type: {field}')
def generate_field(self, field: capnp.lib.capnp._StructSchemaField) -> st.SearchStrategy:
def rec(field_type: capnp.lib.capnp._DynamicStructReader) -> st.SearchStrategy:
if field_type.which() == 'struct':
type_which = field_type.which()
if type_which == 'struct':
return self.generate_struct(field.schema.elementType if base_type == 'list' else field.schema)
elif field_type.which() == 'list':
elif type_which == 'list':
return st.lists(rec(field_type.list.elementType))
elif field_type.which() == 'enum':
elif type_which == 'enum':
schema = field.schema.elementType if base_type == 'list' else field.schema
return st.sampled_from(list(schema.enumerants.keys()))
else:
return self.generate_native_type(field_type.which())
return self.generate_native_type(type_which)
if 'slot' in field.proto.to_dict():
base_type = field.proto.slot.type.which()
return rec(field.proto.slot.type)
try:
if hasattr(field.proto, 'slot'):
slot_type = field.proto.slot.type
base_type = slot_type.which()
return rec(slot_type)
else:
return self.generate_struct(field.schema)
except capnp.lib.capnp.KjException:
return self.generate_struct(field.schema)
def generate_struct(self, schema: capnp.lib.capnp._StructSchema, event: str = None) -> st.SearchStrategy[dict[str, Any]]:
full_fill: list[str] = list(schema.non_union_fields)
single_fill: list[str] = [event] if event else [self.draw(st.sampled_from(schema.union_fields))] if schema.union_fields else []
return st.fixed_dictionaries({field: self.generate_field(schema.fields[field]) for field in full_fill + single_fill})
single_fill: tuple[str, ...] = (event,) if event else (self.draw(st.sampled_from(schema.union_fields)),) if schema.union_fields else ()
fields_to_generate = schema.non_union_fields + single_fill
return st.fixed_dictionaries({field: self.generate_field(schema.fields[field]) for field in fields_to_generate if not field.endswith('DEPRECATED')})
@staticmethod
@cache
def _get_native_type_map(real_floats: bool) -> dict[str, st.SearchStrategy]:
return {
'bool': st.booleans(),
'int8': st.integers(min_value=-2**7, max_value=2**7-1),
'int16': st.integers(min_value=-2**15, max_value=2**15-1),
'int32': st.integers(min_value=-2**31, max_value=2**31-1),
'int64': st.integers(min_value=-2**63, max_value=2**63-1),
'uint8': st.integers(min_value=0, max_value=2**8-1),
'uint16': st.integers(min_value=0, max_value=2**16-1),
'uint32': st.integers(min_value=0, max_value=2**32-1),
'uint64': st.integers(min_value=0, max_value=2**64-1),
'float32': st.floats(width=32, allow_nan=not real_floats, allow_infinity=not real_floats),
'float64': st.floats(width=64, allow_nan=not real_floats, allow_infinity=not real_floats),
'text': st.text(max_size=1000),
'data': st.binary(max_size=1000),
'anyPointer': st.text(), # Note: No need to define a separate function for anyPointer
}
@classmethod
def get_random_msg(cls, draw: DrawType, struct: capnp.lib.capnp._StructModule, real_floats: bool = False) -> dict[str, Any]:

Loading…
Cancel
Save