|
|
@ -2,6 +2,7 @@ import capnp |
|
|
|
import hypothesis.strategies as st |
|
|
|
import hypothesis.strategies as st |
|
|
|
from typing import Any |
|
|
|
from typing import Any |
|
|
|
from collections.abc import Callable |
|
|
|
from collections.abc import Callable |
|
|
|
|
|
|
|
from functools import cache |
|
|
|
|
|
|
|
|
|
|
|
from cereal import log |
|
|
|
from cereal import log |
|
|
|
|
|
|
|
|
|
|
@ -11,67 +12,62 @@ DrawType = Callable[[st.SearchStrategy], Any] |
|
|
|
class FuzzyGenerator: |
|
|
|
class FuzzyGenerator: |
|
|
|
def __init__(self, draw: DrawType, real_floats: bool): |
|
|
|
def __init__(self, draw: DrawType, real_floats: bool): |
|
|
|
self.draw = draw |
|
|
|
self.draw = draw |
|
|
|
self.real_floats = real_floats |
|
|
|
self.native_type_map = FuzzyGenerator._get_native_type_map(real_floats) |
|
|
|
|
|
|
|
|
|
|
|
def generate_native_type(self, field: str) -> st.SearchStrategy[bool | int | float | str | bytes]: |
|
|
|
def generate_native_type(self, field: str) -> st.SearchStrategy[bool | int | float | str | bytes]: |
|
|
|
def floats(**kwargs) -> st.SearchStrategy[float]: |
|
|
|
value_func = self.native_type_map.get(field) |
|
|
|
allow_nan = not self.real_floats |
|
|
|
if value_func: |
|
|
|
allow_infinity = not self.real_floats |
|
|
|
return value_func |
|
|
|
return st.floats(**kwargs, allow_nan=allow_nan, allow_infinity=allow_infinity) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if field == 'bool': |
|
|
|
|
|
|
|
return st.booleans() |
|
|
|
|
|
|
|
elif field == 'int8': |
|
|
|
|
|
|
|
return st.integers(min_value=-2**7, max_value=2**7-1) |
|
|
|
|
|
|
|
elif field == 'int16': |
|
|
|
|
|
|
|
return st.integers(min_value=-2**15, max_value=2**15-1) |
|
|
|
|
|
|
|
elif field == 'int32': |
|
|
|
|
|
|
|
return st.integers(min_value=-2**31, max_value=2**31-1) |
|
|
|
|
|
|
|
elif field == 'int64': |
|
|
|
|
|
|
|
return st.integers(min_value=-2**63, max_value=2**63-1) |
|
|
|
|
|
|
|
elif field == 'uint8': |
|
|
|
|
|
|
|
return st.integers(min_value=0, max_value=2**8-1) |
|
|
|
|
|
|
|
elif field == 'uint16': |
|
|
|
|
|
|
|
return st.integers(min_value=0, max_value=2**16-1) |
|
|
|
|
|
|
|
elif field == 'uint32': |
|
|
|
|
|
|
|
return st.integers(min_value=0, max_value=2**32-1) |
|
|
|
|
|
|
|
elif field == 'uint64': |
|
|
|
|
|
|
|
return st.integers(min_value=0, max_value=2**64-1) |
|
|
|
|
|
|
|
elif field == 'float32': |
|
|
|
|
|
|
|
return floats(width=32) |
|
|
|
|
|
|
|
elif field == 'float64': |
|
|
|
|
|
|
|
return floats(width=64) |
|
|
|
|
|
|
|
elif field == 'text': |
|
|
|
|
|
|
|
return st.text(max_size=1000) |
|
|
|
|
|
|
|
elif field == 'data': |
|
|
|
|
|
|
|
return st.binary(max_size=1000) |
|
|
|
|
|
|
|
elif field == 'anyPointer': |
|
|
|
|
|
|
|
return st.text() |
|
|
|
|
|
|
|
else: |
|
|
|
else: |
|
|
|
raise NotImplementedError(f'Invalid type : {field}') |
|
|
|
raise NotImplementedError(f'Invalid type: {field}') |
|
|
|
|
|
|
|
|
|
|
|
def generate_field(self, field: capnp.lib.capnp._StructSchemaField) -> st.SearchStrategy: |
|
|
|
def generate_field(self, field: capnp.lib.capnp._StructSchemaField) -> st.SearchStrategy: |
|
|
|
def rec(field_type: capnp.lib.capnp._DynamicStructReader) -> st.SearchStrategy: |
|
|
|
def rec(field_type: capnp.lib.capnp._DynamicStructReader) -> st.SearchStrategy: |
|
|
|
if field_type.which() == 'struct': |
|
|
|
type_which = field_type.which() |
|
|
|
|
|
|
|
if type_which == 'struct': |
|
|
|
return self.generate_struct(field.schema.elementType if base_type == 'list' else field.schema) |
|
|
|
return self.generate_struct(field.schema.elementType if base_type == 'list' else field.schema) |
|
|
|
elif field_type.which() == 'list': |
|
|
|
elif type_which == 'list': |
|
|
|
return st.lists(rec(field_type.list.elementType)) |
|
|
|
return st.lists(rec(field_type.list.elementType)) |
|
|
|
elif field_type.which() == 'enum': |
|
|
|
elif type_which == 'enum': |
|
|
|
schema = field.schema.elementType if base_type == 'list' else field.schema |
|
|
|
schema = field.schema.elementType if base_type == 'list' else field.schema |
|
|
|
return st.sampled_from(list(schema.enumerants.keys())) |
|
|
|
return st.sampled_from(list(schema.enumerants.keys())) |
|
|
|
else: |
|
|
|
else: |
|
|
|
return self.generate_native_type(field_type.which()) |
|
|
|
return self.generate_native_type(type_which) |
|
|
|
|
|
|
|
|
|
|
|
if 'slot' in field.proto.to_dict(): |
|
|
|
try: |
|
|
|
base_type = field.proto.slot.type.which() |
|
|
|
if hasattr(field.proto, 'slot'): |
|
|
|
return rec(field.proto.slot.type) |
|
|
|
slot_type = field.proto.slot.type |
|
|
|
else: |
|
|
|
base_type = slot_type.which() |
|
|
|
|
|
|
|
return rec(slot_type) |
|
|
|
|
|
|
|
else: |
|
|
|
|
|
|
|
return self.generate_struct(field.schema) |
|
|
|
|
|
|
|
except capnp.lib.capnp.KjException: |
|
|
|
return self.generate_struct(field.schema) |
|
|
|
return self.generate_struct(field.schema) |
|
|
|
|
|
|
|
|
|
|
|
def generate_struct(self, schema: capnp.lib.capnp._StructSchema, event: str = None) -> st.SearchStrategy[dict[str, Any]]: |
|
|
|
def generate_struct(self, schema: capnp.lib.capnp._StructSchema, event: str = None) -> st.SearchStrategy[dict[str, Any]]: |
|
|
|
full_fill: list[str] = list(schema.non_union_fields) |
|
|
|
single_fill: tuple[str, ...] = (event,) if event else (self.draw(st.sampled_from(schema.union_fields)),) if schema.union_fields else () |
|
|
|
single_fill: list[str] = [event] if event else [self.draw(st.sampled_from(schema.union_fields))] if schema.union_fields else [] |
|
|
|
fields_to_generate = schema.non_union_fields + single_fill |
|
|
|
return st.fixed_dictionaries({field: self.generate_field(schema.fields[field]) for field in full_fill + single_fill}) |
|
|
|
return st.fixed_dictionaries({field: self.generate_field(schema.fields[field]) for field in fields_to_generate if not field.endswith('DEPRECATED')}) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod |
|
|
|
|
|
|
|
@cache |
|
|
|
|
|
|
|
def _get_native_type_map(real_floats: bool) -> dict[str, st.SearchStrategy]: |
|
|
|
|
|
|
|
return { |
|
|
|
|
|
|
|
'bool': st.booleans(), |
|
|
|
|
|
|
|
'int8': st.integers(min_value=-2**7, max_value=2**7-1), |
|
|
|
|
|
|
|
'int16': st.integers(min_value=-2**15, max_value=2**15-1), |
|
|
|
|
|
|
|
'int32': st.integers(min_value=-2**31, max_value=2**31-1), |
|
|
|
|
|
|
|
'int64': st.integers(min_value=-2**63, max_value=2**63-1), |
|
|
|
|
|
|
|
'uint8': st.integers(min_value=0, max_value=2**8-1), |
|
|
|
|
|
|
|
'uint16': st.integers(min_value=0, max_value=2**16-1), |
|
|
|
|
|
|
|
'uint32': st.integers(min_value=0, max_value=2**32-1), |
|
|
|
|
|
|
|
'uint64': st.integers(min_value=0, max_value=2**64-1), |
|
|
|
|
|
|
|
'float32': st.floats(width=32, allow_nan=not real_floats, allow_infinity=not real_floats), |
|
|
|
|
|
|
|
'float64': st.floats(width=64, allow_nan=not real_floats, allow_infinity=not real_floats), |
|
|
|
|
|
|
|
'text': st.text(max_size=1000), |
|
|
|
|
|
|
|
'data': st.binary(max_size=1000), |
|
|
|
|
|
|
|
'anyPointer': st.text(), # Note: No need to define a separate function for anyPointer |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
@classmethod |
|
|
|
@classmethod |
|
|
|
def get_random_msg(cls, draw: DrawType, struct: capnp.lib.capnp._StructModule, real_floats: bool = False) -> dict[str, Any]: |
|
|
|
def get_random_msg(cls, draw: DrawType, struct: capnp.lib.capnp._StructModule, real_floats: bool = False) -> dict[str, Any]: |
|
|
|