fuzzy_generation.py: improve performance (#32591)

* improve performance

* remove DEPRECATED

* formatting

* catch kjException

---------

Co-authored-by: Shane Smiskol <shane@smiskol.com>
old-commit-hash: 3a43f5d784
097
Dean Lee 11 months ago committed by GitHub
parent 9778515344
commit 6f9e28d8f2
  1. 86
      selfdrive/test/fuzzy_generation.py

@ -2,6 +2,7 @@ import capnp
import hypothesis.strategies as st import hypothesis.strategies as st
from typing import Any from typing import Any
from collections.abc import Callable from collections.abc import Callable
from functools import cache
from cereal import log from cereal import log
@ -11,67 +12,62 @@ DrawType = Callable[[st.SearchStrategy], Any]
class FuzzyGenerator: class FuzzyGenerator:
def __init__(self, draw: DrawType, real_floats: bool): def __init__(self, draw: DrawType, real_floats: bool):
self.draw = draw self.draw = draw
self.real_floats = real_floats self.native_type_map = FuzzyGenerator._get_native_type_map(real_floats)
def generate_native_type(self, field: str) -> st.SearchStrategy[bool | int | float | str | bytes]: def generate_native_type(self, field: str) -> st.SearchStrategy[bool | int | float | str | bytes]:
def floats(**kwargs) -> st.SearchStrategy[float]: value_func = self.native_type_map.get(field)
allow_nan = not self.real_floats if value_func:
allow_infinity = not self.real_floats return value_func
return st.floats(**kwargs, allow_nan=allow_nan, allow_infinity=allow_infinity)
if field == 'bool':
return st.booleans()
elif field == 'int8':
return st.integers(min_value=-2**7, max_value=2**7-1)
elif field == 'int16':
return st.integers(min_value=-2**15, max_value=2**15-1)
elif field == 'int32':
return st.integers(min_value=-2**31, max_value=2**31-1)
elif field == 'int64':
return st.integers(min_value=-2**63, max_value=2**63-1)
elif field == 'uint8':
return st.integers(min_value=0, max_value=2**8-1)
elif field == 'uint16':
return st.integers(min_value=0, max_value=2**16-1)
elif field == 'uint32':
return st.integers(min_value=0, max_value=2**32-1)
elif field == 'uint64':
return st.integers(min_value=0, max_value=2**64-1)
elif field == 'float32':
return floats(width=32)
elif field == 'float64':
return floats(width=64)
elif field == 'text':
return st.text(max_size=1000)
elif field == 'data':
return st.binary(max_size=1000)
elif field == 'anyPointer':
return st.text()
else: else:
raise NotImplementedError(f'Invalid type : {field}') raise NotImplementedError(f'Invalid type: {field}')
def generate_field(self, field: capnp.lib.capnp._StructSchemaField) -> st.SearchStrategy: def generate_field(self, field: capnp.lib.capnp._StructSchemaField) -> st.SearchStrategy:
def rec(field_type: capnp.lib.capnp._DynamicStructReader) -> st.SearchStrategy: def rec(field_type: capnp.lib.capnp._DynamicStructReader) -> st.SearchStrategy:
if field_type.which() == 'struct': type_which = field_type.which()
if type_which == 'struct':
return self.generate_struct(field.schema.elementType if base_type == 'list' else field.schema) return self.generate_struct(field.schema.elementType if base_type == 'list' else field.schema)
elif field_type.which() == 'list': elif type_which == 'list':
return st.lists(rec(field_type.list.elementType)) return st.lists(rec(field_type.list.elementType))
elif field_type.which() == 'enum': elif type_which == 'enum':
schema = field.schema.elementType if base_type == 'list' else field.schema schema = field.schema.elementType if base_type == 'list' else field.schema
return st.sampled_from(list(schema.enumerants.keys())) return st.sampled_from(list(schema.enumerants.keys()))
else: else:
return self.generate_native_type(field_type.which()) return self.generate_native_type(type_which)
if 'slot' in field.proto.to_dict(): try:
base_type = field.proto.slot.type.which() if hasattr(field.proto, 'slot'):
return rec(field.proto.slot.type) slot_type = field.proto.slot.type
base_type = slot_type.which()
return rec(slot_type)
else: else:
return self.generate_struct(field.schema) return self.generate_struct(field.schema)
except capnp.lib.capnp.KjException:
return self.generate_struct(field.schema)
def generate_struct(self, schema: capnp.lib.capnp._StructSchema, event: str = None) -> st.SearchStrategy[dict[str, Any]]: def generate_struct(self, schema: capnp.lib.capnp._StructSchema, event: str = None) -> st.SearchStrategy[dict[str, Any]]:
full_fill: list[str] = list(schema.non_union_fields) single_fill: tuple[str, ...] = (event,) if event else (self.draw(st.sampled_from(schema.union_fields)),) if schema.union_fields else ()
single_fill: list[str] = [event] if event else [self.draw(st.sampled_from(schema.union_fields))] if schema.union_fields else [] fields_to_generate = schema.non_union_fields + single_fill
return st.fixed_dictionaries({field: self.generate_field(schema.fields[field]) for field in full_fill + single_fill}) return st.fixed_dictionaries({field: self.generate_field(schema.fields[field]) for field in fields_to_generate if not field.endswith('DEPRECATED')})
@staticmethod
@cache
def _get_native_type_map(real_floats: bool) -> dict[str, st.SearchStrategy]:
return {
'bool': st.booleans(),
'int8': st.integers(min_value=-2**7, max_value=2**7-1),
'int16': st.integers(min_value=-2**15, max_value=2**15-1),
'int32': st.integers(min_value=-2**31, max_value=2**31-1),
'int64': st.integers(min_value=-2**63, max_value=2**63-1),
'uint8': st.integers(min_value=0, max_value=2**8-1),
'uint16': st.integers(min_value=0, max_value=2**16-1),
'uint32': st.integers(min_value=0, max_value=2**32-1),
'uint64': st.integers(min_value=0, max_value=2**64-1),
'float32': st.floats(width=32, allow_nan=not real_floats, allow_infinity=not real_floats),
'float64': st.floats(width=64, allow_nan=not real_floats, allow_infinity=not real_floats),
'text': st.text(max_size=1000),
'data': st.binary(max_size=1000),
'anyPointer': st.text(), # Note: No need to define a separate function for anyPointer
}
@classmethod @classmethod
def get_random_msg(cls, draw: DrawType, struct: capnp.lib.capnp._StructModule, real_floats: bool = False) -> dict[str, Any]: def get_random_msg(cls, draw: DrawType, struct: capnp.lib.capnp._StructModule, real_floats: bool = False) -> dict[str, Any]:

Loading…
Cancel
Save