From f5e48678640bbde0f2616e955ced53f30d61ac6c Mon Sep 17 00:00:00 2001 From: Dean Lee Date: Mon, 29 Nov 2021 18:13:30 +0800 Subject: [PATCH] replay: refactor http download (#23052) * refactor http download * use chunk_size instead of parts testcase:set chunksize to 5mb * use template space * cleanup * remove unused include * check buffer overfllow * simplify print download speed --- selfdrive/ui/replay/filereader.cc | 27 ++---- selfdrive/ui/replay/filereader.h | 6 +- selfdrive/ui/replay/route.cc | 2 +- selfdrive/ui/replay/tests/test_replay.cc | 18 ++-- selfdrive/ui/replay/util.cc | 109 ++++++++++++++--------- selfdrive/ui/replay/util.h | 4 +- 6 files changed, 87 insertions(+), 79 deletions(-) diff --git a/selfdrive/ui/replay/filereader.cc b/selfdrive/ui/replay/filereader.cc index d110255497..84dc76694b 100644 --- a/selfdrive/ui/replay/filereader.cc +++ b/selfdrive/ui/replay/filereader.cc @@ -1,12 +1,7 @@ #include "selfdrive/ui/replay/filereader.h" -#include - -#include -#include #include #include -#include #include "selfdrive/common/util.h" #include "selfdrive/ui/replay/util.h" @@ -31,7 +26,7 @@ std::string FileReader::read(const std::string &file, std::atomic *abort) } else if (is_remote) { result = download(file, abort); if (cache_to_local_ && !result.empty()) { - std::ofstream fs(local_file, fs.binary | fs.out); + std::ofstream fs(local_file, std::ios::binary | std::ios::out); fs.write(result.data(), result.size()); } } @@ -39,23 +34,13 @@ std::string FileReader::read(const std::string &file, std::atomic *abort) } std::string FileReader::download(const std::string &url, std::atomic *abort) { - std::string result; - size_t remote_file_size = 0; for (int i = 0; i <= max_retries_ && !(abort && *abort); ++i) { - if (i > 0) { - std::cout << "download failed, retrying" << i << std::endl; - } - if (remote_file_size <= 0) { - remote_file_size = getRemoteFileSize(url); + std::string result = httpGet(url, chunk_size_, abort); + if (!result.empty()) { + return result; } - if (remote_file_size > 0 && !(abort && *abort)) { - std::ostringstream oss; - result.resize(remote_file_size); - oss.rdbuf()->pubsetbuf(result.data(), result.size()); - int chunks = chunk_size_ > 0 ? std::max(1, (int)std::nearbyint(remote_file_size / (float)chunk_size_)) : 1; - if (httpMultiPartDownload(url, oss, chunks, remote_file_size, abort)) { - return result; - } + if (i != max_retries_) { + std::cout << "download failed, retrying " << i + 1 << std::endl; } } return {}; diff --git a/selfdrive/ui/replay/filereader.h b/selfdrive/ui/replay/filereader.h index 06ce14e9f2..34aa91e858 100644 --- a/selfdrive/ui/replay/filereader.h +++ b/selfdrive/ui/replay/filereader.h @@ -5,14 +5,14 @@ class FileReader { public: - FileReader(bool cache_to_local, int chunk_size = -1, int max_retries = 3) - : cache_to_local_(cache_to_local), chunk_size_(chunk_size), max_retries_(max_retries) {} + FileReader(bool cache_to_local, size_t chunk_size = 0, int retries = 3) + : cache_to_local_(cache_to_local), chunk_size_(chunk_size), max_retries_(retries) {} virtual ~FileReader() {} std::string read(const std::string &file, std::atomic *abort = nullptr); private: std::string download(const std::string &url, std::atomic *abort); - int chunk_size_; + size_t chunk_size_; int max_retries_; bool cache_to_local_; }; diff --git a/selfdrive/ui/replay/route.cc b/selfdrive/ui/replay/route.cc index 9ee1d387c4..caaef4de5c 100644 --- a/selfdrive/ui/replay/route.cc +++ b/selfdrive/ui/replay/route.cc @@ -121,7 +121,7 @@ void Segment::loadFile(int id, const std::string file) { frames[id] = std::make_unique(local_cache, 20 * 1024 * 1024, 3); success = frames[id]->load(file, flags & REPLAY_FLAG_NO_CUDA, &abort_); } else { - log = std::make_unique(local_cache, -1, 3); + log = std::make_unique(local_cache, 0, 3); success = log->load(file, &abort_); } diff --git a/selfdrive/ui/replay/tests/test_replay.cc b/selfdrive/ui/replay/tests/test_replay.cc index f2a2171c82..9efc55a610 100644 --- a/selfdrive/ui/replay/tests/test_replay.cc +++ b/selfdrive/ui/replay/tests/test_replay.cc @@ -1,7 +1,5 @@ #include #include -#include -#include #include "catch2/catch.hpp" #include "selfdrive/common/util.h" @@ -16,19 +14,15 @@ TEST_CASE("httpMultiPartDownload") { char filename[] = "/tmp/XXXXXX"; close(mkstemp(filename)); + const size_t chunk_size = 5 * 1024 * 1024; std::string content; - auto file_size = getRemoteFileSize(TEST_RLOG_URL); - REQUIRE(file_size > 0); - SECTION("5 connections, download to file") { - std::ofstream of(filename, of.binary | of.out); - REQUIRE(httpMultiPartDownload(TEST_RLOG_URL, of, 5, file_size)); + SECTION("download to file") { + REQUIRE(httpDownload(TEST_RLOG_URL, filename, chunk_size)); content = util::read_file(filename); } - SECTION("5 connection, download to buffer") { - std::ostringstream oss; - content.resize(file_size); - oss.rdbuf()->pubsetbuf(content.data(), content.size()); - REQUIRE(httpMultiPartDownload(TEST_RLOG_URL, oss, 5, file_size)); + SECTION("download to buffer") { + content = httpGet(TEST_RLOG_URL, chunk_size); + REQUIRE(!content.empty()); } REQUIRE(content.size() == 9112651); REQUIRE(sha256(content) == TEST_RLOG_CHECKSUM); diff --git a/selfdrive/ui/replay/util.cc b/selfdrive/ui/replay/util.cc index b4fcbf7e6b..27f0bd9cfd 100644 --- a/selfdrive/ui/replay/util.cc +++ b/selfdrive/ui/replay/util.cc @@ -4,8 +4,10 @@ #include #include +#include #include -#include +#include +#include #include #include #include @@ -22,21 +24,34 @@ struct CURLGlobalInitializer { ~CURLGlobalInitializer() { curl_global_cleanup(); } }; +template struct MultiPartWriter { + T *buf; + size_t *total_written; size_t offset; size_t end; - size_t written; - std::ostream *os; + + size_t write(char *data, size_t size, size_t count) { + size_t bytes = size * count; + if ((offset + bytes) > end) return 0; + + if constexpr (std::is_same::value) { + memcpy(buf->data() + offset, data, bytes); + } else if constexpr (std::is_same::value) { + buf->seekp(offset); + buf->write(data, bytes); + } + + offset += bytes; + *total_written += bytes; + return bytes; + } }; +template size_t write_cb(char *data, size_t size, size_t count, void *userp) { - MultiPartWriter *w = (MultiPartWriter *)userp; - w->os->seekp(w->offset); - size_t bytes = size * count; - w->os->write(data, bytes); - w->offset += bytes; - w->written += bytes; - return bytes; + auto w = (MultiPartWriter *)userp; + return w->write(data, size, count); } size_t dumy_write_cb(char *data, size_t size, size_t count, void *userp) { return size * count; } @@ -64,12 +79,12 @@ size_t getRemoteFileSize(const std::string &url) { CURLcode res = curl_easy_perform(curl); double content_length = -1; if (res == CURLE_OK) { - res = curl_easy_getinfo(curl, CURLINFO_CONTENT_LENGTH_DOWNLOAD, &content_length); + curl_easy_getinfo(curl, CURLINFO_CONTENT_LENGTH_DOWNLOAD, &content_length); } else { std::cout << "Download failed: error code: " << res << std::endl; } curl_easy_cleanup(curl); - return content_length > 0 ? content_length : 0; + return content_length > 0 ? (size_t)content_length : 0; } std::string getUrlWithoutQuery(const std::string &url) { @@ -81,30 +96,32 @@ void enableHttpLogging(bool enable) { enable_http_logging = enable; } -bool httpMultiPartDownload(const std::string &url, std::ostream &os, int parts, size_t content_length, std::atomic *abort) { +template +bool httpDownload(const std::string &url, T &buf, size_t chunk_size, size_t content_length, std::atomic *abort) { static CURLGlobalInitializer curl_initializer; - static std::mutex lock; - static uint64_t total_written = 0, prev_total_written = 0; - static double last_print_ts = 0; - os.seekp(content_length - 1); - os.write("\0", 1); + int parts = 1; + if (chunk_size > 0 && content_length > 10 * 1024 * 1024) { + parts = std::nearbyint(content_length / (float)chunk_size); + parts = std::clamp(parts, 1, 5); + } CURLM *cm = curl_multi_init(); - - std::map writers; + size_t written = 0; + std::map> writers; const int part_size = content_length / parts; for (int i = 0; i < parts; ++i) { CURL *eh = curl_easy_init(); writers[eh] = { - .os = &os, + .buf = &buf, + .total_written = &written, .offset = (size_t)(i * part_size), - .end = i == parts - 1 ? content_length - 1 : (i + 1) * part_size - 1, + .end = i == parts - 1 ? content_length : (i + 1) * part_size, }; - curl_easy_setopt(eh, CURLOPT_WRITEFUNCTION, write_cb); + curl_easy_setopt(eh, CURLOPT_WRITEFUNCTION, write_cb); curl_easy_setopt(eh, CURLOPT_WRITEDATA, (void *)(&writers[eh])); curl_easy_setopt(eh, CURLOPT_URL, url.c_str()); - curl_easy_setopt(eh, CURLOPT_RANGE, util::string_format("%d-%d", writers[eh].offset, writers[eh].end).c_str()); + curl_easy_setopt(eh, CURLOPT_RANGE, util::string_format("%d-%d", writers[eh].offset, writers[eh].end - 1).c_str()); curl_easy_setopt(eh, CURLOPT_HTTPGET, 1); curl_easy_setopt(eh, CURLOPT_NOSIGNAL, 1); curl_easy_setopt(eh, CURLOPT_FOLLOWLOCATION, 1); @@ -112,27 +129,22 @@ bool httpMultiPartDownload(const std::string &url, std::ostream &os, int parts, curl_multi_add_handle(cm, eh); } - int still_running = 1; size_t prev_written = 0; + double last_print = millis_since_boot(); + int still_running = 1; while (still_running > 0 && !(abort && *abort)) { curl_multi_wait(cm, nullptr, 0, 1000, nullptr); curl_multi_perform(cm, &still_running); - size_t written = std::accumulate(writers.begin(), writers.end(), 0, [=](int v, auto &w) { return v + w.second.written; }); - int cur_written = written - prev_written; - prev_written = written; - - std::lock_guard lk(lock); - double ts = millis_since_boot(); - total_written += cur_written; - if ((ts - last_print_ts) > 2 * 1000) { - if (enable_http_logging && last_print_ts > 0) { - size_t average = (total_written - prev_total_written) / ((ts - last_print_ts) / 1000.); + if (enable_http_logging) { + if (double ts = millis_since_boot(); (ts - last_print) > 2 * 1000) { + size_t average = (written - prev_written) / ((ts - last_print) / 1000.); int progress = std::min(100, 100.0 * (double)written / (double)content_length); - std::cout << "downloading " << getUrlWithoutQuery(url) << " - " << progress << "% (" << formattedDataSize(average) << "/s)" << std::endl; + std::cout << "downloading " << getUrlWithoutQuery(url) << " - " << progress + << "% (" << formattedDataSize(average) << "/s)" << std::endl; + last_print = ts; + prev_written = written; } - prev_total_written = total_written; - last_print_ts = ts; } } @@ -155,15 +167,32 @@ bool httpMultiPartDownload(const std::string &url, std::ostream &os, int parts, } } - for (auto &[e, w] : writers) { + for (const auto &[e, w] : writers) { curl_multi_remove_handle(cm, e); curl_easy_cleanup(e); } - curl_multi_cleanup(cm); + return complete == parts; } +std::string httpGet(const std::string &url, size_t chunk_size, std::atomic *abort) { + size_t size = getRemoteFileSize(url); + if (size == 0) return {}; + + std::string result(size, '\0'); + return httpDownload(url, result, chunk_size, size, abort) ? result : ""; +} + +bool httpDownload(const std::string &url, const std::string &file, size_t chunk_size, std::atomic *abort) { + size_t size = getRemoteFileSize(url); + if (size == 0) return false; + + std::ofstream of(file, std::ios::binary | std::ios::out); + of.seekp(size - 1).write("\0", 1); + return httpDownload(url, of, chunk_size, size, abort); +} + std::string decompressBZ2(const std::string &in) { if (in.empty()) return {}; diff --git a/selfdrive/ui/replay/util.h b/selfdrive/ui/replay/util.h index 30a26c4314..85d7af0125 100644 --- a/selfdrive/ui/replay/util.h +++ b/selfdrive/ui/replay/util.h @@ -1,7 +1,6 @@ #pragma once #include -#include #include std::string sha256(const std::string &str); @@ -10,4 +9,5 @@ std::string decompressBZ2(const std::string &in); void enableHttpLogging(bool enable); std::string getUrlWithoutQuery(const std::string &url); size_t getRemoteFileSize(const std::string &url); -bool httpMultiPartDownload(const std::string &url, std::ostream &os, int parts, size_t content_length, std::atomic *abort = nullptr); +std::string httpGet(const std::string &url, size_t chunk_size = 0, std::atomic *abort = nullptr); +bool httpDownload(const std::string &url, const std::string &file, size_t chunk_size = 0, std::atomic *abort = nullptr);