#!/usr/bin/env python3
from functools import wraps
import http . server
import os
import threading
import time
import unittest
from parameterized import parameterized
from openpilot . tools . lib . url_file import URLFile
class CachingTestRequestHandler ( http . server . BaseHTTPRequestHandler ) :
FILE_EXISTS = True
def do_GET ( self ) :
if self . FILE_EXISTS :
self . send_response ( 200 , b ' 1234 ' )
else :
self . send_response ( 404 )
self . end_headers ( )
def do_HEAD ( self ) :
if self . FILE_EXISTS :
self . send_response ( 200 )
self . send_header ( " Content-Length " , " 4 " )
else :
self . send_response ( 404 )
self . end_headers ( )
class CachingTestServer ( threading . Thread ) :
def run ( self ) :
self . server = http . server . HTTPServer ( ( " 127.0.0.1 " , 0 ) , CachingTestRequestHandler )
self . port = self . server . server_port
self . server . serve_forever ( )
def stop ( self ) :
self . server . server_close ( )
self . server . shutdown ( )
def with_caching_server ( func ) :
@wraps ( func )
def wrapper ( * args , * * kwargs ) :
server = CachingTestServer ( )
server . start ( )
time . sleep ( 0.25 ) # wait for server to get it's port
try :
func ( * args , * * kwargs , port = server . port )
finally :
server . stop ( )
return wrapper
class TestFileDownload ( unittest . TestCase ) :
def compare_loads ( self , url , start = 0 , length = None ) :
""" Compares range between cached and non cached version """
file_cached = URLFile ( url , cache = True )
file_downloaded = URLFile ( url , cache = False )
file_cached . seek ( start )
file_downloaded . seek ( start )
self . assertEqual ( file_cached . get_length ( ) , file_downloaded . get_length ( ) )
self . assertLessEqual ( length + start if length is not None else 0 , file_downloaded . get_length ( ) )
response_cached = file_cached . read ( ll = length )
response_downloaded = file_downloaded . read ( ll = length )
self . assertEqual ( response_cached , response_downloaded )
# Now test with cache in place
file_cached = URLFile ( url , cache = True )
file_cached . seek ( start )
response_cached = file_cached . read ( ll = length )
self . assertEqual ( file_cached . get_length ( ) , file_downloaded . get_length ( ) )
self . assertEqual ( response_cached , response_downloaded )
def test_small_file ( self ) :
# Make sure we don't force cache
os . environ [ " FILEREADER_CACHE " ] = " 0 "
small_file_url = " https://raw.githubusercontent.com/commaai/openpilot/master/docs/SAFETY.md "
# If you want large file to be larger than a chunk
# large_file_url = "https://commadataci.blob.core.windows.net/openpilotci/0375fdf7b1ce594d/2019-06-13--08-32-25/3/fcamera.hevc"
# Load full small file
self . compare_loads ( small_file_url )
file_small = URLFile ( small_file_url )
length = file_small . get_length ( )
self . compare_loads ( small_file_url , length - 100 , 100 )
self . compare_loads ( small_file_url , 50 , 100 )
# Load small file 100 bytes at a time
for i in range ( length / / 100 ) :
self . compare_loads ( small_file_url , 100 * i , 100 )
def test_large_file ( self ) :
large_file_url = " https://commadataci.blob.core.windows.net/openpilotci/0375fdf7b1ce594d/2019-06-13--08-32-25/3/qlog.bz2 "
# Load the end 100 bytes of both files
file_large = URLFile ( large_file_url )
length = file_large . get_length ( )
self . compare_loads ( large_file_url , length - 100 , 100 )
self . compare_loads ( large_file_url )
@parameterized . expand ( [ ( True , ) , ( False , ) ] )
@with_caching_server
def test_recover_from_missing_file ( self , cache_enabled , port ) :
os . environ [ " FILEREADER_CACHE " ] = " 1 " if cache_enabled else " 0 "
file_url = f " http://localhost: { port } /test.png "
CachingTestRequestHandler . FILE_EXISTS = False
length = URLFile ( file_url ) . get_length ( )
self . assertEqual ( length , - 1 )
CachingTestRequestHandler . FILE_EXISTS = True
length = URLFile ( file_url ) . get_length ( )
self . assertEqual ( length , 4 )
if __name__ == " __main__ " :
unittest . main ( )