use bytes::{Bytes, BytesMut, Buf as _, BufMut as _};
use futures::{prelude::*, io::IoSlice};
use std::{convert::TryFrom as _, io, pin::Pin, task::{Poll, Context}, u16};
const MAX_LEN_BYTES: u16 = 2;
const MAX_FRAME_SIZE: u16 = (1 << (MAX_LEN_BYTES * 8 - MAX_LEN_BYTES)) - 1;
const DEFAULT_BUFFER_SIZE: usize = 64;
#[pin_project::pin_project]
#[derive(Debug)]
pub struct LengthDelimited<R> {
#[pin]
inner: R,
read_buffer: BytesMut,
write_buffer: BytesMut,
read_state: ReadState,
}
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
enum ReadState {
ReadLength { buf: [u8; MAX_LEN_BYTES as usize], pos: usize },
ReadData { len: u16, pos: usize },
}
impl Default for ReadState {
fn default() -> Self {
ReadState::ReadLength {
buf: [0; MAX_LEN_BYTES as usize],
pos: 0
}
}
}
impl<R> LengthDelimited<R> {
pub fn new(inner: R) -> LengthDelimited<R> {
LengthDelimited {
inner,
read_state: ReadState::default(),
read_buffer: BytesMut::with_capacity(DEFAULT_BUFFER_SIZE),
write_buffer: BytesMut::with_capacity(DEFAULT_BUFFER_SIZE + MAX_LEN_BYTES as usize),
}
}
pub fn into_inner(self) -> R {
assert!(self.read_buffer.is_empty());
assert!(self.write_buffer.is_empty());
self.inner
}
pub fn into_reader(self) -> LengthDelimitedReader<R> {
LengthDelimitedReader { inner: self }
}
pub fn poll_write_buffer(self: Pin<&mut Self>, cx: &mut Context<'_>)
-> Poll<Result<(), io::Error>>
where
R: AsyncWrite
{
let mut this = self.project();
while !this.write_buffer.is_empty() {
match this.inner.as_mut().poll_write(cx, &this.write_buffer) {
Poll::Pending => return Poll::Pending,
Poll::Ready(Ok(0)) => {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::WriteZero,
"Failed to write buffered frame.")))
}
Poll::Ready(Ok(n)) => this.write_buffer.advance(n),
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
}
}
Poll::Ready(Ok(()))
}
}
impl<R> Stream for LengthDelimited<R>
where
R: AsyncRead
{
type Item = Result<Bytes, io::Error>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let mut this = self.project();
loop {
match this.read_state {
ReadState::ReadLength { buf, pos } => {
match this.inner.as_mut().poll_read(cx, &mut buf[*pos .. *pos + 1]) {
Poll::Ready(Ok(0)) => {
if *pos == 0 {
return Poll::Ready(None);
} else {
return Poll::Ready(Some(Err(io::ErrorKind::UnexpectedEof.into())));
}
}
Poll::Ready(Ok(n)) => {
debug_assert_eq!(n, 1);
*pos += n;
}
Poll::Ready(Err(err)) => return Poll::Ready(Some(Err(err))),
Poll::Pending => return Poll::Pending,
};
if (buf[*pos - 1] & 0x80) == 0 {
let (len, _) = unsigned_varint::decode::u16(buf)
.map_err(|e| {
log::debug!("invalid length prefix: {}", e);
io::Error::new(io::ErrorKind::InvalidData, "invalid length prefix")
})?;
if len >= 1 {
*this.read_state = ReadState::ReadData { len, pos: 0 };
this.read_buffer.resize(len as usize, 0);
} else {
debug_assert_eq!(len, 0);
*this.read_state = ReadState::default();
return Poll::Ready(Some(Ok(Bytes::new())));
}
} else if *pos == MAX_LEN_BYTES as usize {
return Poll::Ready(Some(Err(io::Error::new(
io::ErrorKind::InvalidData,
"Maximum frame length exceeded"))));
}
}
ReadState::ReadData { len, pos } => {
match this.inner.as_mut().poll_read(cx, &mut this.read_buffer[*pos..]) {
Poll::Ready(Ok(0)) => return Poll::Ready(Some(Err(io::ErrorKind::UnexpectedEof.into()))),
Poll::Ready(Ok(n)) => *pos += n,
Poll::Pending => return Poll::Pending,
Poll::Ready(Err(err)) => return Poll::Ready(Some(Err(err))),
};
if *pos == *len as usize {
let frame = this.read_buffer.split_off(0).freeze();
*this.read_state = ReadState::default();
return Poll::Ready(Some(Ok(frame)));
}
}
}
}
}
}
impl<R> Sink<Bytes> for LengthDelimited<R>
where
R: AsyncWrite,
{
type Error = io::Error;
fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
if self.as_mut().project().write_buffer.len() >= MAX_FRAME_SIZE as usize {
match self.as_mut().poll_write_buffer(cx) {
Poll::Ready(Ok(())) => {},
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
Poll::Pending => return Poll::Pending,
}
debug_assert!(self.as_mut().project().write_buffer.is_empty());
}
Poll::Ready(Ok(()))
}
fn start_send(self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> {
let this = self.project();
let len = match u16::try_from(item.len()) {
Ok(len) if len <= MAX_FRAME_SIZE => len,
_ => {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"Maximum frame size exceeded."))
}
};
let mut uvi_buf = unsigned_varint::encode::u16_buffer();
let uvi_len = unsigned_varint::encode::u16(len, &mut uvi_buf);
this.write_buffer.reserve(len as usize + uvi_len.len());
this.write_buffer.put(uvi_len);
this.write_buffer.put(item);
Ok(())
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
match LengthDelimited::poll_write_buffer(self.as_mut(), cx) {
Poll::Ready(Ok(())) => {},
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
Poll::Pending => return Poll::Pending,
}
let this = self.project();
debug_assert!(this.write_buffer.is_empty());
this.inner.poll_flush(cx)
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
match LengthDelimited::poll_write_buffer(self.as_mut(), cx) {
Poll::Ready(Ok(())) => {},
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
Poll::Pending => return Poll::Pending,
}
let this = self.project();
debug_assert!(this.write_buffer.is_empty());
this.inner.poll_close(cx)
}
}
#[pin_project::pin_project]
#[derive(Debug)]
pub struct LengthDelimitedReader<R> {
#[pin]
inner: LengthDelimited<R>
}
impl<R> LengthDelimitedReader<R> {
pub fn into_inner(self) -> R {
self.inner.into_inner()
}
}
impl<R> Stream for LengthDelimitedReader<R>
where
R: AsyncRead
{
type Item = Result<Bytes, io::Error>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.project().inner.poll_next(cx)
}
}
impl<R> AsyncWrite for LengthDelimitedReader<R>
where
R: AsyncWrite
{
fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8])
-> Poll<Result<usize, io::Error>>
{
let mut this = self.project().inner;
match LengthDelimited::poll_write_buffer(this.as_mut(), cx) {
Poll::Ready(Ok(())) => {},
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
Poll::Pending => return Poll::Pending,
}
debug_assert!(this.write_buffer.is_empty());
this.project().inner.poll_write(cx, buf)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
self.project().inner.poll_flush(cx)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
self.project().inner.poll_close(cx)
}
fn poll_write_vectored(self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[IoSlice<'_>])
-> Poll<Result<usize, io::Error>>
{
let mut this = self.project().inner;
match LengthDelimited::poll_write_buffer(this.as_mut(), cx) {
Poll::Ready(Ok(())) => {},
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
Poll::Pending => return Poll::Pending,
}
debug_assert!(this.write_buffer.is_empty());
this.project().inner.poll_write_vectored(cx, bufs)
}
}
#[cfg(test)]
mod tests {
use crate::length_delimited::LengthDelimited;
use async_std::net::{TcpListener, TcpStream};
use futures::{prelude::*, io::Cursor};
use quickcheck::*;
use std::io::ErrorKind;
#[test]
fn basic_read() {
let data = vec![6, 9, 8, 7, 6, 5, 4];
let framed = LengthDelimited::new(Cursor::new(data));
let recved = futures::executor::block_on(framed.try_collect::<Vec<_>>()).unwrap();
assert_eq!(recved, vec![vec![9, 8, 7, 6, 5, 4]]);
}
#[test]
fn basic_read_two() {
let data = vec![6, 9, 8, 7, 6, 5, 4, 3, 9, 8, 7];
let framed = LengthDelimited::new(Cursor::new(data));
let recved = futures::executor::block_on(framed.try_collect::<Vec<_>>()).unwrap();
assert_eq!(recved, vec![vec![9, 8, 7, 6, 5, 4], vec![9, 8, 7]]);
}
#[test]
fn two_bytes_long_packet() {
let len = 5000u16;
assert!(len < (1 << 15));
let frame = (0..len).map(|n| (n & 0xff) as u8).collect::<Vec<_>>();
let mut data = vec![(len & 0x7f) as u8 | 0x80, (len >> 7) as u8];
data.extend(frame.clone().into_iter());
let mut framed = LengthDelimited::new(Cursor::new(data));
let recved = futures::executor::block_on(async move {
framed.next().await
}).unwrap();
assert_eq!(recved.unwrap(), frame);
}
#[test]
fn packet_len_too_long() {
let mut data = vec![0x81, 0x81, 0x1];
data.extend((0..16513).map(|_| 0));
let mut framed = LengthDelimited::new(Cursor::new(data));
let recved = futures::executor::block_on(async move {
framed.next().await.unwrap()
});
if let Err(io_err) = recved {
assert_eq!(io_err.kind(), ErrorKind::InvalidData)
} else {
panic!()
}
}
#[test]
fn empty_frames() {
let data = vec![0, 0, 6, 9, 8, 7, 6, 5, 4, 0, 3, 9, 8, 7];
let framed = LengthDelimited::new(Cursor::new(data));
let recved = futures::executor::block_on(framed.try_collect::<Vec<_>>()).unwrap();
assert_eq!(
recved,
vec![
vec![],
vec![],
vec![9, 8, 7, 6, 5, 4],
vec![],
vec![9, 8, 7],
]
);
}
#[test]
fn unexpected_eof_in_len() {
let data = vec![0x89];
let framed = LengthDelimited::new(Cursor::new(data));
let recved = futures::executor::block_on(framed.try_collect::<Vec<_>>());
if let Err(io_err) = recved {
assert_eq!(io_err.kind(), ErrorKind::UnexpectedEof)
} else {
panic!()
}
}
#[test]
fn unexpected_eof_in_data() {
let data = vec![5];
let framed = LengthDelimited::new(Cursor::new(data));
let recved = futures::executor::block_on(framed.try_collect::<Vec<_>>());
if let Err(io_err) = recved {
assert_eq!(io_err.kind(), ErrorKind::UnexpectedEof)
} else {
panic!()
}
}
#[test]
fn unexpected_eof_in_data2() {
let data = vec![5, 9, 8, 7];
let framed = LengthDelimited::new(Cursor::new(data));
let recved = futures::executor::block_on(framed.try_collect::<Vec<_>>());
if let Err(io_err) = recved {
assert_eq!(io_err.kind(), ErrorKind::UnexpectedEof)
} else {
panic!()
}
}
#[test]
fn writing_reading() {
fn prop(frames: Vec<Vec<u8>>) -> TestResult {
async_std::task::block_on(async move {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let listener_addr = listener.local_addr().unwrap();
let expected_frames = frames.clone();
let server = async_std::task::spawn(async move {
let socket = listener.accept().await.unwrap().0;
let mut connec = rw_stream_sink::RwStreamSink::new(LengthDelimited::new(socket));
let mut buf = vec![0u8; 0];
for expected in expected_frames {
if expected.is_empty() {
continue;
}
if buf.len() < expected.len() {
buf.resize(expected.len(), 0);
}
let n = connec.read(&mut buf).await.unwrap();
assert_eq!(&buf[..n], &expected[..]);
}
});
let client = async_std::task::spawn(async move {
let socket = TcpStream::connect(&listener_addr).await.unwrap();
let mut connec = LengthDelimited::new(socket);
for frame in frames {
connec.send(From::from(frame)).await.unwrap();
}
});
server.await;
client.await;
});
TestResult::passed()
}
quickcheck(prop as fn(_) -> _)
}
}