diff --git a/CHANGELOG.md b/CHANGELOG.md index bc90e11..bf84317 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [unreleased] +### Changed + +- Remove `num-traits` dependency +- Update all dependencies +- Use `tokio` in examples + ## [0.3.6] - 2022-12-16 ### Changed diff --git a/Cargo.toml b/Cargo.toml index 775019a..7506397 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,7 +15,6 @@ repository = "https://github.com/oblique/async-tftp-rs" [dependencies] bytes = "1.5.0" log = "0.4.20" -nom = "7.1.3" thiserror = "1.0.48" async-executor = "1.5.1" diff --git a/src/error.rs b/src/error.rs index 03b31b9..af3220d 100644 --- a/src/error.rs +++ b/src/error.rs @@ -24,9 +24,3 @@ pub enum Error { #[error("Max send retries reached (peer: {0}, block id: {1})")] MaxSendRetriesReached(std::net::SocketAddr, u16), } - -impl From>> for Error { - fn from(_error: nom::Err>) -> Error { - Error::InvalidPacket - } -} diff --git a/src/packet.rs b/src/packet.rs index 4cd0fc6..3d679da 100644 --- a/src/packet.rs +++ b/src/packet.rs @@ -1,4 +1,4 @@ -///! Packet definitions. +//! Packet definitions. use bytes::{BufMut, Bytes, BytesMut}; use std::convert::From; use std::io; diff --git a/src/parse.rs b/src/parse.rs index 0bd6f9f..3f14150 100644 --- a/src/parse.rs +++ b/src/parse.rs @@ -1,160 +1,192 @@ -use nom::branch::alt; -use nom::bytes::complete::{tag, tag_no_case, take_till}; -use nom::combinator::{map, map_opt, map_res, rest}; -use nom::multi::many0; -use nom::number::complete::be_u16; -use nom::sequence::tuple; -use nom::IResult; +use std::convert::TryInto; use std::str::{self, FromStr}; -use crate::error::Result; -use crate::packet::{self, *}; - -#[derive(Debug)] -enum Opt<'a> { - BlkSize(u16), - Timeout(u8), - Tsize(u64), - Invalid(&'a str, &'a str), -} +use crate::error::{Error, Result}; +use crate::packet::{ + Error as PacketError, Mode, Opts, Packet, PacketType, RwReq, +}; pub(crate) fn parse_packet(input: &[u8]) -> Result { - let (rest, packet) = match parse_packet_type(input)? { - (data, PacketType::Rrq) => parse_rrq(data)?, - (data, PacketType::Wrq) => parse_wrq(data)?, - (data, PacketType::Data) => parse_data(data)?, - (data, PacketType::Ack) => parse_ack(data)?, - (data, PacketType::Error) => parse_error(data)?, - (data, PacketType::OAck) => parse_oack(data)?, - }; - - if rest.is_empty() { - Ok(packet) - } else { - Err(crate::Error::InvalidPacket) - } -} - -fn nul_str(input: &[u8]) -> IResult<&[u8], &str> { - map_res( - tuple((take_till(|c| c == b'\0'), tag(b"\0"))), - |(s, _): (&[u8], _)| str::from_utf8(s), - )(input) + parse_packet_type(input) + .and_then(|(packet_type, data)| match packet_type { + PacketType::Rrq => parse_rrq(data), + PacketType::Wrq => parse_wrq(data), + PacketType::Data => parse_data(data), + PacketType::Ack => parse_ack(data), + PacketType::Error => parse_error(data), + PacketType::OAck => parse_oack(data), + }) + .ok_or(Error::InvalidPacket) } -fn parse_packet_type(input: &[u8]) -> IResult<&[u8], PacketType> { - map_opt(be_u16, PacketType::from_u16)(input) +fn parse_nul_str(input: &[u8]) -> Option<(&str, &[u8])> { + let pos = input.iter().position(|c| *c == b'\0')?; + let s = str::from_utf8(&input[..pos]).ok()?; + Some((s, &input[pos + 1..])) } -fn parse_mode(input: &[u8]) -> IResult<&[u8], Mode> { - alt(( - map(tag_no_case(b"netascii\0"), |_| Mode::Netascii), - map(tag_no_case(b"octet\0"), |_| Mode::Octet), - map(tag_no_case(b"mail\0"), |_| Mode::Mail), - ))(input) +fn parse_u16_be(input: &[u8]) -> Option<(u16, &[u8])> { + let bytes = input.get(..2)?; + let num = u16::from_be_bytes(bytes.try_into().ok()?); + Some((num, &input[2..])) } -fn parse_opt_blksize(input: &[u8]) -> IResult<&[u8], Opt> { - map_opt(tuple((tag_no_case(b"blksize\0"), nul_str)), |(_, n): (_, &str)| { - u16::from_str(n) - .ok() - .filter(|n| *n >= 8 && *n <= 65464) - .map(Opt::BlkSize) - })(input) +fn parse_packet_type(input: &[u8]) -> Option<(PacketType, &[u8])> { + let (num, rest) = parse_u16_be(input)?; + let val = PacketType::from_u16(num)?; + Some((val, rest)) } -fn parse_opt_timeout(input: &[u8]) -> IResult<&[u8], Opt> { - map_opt(tuple((tag_no_case(b"timeout\0"), nul_str)), |(_, n): (_, &str)| { - u8::from_str(n).ok().filter(|n| *n >= 1).map(Opt::Timeout) - })(input) -} +fn parse_mode(input: &[u8]) -> Option<(Mode, &[u8])> { + let (s, rest) = parse_nul_str(input)?; -fn parse_opt_tsize(input: &[u8]) -> IResult<&[u8], Opt> { - map_opt(tuple((tag_no_case(b"tsize\0"), nul_str)), |(_, n): (_, &str)| { - u64::from_str(n).ok().map(Opt::Tsize) - })(input) -} + let mode = if s.eq_ignore_ascii_case("netascii") { + Mode::Netascii + } else if s.eq_ignore_ascii_case("octet") { + Mode::Octet + } else if s.eq_ignore_ascii_case("mail") { + Mode::Mail + } else { + return None; + }; -pub(crate) fn parse_opts(input: &[u8]) -> IResult<&[u8], Opts> { - many0(alt(( - parse_opt_blksize, - parse_opt_timeout, - parse_opt_tsize, - map(tuple((nul_str, nul_str)), |(k, v)| Opt::Invalid(k, v)), - )))(input) - .map(|(i, opt_vec)| (i, to_opts(opt_vec))) + Some((mode, rest)) } -fn to_opts(opt_vec: Vec) -> Opts { +pub(crate) fn parse_opts(mut input: &[u8]) -> Option { let mut opts = Opts::default(); - for opt in opt_vec { - match opt { - Opt::BlkSize(size) => { - if opts.block_size.is_none() { - opts.block_size.replace(size); + while !input.is_empty() { + let (name, rest) = parse_nul_str(input)?; + let (val, rest) = parse_nul_str(rest)?; + + if name.eq_ignore_ascii_case("blksize") { + if let Ok(val) = u16::from_str(val) { + if val >= 8 && val <= 65464 { + opts.block_size = Some(val); } } - Opt::Timeout(timeout) => { - if opts.timeout.is_none() { - opts.timeout.replace(timeout); + } else if name.eq_ignore_ascii_case("timeout") { + if let Ok(val) = u8::from_str(val) { + if val >= 1 { + opts.timeout = Some(val); } } - Opt::Tsize(size) => { - if opts.transfer_size.is_none() { - opts.transfer_size.replace(size); - } + } else if name.eq_ignore_ascii_case("tsize") { + if let Ok(val) = u64::from_str(val) { + opts.transfer_size = Some(val); } - Opt::Invalid(..) => {} } + + input = rest; } - opts + Some(opts) +} + +fn parse_rrq(input: &[u8]) -> Option { + let (filename, rest) = parse_nul_str(input)?; + let (mode, rest) = parse_mode(rest)?; + let opts = parse_opts(rest)?; + + Some(Packet::Rrq(RwReq { + filename: filename.to_owned(), + mode, + opts, + })) } -fn parse_rrq(input: &[u8]) -> IResult<&[u8], Packet> { - let (input, (filename, mode, opts)) = - tuple((nul_str, parse_mode, parse_opts))(input)?; - - Ok(( - input, - Packet::Rrq(RwReq { - filename: filename.to_owned(), - mode, - opts, - }), - )) +fn parse_wrq(input: &[u8]) -> Option { + let (filename, rest) = parse_nul_str(input)?; + let (mode, rest) = parse_mode(rest)?; + let opts = parse_opts(rest)?; + + Some(Packet::Wrq(RwReq { + filename: filename.to_owned(), + mode, + opts, + })) } -fn parse_wrq(input: &[u8]) -> IResult<&[u8], Packet> { - let (input, (filename, mode, opts)) = - tuple((nul_str, parse_mode, parse_opts))(input)?; - - Ok(( - input, - Packet::Wrq(RwReq { - filename: filename.to_owned(), - mode, - opts, - }), - )) +fn parse_data(input: &[u8]) -> Option { + let (block_nr, rest) = parse_u16_be(input)?; + Some(Packet::Data(block_nr, rest)) } -fn parse_data(input: &[u8]) -> IResult<&[u8], Packet> { - tuple((be_u16, rest))(input) - .map(|(i, (block_nr, data))| (i, Packet::Data(block_nr, data))) +fn parse_ack(input: &[u8]) -> Option { + let (block_nr, rest) = parse_u16_be(input)?; + + if !rest.is_empty() { + return None; + } + + Some(Packet::Ack(block_nr)) } -fn parse_ack(input: &[u8]) -> IResult<&[u8], Packet> { - be_u16(input).map(|(i, block_nr)| (i, Packet::Ack(block_nr))) +fn parse_error(input: &[u8]) -> Option { + let (code, rest) = parse_u16_be(input)?; + let (msg, rest) = parse_nul_str(rest)?; + + if !rest.is_empty() { + return None; + } + + Some(Packet::Error(PacketError::from_code(code, Some(msg)))) } -fn parse_error(input: &[u8]) -> IResult<&[u8], Packet> { - tuple((be_u16, nul_str))(input).map(|(i, (code, msg))| { - (i, packet::Error::from_code(code, Some(msg)).into()) - }) +fn parse_oack(input: &[u8]) -> Option { + let opts = parse_opts(input)?; + Some(Packet::OAck(opts)) } -fn parse_oack(input: &[u8]) -> IResult<&[u8], Packet> { - parse_opts(input).map(|(i, opts)| (i, Packet::OAck(opts))) +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn nul_str() { + let (s, rest) = parse_nul_str(b"123\0").unwrap(); + assert_eq!(s, "123"); + assert!(rest.is_empty()); + + let (s, rest) = parse_nul_str(b"123\0\0").unwrap(); + assert_eq!(s, "123"); + assert_eq!(rest, b"\0"); + + let (s1, rest) = parse_nul_str(b"123\0abc\0\xff\xff").unwrap(); + let (s2, rest) = parse_nul_str(rest).unwrap(); + assert_eq!(s1, "123"); + assert_eq!(s2, "abc"); + assert_eq!(rest, b"\xff\xff"); + + let (s1, rest) = parse_nul_str(b"\0\0").unwrap(); + let (s2, rest) = parse_nul_str(rest).unwrap(); + assert_eq!(s1, ""); + assert_eq!(s2, ""); + assert!(rest.is_empty()); + + assert!(parse_nul_str(b"").is_none()); + assert!(parse_nul_str(b"123").is_none()); + assert!(parse_nul_str(b"123\xff\xff\0").is_none()); + } + + #[test] + fn u16_be() { + let (n, rest) = parse_u16_be(b"\x11\x22").unwrap(); + assert_eq!(n, 0x1122); + assert!(rest.is_empty()); + + let (n, rest) = parse_u16_be(b"\x11\x22\x33").unwrap(); + assert_eq!(n, 0x1122); + assert_eq!(rest, b"\x33"); + + let (n1, rest) = parse_u16_be(b"\x11\x22\x33\x44\x55").unwrap(); + let (n2, rest) = parse_u16_be(rest).unwrap(); + assert_eq!(n1, 0x1122); + assert_eq!(n2, 0x3344); + assert_eq!(rest, b"\x55"); + + assert!(parse_u16_be(b"").is_none()); + assert!(parse_u16_be(b"\x11").is_none()); + } } diff --git a/src/tests/packet.rs b/src/tests/packet.rs index 68e4153..07d7f0c 100644 --- a/src/tests/packet.rs +++ b/src/tests/packet.rs @@ -266,7 +266,7 @@ fn check_packet() { #[test] fn check_blksize_boundaries() { - let (_, opts) = parse_opts(b"blksize\07\0").unwrap(); + let opts = parse_opts(b"blksize\07\0").unwrap(); assert_eq!( opts, Opts { @@ -275,7 +275,7 @@ fn check_blksize_boundaries() { } ); - let (_, opts) = parse_opts(b"blksize\08\0").unwrap(); + let opts = parse_opts(b"blksize\08\0").unwrap(); assert_eq!( opts, Opts { @@ -284,7 +284,7 @@ fn check_blksize_boundaries() { } ); - let (_, opts) = parse_opts(b"blksize\065464\0").unwrap(); + let opts = parse_opts(b"blksize\065464\0").unwrap(); assert_eq!( opts, Opts { @@ -293,7 +293,7 @@ fn check_blksize_boundaries() { } ); - let (_, opts) = parse_opts(b"blksize\065465\0").unwrap(); + let opts = parse_opts(b"blksize\065465\0").unwrap(); assert_eq!( opts, Opts { @@ -305,7 +305,7 @@ fn check_blksize_boundaries() { #[test] fn check_timeout_boundaries() { - let (_, opts) = parse_opts(b"timeout\00\0").unwrap(); + let opts = parse_opts(b"timeout\00\0").unwrap(); assert_eq!( opts, Opts { @@ -314,7 +314,7 @@ fn check_timeout_boundaries() { } ); - let (_, opts) = parse_opts(b"timeout\01\0").unwrap(); + let opts = parse_opts(b"timeout\01\0").unwrap(); assert_eq!( opts, Opts { @@ -323,7 +323,7 @@ fn check_timeout_boundaries() { } ); - let (_, opts) = parse_opts(b"timeout\0255\0").unwrap(); + let opts = parse_opts(b"timeout\0255\0").unwrap(); assert_eq!( opts, Opts { @@ -332,7 +332,7 @@ fn check_timeout_boundaries() { } ); - let (_, opts) = parse_opts(b"timeout\0256\0").unwrap(); + let opts = parse_opts(b"timeout\0256\0").unwrap(); assert_eq!( opts, Opts {