use std::io::{self, Cursor, Read, Write};
use std::net::Ipv4Addr;

/// A VRRP version 2 packet.
///
/// Packet format
///
///  0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
/// |Version| Type  | Virtual Rtr ID|   Priority    | Count IP Addrs|
/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
/// |   Auth Type   |   Adver Int   |          Checksum             |
/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
/// |                         IP Address (1)                        |
/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
/// |                            .                                  |
/// |                            .                                  |
/// |                            .                                  |
/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
/// |                         IP Address (n)                        |
/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
/// |                     Authentication Data (1)                   |
/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
/// |                     Authentication Data (2)                   |
/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
#[derive(Debug, PartialEq)]
pub struct VRRPv2 {
    pub virtual_router_id: u8,
    pub priority: u8,
    pub count_ip_addrs: u8,
    pub auth_type: VRRPv2AuthType,
    pub advertisement_interval: u8,
    pub checksum: u16,
    pub ip_addrs: Vec<Ipv4Addr>,
}

#[derive(Debug, PartialEq)]
pub enum VRRPv2Error {
    InvalidAuthType,
    InvalidChecksum,
    InvalidType,
    InvalidVersion,
    ParseError,
}

impl std::fmt::Display for VRRPv2Error {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self {
            Self::InvalidAuthType => write!(f, "Invalid Auth Type"),
            Self::InvalidChecksum => write!(f, "Invalid Checksum"),
            Self::InvalidType => write!(f, "Invalid Type"),
            Self::InvalidVersion => write!(f, "Invalid Version"),
            Self::ParseError => write!(f, "Parse Error"),
        }
    }
}

impl std::error::Error for VRRPv2Error {}

impl From<io::Error> for VRRPv2Error {
    fn from(_err: std::io::Error) -> VRRPv2Error {
        VRRPv2Error::ParseError
    }
}

#[derive(Clone, Copy, Debug, PartialEq)]
pub enum VRRPv2AuthType {
    VRRPv2AuthNoAuth = 0x00,
    VRRPv2AuthReserved1 = 0x01,
    VRRPv2AuthReserved2 = 0x02,
}

trait BytesReader {
    fn read_u8(&mut self) -> io::Result<u8>;
    fn read_u16(&mut self) -> io::Result<u16>;
    fn read_u32(&mut self) -> io::Result<u32>;
}

impl<T: AsRef<[u8]>> BytesReader for Cursor<T> {
    fn read_u8(&mut self) -> io::Result<u8> {
        let mut buffer = [0; 1];
        self.read_exact(&mut buffer)?;
        Ok(u8::from_be_bytes(buffer))
    }

    fn read_u16(&mut self) -> io::Result<u16> {
        let mut buffer = [0; 2];
        self.read_exact(&mut buffer)?;
        Ok(u16::from_be_bytes(buffer))
    }

    fn read_u32(&mut self) -> io::Result<u32> {
        let mut buffer = [0; 4];
        self.read_exact(&mut buffer)?;
        Ok(u32::from_be_bytes(buffer))
    }
}

impl VRRPv2 {
    pub fn to_bytes(&self) -> Result<Vec<u8>, std::io::Error> {
        let sz = (4 + self.count_ip_addrs) * 4;
        let bytes: Vec<u8> = Vec::with_capacity(sz.into());
        let mut wr = Cursor::new(bytes);
        wr.write_all(
            [
                0x21,
                self.virtual_router_id,
                self.priority,
                self.count_ip_addrs,
                self.auth_type as u8,
                self.advertisement_interval,
                0x0, // checksum bits
                0x0, // checksum bits
            ]
            .as_ref(),
        )?;
        for ip in self.ip_addrs.iter() {
            wr.write_all(&ip.to_bits().to_be_bytes())?;
        }
        wr.write_all([0; 8].as_ref())?; // Authentication Data
        let mut bytes = wr.into_inner();
        // Calculate the checksum and set the respective bits
        let sum = checksum(&bytes);
        [bytes[6], bytes[7]] = sum.to_be_bytes();
        Ok(bytes)
    }
}

fn parse(bytes: &[u8]) -> Result<VRRPv2, VRRPv2Error> {
    let mut rdr = Cursor::new(bytes);
    match rdr.read_u8()? {
        i if (i & 0xF) != 1 => return Err(VRRPv2Error::InvalidType),
        i if (i >> 4) != 2 => return Err(VRRPv2Error::InvalidVersion),
        _ => {}
    };
    let virtual_router_id = rdr.read_u8()?;
    let priority = rdr.read_u8()?;
    let count_ip_addrs = rdr.read_u8()?;
    let auth_type = rdr.read_u8()?;
    let auth_type = match auth_type {
        0 => VRRPv2AuthType::VRRPv2AuthNoAuth,
        1 => VRRPv2AuthType::VRRPv2AuthReserved1,
        2 => VRRPv2AuthType::VRRPv2AuthReserved2,
        _ => return Err(VRRPv2Error::InvalidAuthType),
    };
    let advertisement_interval = rdr.read_u8()?;
    let checksum = rdr.read_u16()?;
    let mut ip_addrs = Vec::with_capacity(count_ip_addrs as usize);
    for _i in 0..count_ip_addrs {
        let b = rdr.read_u32()?;
        ip_addrs.push(Ipv4Addr::from(b));
    }
    Ok(VRRPv2 {
        virtual_router_id,
        priority,
        count_ip_addrs,
        auth_type,
        advertisement_interval,
        checksum,
        ip_addrs,
    })
}

