You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
94 lines
2.2 KiB
94 lines
2.2 KiB
import numpy as np
|
|
|
|
_DESC_FMT = """
|
|
{} (n={}):
|
|
MEAN={}
|
|
VAR={}
|
|
MIN={}
|
|
MAX={}
|
|
"""
|
|
|
|
class StatTracker():
|
|
def __init__(self, name):
|
|
self._name = name
|
|
self._mean = 0.
|
|
self._var = 0.
|
|
self._n = 0
|
|
self._min = -float("-inf")
|
|
self._max = -float("inf")
|
|
|
|
@property
|
|
def mean(self):
|
|
return self._mean
|
|
|
|
@property
|
|
def var(self):
|
|
return (self._n * self._var) / (self._n - 1.)
|
|
|
|
@property
|
|
def min(self):
|
|
return self._min
|
|
|
|
@property
|
|
def max(self):
|
|
return self._max
|
|
|
|
def update(self, samples):
|
|
# https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
|
|
data = samples.reshape(-1)
|
|
n_a = data.size
|
|
mean_a = np.mean(data)
|
|
var_a = np.var(data, ddof=0)
|
|
|
|
n_b = self._n
|
|
mean_b = self._mean
|
|
|
|
delta = mean_b - mean_a
|
|
m_a = var_a * (n_a - 1)
|
|
m_b = self._var * (n_b - 1)
|
|
m2 = m_a + m_b + delta**2 * n_a * n_b / (n_a + n_b)
|
|
|
|
self._var = m2 / (n_a + n_b)
|
|
self._mean = (n_a * mean_a + n_b * mean_b) / (n_a + n_b)
|
|
self._n = n_a + n_b
|
|
|
|
self._min = min(self._min, np.min(data))
|
|
self._max = max(self._max, np.max(data))
|
|
|
|
def __str__(self):
|
|
return _DESC_FMT.format(self._name, self._n, self._mean, self.var, self._min,
|
|
self._max)
|
|
|
|
# FIXME(mgraczyk): The variance computation does not work with 1 sample batches.
|
|
class VectorStatTracker(StatTracker):
|
|
def __init__(self, name, dim):
|
|
self._name = name
|
|
self._mean = np.zeros((dim, ))
|
|
self._var = np.zeros((dim, dim))
|
|
self._n = 0
|
|
self._min = np.full((dim, ), -float("-inf"))
|
|
self._max = np.full((dim, ), -float("inf"))
|
|
|
|
@property
|
|
def cov(self):
|
|
return self.var
|
|
|
|
def update(self, samples):
|
|
n_a = samples.shape[0]
|
|
mean_a = np.mean(samples, axis=0)
|
|
var_a = np.cov(samples, ddof=0, rowvar=False)
|
|
|
|
n_b = self._n
|
|
mean_b = self._mean
|
|
|
|
delta = mean_b - mean_a
|
|
m_a = var_a * (n_a - 1)
|
|
m_b = self._var * (n_b - 1)
|
|
m2 = m_a + m_b + delta**2 * n_a * n_b / (n_a + n_b)
|
|
|
|
self._var = m2 / (n_a + n_b)
|
|
self._mean = (n_a * mean_a + n_b * mean_b) / (n_a + n_b)
|
|
self._n = n_a + n_b
|
|
|
|
self._min = np.minimum(self._min, np.min(samples, axis=0))
|
|
self._max = np.maximum(self._max, np.max(samples, axis=0))
|
|
|