You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
			
				
					117 lines
				
				3.1 KiB
			
		
		
			
		
	
	
					117 lines
				
				3.1 KiB
			| 
								 
											7 years ago
										 
									 | 
							
								import numpy as np
							 | 
						||
| 
								 | 
							
								import numpy.matlib
							 | 
						||
| 
								 | 
							
								import unittest
							 | 
						||
| 
								 | 
							
								import timeit
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								from common.kalman.ekf import EKF, SimpleSensor, FastEKF1D
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								class TestEKF(EKF):
							 | 
						||
| 
								 | 
							
								  def __init__(self, var_init, Q):
							 | 
						||
| 
								 | 
							
								    super(TestEKF, self).__init__(False)
							 | 
						||
| 
								 | 
							
								    self.identity = numpy.matlib.identity(2)
							 | 
						||
| 
								 | 
							
								    self.state = numpy.matlib.zeros((2, 1))
							 | 
						||
| 
								 | 
							
								    self.covar = self.identity * var_init
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    self.process_noise = numpy.matlib.diag(Q)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								  def calc_transfer_fun(self, dt):
							 | 
						||
| 
								 | 
							
								    tf = numpy.matlib.identity(2)
							 | 
						||
| 
								 | 
							
								    tf[0, 1] = dt
							 | 
						||
| 
								 | 
							
								    return tf, tf
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								class EKFTest(unittest.TestCase):
							 | 
						||
| 
								 | 
							
								  def test_update_scalar(self):
							 | 
						||
| 
								 | 
							
								    ekf = TestEKF(1e3, [0.1, 1])
							 | 
						||
| 
								 | 
							
								    dt = 1. / 100
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    sensor = SimpleSensor(0, 1, 2)
							 | 
						||
| 
								 | 
							
								    readings = map(sensor.read, np.arange(100, 300))
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    for reading in readings:
							 | 
						||
| 
								 | 
							
								      ekf.update_scalar(reading)
							 | 
						||
| 
								 | 
							
								      ekf.predict(dt)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    np.testing.assert_allclose(ekf.state, [[300], [100]], 1e-4)
							 | 
						||
| 
								 | 
							
								    np.testing.assert_allclose(
							 | 
						||
| 
								 | 
							
								      ekf.covar,
							 | 
						||
| 
								 | 
							
								      np.asarray([[0.0563, 0.10278], [0.10278, 0.55779]]),
							 | 
						||
| 
								 | 
							
								      atol=1e-4)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								  def test_unbiased(self):
							 | 
						||
| 
								 | 
							
								    ekf = TestEKF(1e3, [0., 0.])
							 | 
						||
| 
								 | 
							
								    dt = np.float64(1. / 100)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    sensor = SimpleSensor(0, 1, 2)
							 | 
						||
| 
								 | 
							
								    readings = map(sensor.read, np.arange(1000))
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    for reading in readings:
							 | 
						||
| 
								 | 
							
								      ekf.update_scalar(reading)
							 | 
						||
| 
								 | 
							
								      ekf.predict(dt)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    np.testing.assert_allclose(ekf.state, [[1000.], [100.]], 1e-4)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								class FastEKF1DTest(unittest.TestCase):
							 | 
						||
| 
								 | 
							
								  def test_correctness(self):
							 | 
						||
| 
								 | 
							
								    dt = 1. / 100
							 | 
						||
| 
								 | 
							
								    reading = SimpleSensor(0, 1, 2).read(100)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    ekf = TestEKF(1e3, [0.1, 1])
							 | 
						||
| 
								 | 
							
								    fast_ekf = FastEKF1D(dt, 1e3, [0.1, 1])
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    ekf.update_scalar(reading)
							 | 
						||
| 
								 | 
							
								    fast_ekf.update_scalar(reading)
							 | 
						||
| 
								 | 
							
								    self.assertAlmostEqual(ekf.state[0]   , fast_ekf.state[0])
							 | 
						||
| 
								 | 
							
								    self.assertAlmostEqual(ekf.state[1]   , fast_ekf.state[1])
							 | 
						||
| 
								 | 
							
								    self.assertAlmostEqual(ekf.covar[0, 0], fast_ekf.covar[0])
							 | 
						||
| 
								 | 
							
								    self.assertAlmostEqual(ekf.covar[0, 1], fast_ekf.covar[2])
							 | 
						||
| 
								 | 
							
								    self.assertAlmostEqual(ekf.covar[1, 1], fast_ekf.covar[1])
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    ekf.predict(dt)
							 | 
						||
| 
								 | 
							
								    fast_ekf.predict(dt)
							 | 
						||
| 
								 | 
							
								    self.assertAlmostEqual(ekf.state[0]   , fast_ekf.state[0])
							 | 
						||
| 
								 | 
							
								    self.assertAlmostEqual(ekf.state[1]   , fast_ekf.state[1])
							 | 
						||
| 
								 | 
							
								    self.assertAlmostEqual(ekf.covar[0, 0], fast_ekf.covar[0])
							 | 
						||
| 
								 | 
							
								    self.assertAlmostEqual(ekf.covar[0, 1], fast_ekf.covar[2])
							 | 
						||
| 
								 | 
							
								    self.assertAlmostEqual(ekf.covar[1, 1], fast_ekf.covar[1])
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								  def test_speed(self):
							 | 
						||
| 
								 | 
							
								    setup = """
							 | 
						||
| 
								 | 
							
								import numpy as np
							 | 
						||
| 
								 | 
							
								from common.kalman.tests.test_ekf import TestEKF
							 | 
						||
| 
								 | 
							
								from common.kalman.ekf import SimpleSensor, FastEKF1D
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								dt = 1. / 100
							 | 
						||
| 
								 | 
							
								reading = SimpleSensor(0, 1, 2).read(100)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								var_init, Q = 1e3, [0.1, 1]
							 | 
						||
| 
								 | 
							
								ekf = TestEKF(var_init, Q)
							 | 
						||
| 
								 | 
							
								fast_ekf = FastEKF1D(dt, var_init, Q)
							 | 
						||
| 
								 | 
							
								    """
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    timeit.timeit("""
							 | 
						||
| 
								 | 
							
								ekf.update_scalar(reading)
							 | 
						||
| 
								 | 
							
								ekf.predict(dt)
							 | 
						||
| 
								 | 
							
								    """, setup=setup, number=1000)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    ekf_speed = timeit.timeit("""
							 | 
						||
| 
								 | 
							
								ekf.update_scalar(reading)
							 | 
						||
| 
								 | 
							
								ekf.predict(dt)
							 | 
						||
| 
								 | 
							
								    """, setup=setup, number=20000)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    timeit.timeit("""
							 | 
						||
| 
								 | 
							
								fast_ekf.update_scalar(reading)
							 | 
						||
| 
								 | 
							
								fast_ekf.predict(dt)
							 | 
						||
| 
								 | 
							
								    """, setup=setup, number=1000)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    fast_ekf_speed = timeit.timeit("""
							 | 
						||
| 
								 | 
							
								fast_ekf.update_scalar(reading)
							 | 
						||
| 
								 | 
							
								fast_ekf.predict(dt)
							 | 
						||
| 
								 | 
							
								    """, setup=setup, number=20000)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    assert fast_ekf_speed < ekf_speed / 4
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								if __name__ == "__main__":
							 | 
						||
| 
								 | 
							
								  unittest.main()
							 |