From cb76332176b64736d64a4f18401857597e53d5ea Mon Sep 17 00:00:00 2001 From: Robur Date: Wed, 3 Jan 2024 15:42:52 +0000 Subject: [PATCH] read/unmarshal: avoid exceptions, use Error instead --- lib/tar.ml | 211 ++++++++++++++++++++++++++++++---------------------- lib/tar.mli | 8 +- 2 files changed, 125 insertions(+), 94 deletions(-) diff --git a/lib/tar.ml b/lib/tar.ml index 318f3cd..f5d4bc5 100644 --- a/lib/tar.ml +++ b/lib/tar.ml @@ -15,13 +15,16 @@ * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. *) -type error = [`Eof | `Checksum_mismatch | `Corrupt_pax_header | `Zero_block] +type error = [ `Eof | `Checksum_mismatch | `Corrupt_pax_header | `Zero_block | `Unmarshal of string ] let pp_error ppf = function | `Eof -> Format.fprintf ppf "end of file" | `Checksum_mismatch -> Format.fprintf ppf "checksum mismatch" | `Corrupt_pax_header -> Format.fprintf ppf "corrupt PAX header" | `Zero_block -> Format.fprintf ppf "zero block" + | `Unmarshal e -> Format.fprintf ppf "unmarshal %s" e + +let ( let* ) = Result.bind (** Process and create tar file headers *) module Header = struct @@ -32,38 +35,41 @@ module Header = struct String.(trim (map (function '\000' -> ' ' | x -> x) s)) (** Unmarshal an integer field (stored as 0-padded octal) *) - let unmarshal_int (x: string) : int = + let unmarshal_int x = let tmp = "0o0" ^ (trim_numerical x) in try - int_of_string tmp + Ok (int_of_string tmp) with Failure msg -> - failwith (Printf.sprintf "%s: failed to parse integer %S" msg tmp) + Error (`Unmarshal (Printf.sprintf "%s: failed to parse integer %S" msg tmp)) (** Unmarshal an int64 field (stored as 0-padded octal) *) - let unmarshal_int64 (x: string) : int64 = + let unmarshal_int64 x = let tmp = "0o0" ^ (trim_numerical x) in - Int64.of_string tmp + try + Ok (Int64.of_string tmp) + with Failure msg -> + Error (`Unmarshal (Printf.sprintf "%s: failed to parse int64 %S" msg tmp)) (** Unmarshal a string *) - let unmarshal_string (x: string) : string = + let unmarshal_string x = try let first_0 = String.index x '\000' in - String.sub x 0 first_0 - with - Not_found -> x (* TODO should error *) + Ok (String.sub x 0 first_0) + with Not_found -> + Ok x (** Marshal an integer field of size 'n' *) - let marshal_int (x: int) (n: int) = + let marshal_int x n = let octal = Printf.sprintf "%0*o" (n - 1) x in octal ^ "\000" (* space or NULL allowed *) (** Marshal an int64 field of size 'n' *) - let marshal_int64 (x: int64) (n: int) = + let marshal_int64 x n = let octal = Printf.sprintf "%0*Lo" (n - 1) x in octal ^ "\000" (* space or NULL allowed *) (** Marshal an string field of size 'n' *) - let marshal_string (x: string) (n: int) = + let marshal_string x n = if String.length x < n then let bytes = Bytes.make n '\000' in Bytes.blit_string x 0 bytes 0 (String.length x); @@ -74,11 +80,14 @@ module Header = struct (** Unmarshal a pax Extended Header File time It can contain a ( '.' ) for sub-second granularity, that we ignore. https://pubs.opengroup.org/onlinepubs/9699919799/utilities/pax.html#tag_20_92_13_05 *) - let unmarshal_pax_time (x:string) : int64 = - match String.split_on_char '.' x with - | [seconds] -> Int64.of_string seconds - | [seconds; _subseconds] -> Int64.of_string seconds - | _ -> raise (Failure "Wrong pax Extended Header File Times format") + let unmarshal_pax_time x = + try + match String.split_on_char '.' x with + | [seconds] -> Ok (Int64.of_string seconds) + | [seconds; _subseconds] -> Ok (Int64.of_string seconds) + | _ -> raise (Failure "Wrong pax Extended Header File time format (at most one . allowed)") + with Failure msg -> + Error (`Unmarshal (Printf.sprintf "Failed to parse pax time %S (%s)" x msg)) let hdr_file_name_off = 0 let sizeof_hdr_file_name = 100 @@ -387,7 +396,19 @@ module Header = struct user_id; uname } | None -> extended - let unmarshal ~(global: t option) (c: Cstruct.t) : t = + let decode_int x = + try + Ok (int_of_string x) + with Failure msg -> + Error (`Unmarshal (Printf.sprintf "%s: failed to parse integer %S" msg x)) + + let decode_int64 x = + try + Ok (Int64.of_string x) + with Failure msg -> + Error (`Unmarshal (Printf.sprintf "%s: failed to parse integer %S" msg x)) + + let unmarshal ~(global: t option) c = (* "%d %s=%s\n", , , with constraints that - the cannot contain an equals sign - the is the number of octets of the record, including \n @@ -398,48 +419,59 @@ module Header = struct then None else if Cstruct.get_char buffer i = char then Some i - else loop (i + 1) in - loop 0 in + else loop (i + 1) + in + loop 0 + in let rec loop remaining = if Cstruct.length remaining = 0 - then [] + then Ok [] else begin (* Find the space, then decode the length *) match find remaining ' ' with - | None -> failwith "Failed to decode pax extended header record" + | None -> Error (`Unmarshal "Failed to decode pax extended header record") | Some i -> let length = int_of_string @@ Cstruct.to_string @@ Cstruct.sub remaining 0 i in let record = Cstruct.sub remaining 0 length in let remaining = Cstruct.shift remaining length in begin match find record '=' with - | None -> failwith "Failed to decode pax extended header record" + | None -> Error (`Unmarshal "Failed to decode pax extended header record") | Some j -> let keyword = Cstruct.to_string @@ Cstruct.sub record (i + 1) (j - i - 1) in let v = Cstruct.to_string @@ Cstruct.sub record (j + 1) (Cstruct.length record - j - 2) in - (keyword, v) :: (loop remaining) + let* rem = loop remaining in + Ok ((keyword, v) :: rem) end - end in - let pairs = loop c in + end + in + let* pairs = loop c in let option name f = if List.mem_assoc name pairs - then Some (f (List.assoc name pairs)) - else None in + then + let* v = f (List.assoc name pairs) in + Ok (Some v) + else + Ok None + in (* integers are stored as decimal, not octal here *) - let access_time = option "atime" unmarshal_pax_time in - let charset = option "charset" unmarshal_string in - let comment = option "comment" unmarshal_string in - let group_id = option "gid" int_of_string in - let gname = option "group_name" unmarshal_string in - let header_charset = option "hdrcharset" unmarshal_string in - let link_path = option "linkpath" unmarshal_string in - let mod_time = option "mtime" unmarshal_pax_time in - let path = option "path" unmarshal_string in - let file_size = option "size" Int64.of_string in - let user_id = option "uid" int_of_string in - let uname = option "uname" unmarshal_string in - { access_time; charset; comment; group_id; gname; - header_charset; link_path; mod_time; path; file_size; - user_id; uname } |> merge global + let* access_time = option "atime" unmarshal_pax_time in + let* charset = option "charset" unmarshal_string in + let* comment = option "comment" unmarshal_string in + let* group_id = option "gid" decode_int in + let* gname = option "group_name" unmarshal_string in + let* header_charset = option "hdrcharset" unmarshal_string in + let* link_path = option "linkpath" unmarshal_string in + let* mod_time = option "mtime" unmarshal_pax_time in + let* path = option "path" unmarshal_string in + let* file_size = option "size" decode_int64 in + let* user_id = option "uid" decode_int in + let* uname = option "uname" unmarshal_string in + let g = + { access_time; charset; comment; group_id; gname; + header_charset; link_path; mod_time; path; file_size; + user_id; uname } + in + Ok (merge global g) end @@ -490,12 +522,6 @@ module Header = struct (** A blank header block (two of these in series mark the end of the tar) *) let zero_block = Cstruct.create length - (** [allzeroes buf] is true if [buf] contains only zero bytes *) - let allzeroes buf = - let rec loop i = - (i >= Cstruct.length buf) || (Cstruct.get_uint8 buf i = 0 && (loop (i + 1))) in - loop 0 - (** Pretty-print the header record *) let to_detailed_string (x: t) = let table = [ "file_name", x.file_name; @@ -530,47 +556,48 @@ module Header = struct (** Unmarshal a header block, returning None if it's all zeroes *) let unmarshal ?(extended = Extended.make ()) (c: Cstruct.t) : (t, [>`Zero_block | `Checksum_mismatch]) result = - if allzeroes c then Error `Zero_block + if Cstruct.length c <> length then Error (`Unmarshal "buffer is not of block size") + else if Cstruct.equal zero_block c then Error `Zero_block else - let chksum = get_hdr_chksum c in + let* chksum = get_hdr_chksum c in if checksum c <> chksum then Error `Checksum_mismatch - else let ustar = - let magic = get_hdr_magic c in + else let* ustar = + let* magic = get_hdr_magic c in (* GNU tar and Posix differ in interpretation of the character following ustar. For Posix, it should be '\0' but GNU tar uses ' ' *) - String.length magic >= 5 && (String.sub magic 0 5 = "ustar") in - let prefix = if ustar then get_hdr_prefix c else "" in - let file_name = match extended.Extended.path with - | Some path -> path + Ok (String.length magic >= 5 && (String.sub magic 0 5 = "ustar")) in + let* prefix = if ustar then get_hdr_prefix c else Ok "" in + let* file_name = match extended.Extended.path with + | Some path -> Ok path | None -> - let file_name = get_hdr_file_name c in - if file_name = "" then prefix - else if prefix = "" then file_name - else Filename.concat prefix file_name in - let file_mode = get_hdr_file_mode c in - let user_id = match extended.Extended.user_id with + let* file_name = get_hdr_file_name c in + if file_name = "" then Ok prefix + else if prefix = "" then Ok file_name + else Ok (Filename.concat prefix file_name) in + let* file_mode = get_hdr_file_mode c in + let* user_id = match extended.Extended.user_id with | None -> get_hdr_user_id c - | Some x -> x in - let group_id = match extended.Extended.group_id with + | Some x -> Ok x in + let* group_id = match extended.Extended.group_id with | None -> get_hdr_group_id c - | Some x -> x in - let file_size = match extended.Extended.file_size with + | Some x -> Ok x in + let* file_size = match extended.Extended.file_size with | None -> get_hdr_file_size c - | Some x -> x in - let mod_time = match extended.Extended.mod_time with + | Some x -> Ok x in + let* mod_time = match extended.Extended.mod_time with | None -> get_hdr_mod_time c - | Some x -> x in + | Some x -> Ok x in let link_indicator = Link.of_char (get_hdr_link_indicator c) in - let uname = match extended.Extended.uname with - | None -> if ustar then get_hdr_uname c else "" - | Some x -> x in - let gname = match extended.Extended.gname with - | None -> if ustar then get_hdr_gname c else "" - | Some x -> x in - let devmajor = if ustar then get_hdr_devmajor c else 0 in - let devminor = if ustar then get_hdr_devminor c else 0 in - - let link_name = match extended.Extended.link_path with - | Some link_path -> link_path + let* uname = match extended.Extended.uname with + | None -> if ustar then get_hdr_uname c else Ok "" + | Some x -> Ok x in + let* gname = match extended.Extended.gname with + | None -> if ustar then get_hdr_gname c else Ok "" + | Some x -> Ok x in + let* devmajor = if ustar then get_hdr_devmajor c else Ok 0 in + let* devminor = if ustar then get_hdr_devminor c else Ok 0 in + + let* link_name = match extended.Extended.link_path with + | Some link_path -> Ok link_path | None -> get_hdr_link_name c in Ok (make ~file_mode ~user_id ~group_id ~mod_time ~link_indicator ~link_name ~uname ~gname ~devmajor ~devminor file_name file_size) @@ -667,7 +694,7 @@ module type HEADERREADER = sig type in_channel type 'a io val read : global:Header.Extended.t option -> in_channel -> - (Header.t * Header.Extended.t option, [ `Eof | `Checksum_mismatch | `Corrupt_pax_header ]) result io + (Header.t * Header.Extended.t option, [ `Eof | `Checksum_mismatch | `Corrupt_pax_header | `Unmarshal of string ]) result io end module type HEADERWRITER = sig @@ -684,7 +711,12 @@ module HeaderReader(Async: ASYNC)(Reader: READER with type 'a io = 'a Async.t) = open Reader type in_channel = Reader.in_channel - type 'a io = 'a Async.t + type 'a io = 'a t + + let ( let* ) x f = + match x with + | Ok x -> f x + | Error y -> return (Error y) let fix_link_indicator x = (* For backward compatibility we treat normal files ending in slash as @@ -700,7 +732,7 @@ module HeaderReader(Async: ASYNC)(Reader: READER with type 'a io = 'a Async.t) = else x - let read ~global (ifd: Reader.in_channel) : (Header.t * Header.Extended.t option, [ `Eof | `Checksum_mismatch | `Corrupt_pax_header ]) result t = + let read ~global (ifd: Reader.in_channel) : (Header.t * Header.Extended.t option, [ `Eof | `Checksum_mismatch | `Corrupt_pax_header | `Unmarshal of string ]) result t = (* We might need to read 2 headers at once if we encounter a Pax header *) let buffer = Cstruct.create Header.length in let real_header_buf = Cstruct.create Header.length in @@ -722,7 +754,7 @@ module HeaderReader(Async: ASYNC)(Reader: READER with type 'a io = 'a Async.t) = >>= fun () -> (* unmarshal merges the previous global (if any) with the discovered global (if any) and returns the new global. *) - let global = Header.Extended.unmarshal ~global extra_header_buf in + let* global = Header.Extended.unmarshal ~global extra_header_buf in get_hdr ~next_longname ~next_longlink (Some global) () | Ok x when x.Header.link_indicator = Header.Link.PerFileExtendedHeader -> let extra_header_buf = Cstruct.create (Int64.to_int x.Header.file_size) in @@ -730,12 +762,11 @@ module HeaderReader(Async: ASYNC)(Reader: READER with type 'a io = 'a Async.t) = >>= fun () -> skip ifd (Header.compute_zero_padding_length x) >>= fun () -> - let extended = Header.Extended.unmarshal ~global extra_header_buf in + let* extended = Header.Extended.unmarshal ~global extra_header_buf in really_read ifd real_header_buf >>= fun () -> begin match Header.unmarshal ~extended real_header_buf with | Error _ -> - (* FIXME: Corrupt pax headers *) return (Error `Corrupt_pax_header) | Ok x -> let x = fix_link_indicator x in @@ -769,9 +800,9 @@ module HeaderReader(Async: ASYNC)(Reader: READER with type 'a io = 'a Async.t) = >>= function | Ok x -> return (Ok (x, global)) | Error `Zero_block -> return (Error `Eof) - | Error `Checksum_mismatch as e -> return e + | Error (`Checksum_mismatch | `Unmarshal _) as e -> return e end - | Error `Checksum_mismatch as e -> + | Error (`Checksum_mismatch | `Unmarshal _) as e -> return e in diff --git a/lib/tar.mli b/lib/tar.mli index 306227d..500bed3 100644 --- a/lib/tar.mli +++ b/lib/tar.mli @@ -19,7 +19,7 @@ {e %%VERSION%% - {{:%%PKG_HOMEPAGE%% }homepage}} *) (** The type of errors that may occur. *) -type error = [`Eof | `Checksum_mismatch | `Corrupt_pax_header | `Zero_block] +type error = [`Eof | `Checksum_mismatch | `Corrupt_pax_header | `Zero_block | `Unmarshal of string] (** [pp_error ppf e] pretty prints the error [e] on the formatter [ppf]. *) val pp_error : Format.formatter -> [< error] -> unit @@ -82,7 +82,7 @@ module Header : sig (** Unmarshal a pax Extended Header block. This header block may be preceded by [global] blocks which will override some fields. *) - val unmarshal : global:t option -> Cstruct.t -> t + val unmarshal : global:t option -> Cstruct.t -> (t, [> error ]) result end (** Represents a standard archive (note checksum not stored). *) @@ -123,7 +123,7 @@ module Header : sig (** Unmarshal a header block, returning [None] if it's all zeroes. This header block may be preceded by an [?extended] block which will override some fields. *) - val unmarshal : ?extended:Extended.t -> Cstruct.t -> (t, [`Zero_block | `Checksum_mismatch]) result + val unmarshal : ?extended:Extended.t -> Cstruct.t -> (t, [`Zero_block | `Checksum_mismatch | `Unmarshal of string]) result (** Marshal a header block, computing and inserting the checksum. *) val marshal : ?level:compatibility -> Cstruct.t -> t -> unit @@ -168,7 +168,7 @@ module type HEADERREADER = sig @param global Holds the current global pax extended header, if any. Needs to be given to the next call to [read]. *) val read : global:Header.Extended.t option -> in_channel -> - (Header.t * Header.Extended.t option, [ `Eof | `Checksum_mismatch | `Corrupt_pax_header ]) result io + (Header.t * Header.Extended.t option, [ `Eof | `Checksum_mismatch | `Corrupt_pax_header | `Unmarshal of string ]) result io end module type HEADERWRITER = sig