use crate::protocol::{Protocol, MessageReader, Message, ProtocolError, HeaderLine};
use futures::{prelude::*, io::{IoSlice, IoSliceMut}, ready};
use pin_project::pin_project;
use std::{error::Error, fmt, io, mem, pin::Pin, task::{Context, Poll}};
#[pin_project]
#[derive(Debug)]
pub struct Negotiated<TInner> {
    #[pin]
    state: State<TInner>
}
#[derive(Debug)]
pub struct NegotiatedComplete<TInner> {
    inner: Option<Negotiated<TInner>>,
}
impl<TInner> Future for NegotiatedComplete<TInner>
where
    
    
    TInner: AsyncRead + AsyncWrite + Unpin,
{
    type Output = Result<Negotiated<TInner>, NegotiationError>;
    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
        let mut io = self.inner.take().expect("NegotiatedFuture called after completion.");
        match Negotiated::poll(Pin::new(&mut io), cx) {
            Poll::Pending => {
                self.inner = Some(io);
                Poll::Pending
            },
            Poll::Ready(Ok(())) => Poll::Ready(Ok(io)),
            Poll::Ready(Err(err)) => {
                self.inner = Some(io);
                Poll::Ready(Err(err))
            }
        }
    }
}
impl<TInner> Negotiated<TInner> {
    
    pub(crate) fn completed(io: TInner) -> Self {
        Negotiated { state: State::Completed { io } }
    }
    
    
    pub(crate) fn expecting(
        io: MessageReader<TInner>,
        protocol: Protocol,
        header: Option<HeaderLine>
    ) -> Self {
        Negotiated { state: State::Expecting { io, protocol, header } }
    }
    
    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), NegotiationError>>
    where
        TInner: AsyncRead + AsyncWrite + Unpin
    {
        
        match self.as_mut().poll_flush(cx) {
            Poll::Ready(Ok(())) => {},
            Poll::Pending => return Poll::Pending,
            Poll::Ready(Err(e)) => {
                
                
                if e.kind() != io::ErrorKind::WriteZero {
                    return Poll::Ready(Err(e.into()))
                }
            }
        }
        let mut this = self.project();
        if let StateProj::Completed { .. } = this.state.as_mut().project() {
             return Poll::Ready(Ok(()));
        }
        
        loop {
            match mem::replace(&mut *this.state, State::Invalid) {
                State::Expecting { mut io, header, protocol } => {
                    let msg = match Pin::new(&mut io).poll_next(cx)? {
                        Poll::Ready(Some(msg)) => msg,
                        Poll::Pending => {
                            *this.state = State::Expecting { io, header, protocol };
                            return Poll::Pending
                        },
                        Poll::Ready(None) => {
                            return Poll::Ready(Err(ProtocolError::IoError(
                                io::ErrorKind::UnexpectedEof.into()).into()));
                        }
                    };
                    if let Message::Header(h) = &msg {
                        if Some(h) == header.as_ref() {
                            *this.state = State::Expecting { io, protocol, header: None };
                            continue
                        }
                    }
                    if let Message::Protocol(p) = &msg {
                        if p.as_ref() == protocol.as_ref() {
                            log::debug!("Negotiated: Received confirmation for protocol: {}", p);
                            *this.state = State::Completed { io: io.into_inner() };
                            return Poll::Ready(Ok(()));
                        }
                    }
                    return Poll::Ready(Err(NegotiationError::Failed));
                }
                _ => panic!("Negotiated: Invalid state")
            }
        }
    }
    
    
    pub fn complete(self) -> NegotiatedComplete<TInner> {
        NegotiatedComplete { inner: Some(self) }
    }
}
#[pin_project(project = StateProj)]
#[derive(Debug)]
enum State<R> {
    
    
    
