Cleanup calibration code (#25119)

* First attempt

* worksish

* tests pass

* cleanup

* get rid of garbahe

* fix name

* Still used in xx

* add debug functions

* used

* Revert "used"

This reverts commit 276e8ebab06d2d4f0e9927ba32b7d8aca2bf88c3.

* Update ref
pull/25321/head
HaraldSchafer 3 years ago committed by GitHub
parent ee6570da4a
commit 772b748689
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 15
      common/transformations/camera.py
  2. 147
      common/transformations/model.py
  3. 6
      selfdrive/locationd/calibrationd.py
  4. 4
      selfdrive/modeld/SConscript
  5. 56
      selfdrive/modeld/modeld.cc
  6. 4
      selfdrive/test/process_replay/model_replay.py
  7. 2
      selfdrive/test/process_replay/ref_commit

@ -77,6 +77,7 @@ def get_view_frame_from_road_frame(roll, pitch, yaw, height):
return np.hstack((view_from_road, [[0], [height], [0]])) return np.hstack((view_from_road, [[0], [height], [0]]))
# aka 'extrinsic_matrix' # aka 'extrinsic_matrix'
def get_view_frame_from_calib_frame(roll, pitch, yaw, height): def get_view_frame_from_calib_frame(roll, pitch, yaw, height):
device_from_calib= orient.rot_from_euler([roll, pitch, yaw]) device_from_calib= orient.rot_from_euler([roll, pitch, yaw])
@ -94,12 +95,6 @@ def vp_from_ke(m):
return (m[0, 0]/m[2, 0], m[1, 0]/m[2, 0]) return (m[0, 0]/m[2, 0], m[1, 0]/m[2, 0])
def vp_from_rpy(rpy, intrinsics=fcam_intrinsics):
e = get_view_frame_from_road_frame(rpy[0], rpy[1], rpy[2], 1.22)
ke = np.dot(intrinsics, e)
return vp_from_ke(ke)
def roll_from_ke(m): def roll_from_ke(m):
# note: different from calibration.h/RollAnglefromKE: i think that one's just wrong # note: different from calibration.h/RollAnglefromKE: i think that one's just wrong
return np.arctan2(-(m[1, 0] - m[1, 1] * m[2, 0] / m[2, 1]), return np.arctan2(-(m[1, 0] - m[1, 1] * m[2, 0] / m[2, 1]),
@ -163,11 +158,3 @@ def img_from_device(pt_device):
pt_img = pt_view/pt_view[:, 2:3] pt_img = pt_view/pt_view[:, 2:3]
return pt_img.reshape(input_shape)[:, :2] return pt_img.reshape(input_shape)[:, :2]
def get_camera_frame_from_calib_frame(camera_frame_from_road_frame, intrinsics=fcam_intrinsics):
camera_frame_from_ground = camera_frame_from_road_frame[:, (0, 1, 3)]
calib_frame_from_ground = np.dot(intrinsics,
get_view_frame_from_road_frame(0, 0, 0, 1.22))[:, (0, 1, 3)]
ground_from_calib_frame = np.linalg.inv(calib_frame_from_ground)
camera_frame_from_calib_frame = np.dot(camera_frame_from_ground, ground_from_calib_frame)
return camera_frame_from_calib_frame

@ -1,10 +1,7 @@
import numpy as np import numpy as np
from common.transformations.camera import (FULL_FRAME_SIZE, from common.transformations.camera import (FULL_FRAME_SIZE,
FOCAL, get_view_frame_from_calib_frame)
get_view_frame_from_road_frame,
get_view_frame_from_calib_frame,
vp_from_ke)
# segnet # segnet
SEGNET_SIZE = (512, 384) SEGNET_SIZE = (512, 384)
@ -14,21 +11,6 @@ def get_segnet_frame_from_camera_frame(segnet_size=SEGNET_SIZE, full_frame_size=
[0.0, float(segnet_size[1]) / full_frame_size[1]]]) [0.0, float(segnet_size[1]) / full_frame_size[1]]])
segnet_frame_from_camera_frame = get_segnet_frame_from_camera_frame() # xx segnet_frame_from_camera_frame = get_segnet_frame_from_camera_frame() # xx
# model
MODEL_INPUT_SIZE = (320, 160)
MODEL_YUV_SIZE = (MODEL_INPUT_SIZE[0], MODEL_INPUT_SIZE[1] * 3 // 2)
MODEL_CX = MODEL_INPUT_SIZE[0] / 2.
MODEL_CY = 21.
model_fl = 728.0
model_height = 1.22
# canonical model transform
model_intrinsics = np.array([
[model_fl, 0.0, MODEL_CX],
[0.0, model_fl, MODEL_CY],
[0.0, 0.0, 1.0]])
# MED model # MED model
MEDMODEL_INPUT_SIZE = (512, 256) MEDMODEL_INPUT_SIZE = (512, 256)
@ -63,104 +45,73 @@ sbigmodel_intrinsics = np.array([
[0.0, sbigmodel_fl, 0.5 * (256 + MEDMODEL_CY)], [0.0, sbigmodel_fl, 0.5 * (256 + MEDMODEL_CY)],
[0.0, 0.0, 1.0]]) [0.0, 0.0, 1.0]])
model_frame_from_road_frame = np.dot(model_intrinsics,
get_view_frame_from_road_frame(0, 0, 0, model_height))
bigmodel_frame_from_road_frame = np.dot(bigmodel_intrinsics,
get_view_frame_from_road_frame(0, 0, 0, model_height))
bigmodel_frame_from_calib_frame = np.dot(bigmodel_intrinsics, bigmodel_frame_from_calib_frame = np.dot(bigmodel_intrinsics,
get_view_frame_from_calib_frame(0, 0, 0, 0)) get_view_frame_from_calib_frame(0, 0, 0, 0))
sbigmodel_frame_from_road_frame = np.dot(sbigmodel_intrinsics,
get_view_frame_from_road_frame(0, 0, 0, model_height))
sbigmodel_frame_from_calib_frame = np.dot(sbigmodel_intrinsics, sbigmodel_frame_from_calib_frame = np.dot(sbigmodel_intrinsics,
get_view_frame_from_calib_frame(0, 0, 0, 0)) get_view_frame_from_calib_frame(0, 0, 0, 0))
medmodel_frame_from_road_frame = np.dot(medmodel_intrinsics,
get_view_frame_from_road_frame(0, 0, 0, model_height))
medmodel_frame_from_calib_frame = np.dot(medmodel_intrinsics, medmodel_frame_from_calib_frame = np.dot(medmodel_intrinsics,
get_view_frame_from_calib_frame(0, 0, 0, 0)) get_view_frame_from_calib_frame(0, 0, 0, 0))
model_frame_from_bigmodel_frame = np.dot(model_intrinsics, np.linalg.inv(bigmodel_intrinsics))
medmodel_frame_from_bigmodel_frame = np.dot(medmodel_intrinsics, np.linalg.inv(bigmodel_intrinsics)) medmodel_frame_from_bigmodel_frame = np.dot(medmodel_intrinsics, np.linalg.inv(bigmodel_intrinsics))
# 'camera from model camera' ### This function mimics the update_calibration logic in modeld.cc
def get_model_height_transform(camera_frame_from_road_frame, height): ### Manually verified to give similar results to xx.uncommon.utils.transform_img
camera_frame_from_road_ground = np.dot(camera_frame_from_road_frame, np.array([ def get_warp_matrix(rpy_calib, wide_cam=False, big_model=False, tici=True):
[1, 0, 0], from common.transformations.orientation import rot_from_euler
[0, 1, 0], from common.transformations.camera import view_frame_from_device_frame, eon_fcam_intrinsics, tici_ecam_intrinsics, tici_fcam_intrinsics
[0, 0, 0],
[0, 0, 1],
]))
camera_frame_from_road_high = np.dot(camera_frame_from_road_frame, np.array([
[1, 0, 0],
[0, 1, 0],
[0, 0, height - model_height],
[0, 0, 1],
]))
road_high_from_camera_frame = np.linalg.inv(camera_frame_from_road_high)
high_camera_from_low_camera = np.dot(camera_frame_from_road_ground, road_high_from_camera_frame)
return high_camera_from_low_camera
# camera_frame_from_model_frame aka 'warp matrix'
# was: calibration.h/CalibrationTransform
def get_camera_frame_from_model_frame(camera_frame_from_road_frame, height=model_height, camera_fl=FOCAL):
vp = vp_from_ke(camera_frame_from_road_frame)
model_zoom = camera_fl / model_fl
model_camera_from_model_frame = np.array([
[model_zoom, 0.0, vp[0] - MODEL_CX * model_zoom],
[0.0, model_zoom, vp[1] - MODEL_CY * model_zoom],
[0.0, 0.0, 1.0],
])
# This function is super slow, so skip it if height is very close to canonical
# TODO: speed it up!
if abs(height - model_height) > 0.001:
camera_from_model_camera = get_model_height_transform(camera_frame_from_road_frame, height)
else:
camera_from_model_camera = np.eye(3)
return np.dot(camera_from_model_camera, model_camera_from_model_frame)
def get_camera_frame_from_medmodel_frame(camera_frame_from_road_frame):
camera_frame_from_ground = camera_frame_from_road_frame[:, (0, 1, 3)]
medmodel_frame_from_ground = medmodel_frame_from_road_frame[:, (0, 1, 3)]
ground_from_medmodel_frame = np.linalg.inv(medmodel_frame_from_ground) if tici and wide_cam:
camera_frame_from_medmodel_frame = np.dot(camera_frame_from_ground, ground_from_medmodel_frame) intrinsics = tici_ecam_intrinsics
elif tici:
intrinsics = tici_fcam_intrinsics
else:
intrinsics = eon_fcam_intrinsics
return camera_frame_from_medmodel_frame if big_model:
sbigmodel_from_calib = sbigmodel_frame_from_calib_frame[:, (0,1,2)]
calib_from_model = np.linalg.inv(sbigmodel_from_calib)
else:
medmodel_from_calib = medmodel_frame_from_calib_frame[:, (0,1,2)]
calib_from_model = np.linalg.inv(medmodel_from_calib)
device_from_calib = rot_from_euler(rpy_calib)
camera_from_calib = intrinsics.dot(view_frame_from_device_frame.dot(device_from_calib))
warp_matrix = camera_from_calib.dot(calib_from_model)
return warp_matrix
def get_camera_frame_from_bigmodel_frame(camera_frame_from_road_frame): ### This is old, just for debugging
camera_frame_from_ground = camera_frame_from_road_frame[:, (0, 1, 3)] def get_warp_matrix_old(rpy_calib, wide_cam=False, big_model=False, tici=True):
bigmodel_frame_from_ground = bigmodel_frame_from_road_frame[:, (0, 1, 3)] from common.transformations.orientation import rot_from_euler
from common.transformations.camera import view_frame_from_device_frame, eon_fcam_intrinsics, tici_ecam_intrinsics, tici_fcam_intrinsics
ground_from_bigmodel_frame = np.linalg.inv(bigmodel_frame_from_ground)
camera_frame_from_bigmodel_frame = np.dot(camera_frame_from_ground, ground_from_bigmodel_frame)
return camera_frame_from_bigmodel_frame def get_view_frame_from_road_frame(roll, pitch, yaw, height):
device_from_road = rot_from_euler([roll, pitch, yaw]).dot(np.diag([1, -1, -1]))
view_from_road = view_frame_from_device_frame.dot(device_from_road)
return np.hstack((view_from_road, [[0], [height], [0]]))
if tici and wide_cam:
intrinsics = tici_ecam_intrinsics
elif tici:
intrinsics = tici_fcam_intrinsics
else:
intrinsics = eon_fcam_intrinsics
def get_model_frame(snu_full, camera_frame_from_model_frame, size): model_height = 1.22
idxs = camera_frame_from_model_frame.dot(np.column_stack([np.tile(np.arange(size[0]), size[1]), if big_model:
np.tile(np.arange(size[1]), (size[0], 1)).T.flatten(), model_from_road = np.dot(sbigmodel_intrinsics,
np.ones(size[0] * size[1])]).T).T.astype(int) get_view_frame_from_road_frame(0, 0, 0, model_height))
calib_flat = snu_full[idxs[:, 1], idxs[:, 0]]
if len(snu_full.shape) == 3:
calib = calib_flat.reshape((size[1], size[0], 3))
elif len(snu_full.shape) == 2:
calib = calib_flat.reshape((size[1], size[0]))
else: else:
raise ValueError("shape of input img is weird") model_from_road = np.dot(medmodel_intrinsics,
return calib get_view_frame_from_road_frame(0, 0, 0, model_height))
ground_from_model = np.linalg.inv(model_from_road[:, (0, 1, 3)])
E = get_view_frame_from_road_frame(*rpy_calib, 1.22)
camera_frame_from_road_frame = intrinsics.dot(E)
camera_frame_from_ground = camera_frame_from_road_frame[:,(0,1,3)]
warp_matrix = camera_frame_from_ground .dot(ground_from_model)
return warp_matrix

