You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					
					
						
							116 lines
						
					
					
						
							3.1 KiB
						
					
					
				
			
		
		
	
	
							116 lines
						
					
					
						
							3.1 KiB
						
					
					
				import numpy as np
 | 
						|
import numpy.matlib
 | 
						|
import unittest
 | 
						|
import timeit
 | 
						|
 | 
						|
from common.kalman.ekf import EKF, SimpleSensor, FastEKF1D
 | 
						|
 | 
						|
class TestEKF(EKF):
 | 
						|
  def __init__(self, var_init, Q):
 | 
						|
    super(TestEKF, self).__init__(False)
 | 
						|
    self.identity = numpy.matlib.identity(2)
 | 
						|
    self.state = numpy.matlib.zeros((2, 1))
 | 
						|
    self.covar = self.identity * var_init
 | 
						|
 | 
						|
    self.process_noise = numpy.matlib.diag(Q)
 | 
						|
 | 
						|
  def calc_transfer_fun(self, dt):
 | 
						|
    tf = numpy.matlib.identity(2)
 | 
						|
    tf[0, 1] = dt
 | 
						|
    return tf, tf
 | 
						|
 | 
						|
 | 
						|
class EKFTest(unittest.TestCase):
 | 
						|
  def test_update_scalar(self):
 | 
						|
    ekf = TestEKF(1e3, [0.1, 1])
 | 
						|
    dt = 1. / 100
 | 
						|
 | 
						|
    sensor = SimpleSensor(0, 1, 2)
 | 
						|
    readings = map(sensor.read, np.arange(100, 300))
 | 
						|
 | 
						|
    for reading in readings:
 | 
						|
      ekf.update_scalar(reading)
 | 
						|
      ekf.predict(dt)
 | 
						|
 | 
						|
    np.testing.assert_allclose(ekf.state, [[300], [100]], 1e-4)
 | 
						|
    np.testing.assert_allclose(
 | 
						|
      ekf.covar,
 | 
						|
      np.asarray([[0.0563, 0.10278], [0.10278, 0.55779]]),
 | 
						|
      atol=1e-4)
 | 
						|
 | 
						|
  def test_unbiased(self):
 | 
						|
    ekf = TestEKF(1e3, [0., 0.])
 | 
						|
    dt = np.float64(1. / 100)
 | 
						|
 | 
						|
    sensor = SimpleSensor(0, 1, 2)
 | 
						|
    readings = map(sensor.read, np.arange(1000))
 | 
						|
 | 
						|
    for reading in readings:
 | 
						|
      ekf.update_scalar(reading)
 | 
						|
      ekf.predict(dt)
 | 
						|
 | 
						|
    np.testing.assert_allclose(ekf.state, [[1000.], [100.]], 1e-4)
 | 
						|
 | 
						|
 | 
						|
class FastEKF1DTest(unittest.TestCase):
 | 
						|
  def test_correctness(self):
 | 
						|
    dt = 1. / 100
 | 
						|
    reading = SimpleSensor(0, 1, 2).read(100)
 | 
						|
 | 
						|
    ekf = TestEKF(1e3, [0.1, 1])
 | 
						|
    fast_ekf = FastEKF1D(dt, 1e3, [0.1, 1])
 | 
						|
 | 
						|
    ekf.update_scalar(reading)
 | 
						|
    fast_ekf.update_scalar(reading)
 | 
						|
    self.assertAlmostEqual(ekf.state[0]   , fast_ekf.state[0])
 | 
						|
    self.assertAlmostEqual(ekf.state[1]   , fast_ekf.state[1])
 | 
						|
    self.assertAlmostEqual(ekf.covar[0, 0], fast_ekf.covar[0])
 | 
						|
    self.assertAlmostEqual(ekf.covar[0, 1], fast_ekf.covar[2])
 | 
						|
    self.assertAlmostEqual(ekf.covar[1, 1], fast_ekf.covar[1])
 | 
						|
 | 
						|
    ekf.predict(dt)
 | 
						|
    fast_ekf.predict(dt)
 | 
						|
    self.assertAlmostEqual(ekf.state[0]   , fast_ekf.state[0])
 | 
						|
    self.assertAlmostEqual(ekf.state[1]   , fast_ekf.state[1])
 | 
						|
    self.assertAlmostEqual(ekf.covar[0, 0], fast_ekf.covar[0])
 | 
						|
    self.assertAlmostEqual(ekf.covar[0, 1], fast_ekf.covar[2])
 | 
						|
    self.assertAlmostEqual(ekf.covar[1, 1], fast_ekf.covar[1])
 | 
						|
 | 
						|
  def test_speed(self):
 | 
						|
    setup = """
 | 
						|
import numpy as np
 | 
						|
from common.kalman.tests.test_ekf import TestEKF
 | 
						|
from common.kalman.ekf import SimpleSensor, FastEKF1D
 | 
						|
 | 
						|
dt = 1. / 100
 | 
						|
reading = SimpleSensor(0, 1, 2).read(100)
 | 
						|
 | 
						|
var_init, Q = 1e3, [0.1, 1]
 | 
						|
ekf = TestEKF(var_init, Q)
 | 
						|
fast_ekf = FastEKF1D(dt, var_init, Q)
 | 
						|
    """
 | 
						|
 | 
						|
    timeit.timeit("""
 | 
						|
ekf.update_scalar(reading)
 | 
						|
ekf.predict(dt)
 | 
						|
    """, setup=setup, number=1000)
 | 
						|
 | 
						|
    ekf_speed = timeit.timeit("""
 | 
						|
ekf.update_scalar(reading)
 | 
						|
ekf.predict(dt)
 | 
						|
    """, setup=setup, number=20000)
 | 
						|
 | 
						|
    timeit.timeit("""
 | 
						|
fast_ekf.update_scalar(reading)
 | 
						|
fast_ekf.predict(dt)
 | 
						|
    """, setup=setup, number=1000)
 | 
						|
 | 
						|
    fast_ekf_speed = timeit.timeit("""
 | 
						|
fast_ekf.update_scalar(reading)
 | 
						|
fast_ekf.predict(dt)
 | 
						|
    """, setup=setup, number=20000)
 | 
						|
 | 
						|
    assert fast_ekf_speed < ekf_speed / 4
 | 
						|
 | 
						|
if __name__ == "__main__":
 | 
						|
  unittest.main()
 | 
						|
 |