diff --git a/src/bin/download_sysext.rs b/src/bin/download_sysext.rs index 08556e3..1766935 100644 --- a/src/bin/download_sysext.rs +++ b/src/bin/download_sysext.rs @@ -6,14 +6,17 @@ use std::fs; use std::io; use std::io::{Read, Seek, SeekFrom}; use std::io::BufReader; +use std::str::FromStr; #[macro_use] extern crate log; use anyhow::{Context, Result, bail}; +use argh::FromArgs; use globset::{Glob, GlobSet, GlobSetBuilder}; use hard_xml::XmlRead; -use argh::FromArgs; +use omaha::FileSize; +use reqwest::Client; use reqwest::redirect::Policy; use url::Url; @@ -45,51 +48,7 @@ impl<'a> Package<'a> { // If maxlen is None, a simple read to the end of the file. // If maxlen is Some, read only until the given length. fn hash_on_disk(&mut self, path: &Path, maxlen: Option) -> Result> { - use sha2::{Sha256, Digest}; - - let file = File::open(path).context({ - format!("failed to open path({:?})", path.display()) - })?; - let mut hasher = Sha256::new(); - - let filelen = file.metadata().unwrap().len() as usize; - - let mut maxlen_to_read: usize = match maxlen { - Some(len) => { - if filelen < len { - filelen - } else { - len - } - } - None => filelen, - }; - - const CHUNKLEN: usize = 10485760; // 10M - - let mut freader = BufReader::new(file); - let mut chunklen: usize; - - freader.seek(SeekFrom::Start(0)).context("failed to seek(0)".to_string())?; - while maxlen_to_read > 0 { - if maxlen_to_read < CHUNKLEN { - chunklen = maxlen_to_read; - } else { - chunklen = CHUNKLEN; - } - - let mut databuf = vec![0u8; chunklen]; - - freader.read_exact(&mut databuf).context(format!("failed to read_exact(chunklen {:?})", chunklen))?; - - maxlen_to_read -= chunklen; - - hasher.update(&databuf); - } - - Ok(omaha::Hash::from_bytes( - hasher.finalize().into() - )) + hash_on_disk_sha256(path, maxlen) } #[rustfmt::skip] @@ -147,7 +106,7 @@ impl<'a> Package<'a> { let path = into_dir.join(&*self.name); let mut file = File::create(path.clone()).context(format!("failed to create path ({:?})", path.display()))?; - let res = match ue_rs::download_and_hash(&client, self.url.clone(), &mut file).await { + let res = match ue_rs::download_and_hash(client, self.url.clone(), &mut file).await { Ok(ok) => ok, Err(err) => { error!("Downloading failed with error {}", err); @@ -237,6 +196,50 @@ impl<'a> Package<'a> { } } +fn hash_on_disk_sha256(path: &Path, maxlen: Option) -> Result> { + use sha2::{Sha256, Digest}; + + let file = File::open(path).context(format!("failed to open path({:?})", path.display()))?; + let mut hasher = Sha256::new(); + + let filelen = file.metadata().unwrap().len() as usize; + + let mut maxlen_to_read: usize = match maxlen { + Some(len) => { + if filelen < len { + filelen + } else { + len + } + } + None => filelen, + }; + + const CHUNKLEN: usize = 10485760; // 10M + + let mut freader = BufReader::new(file); + let mut chunklen: usize; + + freader.seek(SeekFrom::Start(0)).context("failed to seek(0)".to_string())?; + while maxlen_to_read > 0 { + if maxlen_to_read < CHUNKLEN { + chunklen = maxlen_to_read; + } else { + chunklen = CHUNKLEN; + } + + let mut databuf = vec![0u8; chunklen]; + + freader.read_exact(&mut databuf).context(format!("failed to read_exact(chunklen {:?})", chunklen))?; + + maxlen_to_read -= chunklen; + + hasher.update(&databuf); + } + + Ok(omaha::Hash::from_bytes(hasher.finalize().into())) +} + #[rustfmt::skip] fn get_pkgs_to_download<'a>(resp: &'a omaha::Response, glob_set: &GlobSet) -> Result>> { @@ -282,6 +285,55 @@ fn get_pkgs_to_download<'a>(resp: &'a omaha::Response, glob_set: &GlobSet) Ok(to_download) } +// Read data from remote URL into File +async fn fetch_url_to_file<'a, U>(path: &'a Path, input_url: U, client: &'a Client) -> Result> +where + U: reqwest::IntoUrl + From + std::clone::Clone + std::fmt::Debug, + Url: From, +{ + let mut file = File::create(path).context(format!("failed to create path ({:?})", path.display()))?; + + let _ = match ue_rs::download_and_hash(client, input_url.clone(), &mut file).await { + Ok(ok) => ok, + Err(err) => { + bail!("unable to download data(url {:?}), err {:?}", input_url, err); + } + }; + + Ok(Package { + name: Cow::Borrowed(path.file_name().unwrap().to_str().unwrap()), + hash: hash_on_disk_sha256(path, None)?, + size: FileSize::from_bytes(file.metadata().unwrap().len() as usize), + url: input_url.into(), + status: PackageStatus::Unverified, + }) +} + +async fn do_download_verify(pkg: &mut Package<'_>, output_dir: &Path, unverified_dir: &Path, pubkey_file: &str, client: Client) -> Result<()> { + pkg.check_download(unverified_dir)?; + + match pkg.download(unverified_dir, &client).await { + Ok(_) => (), + _ => bail!("unable to download \"{}\"", pkg.name), + }; + + // Unverified payload is stored in e.g. "output_dir/.unverified/oem.gz". + // Verified payload is stored in e.g. "output_dir/oem.raw". + let pkg_unverified = unverified_dir.join(&*pkg.name); + let pkg_verified = output_dir.join(pkg_unverified.with_extension("raw").file_name().unwrap_or_default()); + + match pkg.verify_signature_on_disk(&pkg_unverified, pubkey_file) { + Ok(datablobspath) => { + // write extracted data into the final data. + fs::rename(datablobspath, pkg_verified.clone())?; + debug!("data blobs written into file {:?}", pkg_verified); + } + _ => bail!("unable to verify signature \"{}\"", pkg.name), + }; + + Ok(()) +} + #[derive(FromArgs, Debug)] /// Parse an update-engine Omaha XML response to extract sysext images, then download and verify /// their signatures. @@ -292,7 +344,11 @@ struct Args { /// path to the Omaha XML file, or - to read from stdin #[argh(option, short = 'i')] - input_xml: String, + input_xml: Option, + + /// URL to fetch remote update payload + #[argh(option, short = 'u')] + payload_url: Option, /// path to the public key file #[argh(option, short = 'p')] @@ -325,14 +381,6 @@ async fn main() -> Result<(), Box> { let glob_set = args.image_match_glob_set()?; - let response_text = match &*args.input_xml { - "-" => io::read_to_string(io::stdin())?, - path => { - let file = File::open(path)?; - io::read_to_string(file)? - } - }; - let output_dir = Path::new(&*args.output_dir); if !output_dir.try_exists()? { return Err(format!("output directory `{}` does not exist", args.output_dir).into()); @@ -343,6 +391,59 @@ async fn main() -> Result<(), Box> { fs::create_dir_all(&unverified_dir)?; fs::create_dir_all(&temp_dir)?; + // The default policy of reqwest Client supports max 10 attempts on HTTP redirect. + let client = Client::builder().redirect(Policy::default()).build()?; + + // If input_xml exists, simply read it. + // If not, try to read from payload_url. + let res_local = match args.input_xml { + Some(name) => { + if name == "-".to_string() { + Some(io::read_to_string(io::stdin())?) + } else { + let file = File::open(name)?; + Some(io::read_to_string(file)?) + } + } + None => None, + }; + + let mut pkg_fake: Package; + + let response_text = match res_local { + Some(res) => { + if args.payload_url.is_some() { + return Err(format!("Only one of the options can be given, --input-xml or --payload-url.").into()); + } + res + } + None => { + match args.payload_url { + Some(url) => { + let u = Url::parse(&url)?; + let fname = u.path_segments().unwrap().next_back().unwrap(); + + let temp_payload_path = unverified_dir.join(fname); + pkg_fake = fetch_url_to_file(&temp_payload_path, Url::from_str(url.as_str()).unwrap(), &client).await?; + do_download_verify( + &mut pkg_fake, + output_dir, + unverified_dir.as_path(), + args.pubkey_file.as_str(), + client.clone(), + ) + .await?; + + // verify only a fake package, early exit and skip the rest. + return Ok(()); + } + None => return Err(format!("Either --input-xml or --payload-url must be given.").into()), + } + } + }; + + debug!("response_text: {:?}", response_text); + //// // parse response //// @@ -356,30 +457,16 @@ async fn main() -> Result<(), Box> { //// // download //// - // The default policy of reqwest Client supports max 10 attempts on HTTP redirect. - let client = reqwest::Client::builder().redirect(Policy::default()).build()?; for pkg in pkgs_to_dl.iter_mut() { - pkg.check_download(&unverified_dir)?; - - match pkg.download(&unverified_dir, &client).await { - Ok(_) => (), - _ => return Err(format!("unable to download \"{}\"", pkg.name).into()), - }; - - // Unverified payload is stored in e.g. "output_dir/.unverified/oem.gz". - // Verified payload is stored in e.g. "output_dir/oem.raw". - let pkg_unverified = unverified_dir.join(&*pkg.name); - let pkg_verified = output_dir.join(pkg_unverified.with_extension("raw").file_name().unwrap_or_default()); - - match pkg.verify_signature_on_disk(&pkg_unverified, &args.pubkey_file) { - Ok(datablobspath) => { - // write extracted data into the final data. - fs::rename(datablobspath, pkg_verified.clone())?; - debug!("data blobs written into file {:?}", pkg_verified); - } - _ => return Err(format!("unable to verify signature \"{}\"", pkg.name).into()), - }; + do_download_verify( + pkg, + output_dir, + unverified_dir.as_path(), + args.pubkey_file.as_str(), + client.clone(), + ) + .await?; } // clean up data diff --git a/src/lib.rs b/src/lib.rs index 72dbc06..8734a59 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,5 @@ mod download; pub use download::download_and_hash; +pub use download::DownloadResult; pub mod request;