Skip to content

Commit

Permalink
Remove Checksum_mismatch exception, use result
Browse files Browse the repository at this point in the history
  • Loading branch information
robur-team committed Jan 3, 2024
1 parent cf20a20 commit 81943c3
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 44 deletions.
6 changes: 5 additions & 1 deletion bin/otar.ml
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,11 @@ let list filename =
let to_skip = Tar.Header.(Int64.to_int (to_sectors hdr) * length) in
Tar_gz.skip ic to_skip ;
go global ()
| Error `Eof -> () in
| Error `Eof -> ()
| Error e ->
Format.eprintf "Error listing archive: %a\n%!" Tar.pp_error e;
exit 2
in
go None ()

let () = match Sys.argv with
Expand Down
53 changes: 30 additions & 23 deletions lib/tar.ml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,14 @@
* OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
*)

type error = [`Eof | `Checksum_mismatch | `Corrupt_pax_header | `Zero_block]

let pp_error ppf = function
| `Eof -> Format.fprintf ppf "end of file"
| `Checksum_mismatch -> Format.fprintf ppf "checksum mismatch"
| `Corrupt_pax_header -> Format.fprintf ppf "corrupt PAX header"
| `Zero_block -> Format.fprintf ppf "zero block"

(** Process and create tar file headers *)
module Header = struct
(** Map of field name -> (start offset, length) taken from wikipedia:
Expand Down Expand Up @@ -500,9 +508,6 @@ module Header = struct
"link_name", x.link_name ] in
"{\n\t" ^ (String.concat "\n\t" (List.map (fun (k, v) -> k ^ ": " ^ v) table)) ^ "}"

