use nom::bits::{bits, streaming::take}; use nom::combinator::map_res; use nom::error::{Error, ErrorKind}; use nom::multi::count; use nom::number::complete::{be_u16, be_u32, u8}; use nom::sequence::tuple; use nom::{Err, IResult}; use std::net::Ipv4Addr; const VRRP_REQUIRED_VERSION: u8 = 2; const VRRP_REQUIRED_TYPE: u8 = 1; // Advertisement #[derive(Debug, Clone, PartialEq)] pub enum VRRPv2Error { VRRPv2ParseError, } type NomError<'a> = nom::Err>; impl From> for VRRPv2Error { fn from(_: NomError) -> Self { Self::VRRPv2ParseError } } #[derive(Debug, PartialEq)] pub enum VRRPVersion { V2, } #[derive(Debug, PartialEq)] pub enum VRRPv2Type { VRRPv2Advertisement, } #[derive(Debug, PartialEq)] pub enum VRRPv2AuthType { VRRPv2AuthNoAuth = 0x00, VRRPv2AuthReserved1 = 0x01, VRRPv2AuthReserved2 = 0x02, } #[derive(Debug, PartialEq)] pub struct VRRPv2 { pub version: VRRPVersion, pub type_: VRRPv2Type, pub virtual_router_id: u8, pub priority: u8, pub count_ip_addrs: u8, pub auth_type: VRRPv2AuthType, pub advertisement_interval: u8, pub checksum: u16, pub ip_addrs: Vec, } fn two_nibbles(input: &[u8]) -> IResult<&[u8], (u8, u8)> { bits::<_, _, Error<(&[u8], usize)>, _, _>(tuple((take(4usize), take(4usize))))(input) } fn parse_version_type(input: &[u8]) -> IResult<&[u8], (VRRPVersion, VRRPv2Type)> { let (input, pair) = two_nibbles(input)?; match pair { (VRRP_REQUIRED_VERSION, VRRP_REQUIRED_TYPE) => { Ok((input, (VRRPVersion::V2, VRRPv2Type::VRRPv2Advertisement))) } _ => Err(Err::Error(Error::new(input, ErrorKind::Alt))), } } fn parse_auth_type(input: &[u8]) -> IResult<&[u8], VRRPv2AuthType> { map_res(u8, |auth_type| { Ok(match auth_type { 0 => VRRPv2AuthType::VRRPv2AuthNoAuth, 1 => VRRPv2AuthType::VRRPv2AuthReserved1, 2 => VRRPv2AuthType::VRRPv2AuthReserved2, _ => return Err(Err::Error(Error::new(input, ErrorKind::Alt))), }) })(input) } // Nightly has a nicer array_chunks API to express it more succinctly. // let mut chunks = bytes.array_chunks(2); // let mut sum = chunks.map(u16::from_ne_bytes).map(|b| b as u32).sum::(); // // handle the remainder // if let Some([b]) = chunks.remainder() { // sum += *b as u32 // } // Shadowing can be used to avoid `mut`... // let sum =...; // let sum = (sum & 0xffff) + (sum >> 16); // let sum = (sum & 0xffff) + (sum >> 16); // manually un-rolling while loop since it's needed atmost twice for an u32. fn validate_checksum(bytes: &[u8]) -> bool { let mut sum: u32 = bytes.chunks(2).fold(0, |acc: u32, x| { acc + u32::from(u16::from_ne_bytes(x.try_into().unwrap())) }); while (sum >> 16) > 0 { sum = (sum & 0xffff) + (sum >> 16); } let checksum = !(sum as u16); checksum == 0 } fn parse(input: &[u8]) -> IResult<&[u8], VRRPv2> { if !validate_checksum(input) { return Err(Err::Error(Error::new(input, ErrorKind::Alt))); } let (input, (version, type_)) = parse_version_type(input)?; let (input, virtual_router_id) = u8(input)?; let (input, priority) = u8(input)?; let (input, count_ip_addrs) = u8(input)?; let (input, auth_type) = parse_auth_type(input)?; let (input, advertisement_interval) = u8(input)?; let (input, checksum) = be_u16(input)?; let (input, xs) = count(be_u32, usize::from(count_ip_addrs))(input)?; let ip_addrs = xs.into_iter().map(Ipv4Addr::from).collect(); Ok(( input, VRRPv2 { version, type_, virtual_router_id, priority, count_ip_addrs, auth_type, advertisement_interval, checksum, ip_addrs, }, )) } pub fn from_bytes(bytes: &[u8]) -> Result { match parse(bytes) { Ok((_, v)) => Ok(v), Err(e) => Err(e.into()), } } #[test] fn test_standard_bytes() { let bytes = [ 0x21, 0x01, 0x64, 0x01, 0x00, 0x01, 0xba, 0x52, 0xc0, 0xa8, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, ]; let got = from_bytes(&bytes).unwrap(); let expected = VRRPv2 { version: VRRPVersion::V2, type_: VRRPv2Type::VRRPv2Advertisement, virtual_router_id: 1, priority: 100, count_ip_addrs: 1, auth_type: VRRPv2AuthType::VRRPv2AuthNoAuth, checksum: 47698, advertisement_interval: 1, ip_addrs: vec![Ipv4Addr::from([192, 168, 0, 1])], }; assert_eq!(got, expected); } #[test] fn test_incomplete_bytes() { let bytes = [0x21, 0x01]; let got = from_bytes(&bytes); assert_eq!(got.is_err(), true); assert_eq!(got.err(), Some(VRRPv2Error::VRRPv2ParseError)); } #[test] fn test_invalid_version_type() { let bytes = [ 0x00, 0x01, 0x64, 0x01, 0x00, 0x01, 0xba, 0x52, 0xc0, 0xa8, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, ]; let got = from_bytes(&bytes); assert_eq!(got.is_err(), true); assert_eq!(got.err(), Some(VRRPv2Error::VRRPv2ParseError)); } #[test] fn test_invalid_auth_type() { let bytes = [ 0x21, 0x01, 0x64, 0x01, 0x03, 0x01, 0xba, 0x52, 0xc0, 0xa8, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, ]; let got = from_bytes(&bytes); assert_eq!(got.is_err(), true); assert_eq!(got.err(), Some(VRRPv2Error::VRRPv2ParseError)); } #[test] fn test_invalid_checksum() { let bytes = [ 0x21, 0x01, 0x64, 0x01, 0x00, 0x01, 0xbb, 0x52, 0xc0, 0xa8, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, ]; let got = from_bytes(&bytes); assert_eq!(got.is_err(), true); assert_eq!(got.err(), Some(VRRPv2Error::VRRPv2ParseError)); }