diff --git a/system/hardware/tici/agnos.py b/system/hardware/tici/agnos.py index e664654c65..527422aca7 100755 --- a/system/hardware/tici/agnos.py +++ b/system/hardware/tici/agnos.py @@ -1,13 +1,16 @@ #!/usr/bin/env python3 +import hashlib import json import lzma -import hashlib -import requests +import os import struct import subprocess import time -import os -from typing import Dict, Generator, Union +from typing import Dict, Generator, List, Tuple, Union + +import requests + +import system.hardware.tici.casync as casync SPARSE_CHUNK_FMT = struct.Struct('H2xI4x') @@ -74,6 +77,7 @@ def unsparsify(f: StreamingDecompressor) -> Generator[bytes, None, None]: else: raise Exception("Unhandled sparse chunk type") + # noop wrapper with same API as unsparsify() for non sparse images def noop(f: StreamingDecompressor) -> Generator[bytes, None, None]: while not f.eof: @@ -99,8 +103,8 @@ def get_partition_path(target_slot_number: int, partition: dict) -> str: return path -def verify_partition(target_slot_number: int, partition: Dict[str, Union[str, int]]) -> bool: - full_check = partition['full_check'] +def verify_partition(target_slot_number: int, partition: Dict[str, Union[str, int]], force_full_check: bool = False) -> bool: + full_check = partition['full_check'] or force_full_check path = get_partition_path(target_slot_number, partition) if not isinstance(partition['size'], int): return False @@ -135,21 +139,10 @@ def clear_partition_hash(target_slot_number: int, partition: dict) -> None: os.sync() -def flash_partition(target_slot_number: int, partition: dict, cloudlog): - cloudlog.info(f"Downloading and writing {partition['name']}") - - if verify_partition(target_slot_number, partition): - cloudlog.info(f"Already flashed {partition['name']}") - return - +def extract_compressed_image(target_slot_number: int, partition: dict, cloudlog): + path = get_partition_path(target_slot_number, partition) downloader = StreamingDecompressor(partition['url']) - # Clear hash before flashing in case we get interrupted - full_check = partition['full_check'] - if not full_check: - clear_partition_hash(target_slot_number, partition) - - path = get_partition_path(target_slot_number, partition) with open(path, 'wb+') as out: # Flash partition last_p = 0 @@ -172,9 +165,67 @@ def flash_partition(target_slot_number: int, partition: dict, cloudlog): if out.tell() != partition['size']: raise Exception("Uncompressed size mismatch") - # Write hash after successfull flash os.sync() - if not full_check: + + +def extract_casync_image(target_slot_number: int, partition: dict, cloudlog): + path = get_partition_path(target_slot_number, partition) + seed_path = path[:-1] + ('b' if path[-1] == 'a' else 'a') + + target = casync.parse_caibx(partition['casync_caibx']) + + sources: List[Tuple[str, casync.ChunkReader, casync.ChunkDict]] = [] + + # First source is the current partition. Index file for current version is provided in the manifest + if 'casync_seed_caibx' in partition: + sources += [('seed', casync.FileChunkReader(seed_path), casync.build_chunk_dict(casync.parse_caibx(partition['casync_seed_caibx'])))] + + # Second source is the target partition, this allows for resuming + sources += [('target', casync.FileChunkReader(path), casync.build_chunk_dict(target))] + + # Finally we add the remote source to download any missing chunks + sources += [('remote', casync.RemoteChunkReader(partition['casync_store']), casync.build_chunk_dict(target))] + + last_p = 0 + + def progress(cur): + nonlocal last_p + p = int(cur / partition['size'] * 100) + if p != last_p: + last_p = p + print(f"Installing {partition['name']}: {p}", flush=True) + + stats = casync.extract(target, sources, path, progress) + cloudlog.error(f'casync done {json.dumps(stats)}') + + os.sync() + if not verify_partition(target_slot_number, partition, force_full_check=True): + raise Exception(f"Raw hash mismatch '{partition['hash_raw'].lower()}'") + + +def flash_partition(target_slot_number: int, partition: dict, cloudlog): + cloudlog.info(f"Downloading and writing {partition['name']}") + + if verify_partition(target_slot_number, partition): + cloudlog.info(f"Already flashed {partition['name']}") + return + + # Clear hash before flashing in case we get interrupted + full_check = partition['full_check'] + if not full_check: + clear_partition_hash(target_slot_number, partition) + + path = get_partition_path(target_slot_number, partition) + + if 'casync_caibx' in partition: + extract_casync_image(target_slot_number, partition, cloudlog) + else: + extract_compressed_image(target_slot_number, partition, cloudlog) + + # Write hash after successfull flash + if not full_check: + with open(path, 'wb+') as out: + out.seek(partition['size']) out.write(partition['hash_raw'].lower().encode()) @@ -228,8 +279,8 @@ def verify_agnos_update(manifest_path: str, target_slot_number: int) -> bool: if __name__ == "__main__": - import logging import argparse + import logging parser = argparse.ArgumentParser(description="Flash and verify AGNOS update", formatter_class=argparse.ArgumentDefaultsHelpFormatter) diff --git a/system/hardware/tici/casync.py b/system/hardware/tici/casync.py new file mode 100755 index 0000000000..e77f473636 --- /dev/null +++ b/system/hardware/tici/casync.py @@ -0,0 +1,192 @@ +#!/usr/bin/env python3 +import io +import lzma +import os +import struct +import sys +import time +from abc import ABC, abstractmethod +from collections import defaultdict, namedtuple +from typing import Callable, Dict, List, Optional, Tuple + +import requests +from Crypto.Hash import SHA512 + +CA_FORMAT_INDEX = 0x96824d9c7b129ff9 +CA_FORMAT_TABLE = 0xe75b9e112f17417d +CA_FORMAT_TABLE_TAIL_MARKER = 0xe75b9e112f17417 +FLAGS = 0xb000000000000000 + +CA_HEADER_LEN = 48 +CA_TABLE_HEADER_LEN = 16 +CA_TABLE_ENTRY_LEN = 40 +CA_TABLE_MIN_LEN = CA_TABLE_HEADER_LEN + CA_TABLE_ENTRY_LEN + +CHUNK_DOWNLOAD_TIMEOUT = 10 +CHUNK_DOWNLOAD_RETRIES = 3 + +CAIBX_DOWNLOAD_TIMEOUT = 120 + +Chunk = namedtuple('Chunk', ['sha', 'offset', 'length']) +ChunkDict = Dict[bytes, Chunk] + + +class ChunkReader(ABC): + @abstractmethod + def read(self, chunk: Chunk) -> bytes: + ... + + +class FileChunkReader(ChunkReader): + """Reads chunks from a local file""" + def __init__(self, fn: str) -> None: + + super().__init__() + self.f = open(fn, 'rb') + + def read(self, chunk: Chunk) -> bytes: + self.f.seek(chunk.offset) + return self.f.read(chunk.length) + + +class RemoteChunkReader(ChunkReader): + """Reads lzma compressed chunks from a remote store""" + + def __init__(self, url: str) -> None: + super().__init__() + self.url = url + + def read(self, chunk: Chunk) -> bytes: + sha_hex = chunk.sha.hex() + url = os.path.join(self.url, sha_hex[:4], sha_hex + ".cacnk") + + for i in range(CHUNK_DOWNLOAD_RETRIES): + try: + resp = requests.get(url, timeout=CHUNK_DOWNLOAD_TIMEOUT) + break + except Exception: + if i == CHUNK_DOWNLOAD_RETRIES - 1: + raise + time.sleep(CHUNK_DOWNLOAD_TIMEOUT) + + resp.raise_for_status() + + decompressor = lzma.LZMADecompressor(format=lzma.FORMAT_AUTO) + return decompressor.decompress(resp.content) + + +def parse_caibx(caibx_path: str) -> List[Chunk]: + """Parses the chunks from a caibx file. Can handle both local and remote files. + Returns a list of chunks with hash, offset and length""" + if os.path.isfile(caibx_path): + caibx = open(caibx_path, 'rb') + else: + resp = requests.get(caibx_path, timeout=CAIBX_DOWNLOAD_TIMEOUT) + resp.raise_for_status() + caibx = io.BytesIO(resp.content) + + caibx.seek(0, os.SEEK_END) + caibx_len = caibx.tell() + caibx.seek(0, os.SEEK_SET) + + # Parse header + length, magic, flags, min_size, _, max_size = struct.unpack("= min_size + + chunks.append(Chunk(sha, offset, length)) + offset = new_offset + + return chunks + + +def build_chunk_dict(chunks: List[Chunk]) -> ChunkDict: + """Turn a list of chunks into a dict for faster lookups based on hash""" + return {c.sha: c for c in chunks} + + +def extract(target: List[Chunk], + sources: List[Tuple[str, ChunkReader, ChunkDict]], + out_path: str, + progress: Optional[Callable[[int], None]] = None): + stats: Dict[str, int] = defaultdict(int) + + with open(out_path, 'wb') as out: + for cur_chunk in target: + + # Find source for desired chunk + for name, chunk_reader, store_chunks in sources: + if cur_chunk.sha in store_chunks: + bts = chunk_reader.read(store_chunks[cur_chunk.sha]) + + # Check length + if len(bts) != cur_chunk.length: + continue + + # Check hash + if SHA512.new(bts, truncate="256").digest() != cur_chunk.sha: + continue + + # Write to output + out.seek(cur_chunk.offset) + out.write(bts) + + stats[name] += cur_chunk.length + + if progress is not None: + progress(sum(stats.values())) + + break + else: + raise RuntimeError("Desired chunk not found in provided stores") + + return stats + + +def print_stats(stats: Dict[str, int]): + total_bytes = sum(stats.values()) + print(f"Total size: {total_bytes / 1024 / 1024:.2f} MB") + for name, total in stats.items(): + print(f" {name}: {total / 1024 / 1024:.2f} MB ({total / total_bytes * 100:.1f}%)") + + +def extract_simple(caibx_path, out_path, store_path): + # (name, callback, chunks) + target = parse_caibx(caibx_path) + sources = [ + # (store_path, RemoteChunkReader(store_path), build_chunk_dict(target)), + (store_path, FileChunkReader(store_path), build_chunk_dict(target)), + ] + + return extract(target, sources, out_path) + + +if __name__ == "__main__": + caibx = sys.argv[1] + out = sys.argv[2] + store = sys.argv[3] + + stats = extract_simple(caibx, out, store) + print_stats(stats)