fuzzy_generation.py: improve performance (#32591)

* improve performance

* remove DEPRECATED

* formatting

* catch kjException

---------

Co-authored-by: Shane Smiskol <shane@smiskol.com>
pull/31315/head
Dean Lee 11 months ago committed by GitHub
parent ae375091db
commit 3a43f5d784
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 88
      selfdrive/test/fuzzy_generation.py

@ -2,6 +2,7 @@ import capnp
import hypothesis.strategies as st import hypothesis.strategies as st
from typing import Any from typing import Any
from collections.abc import Callable from collections.abc import Callable
from functools import cache
from cereal import log from cereal import log
@ -11,67 +12,62 @@ DrawType = Callable[[st.SearchStrategy], Any]
class FuzzyGenerator: class FuzzyGenerator:
def __init__(self, draw: DrawType, real_floats: bool): def __init__(self, draw: DrawType, real_floats: bool):
self.draw = draw self.draw = draw
self.real_floats = real_floats self.native_type_map = FuzzyGenerator._get_native_type_map(real_floats)
def generate_native_type(self, field: str) -> st.SearchStrategy[bool | int | float | str | bytes]: def generate_native_type(self, field: str) -> st.SearchStrategy[bool | int | float | str | bytes]:
def floats(**kwargs) -> st.SearchStrategy[float]: value_func = self.native_type_map.get(field)
allow_nan = not self.real_floats if value_func:
allow_infinity = not self.real_floats return value_func
return st.floats(**kwargs, allow_nan=allow_nan, allow_infinity=allow_infinity)
if field == 'bool':
return st.booleans()
elif field == 'int8':
return st.integers(min_value=-2**7, max_value=2**7-1)
elif field == 'int16':
return st.integers(min_value=-2**15, max_value=2**15-1)
elif field == 'int32':
return st.integers(min_value=-2**31, max_value=2**31-1)
elif field == 'int64':
return st.integers(min_value=-2**63, max_value=2**63-1)
elif field == 'uint8':
return st.integers(min_value=0, max_value=2**8-1)
elif field == 'uint16':
return st.integers(min_value=0, max_value=2**16-1)
elif field == 'uint32':
return st.integers(min_value=0, max_value=2**32-1)
elif field == 'uint64':
return st.integers(min_value=0, max_value=2**64-1)
elif field == 'float32':
return floats(width=32)
elif field == 'float64':
return floats(width=64)
elif field == 'text':
return st.text(max_size=1000)
elif field == 'data':
return st.binary(max_size=1000)
elif field == 'anyPointer':
return st.text()
else: else:
raise NotImplementedError(f'Invalid type : {field}') raise NotImplementedError(f'Invalid type: {field}')
def generate_field(self, field: capnp.lib.capnp._StructSchemaField) -> st.SearchStrategy: def generate_field(self, field: capnp.lib.capnp._StructSchemaField) -> st.SearchStrategy:
def rec(field_type: capnp.lib.capnp._DynamicStructReader) -> st.SearchStrategy: def rec(field_type: capnp.lib.capnp._DynamicStructReader) -> st.SearchStrategy:
if field_type.which() == 'struct': type_which = field_type.which()
if type_which == 'struct':
return self.generate_struct(field.schema.elementType if base_type == 'list' else field.schema) return self.generate_struct(field.schema.elementType if base_type == 'list' else field.schema)
elif field_type.which() == 'list': elif type_which == 'list':
return st.lists(rec(field_type.list.elementType)) return st.lists(rec(field_type.list.elementType))
elif field_type.which() == 'enum': elif type_which == 'enum':
schema = field.schema.elementType if base_type == 'list' else field.schema schema = field.schema.elementType if base_type == 'list' else field.schema
return st.sampled_from(list(schema.enumerants.keys())) return st.sampled_from(list(schema.enumerants.keys()))
else: else:
return self.generate_native_type(field_type.which()) return self.generate_native_type(type_which)
if 'slot' in field.proto.to_dict(): try:
base_type = field.proto.slot.type.which() if hasattr(field.proto, 'slot'):
return rec(field.proto.slot.type) slot_type = field.proto.slot.type
else: base_type = slot_type.which()
return rec(slot_type)
else:
return self.generate_struct(field.schema)
except capnp.lib.capnp.KjException:
return self.generate_struct(field.schema) return self.generate_struct(field.schema)
def generate_struct(self, schema: capnp.lib.capnp._StructSchema, event: str = None) -> st.SearchStrategy[dict[str, Any]]: def generate_struct(self, schema: capnp.lib.capnp._StructSchema, event: str = None) -> st.SearchStrategy[dict[str, Any]]:
full_fill: list[str] = list(schema.non_union_fields) single_fill: tuple[str, ...] = (event,) if event else (self.draw(st.sampled_from(schema.union_fields)),) if schema.union_fields else ()
single_fill: list[str] = [event] if event else [self.draw(st.sampled_from(schema.union_fields))] if schema.union_fields else [] fields_to_generate = schema.non_union_fields + single_fill
return st.fixed_dictionaries({field: self.generate_field(schema.fields[field]) for field in full_fill + single_fill}) return st.fixed_dictionaries({field: self.generate_field(schema.fields[field]) for field in fields_to_generate if not field.endswith('DEPRECATED')})
@staticmethod
@cache
def _get_native_type_map(real_floats: bool) -> dict[str, st.SearchStrategy]:
return {
'bool': st.booleans(),
'int8': st.integers(min_value=-2**7, max_value=2**7-1),
'int16': st.integers(min_value=-2**15, max_value=2**15-1),
'int32': st.integers(min_value=-2**31, max_value=2**31-1),
'int64': st.integers(min_value=-2**63, max_value=2**63-1),
'uint8': st.integers(min_value=0, max_value=2**8-1),
'uint16': st.integers(min_value=0, max_value=2**16-1),
'uint32': st.integers(min_value=0, max_value=2**32-1),
'uint64': st.integers(min_value=0, max_value=2**64-1),
'float32': st.floats(width=32, allow_nan=not real_floats, allow_infinity=not real_floats),
'float64': st.floats(width=64, allow_nan=not real_floats, allow_infinity=not real_floats),
'text': st.text(max_size=1000),
'data': st.binary(max_size=1000),
'anyPointer': st.text(), # Note: No need to define a separate function for anyPointer
}
@classmethod @classmethod
def get_random_msg(cls, draw: DrawType, struct: capnp.lib.capnp._StructModule, real_floats: bool = False) -> dict[str, Any]: def get_random_msg(cls, draw: DrawType, struct: capnp.lib.capnp._StructModule, real_floats: bool = False) -> dict[str, Any]:

Loading…
Cancel
Save