From 497b9b9ccc6c3a4dace85d0febbdc66667e4edcb Mon Sep 17 00:00:00 2001 From: Kai Lueke Date: Thu, 7 Dec 2023 12:52:16 +0100 Subject: [PATCH] download: Make progress reporting opt-in For scripting the current progress report is too verbose. Add a flag to opt-in. --- examples/download_test.rs | 2 +- examples/full_test.rs | 2 +- src/bin/download_sysext.rs | 28 +++++++++++++++++++++------- src/download.rs | 19 ++++++++++--------- 4 files changed, 33 insertions(+), 18 deletions(-) diff --git a/examples/download_test.rs b/examples/download_test.rs index 24b0575..bdab4f7 100644 --- a/examples/download_test.rs +++ b/examples/download_test.rs @@ -13,7 +13,7 @@ async fn main() -> Result<(), Box> { println!("fetching {}...", url); let data = Vec::new(); - let res = download_and_hash(&client, url, data).await?; + let res = download_and_hash(&client, url, data, false).await?; println!("hash: {}", res.hash); diff --git a/examples/full_test.rs b/examples/full_test.rs index 36c56cc..4a614b2 100644 --- a/examples/full_test.rs +++ b/examples/full_test.rs @@ -83,7 +83,7 @@ async fn main() -> Result<(), Box> { // std::io::BufWriter wrapping an std::fs::File is probably the right choice. // std::io::sink() is basically just /dev/null let data = std::io::sink(); - let res = ue_rs::download_and_hash(&client, url.clone(), data).await.context(format!("download_and_hash({url:?}) failed"))?; + let res = ue_rs::download_and_hash(&client, url.clone(), data, false).await.context(format!("download_and_hash({url:?}) failed"))?; println!("\texpected sha256: {}", expected_sha256); println!("\tcalculated sha256: {}", res.hash); diff --git a/src/bin/download_sysext.rs b/src/bin/download_sysext.rs index 905a496..4507837 100644 --- a/src/bin/download_sysext.rs +++ b/src/bin/download_sysext.rs @@ -94,7 +94,7 @@ impl<'a> Package<'a> { Ok(()) } - async fn download(&mut self, into_dir: &Path, client: &reqwest::Client) -> Result<()> { + async fn download(&mut self, into_dir: &Path, client: &reqwest::Client, print_progress: bool) -> Result<()> { // FIXME: use _range_start for completing downloads let _range_start = match self.status { PackageStatus::ToDownload => 0, @@ -107,7 +107,7 @@ impl<'a> Package<'a> { let path = into_dir.join(&*self.name); let mut file = File::create(path.clone()).context(format!("failed to create path ({:?})", path.display()))?; - let res = match ue_rs::download_and_hash(client, self.url.clone(), &mut file).await { + let res = match ue_rs::download_and_hash(client, self.url.clone(), &mut file, print_progress).await { Ok(ok) => ok, Err(err) => { error!("Downloading failed with error {}", err); @@ -243,14 +243,14 @@ fn get_pkgs_to_download<'a>(resp: &'a omaha::Response, glob_set: &GlobSet) } // Read data from remote URL into File -async fn fetch_url_to_file<'a, U>(path: &'a Path, input_url: U, client: &'a Client) -> Result> +async fn fetch_url_to_file<'a, U>(path: &'a Path, input_url: U, client: &'a Client, print_progress: bool) -> Result> where U: reqwest::IntoUrl + From + std::clone::Clone + std::fmt::Debug, Url: From, { let mut file = File::create(path).context(format!("failed to create path ({:?})", path.display()))?; - ue_rs::download_and_hash(client, input_url.clone(), &mut file).await.context(format!("unable to download data(url {:?})", input_url))?; + ue_rs::download_and_hash(client, input_url.clone(), &mut file, print_progress).await.context(format!("unable to download data(url {:?})", input_url))?; Ok(Package { name: Cow::Borrowed(path.file_name().unwrap_or(OsStr::new("fakepackage")).to_str().unwrap_or("fakepackage")), @@ -261,10 +261,10 @@ where }) } -async fn do_download_verify(pkg: &mut Package<'_>, output_dir: &Path, unverified_dir: &Path, pubkey_file: &str, client: &Client) -> Result<()> { +async fn do_download_verify(pkg: &mut Package<'_>, output_dir: &Path, unverified_dir: &Path, pubkey_file: &str, client: &Client, print_progress: bool) -> Result<()> { pkg.check_download(unverified_dir)?; - pkg.download(unverified_dir, client).await.context(format!("unable to download \"{:?}\"", pkg.name))?; + pkg.download(unverified_dir, client, print_progress).await.context(format!("unable to download \"{:?}\"", pkg.name))?; // Unverified payload is stored in e.g. "output_dir/.unverified/oem.gz". // Verified payload is stored in e.g. "output_dir/oem.raw". @@ -304,6 +304,10 @@ struct Args { /// may be specified multiple times. #[argh(option, short = 'm')] image_match: Vec, + + /// report download progress + #[argh(option, short = 'v', default = "false")] + print_progress: bool, } impl Args { @@ -369,6 +373,7 @@ async fn main() -> Result<(), Box> { &temp_payload_path, Url::from_str(url.as_str()).context(anyhow!("failed to convert into url ({:?})", url))?, &client, + args.print_progress, ) .await?; do_download_verify( @@ -377,6 +382,7 @@ async fn main() -> Result<(), Box> { unverified_dir.as_path(), args.pubkey_file.as_str(), &client, + args.print_progress, ) .await?; @@ -404,7 +410,15 @@ async fn main() -> Result<(), Box> { //// for pkg in pkgs_to_dl.iter_mut() { - do_download_verify(pkg, output_dir, unverified_dir.as_path(), args.pubkey_file.as_str(), &client).await?; + do_download_verify( + pkg, + output_dir, + unverified_dir.as_path(), + args.pubkey_file.as_str(), + &client, + args.print_progress, + ) + .await?; } // clean up data diff --git a/src/download.rs b/src/download.rs index b42edee..d00b841 100644 --- a/src/download.rs +++ b/src/download.rs @@ -57,7 +57,7 @@ pub fn hash_on_disk_sha256(path: &Path, maxlen: Option) -> Result(client: &reqwest::Client, url: U, mut data: W) -> Result> +pub async fn download_and_hash(client: &reqwest::Client, url: U, mut data: W, print_progress: bool) -> Result> where U: reqwest::IntoUrl + Clone, W: io::Write, @@ -100,14 +100,15 @@ where hasher.update(&chunk); data.write_all(&chunk).context("failed to write_all chunk")?; - // TODO: better way to report progress? - print!( - "\rread {}/{} ({:3}%)", - bytes_read, - bytes_to_read, - ((bytes_read as f32 / bytes_to_read as f32) * 100.0f32).floor() - ); - io::stdout().flush().context("failed to flush stdout")?; + if print_progress { + print!( + "\rread {}/{} ({:3}%)", + bytes_read, + bytes_to_read, + ((bytes_read as f32 / bytes_to_read as f32) * 100.0f32).floor() + ); + io::stdout().flush().context("failed to flush stdout")?; + } } data.flush().context("failed to flush data")?;