(** Thrown when unmarshalling a header if the checksums don't match *)
exception Checksum_mismatch

(** From an already-marshalled block, compute what the checksum should be *)
let checksum (x: Cstruct.t) : int64 =
(* Sum of all the byte values of the header with the checksum field taken
Expand All @@ -523,11 +528,12 @@ module Header = struct
Int64.of_int !result

(** Unmarshal a header block, returning None if it's all zeroes *)
let unmarshal ?(extended = Extended.make ()) (c: Cstruct.t) : t option =
if allzeroes c then None
let unmarshal ?(extended = Extended.make ()) (c: Cstruct.t)
: (t, [>`Zero_block | `Checksum_mismatch]) result =
if allzeroes c then Error `Zero_block
else
let chksum = get_hdr_chksum c in
if checksum c <> chksum then raise Checksum_mismatch
if checksum c <> chksum then Error `Checksum_mismatch
else let ustar =
let magic = get_hdr_magic c in
(* GNU tar and Posix differ in interpretation of the character following ustar. For Posix, it should be '\0' but GNU tar uses ' ' *)
Expand Down Expand Up @@ -566,7 +572,7 @@ module Header = struct
let link_name = match extended.Extended.link_path with
| Some link_path -> link_path
| None -> get_hdr_link_name c in
Some (make ~file_mode ~user_id ~group_id ~mod_time ~link_indicator
Ok (make ~file_mode ~user_id ~group_id ~mod_time ~link_indicator
~link_name ~uname ~gname ~devmajor ~devminor file_name file_size)

(** Marshal a header block, computing and inserting the checksum *)
Expand Down Expand Up @@ -664,7 +670,7 @@ module type HEADERREADER = sig
type in_channel
type 'a io
val read : global:Header.Extended.t option -> in_channel ->
(Header.t * Header.Extended.t option, [ `Eof ]) result io
(Header.t * Header.Extended.t option, [ `Eof | `Checksum_mismatch | `Corrupt_pax_header ]) result io
end

module type HEADERWRITER = sig
Expand Down Expand Up @@ -697,23 +703,21 @@ module HeaderReader(Async: ASYNC)(Reader: READER with type 'a io = 'a Async.t) =
else
x

let read ~global (ifd: Reader.in_channel) : (Header.t * Header.Extended.t option, [ `Eof ]) result t =
let read ~global (ifd: Reader.in_channel) : (Header.t * Header.Extended.t option, [ `Eof | `Checksum_mismatch | `Corrupt_pax_header ]) result t =
(* We might need to read 2 headers at once if we encounter a Pax header *)
let buffer = Cstruct.create Header.length in
let real_header_buf = Cstruct.create Header.length in

let next_block global () =
really_read ifd buffer
>>= fun () ->
match Header.unmarshal ?extended:global buffer with
| None -> return None
| Some hdr -> return (Some hdr)
return (Header.unmarshal ?extended:global buffer)
in

let rec get_hdr ~next_longname ~next_longlink global () : (Header.t * Header.Extended.t option, [> `Eof ]) result t =
let rec get_hdr ~next_longname ~next_longlink global () : (Header.t * Header.Extended.t option, [> `Eof | `Checksum_mismatch | `Corrupt_pax_header ]) result t =
next_block global ()
>>= function
| Some x when x.Header.link_indicator = Header.Link.GlobalExtendedHeader ->
| Ok x when x.Header.link_indicator = Header.Link.GlobalExtendedHeader ->
let extra_header_buf = Cstruct.create (Int64.to_int x.Header.file_size) in
really_read ifd extra_header_buf
>>= fun () ->
Expand All @@ -723,7 +727,7 @@ module HeaderReader(Async: ASYNC)(Reader: READER with type 'a io = 'a Async.t) =
discovered global (if any) and returns the new global. *)
let global = Header.Extended.unmarshal ~global extra_header_buf in
get_hdr ~next_longname ~next_longlink (Some global) ()
| Some x when x.Header.link_indicator = Header.Link.PerFileExtendedHeader ->
| Ok x when x.Header.link_indicator = Header.Link.PerFileExtendedHeader ->
let extra_header_buf = Cstruct.create (Int64.to_int x.Header.file_size) in
really_read ifd extra_header_buf
>>= fun () ->
Expand All @@ -733,14 +737,14 @@ module HeaderReader(Async: ASYNC)(Reader: READER with type 'a io = 'a Async.t) =
really_read ifd real_header_buf
>>= fun () ->
begin match Header.unmarshal ~extended real_header_buf with
| None ->
| Error _ ->
(* FIXME: Corrupt pax headers *)
return (Error `Eof)
| Some x ->
return (Error `Corrupt_pax_header)
| Ok x ->
let x = fix_link_indicator x in
return (Ok (x, global))
end
| Some ({ Header.link_indicator = Header.Link.LongLink | Header.Link.LongName; _ } as x) when x.Header.file_name = longlink ->
| Ok ({ Header.link_indicator = Header.Link.LongLink | Header.Link.LongName; _ } as x) when x.Header.file_name = longlink ->
let extra_header_buf = Cstruct.create (Int64.to_int x.Header.file_size) in
really_read ifd extra_header_buf
>>= fun () ->
Expand All @@ -750,7 +754,7 @@ module HeaderReader(Async: ASYNC)(Reader: READER with type 'a io = 'a Async.t) =
let next_longlink = if x.Header.link_indicator = Header.Link.LongLink then Some name else next_longlink in
let next_longname = if x.Header.link_indicator = Header.Link.LongName then Some name else next_longname in
get_hdr ~next_longname ~next_longlink global ()
| Some x ->
| Ok x ->
(* XXX: unclear how/if pax headers should interact with gnu extensions *)
let x = match next_longname with
| None -> x
Expand All @@ -762,13 +766,16 @@ module HeaderReader(Async: ASYNC)(Reader: READER with type 'a io = 'a Async.t) =
in
let x = fix_link_indicator x in
return (Ok (x, global))
| None ->
| Error `Zero_block ->
begin
next_block global ()
>>= function
| Some x -> return (Ok (x, global))
| None -> return (Error `Eof)
| Ok x -> return (Ok (x, global))
| Error `Zero_block -> return (Error `Eof)
| Error `Checksum_mismatch as e -> return e
end
| Error `Checksum_mismatch as e ->
return e
in

get_hdr ~next_longname:None ~next_longlink:None global ()
Expand Down
13 changes: 8 additions & 5 deletions lib/tar.mli
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@
{e %%VERSION%% - {{:%%PKG_HOMEPAGE%% }homepage}} *)