@ -17,8 +17,6 @@ import cereal.messaging as messaging
from common.conversions import Conversions as CV from common.conversions import Conversions as CV
from common.params import Params, put_nonblocking from common.params import Params, put_nonblocking
from common.realtime import set_realtime_priority from common.realtime import set_realtime_priority
from common.transformations.model import model_height
from common.transformations.camera import get_view_frame_from_road_frame
from common.transformations.orientation import rot_from_euler, euler_from_rot from common.transformations.orientation import rot_from_euler, euler_from_rot
from system.swaglog import cloudlog from system.swaglog import cloudlog
@ -180,7 +178,6 @@ class Calibrator:
def get_msg(self) -> capnp.lib.capnp._DynamicStructBuilder: def get_msg(self) -> capnp.lib.capnp._DynamicStructBuilder:
smooth_rpy = self.get_smooth_rpy() smooth_rpy = self.get_smooth_rpy()
extrinsic_matrix = get_view_frame_from_road_frame(0, smooth_rpy[1], smooth_rpy[2], model_height)
msg = messaging.new_message('liveCalibration') msg = messaging.new_message('liveCalibration')
liveCalibration = msg.liveCalibration liveCalibration = msg.liveCalibration
@ -188,16 +185,13 @@ class Calibrator:
liveCalibration.validBlocks = self.valid_blocks liveCalibration.validBlocks = self.valid_blocks
liveCalibration.calStatus = self.cal_status liveCalibration.calStatus = self.cal_status
liveCalibration.calPerc = min(100 * (self.valid_blocks * BLOCK_SIZE + self.idx) // (INPUTS_NEEDED * BLOCK_SIZE), 100) liveCalibration.calPerc = min(100 * (self.valid_blocks * BLOCK_SIZE + self.idx) // (INPUTS_NEEDED * BLOCK_SIZE), 100)
liveCalibration.extrinsicMatrix = extrinsic_matrix.flatten().tolist()
liveCalibration.rpyCalib = smooth_rpy.tolist() liveCalibration.rpyCalib = smooth_rpy.tolist()
liveCalibration.rpyCalibSpread = self.calib_spread.tolist() liveCalibration.rpyCalibSpread = self.calib_spread.tolist()
if self.not_car: if self.not_car:
extrinsic_matrix = get_view_frame_from_road_frame(0, 0, 0, model_height)
liveCalibration.validBlocks = INPUTS_NEEDED liveCalibration.validBlocks = INPUTS_NEEDED
liveCalibration.calStatus = Calibration.CALIBRATED liveCalibration.calStatus = Calibration.CALIBRATED
liveCalibration.calPerc = 100. liveCalibration.calPerc = 100.
liveCalibration.extrinsicMatrix = extrinsic_matrix.flatten().tolist()
liveCalibration.rpyCalib = [0, 0, 0] liveCalibration.rpyCalib = [0, 0, 0]
liveCalibration.rpyCalibSpread = self.calib_spread.tolist() liveCalibration.rpyCalibSpread = self.calib_spread.tolist()

@ -1,6 +1,6 @@
import os import os
Import('env', 'arch', 'cereal', 'messaging', 'common', 'gpucommon', 'visionipc') Import('env', 'arch', 'cereal', 'messaging', 'common', 'gpucommon', 'visionipc', 'transformations')
lenv = env.Clone() lenv = env.Clone()
libs = [cereal, messaging, common, visionipc, gpucommon, libs = [cereal, messaging, common, visionipc, gpucommon,
@ -82,4 +82,4 @@ lenv.Program('_dmonitoringmodeld', [
lenv.Program('_modeld', [ lenv.Program('_modeld', [
"modeld.cc", "modeld.cc",
"models/driving.cc", "models/driving.cc",
]+common_model, LIBS=libs) ]+common_model, LIBS=libs + transformations)

@ -6,6 +6,8 @@
#include <eigen3/Eigen/Dense> #include <eigen3/Eigen/Dense>
#include "cereal/messaging/messaging.h" #include "cereal/messaging/messaging.h"
#include "common/transformations/orientation.hpp"
#include "cereal/visionipc/visionipc_client.h" #include "cereal/visionipc/visionipc_client.h"
#include "common/clutil.h" #include "common/clutil.h"
#include "common/params.h" #include "common/params.h"
@ -14,40 +16,43 @@
#include "system/hardware/hw.h" #include "system/hardware/hw.h"
#include "selfdrive/modeld/models/driving.h" #include "selfdrive/modeld/models/driving.h"
ExitHandler do_exit; ExitHandler do_exit;
mat3 update_calibration(Eigen::Matrix<float, 3, 4> &extrinsics, bool wide_camera, bool bigmodel_frame) { mat3 update_calibration(Eigen::Vector3d device_from_calib_euler, bool wide_camera, bool bigmodel_frame) {
/* /*
import numpy as np import numpy as np
from common.transformations.model import medmodel_frame_from_road_frame from common.transformations.model import medmodel_frame_from_calib_frame
medmodel_frame_from_ground = medmodel_frame_from_road_frame[:, (0, 1, 3)] medmodel_frame_from_calib_frame = medmodel_frame_from_calib_frame[:, :3]
ground_from_medmodel_frame = np.linalg.inv(medmodel_frame_from_ground) calib_from_smedmodel_frame = np.linalg.inv(medmodel_frame_from_calib_frame)
*/ */
static const auto ground_from_medmodel_frame = (Eigen::Matrix<float, 3, 3>() << static const auto calib_from_medmodel = (Eigen::Matrix<float, 3, 3>() <<
0.00000000e+00, 0.00000000e+00, 1.00000000e+00, 0.00000000e+00, 0.00000000e+00, 1.00000000e+00,
-1.09890110e-03, 0.00000000e+00, 2.81318681e-01, 1.09890110e-03, 0.00000000e+00, -2.81318681e-01,
-1.84808520e-20, 9.00738606e-04, -4.28751576e-02).finished(); -2.25466395e-20, 1.09890110e-03,-5.23076923e-02).finished();
static const auto ground_from_sbigmodel_frame = (Eigen::Matrix<float, 3, 3>() << static const auto calib_from_sbigmodel = (Eigen::Matrix<float, 3, 3>() <<
0.00000000e+00, 7.31372216e-19, 1.00000000e+00, 0.00000000e+00, 7.31372216e-19, 1.00000000e+00,
-2.19780220e-03, 4.11497335e-19, 5.62637363e-01, 2.19780220e-03, 4.11497335e-19, -5.62637363e-01,
-5.46146580e-20, 1.80147721e-03, -2.73464241e-01).finished(); -6.66298828e-20, 2.19780220e-03, -3.33626374e-01).finished();
const auto cam_intrinsics = Eigen::Matrix<float, 3, 3, Eigen::RowMajor>(wide_camera ? ecam_intrinsic_matrix.v : fcam_intrinsic_matrix.v); static const auto view_from_device = (Eigen::Matrix<float, 3, 3>() <<
static const mat3 yuv_transform = get_model_yuv_transform(); 0.0, 1.0, 0.0,
0.0, 0.0, 1.0,
1.0, 0.0, 0.0).finished();
auto ground_from_model_frame = bigmodel_frame ? ground_from_sbigmodel_frame : ground_from_medmodel_frame;
auto camera_frame_from_road_frame = cam_intrinsics * extrinsics;
Eigen::Matrix<float, 3, 3> camera_frame_from_ground;
camera_frame_from_ground.col(0) = camera_frame_from_road_frame.col(0);
camera_frame_from_ground.col(1) = camera_frame_from_road_frame.col(1);
camera_frame_from_ground.col(2) = camera_frame_from_road_frame.col(3);
auto warp_matrix = camera_frame_from_ground * ground_from_model_frame; const auto cam_intrinsics = Eigen::Matrix<float, 3, 3, Eigen::RowMajor>(wide_camera ? ecam_intrinsic_matrix.v : fcam_intrinsic_matrix.v);
Eigen::Matrix<float, 3, 3, Eigen::RowMajor> device_from_calib = euler2rot(device_from_calib_euler).cast <float> ();
auto calib_from_model = bigmodel_frame ? calib_from_sbigmodel : calib_from_medmodel;
auto camera_from_calib = cam_intrinsics * view_from_device * device_from_calib;
auto warp_matrix = camera_from_calib * calib_from_model;
mat3 transform = {}; mat3 transform = {};
for (int i=0; i<3*3; i++) { for (int i=0; i<3*3; i++) {
transform.v[i] = warp_matrix(i / 3, i % 3); transform.v[i] = warp_matrix(i / 3, i % 3);
} }
static const mat3 yuv_transform = get_model_yuv_transform();
return matmul3(yuv_transform, transform); return matmul3(yuv_transform, transform);
} }
@ -114,14 +119,13 @@ void run_model(ModelState &model, VisionIpcClient &vipc_client_main, VisionIpcCl
bool is_rhd = ((bool)sm["driverMonitoringState"].getDriverMonitoringState().getIsRHD()); bool is_rhd = ((bool)sm["driverMonitoringState"].getDriverMonitoringState().getIsRHD());
frame_id = sm["roadCameraState"].getRoadCameraState().getFrameId(); frame_id = sm["roadCameraState"].getRoadCameraState().getFrameId();
if (sm.updated("liveCalibration")) { if (sm.updated("liveCalibration")) {
auto extrinsic_matrix = sm["liveCalibration"].getLiveCalibration().getExtrinsicMatrix(); auto rpy_calib = sm["liveCalibration"].getLiveCalibration().getRpyCalib();
Eigen::Matrix<float, 3, 4> extrinsic_matrix_eigen; Eigen::Vector3d device_from_calib_euler;
for (int i = 0; i < 4*3; i++) { for (int i=0; i<3; i++) {
extrinsic_matrix_eigen(i / 4, i % 4) = extrinsic_matrix[i]; device_from_calib_euler(i) = rpy_calib[i];
} }
model_transform_main = update_calibration(device_from_calib_euler, main_wide_camera, false);
model_transform_main = update_calibration(extrinsic_matrix_eigen, main_wide_camera, false); model_transform_extra = update_calibration(device_from_calib_euler, true, true);
model_transform_extra = update_calibration(extrinsic_matrix_eigen, true, true);
live_calib_seen = true; live_calib_seen = true;
} }

@ -10,7 +10,7 @@ import cereal.messaging as messaging
from cereal.visionipc import VisionIpcServer, VisionStreamType from cereal.visionipc import VisionIpcServer, VisionStreamType
from common.spinner import Spinner from common.spinner import Spinner
from common.timeout import Timeout from common.timeout import Timeout
from common.transformations.camera import get_view_frame_from_road_frame, tici_f_frame_size, tici_d_frame_size from common.transformations.camera import tici_f_frame_size, tici_d_frame_size
from system.hardware import PC from system.hardware import PC
from selfdrive.manager.process_config import managed_processes from selfdrive.manager.process_config import managed_processes
from selfdrive.test.openpilotci import BASE_URL, get_url from selfdrive.test.openpilotci import BASE_URL, get_url
@ -35,7 +35,7 @@ def get_log_fn(ref_commit, test_route):
def replace_calib(msg, calib): def replace_calib(msg, calib):
msg = msg.as_builder() msg = msg.as_builder()
if calib is not None: if calib is not None:
msg.liveCalibration.extrinsicMatrix = get_view_frame_from_road_frame(*calib, 1.22).flatten().tolist() msg.liveCalibration.rpyCalib = calib.tolist()
return msg return msg

@ -1 +1 @@
20c140b10eef52b6f6d6b9e142ed944264865bac 6d9bd0e80ccdf39827bded1883adbc922224bdf1

Loading…
Cancel
Save