diff --git a/tools/sim/bridge/metadrive.py b/tools/sim/bridge/metadrive.py index 5814364be1..ec9c22b9be 100644 --- a/tools/sim/bridge/metadrive.py +++ b/tools/sim/bridge/metadrive.py @@ -10,6 +10,7 @@ from openpilot.tools.sim.lib.camerad import W, H def apply_metadrive_patches(): from metadrive.engine.core.engine_core import EngineCore from metadrive.engine.core.image_buffer import ImageBuffer + from metadrive.envs.metadrive_env import MetaDriveEnv from metadrive.obs.image_obs import ImageObservation # By default, metadrive won't try to use cuda images unless it's used as a sensor for vehicles, so patch that in @@ -29,6 +30,11 @@ def apply_metadrive_patches(): ImageObservation.observe = observe_patched + def arrive_destination_patch(self, vehicle): + return False + + MetaDriveEnv._is_arrive_destination = arrive_destination_patch + class MetaDriveWorld(World): def __init__(self, env, ticks_per_frame: float, dual_camera = False): @@ -91,6 +97,24 @@ class MetaDriveWorld(World): pass +def straight_block(length): + return { + "id": "S", + "pre_block_socket_index": 0, + "length": length + } + +def curve_block(length, angle=45, direction=0): + return { + "id": "C", + "pre_block_socket_index": 0, + "length": length, + "radius": length, + "angle": angle, + "dir": direction + } + + class MetaDriveBridge(SimulatorBridge): TICKS_PER_FRAME = 2 @@ -105,6 +129,7 @@ class MetaDriveBridge(SimulatorBridge): print("----------------------------------------------------------") from metadrive.component.sensors.rgb_camera import RGBCamera from metadrive.component.sensors.base_camera import _cuda_enable + from metadrive.component.map.pg_map import MapGenerateMethod from metadrive.envs.metadrive_env import MetaDriveEnv from panda3d.core import Vec3 @@ -151,6 +176,20 @@ class MetaDriveBridge(SimulatorBridge): on_continuous_line_done=False, crash_vehicle_done=False, crash_object_done=False, + map_config=dict( + type=MapGenerateMethod.PG_MAP_FILE, + config=[ + None, + straight_block(120), + curve_block(120, 90), + straight_block(120), + curve_block(120, 90), + straight_block(120), + curve_block(120, 90), + straight_block(120), + curve_block(120, 90), + ] + ) ) )