    Expecting {
        
        #[pin]
        io: MessageReader<R>,
        
        
        header: Option<HeaderLine>,
        
        protocol: Protocol,
    },
    
    
    Completed { #[pin] io: R },
    
    
    Invalid,
}
impl<TInner> AsyncRead for Negotiated<TInner>
where
    TInner: AsyncRead + AsyncWrite + Unpin
{
    fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8])
        -> Poll<Result<usize, io::Error>>
    {
        loop {
            if let StateProj::Completed { io } = self.as_mut().project().state.project() {
                
                return io.poll_read(cx, buf);
            }
            
            
            match self.as_mut().poll(cx) {
                Poll::Ready(Ok(())) => {},
                Poll::Pending => return Poll::Pending,
                Poll::Ready(Err(err)) => return Poll::Ready(Err(From::from(err))),
            }
        }
    }
    
    
    fn poll_read_vectored(mut self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &mut [IoSliceMut<'_>])
        -> Poll<Result<usize, io::Error>>
    {
        loop {
            if let StateProj::Completed { io } = self.as_mut().project().state.project() {
                
                return io.poll_read_vectored(cx, bufs)
            }
            
            
            match self.as_mut().poll(cx) {
                Poll::Ready(Ok(())) => {},
                Poll::Pending => return Poll::Pending,
                Poll::Ready(Err(err)) => return Poll::Ready(Err(From::from(err))),
            }
        }
    }
}
impl<TInner> AsyncWrite for Negotiated<TInner>
where
    TInner: AsyncWrite + AsyncRead + Unpin
{
    fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, io::Error>> {
        match self.project().state.project() {
            StateProj::Completed { io } => io.poll_write(cx, buf),
            StateProj::Expecting { io, .. } => io.poll_write(cx, buf),
            StateProj::Invalid => panic!("Negotiated: Invalid state"),
        }
    }
    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
        match self.project().state.project() {
            StateProj::Completed { io } => io.poll_flush(cx),
            StateProj::Expecting { io, .. } => io.poll_flush(cx),
            StateProj::Invalid => panic!("Negotiated: Invalid state"),
        }
    }
    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
        
        
        ready!(self.as_mut().poll(cx).map_err(Into::<io::Error>::into)?);
        ready!(self.as_mut().poll_flush(cx).map_err(Into::<io::Error>::into)?);
        
        match self.project().state.project() {
            StateProj::Completed { io, .. } => io.poll_close(cx),
            StateProj::Expecting { io, .. } => io.poll_close(cx),
            StateProj::Invalid => panic!("Negotiated: Invalid state"),
        }
    }
    fn poll_write_vectored(self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[IoSlice<'_>])
        -> Poll<Result<usize, io::Error>>
    {
        match self.project().state.project() {
            StateProj::Completed { io } => io.poll_write_vectored(cx, bufs),
            StateProj::Expecting { io, .. } => io.poll_write_vectored(cx, bufs),
            StateProj::Invalid => panic!("Negotiated: Invalid state"),
        }
    }
}
#[derive(Debug)]
pub enum NegotiationError {
    
    ProtocolError(ProtocolError),
    
    Failed,
}
impl From<ProtocolError> for NegotiationError {
    fn from(err: ProtocolError) -> NegotiationError {
        NegotiationError::ProtocolError(err)
    }
}
impl From<io::Error> for NegotiationError {
    fn from(err: io::Error) -> NegotiationError {
        ProtocolError::from(err).into()
    }
}
impl From<NegotiationError> for io::Error {
    fn from(err: NegotiationError) -> io::Error {
        if let NegotiationError::ProtocolError(e) = err {
            return e.into()
        }
        io::Error::new(io::ErrorKind::Other, err)
    }
}
impl Error for NegotiationError {
    fn source(&self) -> Option<&(dyn Error + 'static)> {
        match self {
            NegotiationError::ProtocolError(err) => Some(err),
            _ => None,
        }
    }
}
impl fmt::Display for NegotiationError {
    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
        match self {
            NegotiationError::ProtocolError(p) =>
                fmt.write_fmt(format_args!("Protocol error: {}", p)),
            NegotiationError::Failed =>
                fmt.write_str("Protocol negotiation failed.")
        }
    }
}