/// Parse and validate a slice to construct a VRRPv2 struct.
///
/// # Examples
///
/// ```
/// use vrrpd::vrrpv2::VRRPv2;
/// use vrrpd::vrrpv2::VRRPv2AuthType;
/// use vrrpd::vrrpv2::from_bytes;
/// use std::net::Ipv4Addr;
///
/// let bytes = [
///    0x21, 0x01, 0x64, 0x01, 0x00, 0x01, 0xba, 0x52, 0xc0, 0xa8, 0x00, 0x01,
///    0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
///    0x00, 0x00,
/// ];
/// let expected = VRRPv2 {
///     virtual_router_id: 1,
///     priority: 100,
///     count_ip_addrs: 1,
///     auth_type: VRRPv2AuthType::VRRPv2AuthNoAuth,
///     checksum: 47698,
///     advertisement_interval: 1,
///     ip_addrs: vec![Ipv4Addr::from([192, 168, 0, 1])],
/// };
/// assert_eq!(from_bytes(&bytes), Ok(expected));
/// ```
pub fn from_bytes(bytes: &[u8]) -> Result<VRRPv2, VRRPv2Error> {
    let vrrpv2 = parse(bytes)?;
    match checksum(bytes) {
        0 => Ok(vrrpv2),
        _ => Err(VRRPv2Error::InvalidChecksum),
    }
}

fn checksum(bytes: &[u8]) -> u16 {
    let (chunks, remainder) = bytes.as_chunks::<2>();
    let mut sum = chunks
        .iter()
        .fold(0, |acc, x| acc + u32::from(u16::from_be_bytes(*x)));
    if !remainder.is_empty() {
        sum += u32::from(remainder[0]);
    }
    sum = (sum >> 16) + (sum & 0xffff);
    sum += sum >> 16;
    !(sum as u16)
}

#[test]
fn test_incomplete_bytes() {
    let bytes = [0x21, 0x01];
    assert_eq!(from_bytes(&bytes), Err(VRRPv2Error::ParseError));
}

#[test]
fn test_invalid_version() {
    let bytes = [
        0x31, 0x1, 0x2a, 0x0, 0x0, 0x1, 0xb5, 0xfd, 0x0, 0x0, 0x0, 0x0, 0x0,
        0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0,
    ];
    assert_eq!(from_bytes(&bytes), Err(VRRPv2Error::InvalidVersion));
}

#[test]
fn test_invalid_type() {
    let bytes = [
        0x20, 0x2a, 0x64, 0x1, 0x0, 0x1, 0xaa, 0x29, 0xc0, 0xa8, 0x0, 0x1, 0x0,
        0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0,
    ];
    assert_eq!(from_bytes(&bytes), Err(VRRPv2Error::InvalidType));
}

#[test]
fn test_invalid_auth_type() {
    let bytes = [
        0x21, 0x01, 0x64, 0x01, 0x03, 0x01, 0xba, 0x52, 0xc0, 0xa8, 0x00, 0x01,
        0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
        0x00, 0x00,
    ];
    assert_eq!(from_bytes(&bytes), Err(VRRPv2Error::InvalidAuthType));
}

#[test]
fn test_invalid_checksum() {
    let bytes = [
        0x21, 0x01, 0x64, 0x01, 0x00, 0x01, 0xbb, 0x52, 0xc0, 0xa8, 0x00, 0x01,
        0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
        0x00, 0x00,
    ];
    assert_eq!(from_bytes(&bytes), Err(VRRPv2Error::InvalidChecksum));
}

#[test]
fn test_checksum() {
    let bytes = [0x00, 0x01, 0xf2, 0x03, 0xf4, 0xf5, 0xf6, 0xf7];
    assert_eq!(checksum(&bytes), 0x220d);
}

#[test]
fn test_checksum_singlebyte() {
    let bytes = [0; 1];
    assert_eq!(checksum(&bytes), 0xffff);
}

#[test]
fn test_checksum_twobytes() {
    let bytes = [0x00, 0xff];
    assert_eq!(checksum(&bytes), 0xff00);
}

#[test]
fn test_checksum_another() {
    let bytes = [0xe3, 0x4f, 0x23, 0x96, 0x44, 0x27, 0x99, 0xf3];
    assert_eq!(checksum(&bytes), 0x1aff);
}

#[test]
fn test_to_bytes() {
    let in_bytes = [
        0x21, 0x01, 0x64, 0x01, 0x00, 0x01, 0xba, 0x52, 0xc0, 0xa8, 0x00, 0x01,
        0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
    ];
    let vrrpv2 = from_bytes(&in_bytes).expect("parsing failed");
    let out_bytes = vrrpv2.to_bytes().expect("conversion failed");
    assert_eq!(in_bytes.as_ref(), out_bytes);
}