add casync support to agnos updater (#23654)

* add casync option to agnos updater

* open if necessary

* add python implementation

* last chunk can be small

* check flags

* cleaner check

* add remote and file stores

* remote caibx file

* print stats

* use python implementation

* clean up imports

* add progress

* fix logging

* fix duplicate chunks

* add comments

* json stats

* cleanup tmp

* normal image is still sparse

* Update system/hardware/tici/agnos.py

Co-authored-by: Adeeb Shihadeh <adeebshihadeh@gmail.com>

* Update system/hardware/tici/agnos.py

Co-authored-by: Adeeb Shihadeh <adeebshihadeh@gmail.com>

* add some types

* remove comment

* create Chunk type

* make readers a class

* try agnos 5.2

* add download retries

* catch all exceptions

* sleep between retry

* revert agnos.json changes

Co-authored-by: Adeeb Shihadeh <adeebshihadeh@gmail.com>
old-commit-hash: 3900781092
taco
Willem Melching 3 years ago committed by GitHub
parent 297a0bd65b
commit 3230474724
  1. 95
      system/hardware/tici/agnos.py
  2. 192
      system/hardware/tici/casync.py

@ -1,13 +1,16 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import hashlib
import json import json
import lzma import lzma
import hashlib import os
import requests
import struct import struct
import subprocess import subprocess
import time import time
import os from typing import Dict, Generator, List, Tuple, Union
from typing import Dict, Generator, Union
import requests
import system.hardware.tici.casync as casync
SPARSE_CHUNK_FMT = struct.Struct('H2xI4x') SPARSE_CHUNK_FMT = struct.Struct('H2xI4x')
@ -74,6 +77,7 @@ def unsparsify(f: StreamingDecompressor) -> Generator[bytes, None, None]:
else: else:
raise Exception("Unhandled sparse chunk type") raise Exception("Unhandled sparse chunk type")
# noop wrapper with same API as unsparsify() for non sparse images # noop wrapper with same API as unsparsify() for non sparse images
def noop(f: StreamingDecompressor) -> Generator[bytes, None, None]: def noop(f: StreamingDecompressor) -> Generator[bytes, None, None]:
while not f.eof: while not f.eof:
@ -99,8 +103,8 @@ def get_partition_path(target_slot_number: int, partition: dict) -> str:
return path return path
def verify_partition(target_slot_number: int, partition: Dict[str, Union[str, int]]) -> bool: def verify_partition(target_slot_number: int, partition: Dict[str, Union[str, int]], force_full_check: bool = False) -> bool:
full_check = partition['full_check'] full_check = partition['full_check'] or force_full_check
path = get_partition_path(target_slot_number, partition) path = get_partition_path(target_slot_number, partition)
if not isinstance(partition['size'], int): if not isinstance(partition['size'], int):
return False return False
@ -135,21 +139,10 @@ def clear_partition_hash(target_slot_number: int, partition: dict) -> None:
os.sync() os.sync()
def flash_partition(target_slot_number: int, partition: dict, cloudlog): def extract_compressed_image(target_slot_number: int, partition: dict, cloudlog):
cloudlog.info(f"Downloading and writing {partition['name']}") path = get_partition_path(target_slot_number, partition)
if verify_partition(target_slot_number, partition):
cloudlog.info(f"Already flashed {partition['name']}")
return
downloader = StreamingDecompressor(partition['url']) downloader = StreamingDecompressor(partition['url'])
# Clear hash before flashing in case we get interrupted
full_check = partition['full_check']
if not full_check:
clear_partition_hash(target_slot_number, partition)
path = get_partition_path(target_slot_number, partition)
with open(path, 'wb+') as out: with open(path, 'wb+') as out:
# Flash partition # Flash partition
last_p = 0 last_p = 0
@ -172,9 +165,67 @@ def flash_partition(target_slot_number: int, partition: dict, cloudlog):
if out.tell() != partition['size']: if out.tell() != partition['size']:
raise Exception("Uncompressed size mismatch") raise Exception("Uncompressed size mismatch")
# Write hash after successfull flash
os.sync() os.sync()
if not full_check:
def extract_casync_image(target_slot_number: int, partition: dict, cloudlog):
path = get_partition_path(target_slot_number, partition)
seed_path = path[:-1] + ('b' if path[-1] == 'a' else 'a')
target = casync.parse_caibx(partition['casync_caibx'])
sources: List[Tuple[str, casync.ChunkReader, casync.ChunkDict]] = []
# First source is the current partition. Index file for current version is provided in the manifest
if 'casync_seed_caibx' in partition:
sources += [('seed', casync.FileChunkReader(seed_path), casync.build_chunk_dict(casync.parse_caibx(partition['casync_seed_caibx'])))]
# Second source is the target partition, this allows for resuming
sources += [('target', casync.FileChunkReader(path), casync.build_chunk_dict(target))]
# Finally we add the remote source to download any missing chunks
sources += [('remote', casync.RemoteChunkReader(partition['casync_store']), casync.build_chunk_dict(target))]
last_p = 0
def progress(cur):
nonlocal last_p
p = int(cur / partition['size'] * 100)
if p != last_p:
last_p = p
print(f"Installing {partition['name']}: {p}", flush=True)
stats = casync.extract(target, sources, path, progress)
cloudlog.error(f'casync done {json.dumps(stats)}')
os.sync()
if not verify_partition(target_slot_number, partition, force_full_check=True):
raise Exception(f"Raw hash mismatch '{partition['hash_raw'].lower()}'")
def flash_partition(target_slot_number: int, partition: dict, cloudlog):
cloudlog.info(f"Downloading and writing {partition['name']}")
if verify_partition(target_slot_number, partition):
cloudlog.info(f"Already flashed {partition['name']}")
return
# Clear hash before flashing in case we get interrupted
full_check = partition['full_check']
if not full_check:
clear_partition_hash(target_slot_number, partition)
path = get_partition_path(target_slot_number, partition)
if 'casync_caibx' in partition:
extract_casync_image(target_slot_number, partition, cloudlog)
else:
extract_compressed_image(target_slot_number, partition, cloudlog)
# Write hash after successfull flash
if not full_check:
with open(path, 'wb+') as out:
out.seek(partition['size'])
out.write(partition['hash_raw'].lower().encode()) out.write(partition['hash_raw'].lower().encode())
@ -228,8 +279,8 @@ def verify_agnos_update(manifest_path: str, target_slot_number: int) -> bool:
if __name__ == "__main__": if __name__ == "__main__":
import logging
import argparse import argparse
import logging
parser = argparse.ArgumentParser(description="Flash and verify AGNOS update", parser = argparse.ArgumentParser(description="Flash and verify AGNOS update",
formatter_class=argparse.ArgumentDefaultsHelpFormatter) formatter_class=argparse.ArgumentDefaultsHelpFormatter)

@ -0,0 +1,192 @@
#!/usr/bin/env python3
import io
import lzma
import os
import struct
import sys
import time
from abc import ABC, abstractmethod
from collections import defaultdict, namedtuple
from typing import Callable, Dict, List, Optional, Tuple
import requests
from Crypto.Hash import SHA512
CA_FORMAT_INDEX = 0x96824d9c7b129ff9
CA_FORMAT_TABLE = 0xe75b9e112f17417d
CA_FORMAT_TABLE_TAIL_MARKER = 0xe75b9e112f17417
FLAGS = 0xb000000000000000
CA_HEADER_LEN = 48
CA_TABLE_HEADER_LEN = 16
CA_TABLE_ENTRY_LEN = 40
CA_TABLE_MIN_LEN = CA_TABLE_HEADER_LEN + CA_TABLE_ENTRY_LEN
CHUNK_DOWNLOAD_TIMEOUT = 10
CHUNK_DOWNLOAD_RETRIES = 3
CAIBX_DOWNLOAD_TIMEOUT = 120
Chunk = namedtuple('Chunk', ['sha', 'offset', 'length'])
ChunkDict = Dict[bytes, Chunk]
class ChunkReader(ABC):
@abstractmethod
def read(self, chunk: Chunk) -> bytes:
...
class FileChunkReader(ChunkReader):
"""Reads chunks from a local file"""
def __init__(self, fn: str) -> None:
super().__init__()
self.f = open(fn, 'rb')
def read(self, chunk: Chunk) -> bytes:
self.f.seek(chunk.offset)
return self.f.read(chunk.length)
class RemoteChunkReader(ChunkReader):
"""Reads lzma compressed chunks from a remote store"""
def __init__(self, url: str) -> None:
super().__init__()
self.url = url
def read(self, chunk: Chunk) -> bytes:
sha_hex = chunk.sha.hex()
url = os.path.join(self.url, sha_hex[:4], sha_hex + ".cacnk")
for i in range(CHUNK_DOWNLOAD_RETRIES):
try:
resp = requests.get(url, timeout=CHUNK_DOWNLOAD_TIMEOUT)
break
except Exception:
if i == CHUNK_DOWNLOAD_RETRIES - 1:
raise
time.sleep(CHUNK_DOWNLOAD_TIMEOUT)
resp.raise_for_status()
decompressor = lzma.LZMADecompressor(format=lzma.FORMAT_AUTO)
return decompressor.decompress(resp.content)
def parse_caibx(caibx_path: str) -> List[Chunk]:
"""Parses the chunks from a caibx file. Can handle both local and remote files.
Returns a list of chunks with hash, offset and length"""
if os.path.isfile(caibx_path):
caibx = open(caibx_path, 'rb')
else:
resp = requests.get(caibx_path, timeout=CAIBX_DOWNLOAD_TIMEOUT)
resp.raise_for_status()
caibx = io.BytesIO(resp.content)
caibx.seek(0, os.SEEK_END)
caibx_len = caibx.tell()
caibx.seek(0, os.SEEK_SET)
# Parse header
length, magic, flags, min_size, _, max_size = struct.unpack("<QQQQQQ", caibx.read(CA_HEADER_LEN))
assert flags == flags
assert length == CA_HEADER_LEN
assert magic == CA_FORMAT_INDEX
# Parse table header
length, magic = struct.unpack("<QQ", caibx.read(CA_TABLE_HEADER_LEN))
assert magic == CA_FORMAT_TABLE
# Parse chunks
num_chunks = (caibx_len - CA_HEADER_LEN - CA_TABLE_MIN_LEN) // CA_TABLE_ENTRY_LEN
chunks = []
offset = 0
for i in range(num_chunks):
new_offset = struct.unpack("<Q", caibx.read(8))[0]
sha = caibx.read(32)
length = new_offset - offset
assert length <= max_size
# Last chunk can be smaller
if i < num_chunks - 1:
assert length >= min_size
chunks.append(Chunk(sha, offset, length))
offset = new_offset
return chunks
def build_chunk_dict(chunks: List[Chunk]) -> ChunkDict:
"""Turn a list of chunks into a dict for faster lookups based on hash"""
return {c.sha: c for c in chunks}
def extract(target: List[Chunk],
sources: List[Tuple[str, ChunkReader, ChunkDict]],
out_path: str,
progress: Optional[Callable[[int], None]] = None):
stats: Dict[str, int] = defaultdict(int)
with open(out_path, 'wb') as out:
for cur_chunk in target:
# Find source for desired chunk
for name, chunk_reader, store_chunks in sources:
if cur_chunk.sha in store_chunks:
bts = chunk_reader.read(store_chunks[cur_chunk.sha])
# Check length
if len(bts) != cur_chunk.length:
continue
# Check hash
if SHA512.new(bts, truncate="256").digest() != cur_chunk.sha:
continue
# Write to output
out.seek(cur_chunk.offset)
out.write(bts)
stats[name] += cur_chunk.length
if progress is not None:
progress(sum(stats.values()))
break
else:
raise RuntimeError("Desired chunk not found in provided stores")
return stats
def print_stats(stats: Dict[str, int]):
total_bytes = sum(stats.values())
print(f"Total size: {total_bytes / 1024 / 1024:.2f} MB")
for name, total in stats.items():
print(f" {name}: {total / 1024 / 1024:.2f} MB ({total / total_bytes * 100:.1f}%)")
def extract_simple(caibx_path, out_path, store_path):
# (name, callback, chunks)
target = parse_caibx(caibx_path)
sources = [
# (store_path, RemoteChunkReader(store_path), build_chunk_dict(target)),
(store_path, FileChunkReader(store_path), build_chunk_dict(target)),
]
return extract(target, sources, out_path)
if __name__ == "__main__":
caibx = sys.argv[1]
out = sys.argv[2]
store = sys.argv[3]
stats = extract_simple(caibx, out, store)
print_stats(stats)
Loading…
Cancel
Save