openpilot is an open source driver assistance system. openpilot performs the functions of Automated Lane Centering and Adaptive Cruise Control for over 200 supported car makes and models.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

310 lines
14 KiB

import time, random, unittest, itertools
from unittest.mock import patch
from io import StringIO
from collections import namedtuple
from tqdm import tqdm
from tinygrad.helpers import tqdm as tinytqdm, trange as tinytrange
import numpy as np
def _get_iter_per_second(raw:str) -> float:
# raw might have unit scale
if raw.endswith("k"): return float(raw[:-1])*1e3
if raw.endswith("M"): return float(raw[:-1])*1e6
return float(raw)
# TODO: _get_iter_per_second in test_unit_scale might fail if lower bound is too small
NCOLS_RANGE = [80, 240]
class TestProgressBar(unittest.TestCase):
def _compare_bars(self, bar1, bar2):
prefix1, prog1, suffix1 = bar1.split("|")
prefix2, prog2, suffix2 = bar2.split("|")
self.assertEqual(len(bar1), len(bar2))
self.assertEqual(prefix1, prefix2)
def parse_timer(timer): return sum(int(x) * y for x, y in zip(timer.split(':')[::-1], (1, 60, 3600)))
if "?" not in suffix1 and "?" not in suffix2:
# allow for few sec diff in timers (removes flakiness)
timer1, rm1 = [parse_timer(timer) for timer in suffix1.split("[")[-1].split(",")[0].split("<")]
timer2, rm2 = [parse_timer(timer) for timer in suffix2.split("[")[-1].split(",")[0].split("<")]
np.testing.assert_allclose(timer1, timer2, atol=5, rtol=1e-2)
np.testing.assert_allclose(rm1, rm2, atol=5, rtol=1e-2)
# get suffix without timers
suffix1 = suffix1.split("[")[0] + suffix1.split(",")[1]
suffix2 = suffix2.split("[")[0] + suffix2.split(",")[1]
self.assertEqual(suffix1, suffix2)
else:
self.assertEqual(suffix1, suffix2)
diff = sum([c1 != c2 for c1, c2 in zip(prog1, prog2)]) # allow 1 char diff to be less flaky, but it should match
assert diff <= 1, f"{diff=}\n{prog1=}\n{prog2=}"
@patch('sys.stderr', new_callable=StringIO)
@patch('shutil.get_terminal_size')
def test_tqdm_output_iter(self, mock_terminal_size, mock_stderr):
for _ in range(10):
total, ncols = random.randint(5, 30), random.randint(*NCOLS_RANGE)
mock_terminal_size.return_value = namedtuple(field_names='columns', typename='terminal_size')(ncols)
mock_stderr.truncate(0)
# compare bars at each iteration (only when tinytqdm bar has been updated)
for n in (bar := tinytqdm(range(total), desc="Test")):
if bar.i % bar.skip != 0: continue
tinytqdm_output = mock_stderr.getvalue().split("\r")[-1].rstrip()
iters_per_sec = float(tinytqdm_output.split("it/s")[-2].split(" ")[-1]) if n>0 else 0
elapsed = n/iters_per_sec if n>0 else 0
tqdm_output = tqdm.format_meter(n=n, total=total, elapsed=elapsed, ncols=ncols, prefix="Test")
self._compare_bars(tinytqdm_output, tqdm_output)
# compare final bars
tinytqdm_output = mock_stderr.getvalue().split("\r")[-1].rstrip()
iters_per_sec = float(tinytqdm_output.split("it/s")[-2].split(" ")[-1]) if n>0 else 0
elapsed = total/iters_per_sec if n>0 else 0
tqdm_output = tqdm.format_meter(n=total, total=total, elapsed=elapsed, ncols=ncols, prefix="Test")
self._compare_bars(tinytqdm_output, tqdm_output)
@patch('sys.stderr', new_callable=StringIO)
@patch('shutil.get_terminal_size')
def test_unit_scale(self, mock_terminal_size, mock_stderr):
for unit_scale in [True, False]:
# NOTE: numpy comparison raises TypeError if exponent > 22
for exponent in range(1, 22, 3):
low, high = 10 ** exponent, 10 ** (exponent+1)
for _ in range(5):
total, ncols = random.randint(low, high), random.randint(*NCOLS_RANGE)
mock_terminal_size.return_value = namedtuple(field_names='columns', typename='terminal_size')(ncols)
mock_stderr.truncate(0)
# compare bars at each iteration (only when tinytqdm bar has been updated)
# setting high rate to make sure it does not skip
for n in tinytqdm(range(total), desc="Test", total=total, unit_scale=unit_scale, rate=1e9):
tinytqdm_output = mock_stderr.getvalue().split("\r")[-1].rstrip()
if n:
iters_per_sec = _get_iter_per_second(tinytqdm_output.split("it/s")[-2].split(" ")[-1])
elapsed = n/iters_per_sec
else:
elapsed = 0
tqdm_output = tqdm.format_meter(n=n, total=total, elapsed=elapsed, ncols=ncols, prefix="Test", unit_scale=unit_scale)
self._compare_bars(tinytqdm_output, tqdm_output)
if n > 3: break
@patch('sys.stderr', new_callable=StringIO)
@patch('shutil.get_terminal_size')
def test_unit_scale_exact(self, mock_terminal_size, mock_stderr):
unit_scale = True
ncols = 80
mock_terminal_size.return_value = namedtuple(field_names='columns', typename='terminal_size')(ncols)
mock_stderr.truncate(0)
total = 10
with patch('time.perf_counter', side_effect=[0]+list(range(100))): # one more 0 for the init call
# compare bars at each iteration (only when tinytqdm bar has been updated)
for n in tinytqdm(range(total), desc="Test", total=total, unit_scale=unit_scale, rate=1e9):
tinytqdm_output = mock_stderr.getvalue().split("\r")[-1].rstrip()
elapsed = n
tqdm_output = tqdm.format_meter(n=n, total=total, elapsed=elapsed, ncols=ncols, prefix="Test", unit_scale=unit_scale)
self._compare_bars(tinytqdm_output, tqdm_output)
if n > 5: break
total = 10
k=0.001000001
# regression test for
# E AssertionError: ' 1.00/10.0 1000it/s]' != ' 1.00/10.0 1.00kit/s]'
# E - 1.00/10.0 1000it/s]
# E ? ^
# E + 1.00/10.0 1.00kit/s]
# E ? + ^
with patch('time.perf_counter', side_effect=[0, *[i*k for i in range(100)]]): # one more 0 for the init call
# compare bars at each iteration (only when tinytqdm bar has been updated)
for n in tinytqdm(range(total), desc="Test", total=total, unit_scale=unit_scale, rate=1e9):
tinytqdm_output = mock_stderr.getvalue().split("\r")[-1].rstrip()
elapsed = n*k
tqdm_output = tqdm.format_meter(n=n, total=total, elapsed=elapsed, ncols=ncols, prefix="Test", unit_scale=unit_scale)
self._compare_bars(tinytqdm_output, tqdm_output)
if n > 5: break
@patch('sys.stderr', new_callable=StringIO)
@patch('shutil.get_terminal_size')
def test_set_description(self, mock_terminal_size, mock_stderr):
for _ in range(10):
total, ncols = random.randint(5, 30), random.randint(*NCOLS_RANGE)
mock_terminal_size.return_value = namedtuple(field_names='columns', typename='terminal_size')(ncols)
mock_stderr.truncate(0)
expected_prefix = "Test"
# compare bars at each iteration (only when tinytqdm bar has been updated)
for i,n in enumerate(bar := tinytqdm(range(total), desc="Test")):
if bar.i % bar.skip != 0: continue
tinytqdm_output = mock_stderr.getvalue().split("\r")[-1].rstrip()
iters_per_sec = float(tinytqdm_output.split("it/s")[-2].split(" ")[-1]) if n>0 else 0
elapsed = n/iters_per_sec if n>0 else 0
tqdm_output = tqdm.format_meter(n=n, total=total, elapsed=elapsed, ncols=ncols, prefix=expected_prefix)
expected_prefix = desc = f"Test {i}" if i % 2 == 0 else ""
bar.set_description(desc)
self._compare_bars(tinytqdm_output, tqdm_output)
# compare final bars
tinytqdm_output = mock_stderr.getvalue().split("\r")[-1].rstrip()
iters_per_sec = float(tinytqdm_output.split("it/s")[-2].split(" ")[-1]) if n>0 else 0
elapsed = total/iters_per_sec if n>0 else 0
tqdm_output = tqdm.format_meter(n=total, total=total, elapsed=elapsed, ncols=ncols, prefix=expected_prefix)
self._compare_bars(tinytqdm_output, tqdm_output)
@patch('sys.stderr', new_callable=StringIO)
@patch('shutil.get_terminal_size')
def test_trange_output_iter(self, mock_terminal_size, mock_stderr):
for _ in range(5):
total, ncols = random.randint(5, 30), random.randint(*NCOLS_RANGE)
mock_terminal_size.return_value = namedtuple(field_names='columns', typename='terminal_size')(ncols)
mock_stderr.truncate(0)
# compare bars at each iteration (only when tinytqdm bar has been updated)
for n in (bar := tinytrange(total, desc="Test")):
if bar.i % bar.skip != 0: continue
tiny_output = mock_stderr.getvalue().split("\r")[-1].rstrip()
iters_per_sec = float(tiny_output.split("it/s")[-2].split(" ")[-1]) if n>0 else 0
elapsed = n/iters_per_sec if n>0 else 0
tqdm_output = tqdm.format_meter(n=n, total=total, elapsed=elapsed, ncols=ncols, prefix="Test")
self._compare_bars(tiny_output, tqdm_output)
# compare final bars
tiny_output = mock_stderr.getvalue().split("\r")[-1].rstrip()
iters_per_sec = float(tiny_output.split("it/s")[-2].split(" ")[-1]) if n>0 else 0
elapsed = total/iters_per_sec if n>0 else 0
tqdm_output = tqdm.format_meter(n=total, total=total, elapsed=elapsed, ncols=ncols, prefix="Test")
self._compare_bars(tiny_output, tqdm_output)
@patch('sys.stderr', new_callable=StringIO)
@patch('shutil.get_terminal_size')
def test_tqdm_output_custom(self, mock_terminal_size, mock_stderr):
for _ in range(10):
total, ncols = random.randint(10000, 1000000), random.randint(*NCOLS_RANGE)
mock_terminal_size.return_value = namedtuple(field_names='columns', typename='terminal_size')(ncols)
mock_stderr.truncate(0)
# compare bars at each iteration (only when tinytqdm bar has been updated)
bar = tinytqdm(total=total, desc="Test")
n = 0
while n < total:
incr = (total // 100) + random.randint(0, 1000)
if n + incr > total: incr = total - n
bar.update(incr, close=n+incr==total)
n += incr
if bar.i % bar.skip != 0: continue
tinytqdm_output = mock_stderr.getvalue().split("\r")[-1].rstrip()
iters_per_sec = float(tinytqdm_output.split("it/s")[-2].split(" ")[-1]) if n>0 else 0
elapsed = n/iters_per_sec if n>0 else 0
tqdm_output = tqdm.format_meter(n=n, total=total, elapsed=elapsed, ncols=ncols, prefix="Test")
self._compare_bars(tinytqdm_output, tqdm_output)
@patch('sys.stderr', new_callable=StringIO)
@patch('shutil.get_terminal_size')
def test_tqdm_output_custom_0_total(self, mock_terminal_size, mock_stderr):
for _ in range(10):
total, ncols = random.randint(10000, 100000), random.randint(*NCOLS_RANGE)
mock_terminal_size.return_value = namedtuple(field_names='columns', typename='terminal_size')(ncols)
mock_stderr.truncate(0)
# compare bars at each iteration (only when tinytqdm bar has been updated)
bar = tinytqdm(total=0, desc="Test")
n = 0
while n < total:
incr = (total // 10) + random.randint(0, 100)
if n + incr > total: incr = total - n
bar.update(incr, close=n+incr==total)
n += incr
if bar.i % bar.skip != 0: continue
tinytqdm_output = mock_stderr.getvalue().split("\r")[-1].rstrip()
iters_per_sec = float(tinytqdm_output.split("it/s")[-2].split(" ")[-1]) if n>0 else 0
elapsed = n/iters_per_sec if n>0 else 0
tqdm_output = tqdm.format_meter(n=n, total=0, elapsed=elapsed, ncols=ncols, prefix="Test")
self.assertEqual(tinytqdm_output, tqdm_output)
@patch('sys.stderr', new_callable=StringIO)
@patch('shutil.get_terminal_size')
def test_tqdm_output_custom_nolen_total(self, mock_terminal_size, mock_stderr):
for unit_scale in [True, False]:
for _ in range(5):
gen = itertools.count(0)
ncols = random.randint(*NCOLS_RANGE)
mock_terminal_size.return_value = namedtuple(field_names='columns', typename='terminal_size')(ncols)
mock_stderr.truncate(0)
# compare bars at each iteration (only when tinytqdm bar has been updated)
# setting high rate to make sure it does not skip
for n,g in enumerate(tinytqdm(gen, desc="Test", unit_scale=unit_scale, rate=1e9)):
assert g == n
tinytqdm_output = mock_stderr.getvalue().split("\r")[-1].rstrip()
if n:
iters_per_sec = _get_iter_per_second(tinytqdm_output.split("it/s")[-2].split(" ")[-1])
elapsed = n/iters_per_sec
else:
elapsed = 0
tqdm_output = tqdm.format_meter(n=n, total=0, elapsed=elapsed, ncols=ncols, prefix="Test", unit_scale=unit_scale)
self.assertEqual(tinytqdm_output, tqdm_output)
if n > 5: break
@patch('sys.stderr', new_callable=StringIO)
@patch('shutil.get_terminal_size')
def test_tqdm_write(self, mock_terminal_size, mock_stderr):
for _ in range(5):
ncols, tqdm_fp = random.randint(*NCOLS_RANGE), StringIO()
mock_terminal_size.return_value = namedtuple(field_names='columns', typename='terminal_size')(ncols)
mock_stderr.truncate(0)
tqdm_fp.truncate(0)
for i in tinytqdm(range(10)):
time.sleep(0.01)
tinytqdm.write(str(i))
tqdm.write(str(i), file=tqdm_fp)
tinytqdm_out, tqdm_out = mock_stderr.getvalue(), tqdm_fp.getvalue()
self.assertEqual(tinytqdm_out.split("\r\033[K")[-1], tqdm_out.split(f"{i-1}\n")[-1])
self.assertEqual(tinytqdm_out, tinytqdm_out)
@patch('sys.stderr', new_callable=StringIO)
@patch('shutil.get_terminal_size')
def test_tqdm_context_manager(self, mock_terminal_size, mock_stderr):
for _ in range(10):
total, ncols = random.randint(5, 30), random.randint(*NCOLS_RANGE)
mock_terminal_size.return_value = namedtuple(field_names='columns', typename='terminal_size')(ncols)
mock_stderr.truncate(0)
with tinytqdm(desc="Test", total=total) as bar:
for _ in range(total):
bar.update(1)
tinytqdm_output = mock_stderr.getvalue().split("\r")[-1].rstrip()
iters_per_sec = float(tinytqdm_output.split("it/s")[-2].split(" ")[-1])
elapsed = total/iters_per_sec
tqdm_output = tqdm.format_meter(n=total, total=total, elapsed=elapsed, ncols=ncols, prefix="Test")
self._compare_bars(tinytqdm_output, tqdm_output)
def test_tqdm_perf(self):
st = time.perf_counter()
for _ in tqdm(range(100)): pass
tqdm_time = time.perf_counter() - st
st = time.perf_counter()
for _ in tinytqdm(range(100)): pass
tinytqdm_time = time.perf_counter() - st
assert tinytqdm_time < 2 * tqdm_time
def test_tqdm_perf_high_iter(self):
st = time.perf_counter()
for _ in tqdm(range(10^7)): pass
tqdm_time = time.perf_counter() - st
st = time.perf_counter()
for _ in tinytqdm(range(10^7)): pass
tinytqdm_time = time.perf_counter() - st
assert tinytqdm_time < 5 * tqdm_time
if __name__ == '__main__':
unittest.main()