(** The type of errors that may occur. *)
type error = [`Eof | `Checksum_mismatch | `Corrupt_pax_header | `Zero_block]

(** [pp_error ppf e] pretty prints the error [e] on the formatter [ppf]. *)
val pp_error : Format.formatter -> [< error] -> unit

module Header : sig
(** Process and create tar file headers. *)

Expand Down Expand Up @@ -114,13 +120,10 @@ module Header : sig
(** Pretty-print the header record. *)
val to_detailed_string : t -> string

(** Thrown when unmarshalling a header if the checksums don't match. *)
exception Checksum_mismatch

(** Unmarshal a header block, returning [None] if it's all zeroes.
This header block may be preceded by an [?extended] block which
will override some fields. *)
val unmarshal : ?extended:Extended.t -> Cstruct.t -> t option
val unmarshal : ?extended:Extended.t -> Cstruct.t -> (t, [`Zero_block | `Checksum_mismatch]) result

(** Marshal a header block, computing and inserting the checksum. *)
val marshal : ?level:compatibility -> Cstruct.t -> t -> unit
Expand Down Expand Up @@ -165,7 +168,7 @@ module type HEADERREADER = sig
@param global Holds the current global pax extended header, if
any. Needs to be given to the next call to [read]. *)
val read : global:Header.Extended.t option -> in_channel ->
(Header.t * Header.Extended.t option, [ `Eof ]) result io
(Header.t * Header.Extended.t option, [ `Eof | `Checksum_mismatch | `Corrupt_pax_header ]) result io
end

module type HEADERWRITER = sig
Expand Down
12 changes: 6 additions & 6 deletions lib_test/global_extended_headers_test.ml
Original file line number Diff line number Diff line change
Expand Up @@ -90,34 +90,34 @@ let use_global_extended_headers _test_ctxt =
Alcotest.(check int) "expected user" 1000 hdr.Tar.Header.user_id;
let to_skip = Tar.Header.(Int64.to_int (to_sectors hdr) * length) in
Reader.skip cin to_skip;
| Error `Eof -> failwith "Couldn't read header" );
| Error _ -> failwith "Couldn't read header" );
( match HR.read ~global:!global cin with
| Ok (hdr, global') ->
Alcotest.check header "expected global header" (Some g0) global';
global := global';
Alcotest.(check int) "expected user" 2000 hdr.Tar.Header.user_id;
let to_skip = Tar.Header.(Int64.to_int (to_sectors hdr) * length) in
Reader.skip cin to_skip;
| Error `Eof -> failwith "Couldn't read header" );
| Error _ -> failwith "Couldn't read header" );
( match HR.read ~global:!global cin with
| Ok (hdr, global') ->
Alcotest.check header "expected global header" (Some g0) global';
global := global';
Alcotest.(check int) "expected user" 1000 hdr.Tar.Header.user_id;
let to_skip = Tar.Header.(Int64.to_int (to_sectors hdr) * length) in
Reader.skip cin to_skip;
| Error `Eof -> failwith "Couldn't read header" );
| Error _ -> failwith "Couldn't read header" );
( match HR.read ~global:!global cin with
| Ok (hdr, global') ->
Alcotest.check header "expected global header" (Some g1) global';
global := global';
Alcotest.(check int) "expected user" 3000 hdr.Tar.Header.user_id;
let to_skip = Tar.Header.(Int64.to_int (to_sectors hdr) * length) in
Reader.skip cin to_skip;
| Error `Eof -> failwith "Couldn't read header" );
| Error _ -> failwith "Couldn't read header" );
( match HR.read ~global:!global cin with
| Ok _ -> failwith "Should have found EOF"
| Error `Eof -> () );
| Error `Eof -> ()
| _ -> failwith "Should have found EOF");
()

