use bytes::Bytes;
use crate::{MplexConfig, MaxBufferBehaviour};
use crate::codec::{Codec, Frame, LocalStreamId, RemoteStreamId};
use log::{debug, trace};
use futures::{prelude::*, ready, stream::Fuse};
use futures::task::{AtomicWaker, ArcWake, waker_ref, WakerRef};
use asynchronous_codec::Framed;
use nohash_hasher::{IntMap, IntSet};
use parking_lot::Mutex;
use smallvec::SmallVec;
use std::collections::VecDeque;
use std::{cmp, fmt, io, mem, sync::Arc, task::{Context, Poll, Waker}};
pub use std::io::{Result, Error, ErrorKind};
#[derive(Clone, Copy)]
struct ConnectionId(u64);
impl fmt::Debug for ConnectionId {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{:16x}", self.0)
}
}
impl fmt::Display for ConnectionId {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{:16x}", self.0)
}
}
pub struct Multiplexed<C> {
id: ConnectionId,
status: Status,
io: Fuse<Framed<C, Codec>>,
config: MplexConfig,
open_buffer: VecDeque<LocalStreamId>,
pending_flush_open: IntSet<LocalStreamId>,
blocking_stream: Option<LocalStreamId>,
pending_frames: VecDeque<Frame<LocalStreamId>>,
substreams: IntMap<LocalStreamId, SubstreamState>,
next_outbound_stream_id: LocalStreamId,
notifier_read: Arc<NotifierRead>,
notifier_write: Arc<NotifierWrite>,
notifier_open: NotifierOpen,
}
#[derive(Debug)]
enum Status {
Open,
Closed,
Err(io::Error),
}
impl<C> Multiplexed<C>
where
C: AsyncRead + AsyncWrite + Unpin
{
pub fn new(io: C, config: MplexConfig) -> Self {
let id = ConnectionId(rand::random());
debug!("New multiplexed connection: {}", id);
Multiplexed {
id,
config,
status: Status::Open,
io: Framed::new(io, Codec::new()).fuse(),
open_buffer: Default::default(),
substreams: Default::default(),
pending_flush_open: Default::default(),
pending_frames: Default::default(),
blocking_stream: None,
next_outbound_stream_id: LocalStreamId::dialer(0),
notifier_read: Arc::new(NotifierRead {
read_stream: Mutex::new(Default::default()),
next_stream: AtomicWaker::new(),
}),
notifier_write: Arc::new(NotifierWrite {
pending: Mutex::new(Default::default()),
}),
notifier_open: NotifierOpen {
pending: Default::default()
}
}
}
pub fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
match &self.status {
Status::Closed => return Poll::Ready(Ok(())),
Status::Err(e) => return Poll::Ready(Err(io::Error::new(e.kind(), e.to_string()))),
Status::Open => {}
}
ready!(self.send_pending_frames(cx))?;
let waker = NotifierWrite::register(&self.notifier_write, cx.waker());
match ready!(self.io.poll_flush_unpin(&mut Context::from_waker(&waker))) {
Err(e) => Poll::Ready(self.on_error(e)),
Ok(()) => {
self.pending_flush_open = Default::default();
Poll::Ready(Ok(()))
}
}
}
pub fn poll_close(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
match &self.status {
Status::Closed => return Poll::Ready(Ok(())),
Status::Err(e) => return Poll::Ready(Err(io::Error::new(e.kind(), e.to_string()))),
Status::Open => {}
}
let waker = NotifierWrite::register(&self.notifier_write, cx.waker());
match self.io.poll_close_unpin(&mut Context::from_waker(&waker)) {
Poll::Pending => Poll::Pending,
Poll::Ready(Err(e)) => Poll::Ready(self.on_error(e)),
Poll::Ready(Ok(())) => {
self.pending_frames = VecDeque::new();
self.open_buffer = Default::default();
self.substreams = Default::default();
self.status = Status::Closed;
Poll::Ready(Ok(()))
}
}
}
pub fn poll_next_stream(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<LocalStreamId>> {
self.guard_open()?;
if let Some(stream_id) = self.open_buffer.pop_back() {
return Poll::Ready(Ok(stream_id));
}
debug_assert!(self.open_buffer.is_empty());
let mut num_buffered = 0;
loop {
if num_buffered == self.config.max_buffer_len {
cx.waker().clone().wake();
return Poll::Pending
}
match ready!(self.poll_read_frame(cx, None))? {
Frame::Open { stream_id } => {
if let Some(id) = self.on_open(stream_id)? {
return Poll::Ready(Ok(id))
}
}
Frame::Data { stream_id, data } => {
self.buffer(stream_id.into_local(), data)?;
num_buffered += 1;
}
Frame::Close { stream_id } => {
self.on_close(stream_id.into_local())?;
}
Frame::Reset { stream_id } => {
self.on_reset(stream_id.into_local())
}
}
}
}
pub fn poll_open_stream(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<LocalStreamId>> {
self.guard_open()?;
if self.substreams.len() >= self.config.max_substreams {
debug!("{}: Maximum number of substreams reached ({})",
self.id, self.config.max_substreams);
self.notifier_open.register(cx.waker());
return Poll::Pending
}
let waker = NotifierWrite::register(&self.notifier_write, cx.waker());
match ready!(self.io.poll_ready_unpin(&mut Context::from_waker(&waker))) {
Ok(()) => {
let stream_id = self.next_outbound_stream_id();
let frame = Frame::Open { stream_id };
match self.io.start_send_unpin(frame) {
Ok(()) => {
self.substreams.insert(stream_id, SubstreamState::Open {
buf: Default::default()
});
debug!("{}: New outbound substream: {} (total {})",
self.id, stream_id, self.substreams.len());
self.pending_flush_open.insert(stream_id);
Poll::Ready(Ok(stream_id))
}
Err(e) => Poll::Ready(self.on_error(e)),
}
},
Err(e) => Poll::Ready(self.on_error(e))
}
}
pub fn drop_stream(&mut self, id: LocalStreamId) {
match self.status {
Status::Closed | Status::Err(_) => return,
Status::Open => {},
}
self.notifier_read.wake_read_stream(id);
match self.substreams.remove(&id) {
None => return,
Some(state) => {
let below_limit = self.substreams.len() == self.config.max_substreams - 1;
if below_limit {
self.notifier_open.wake_all();
}
match state {
SubstreamState::Closed { .. } => {}
SubstreamState::SendClosed { .. } => {}
SubstreamState::Reset { .. } => {}
SubstreamState::RecvClosed { .. } => {
if self.check_max_pending_frames().is_err() {
return
}
trace!("{}: Pending close for stream {}", self.id, id);
self.pending_frames.push_front(Frame::Close { stream_id: id });
}
SubstreamState::Open { .. } => {
if self.check_max_pending_frames().is_err() {
return
}
trace!("{}: Pending reset for stream {}", self.id, id);
self.pending_frames.push_front(Frame::Reset { stream_id: id });
}
}
}
}
}
pub fn poll_write_stream(&mut self, cx: &mut Context<'_>, id: LocalStreamId, buf: &[u8])
-> Poll<io::Result<usize>>
{
self.guard_open()?;
match self.substreams.get(&id) {
None | Some(SubstreamState::Reset { .. }) =>
return Poll::Ready(Err(io::ErrorKind::BrokenPipe.into())),
Some(SubstreamState::SendClosed { .. }) | Some(SubstreamState::Closed { .. }) =>
return Poll::Ready(Err(io::ErrorKind::WriteZero.into())),
Some(SubstreamState::Open { .. }) | Some(SubstreamState::RecvClosed { .. }) => {
}
}
let frame_len = cmp::min(buf.len(), self.config.split_send_size);
ready!(self.poll_send_frame(cx, || {
let data = Bytes::copy_from_slice(&buf[.. frame_len]);
Frame::Data { stream_id: id, data }
}))?;
Poll::Ready(Ok(frame_len))
}
pub fn poll_read_stream(&mut self, cx: &mut Context<'_>, id: LocalStreamId)
-> Poll<io::Result<Option<Bytes>>>
{
self.guard_open()?;
if let Some(state) = self.substreams.get_mut(&id) {
let buf = state.recv_buf();
if !buf.is_empty() {
if self.blocking_stream == Some(id) {
self.blocking_stream = None;
ArcWake::wake_by_ref(&self.notifier_read);
}
let data = buf.remove(0);
return Poll::Ready(Ok(Some(data)))
}
buf.shrink_to_fit();
}
let mut num_buffered = 0;
loop {
if num_buffered == self.config.max_buffer_len {
cx.waker().clone().wake();
return Poll::Pending
}
if !self.can_read(&id) {
return Poll::Ready(Ok(None))
}
match ready!(self.poll_read_frame(cx, Some(id)))? {
Frame::Data { data, stream_id } if stream_id.into_local() == id => {
return Poll::Ready(Ok(Some(data.clone())))
},
Frame::Data { stream_id, data } => {
self.buffer(stream_id.into_local(), data)?;
num_buffered += 1;
}
frame @ Frame::Open { .. } => {
if let Some(id) = self.on_open(frame.remote_id())? {
self.open_buffer.push_front(id);
trace!("{}: Buffered new inbound stream {} (total: {})", self.id, id, self.open_buffer.len());
self.notifier_read.wake_next_stream();
}
}
Frame::Close { stream_id } => {
let stream_id = stream_id.into_local();
self.on_close(stream_id)?;
if id == stream_id {
return Poll::Ready(Ok(None))
}
}
Frame::Reset { stream_id } => {
let stream_id = stream_id.into_local();
self.on_reset(stream_id);
if id == stream_id {
return Poll::Ready(Ok(None))
}
}
}
}
}
pub fn poll_flush_stream(&mut self, cx: &mut Context<'_>, id: LocalStreamId)
-> Poll<io::Result<()>>
{
self.guard_open()?;
ready!(self.poll_flush(cx))?;
trace!("{}: Flushed substream {}", self.id, id);
Poll::Ready(Ok(()))
}
pub fn poll_close_stream(&mut self, cx: &mut Context<'_>, id: LocalStreamId)
-> Poll<io::Result<()>>
{
self.guard_open()?;
match self.substreams.remove(&id) {
None => Poll::Ready(Ok(())),
Some(SubstreamState::SendClosed { buf }) => {
self.substreams.insert(id, SubstreamState::SendClosed { buf });
Poll::Ready(Ok(()))
}
Some(SubstreamState::Closed { buf }) => {
self.substreams.insert(id, SubstreamState::Closed { buf });
Poll::Ready(Ok(()))
}
Some(SubstreamState::Reset { buf }) => {
self.substreams.insert(id, SubstreamState::Reset { buf });
Poll::Ready(Ok(()))
}
Some(SubstreamState::Open { buf }) => {
if self.poll_send_frame(cx, || Frame::Close { stream_id: id })?.is_pending() {
self.substreams.insert(id, SubstreamState::Open { buf });
Poll::Pending
} else {
debug!("{}: Closed substream {} (half-close)", self.id, id);
self.substreams.insert(id, SubstreamState::SendClosed { buf });
Poll::Ready(Ok(()))
}
}
Some(SubstreamState::RecvClosed { buf }) => {
if self.poll_send_frame(cx, || Frame::Close { stream_id: id })?.is_pending() {
self.substreams.insert(id, SubstreamState::RecvClosed { buf });
Poll::Pending
} else {
debug!("{}: Closed substream {}", self.id, id);
self.substreams.insert(id, SubstreamState::Closed { buf });
Poll::Ready(Ok(()))
}
}
}
}
fn poll_send_frame<F>(&mut self, cx: &mut Context<'_>, frame: F)
-> Poll<io::Result<()>>
where
F: FnOnce() -> Frame<LocalStreamId>
{
let waker = NotifierWrite::register(&self.notifier_write, cx.waker());
match ready!(self.io.poll_ready_unpin(&mut Context::from_waker(&waker))) {
Ok(()) => {
let frame = frame();
trace!("{}: Sending {:?}", self.id, frame);
match self.io.start_send_unpin(frame) {
Ok(()) => Poll::Ready(Ok(())),
Err(e) => Poll::Ready(self.on_error(e))
}
},
Err(e) => Poll::Ready(self.on_error(e))
}
}
fn poll_read_frame(&mut self, cx: &mut Context<'_>, stream_id: Option<LocalStreamId>)
-> Poll<io::Result<Frame<RemoteStreamId>>>
{
if let Poll::Ready(Err(e)) = self.send_pending_frames(cx) {
return Poll::Ready(Err(e))
}
if let Some(id) = &stream_id {
if self.pending_flush_open.contains(id) {
trace!("{}: Executing pending flush for {}.", self.id, id);
ready!(self.poll_flush(cx))?;
self.pending_flush_open = Default::default();
}
}
if let Some(blocked_id) = &self.blocking_stream {
if !self.notifier_read.wake_read_stream(*blocked_id) {
trace!("{}: No task to read from blocked stream. Waking current task.", self.id);
cx.waker().clone().wake();
} else {
if let Some(id) = stream_id {
debug_assert!(blocked_id != &id, "Unexpected attempt at reading a new \
frame from a substream with a full buffer.");
let _ = NotifierRead::register_read_stream(&self.notifier_read, cx.waker(), id);
} else {
let _ = NotifierRead::register_next_stream(&self.notifier_read, cx.waker());
}
}
return Poll::Pending
}
let waker = match stream_id {
Some(id) => NotifierRead::register_read_stream(&self.notifier_read, cx.waker(), id),
None => NotifierRead::register_next_stream(&self.notifier_read, cx.waker())
};
match ready!(self.io.poll_next_unpin(&mut Context::from_waker(&waker))) {
Some(Ok(frame)) => {
trace!("{}: Received {:?}", self.id, frame);
Poll::Ready(Ok(frame))
}
Some(Err(e)) => Poll::Ready(self.on_error(e)),
None => Poll::Ready(self.on_error(io::ErrorKind::UnexpectedEof.into()))
}
}
fn on_open(&mut self, id: RemoteStreamId) -> io::Result<Option<LocalStreamId>> {
let id = id.into_local();
if self.substreams.contains_key(&id) {
debug!("{}: Received unexpected `Open` frame for open substream {}", self.id, id);
return self.on_error(io::Error::new(io::ErrorKind::Other,
"Protocol error: Received `Open` frame for open substream."))
}
if self.substreams.len() >= self.config.max_substreams {
debug!("{}: Maximum number of substreams exceeded: {}",
self.id, self.config.max_substreams);
self.check_max_pending_frames()?;
debug!("{}: Pending reset for new stream {}", self.id, id);
self.pending_frames.push_front(Frame::Reset {
stream_id: id
});
return Ok(None)
}
self.substreams.insert(id, SubstreamState::Open {
buf: Default::default()
});
debug!("{}: New inbound substream: {} (total {})", self.id, id, self.substreams.len());
Ok(Some(id))
}
fn on_reset(&mut self, id: LocalStreamId) {
if let Some(state) = self.substreams.remove(&id) {
match state {
SubstreamState::Closed { .. } => {
trace!("{}: Ignoring reset for mutually closed substream {}.", self.id, id);
}
SubstreamState::Reset { .. } => {
trace!("{}: Ignoring redundant reset for already reset substream {}",
self.id, id);
}
SubstreamState::RecvClosed { buf } |
SubstreamState::SendClosed { buf } |
SubstreamState::Open { buf } => {
debug!("{}: Substream {} reset by remote.", self.id, id);
self.substreams.insert(id, SubstreamState::Reset { buf });
NotifierRead::wake_read_stream(&self.notifier_read, id);
}
}
} else {
trace!("{}: Ignoring `Reset` for unknown substream {}. Possibly dropped earlier.",
self.id, id);
}
}
fn on_close(&mut self, id: LocalStreamId) -> io::Result<()> {
if let Some(state) = self.substreams.remove(&id) {
match state {
SubstreamState::RecvClosed { .. } | SubstreamState::Closed { .. } => {
debug!("{}: Ignoring `Close` frame for closed substream {}",
self.id, id);
self.substreams.insert(id, state);
},
SubstreamState::Reset { buf } => {
debug!("{}: Ignoring `Close` frame for already reset substream {}",
self.id, id);
self.substreams.insert(id, SubstreamState::Reset { buf });
}
SubstreamState::SendClosed { buf } => {
debug!("{}: Substream {} closed by remote (SendClosed -> Closed).",
self.id, id);
self.substreams.insert(id, SubstreamState::Closed { buf });
self.notifier_read.wake_read_stream(id);
},
SubstreamState::Open { buf } => {
debug!("{}: Substream {} closed by remote (Open -> RecvClosed)",
self.id, id);
self.substreams.insert(id, SubstreamState::RecvClosed { buf });
self.notifier_read.wake_read_stream(id);
},
}
} else {
trace!("{}: Ignoring `Close` for unknown substream {}. Possibly dropped earlier.",
self.id, id);
}
Ok(())
}
fn next_outbound_stream_id(&mut self) -> LocalStreamId {
let id = self.next_outbound_stream_id;
self.next_outbound_stream_id = self.next_outbound_stream_id.next();
id
}
fn can_read(&self, id: &LocalStreamId) -> bool {
match self.substreams.get(id) {
Some(SubstreamState::Open { .. }) | Some(SubstreamState::SendClosed { .. }) => true,
_ => false,
}
}
fn send_pending_frames(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
while let Some(frame) = self.pending_frames.pop_back() {
if self.poll_send_frame(cx, || {
frame.clone()
})?.is_pending() {
self.pending_frames.push_back(frame);
return Poll::Pending
}
}
Poll::Ready(Ok(()))
}
fn on_error<T>(&mut self, e: io::Error) -> io::Result<T> {
debug!("{}: Multiplexed connection failed: {:?}", self.id, e);
self.status = Status::Err(io::Error::new(e.kind(), e.to_string()));
self.pending_frames = Default::default();
self.substreams = Default::default();
self.open_buffer = Default::default();
Err(e)
}
fn guard_open(&self) -> io::Result<()> {
match &self.status {
Status::Closed => Err(io::Error::new(io::ErrorKind::Other, "Connection is closed")),
Status::Err(e) => Err(io::Error::new(e.kind(), e.to_string())),
Status::Open => Ok(())
}
}
fn check_max_pending_frames(&mut self) -> io::Result<()> {
if self.pending_frames.len() >= self.config.max_substreams + EXTRA_PENDING_FRAMES {
return self.on_error(io::Error::new(io::ErrorKind::Other,
"Too many pending frames."));
}
Ok(())
}
fn buffer(&mut self, id: LocalStreamId, data: Bytes) -> io::Result<()> {
let state = if let Some(state) = self.substreams.get_mut(&id) {
state
} else {
trace!("{}: Dropping data {:?} for unknown substream {}", self.id, data, id);
return Ok(())
};
let buf = if let Some(buf) = state.recv_buf_open() {
buf
} else {
trace!("{}: Dropping data {:?} for closed or reset substream {}", self.id, data, id);
return Ok(())
};
debug_assert!(buf.len() <= self.config.max_buffer_len);
trace!("{}: Buffering {:?} for stream {} (total: {})", self.id, data, id, buf.len() + 1);
buf.push(data);
self.notifier_read.wake_read_stream(id);
if buf.len() > self.config.max_buffer_len {
debug!("{}: Frame buffer of stream {} is full.", self.id, id);
match self.config.max_buffer_behaviour {
MaxBufferBehaviour::ResetStream => {
let buf = buf.clone();
self.check_max_pending_frames()?;
self.substreams.insert(id, SubstreamState::Reset { buf });
debug!("{}: Pending reset for stream {}", self.id, id);
self.pending_frames.push_front(Frame::Reset {
stream_id: id
});
}
MaxBufferBehaviour::Block => {
self.blocking_stream = Some(id);
}
}
}
Ok(())
}
}
type RecvBuf = SmallVec<[Bytes; 10]>;
#[derive(Clone, Debug)]
enum SubstreamState {
Open { buf: RecvBuf },
SendClosed { buf: RecvBuf },
RecvClosed { buf: RecvBuf },
Closed { buf: RecvBuf },
Reset { buf: RecvBuf }
}
impl SubstreamState {
fn recv_buf(&mut self) -> &mut RecvBuf {
match self {
SubstreamState::Open { buf } => buf,
SubstreamState::SendClosed { buf } => buf,
SubstreamState::RecvClosed { buf } => buf,
SubstreamState::Closed { buf } => buf,
SubstreamState::Reset { buf } => buf,
}
}
fn recv_buf_open(&mut self) -> Option<&mut RecvBuf> {
match self {
SubstreamState::Open { buf } => Some(buf),
SubstreamState::SendClosed { buf } => Some(buf),
SubstreamState::RecvClosed { .. } => None,
SubstreamState::Closed { .. } => None,
SubstreamState::Reset { .. } => None,
}
}
}
struct NotifierRead {
next_stream: AtomicWaker,
read_stream: Mutex<IntMap<LocalStreamId, Waker>>,
}
impl NotifierRead {
#[must_use]
fn register_read_stream<'a>(self: &'a Arc<Self>, waker: &Waker, id: LocalStreamId)
-> WakerRef<'a>
{
let mut pending = self.read_stream.lock();
pending.insert(id, waker.clone());
waker_ref(self)
}
#[must_use]
fn register_next_stream<'a>(self: &'a Arc<Self>, waker: &Waker) -> WakerRef<'a> {
self.next_stream.register(waker);
waker_ref(self)
}
fn wake_read_stream(&self, id: LocalStreamId) -> bool {
let mut pending = self.read_stream.lock();
if let Some(waker) = pending.remove(&id) {
waker.wake();
return true
}
false
}
fn wake_next_stream(&self) {
self.next_stream.wake();
}
}
impl ArcWake for NotifierRead {
fn wake_by_ref(this: &Arc<Self>) {
let wakers = mem::replace(&mut *this.read_stream.lock(), Default::default());
for (_, waker) in wakers {
waker.wake();
}
this.wake_next_stream();
}
}
struct NotifierWrite {
pending: Mutex<Vec<Waker>>,
}
impl NotifierWrite {
#[must_use]
fn register<'a>(self: &'a Arc<Self>, waker: &Waker) -> WakerRef<'a> {
let mut pending = self.pending.lock();
if pending.iter().all(|w| !w.will_wake(waker)) {
pending.push(waker.clone());
}
waker_ref(self)
}
}
impl ArcWake for NotifierWrite {
fn wake_by_ref(this: &Arc<Self>) {
let wakers = mem::replace(&mut *this.pending.lock(), Default::default());
for waker in wakers {
waker.wake();
}
}
}
struct NotifierOpen {
pending: Vec<Waker>,
}
impl NotifierOpen {
fn register(&mut self, waker: &Waker) {
if self.pending.iter().all(|w| !w.will_wake(waker)) {
self.pending.push(waker.clone());
}
}
fn wake_all(&mut self) {
let wakers = mem::replace(&mut self.pending, Default::default());
for waker in wakers {
waker.wake();
}
}
}
const EXTRA_PENDING_FRAMES: usize = 1000;
#[cfg(test)]
mod tests {
use async_std::task;
use bytes::BytesMut;
use futures::prelude::*;
use asynchronous_codec::{Decoder, Encoder};
use quickcheck::*;
use rand::prelude::*;
use std::collections::HashSet;
use std::num::NonZeroU8;
use std::ops::DerefMut;
use std::pin::Pin;
use super::*;
impl Arbitrary for MaxBufferBehaviour {
fn arbitrary<G: Gen>(g: &mut G) -> MaxBufferBehaviour {
*[MaxBufferBehaviour::Block, MaxBufferBehaviour::ResetStream].choose(g).unwrap()
}
}
impl Arbitrary for MplexConfig {
fn arbitrary<G: Gen>(g: &mut G) -> MplexConfig {
MplexConfig {
max_substreams: g.gen_range(1, 100),
max_buffer_len: g.gen_range(1, 1000),
max_buffer_behaviour: MaxBufferBehaviour::arbitrary(g),
split_send_size: g.gen_range(1, 10000),
}
}
}
struct Connection {
r_buf: BytesMut,
w_buf: BytesMut,
eof: bool,
}
impl AsyncRead for Connection {
fn poll_read(
mut self: Pin<&mut Self>,
_: &mut Context<'_>,
buf: &mut [u8]
) -> Poll<io::Result<usize>> {
if self.eof {
return Poll::Ready(Err(io::ErrorKind::UnexpectedEof.into()))
}
let n = std::cmp::min(buf.len(), self.r_buf.len());
let data = self.r_buf.split_to(n);
buf[..n].copy_from_slice(&data[..]);
if n == 0 {
Poll::Pending
} else {
Poll::Ready(Ok(n))
}
}
}
impl AsyncWrite for Connection {
fn poll_write(
mut self: Pin<&mut Self>,
_: &mut Context<'_>,
buf: &[u8]
) -> Poll<io::Result<usize>> {
self.w_buf.extend_from_slice(buf);
Poll::Ready(Ok(buf.len()))
}
fn poll_flush(
self: Pin<&mut Self>,
_: &mut Context<'_>
) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_close(
self: Pin<&mut Self>,
_: &mut Context<'_>
) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
}
#[test]
fn max_buffer_behaviour() {
let _ = env_logger::try_init();
fn prop(cfg: MplexConfig, overflow: NonZeroU8) {
let mut r_buf = BytesMut::new();
let mut codec = Codec::new();
for i in 0 .. cfg.max_substreams {
let stream_id = LocalStreamId::dialer(i as u32);
codec.encode(Frame::Open { stream_id }, &mut r_buf).unwrap();
}
let stream_id = LocalStreamId::dialer(0);
let data = Bytes::from("Hello world");
for _ in 0 .. cfg.max_buffer_len + overflow.get() as usize {
codec.encode(Frame::Data { stream_id, data: data.clone() }, &mut r_buf).unwrap();
}
let conn = Connection { r_buf, w_buf: BytesMut::new(), eof: false };
let mut m = Multiplexed::new(conn, cfg.clone());
task::block_on(future::poll_fn(move |cx| {
for i in 0 .. cfg.max_substreams {
match m.poll_next_stream(cx) {
Poll::Pending => panic!("Expected new inbound stream."),
Poll::Ready(Err(e)) => panic!("{:?}", e),
Poll::Ready(Ok(id)) => {
assert_eq!(id, LocalStreamId::listener(i as u32));
}
};
}
let id = LocalStreamId::listener(0);
match m.poll_next_stream(cx) {
Poll::Ready(r) => panic!("Unexpected result for next stream: {:?}", r),
Poll::Pending => {
assert_eq!(
m.substreams.get_mut(&id).unwrap().recv_buf().len(),
cfg.max_buffer_len
);
match m.poll_next_stream(cx) {
Poll::Ready(r) => panic!("Unexpected result for next stream: {:?}", r),
Poll::Pending => {
assert_eq!(
m.substreams.get_mut(&id).unwrap().recv_buf().len(),
cfg.max_buffer_len + 1
);
}
}
}
}
match cfg.max_buffer_behaviour {
MaxBufferBehaviour::ResetStream => {
let _ = m.poll_flush_stream(cx, id);
let w_buf = &mut m.io.get_mut().deref_mut().w_buf;
let frame = codec.decode(w_buf).unwrap();
let stream_id = stream_id.into_remote();
assert_eq!(frame, Some(Frame::Reset { stream_id }));
}
MaxBufferBehaviour::Block => {
assert!(m.poll_next_stream(cx).is_pending());
for i in 1 .. cfg.max_substreams {
let id = LocalStreamId::listener(i as u32);
assert!(m.poll_read_stream(cx, id).is_pending());
}
}
}
for _ in 0 .. cfg.max_buffer_len + 1 {
match m.poll_read_stream(cx, id) {
Poll::Ready(Ok(Some(bytes))) => {
assert_eq!(bytes, data);
}
x => panic!("Unexpected: {:?}", x)
}
}
match cfg.max_buffer_behaviour {
MaxBufferBehaviour::ResetStream => {
match m.poll_read_stream(cx, id) {
Poll::Ready(Ok(None)) => {},
poll => panic!("Unexpected: {:?}", poll)
}
}
MaxBufferBehaviour::Block => {
match m.poll_read_stream(cx, id) {
Poll::Ready(Ok(Some(bytes))) => assert_eq!(bytes, data),
Poll::Pending => assert_eq!(overflow.get(), 1),
poll => panic!("Unexpected: {:?}", poll)
}
}
}
Poll::Ready(())
}));
}
quickcheck(prop as fn(_,_))
}
#[test]
fn close_on_error() {
let _ = env_logger::try_init();
fn prop(cfg: MplexConfig, num_streams: NonZeroU8) {
let num_streams = cmp::min(cfg.max_substreams, num_streams.get() as usize);
let conn = Connection {
r_buf: BytesMut::new(),
w_buf: BytesMut::new(),
eof: false
};
let mut m = Multiplexed::new(conn, cfg.clone());
let mut opened = HashSet::new();
task::block_on(future::poll_fn(move |cx| {
for _ in 0 .. num_streams {
let id = ready!(m.poll_open_stream(cx)).unwrap();
assert!(opened.insert(id));
assert!(m.poll_read_stream(cx, id).is_pending());
}
m.io.get_mut().deref_mut().eof = true;
assert!(opened.iter().all(|id| match m.poll_read_stream(cx, *id) {
Poll::Ready(Err(e)) => e.kind() == io::ErrorKind::UnexpectedEof,
_ => false
}));
assert!(m.substreams.is_empty());
Poll::Ready(())
}))
}
quickcheck(prop as fn(_,_))
}
}