let () =
Expand Down
17 changes: 8 additions & 9 deletions lib_test/parse_test.ml
Original file line number Diff line number Diff line change
Expand Up @@ -41,17 +41,17 @@ let list fd =
loop global (hdr :: acc)
| Error `Eof ->
List.rev acc
| Error e -> Alcotest.failf "unexpected error: %a" Tar.pp_error e
in
let r = loop None [] in
List.iter (fun h -> print_endline h.Tar.Header.file_name) r;
r

let cstruct = Alcotest.testable
(fun f x -> Fmt.pf f "%a" Cstruct.hexdump_pp x)
Cstruct.equal
let cstruct = Alcotest.testable Cstruct.hexdump_pp Cstruct.equal
let pp_header f x = Fmt.pf f "%s" (Tar.Header.to_detailed_string x)
let header =
Alcotest.testable (fun f x -> Fmt.pf f "%a" (Fmt.option pp_header) x) ( = )
let header = Alcotest.testable pp_header ( = )

let error = Alcotest.testable Tar.pp_error ( = )

let link = Alcotest.testable (Fmt.of_to_string Tar.Header.Link.to_string) ( = )

Expand All @@ -65,7 +65,7 @@ let header () =
for i = 0 to Tar.Header.length - 1 do Cstruct.set_uint8 c' i 0 done;
Tar.Header.marshal c' h;
Alcotest.(check cstruct) "marshalled headers" c c';
Alcotest.(check header) "unmarshalled headers" (Some h) (Tar.Header.unmarshal c');
Alcotest.(check (result header error)) "unmarshalled headers" (Ok h) (Tar.Header.unmarshal c');
Alcotest.(check int) "zero padding length" 302 (Tar.Header.compute_zero_padding_length h)

let set_difference a b = List.filter (fun a -> not(List.mem a b)) a
Expand Down Expand Up @@ -180,7 +180,7 @@ let can_list_pax_implicit_dir () =
Fun.protect ~finally:(fun () -> Unix.close fd)
(fun () ->
match Tar_unix.HeaderReader.read ~global:None fd with
| Error `Eof -> Alcotest.fail "unexpected end of file"
| Error e -> Alcotest.failf "unexpected error: %a" Tar.pp_error e
| Ok (hdr, _global) ->
Alcotest.(check link) "is directory" Tar.Header.Link.Directory hdr.link_indicator;
Alcotest.(check string) "filename is patched" "clearly/a/directory/" hdr.file_name)
Expand All @@ -204,8 +204,7 @@ let can_list_longlink_implicit_dir () =
| Ok (hdr, _global) ->
Alcotest.(check link) "is directory" Tar.Header.Link.Directory hdr.link_indicator;
Alcotest.(check string) "filename is patched" "some/long/name/for/a/directory/" hdr.file_name
| Error `Eof ->
Alcotest.fail "reached end of file")
| Error e -> Alcotest.failf "unexpected error: %a" Tar.pp_error e)


let starts_with ~prefix s =
Expand Down
2 changes: 2 additions & 0 deletions mirage/tar_mirage.ml
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,8 @@ module Make_KV_RO (BLOCK : Mirage_block.S) = struct
let rec loop ~global map =
HR.read ~global in_channel >>= function
| Error `Eof -> Lwt.return map
| Error e ->
Format.kasprintf failwith "Error reading archive: %a" Tar.pp_error e
| Ok (tar, global) ->
let filename = trim_slash tar.Tar.Header.file_name in
let map =
Expand Down

0 comments on commit 81943c3

Please sign in to comment.