use bytes::BytesMut;
use futures::prelude::*;
use asynchronous_codec::Framed;
use libp2p::core::{UpgradeInfo, InboundUpgrade, OutboundUpgrade, upgrade};
use log::error;
use std::{borrow::Cow, convert::{Infallible, TryFrom as _}, io, iter, mem, pin::Pin, task::{Context, Poll}};
use unsigned_varint::codec::UviBytes;
const MAX_HANDSHAKE_SIZE: usize = 1024;
#[derive(Debug, Clone)]
pub struct NotificationsIn {
protocol_name: Cow<'static, str>,
max_notification_size: u64,
}
#[derive(Debug, Clone)]
pub struct NotificationsOut {
protocol_name: Cow<'static, str>,
initial_message: Vec<u8>,
max_notification_size: u64,
}
#[pin_project::pin_project]
pub struct NotificationsInSubstream<TSubstream> {
#[pin]
socket: Framed<TSubstream, UviBytes<io::Cursor<Vec<u8>>>>,
handshake: NotificationsInSubstreamHandshake,
}
enum NotificationsInSubstreamHandshake {
NotSent,
PendingSend(Vec<u8>),
Flush,
Sent,
ClosingInResponseToRemote,
BothSidesClosed,
}
#[pin_project::pin_project]
pub struct NotificationsOutSubstream<TSubstream> {
#[pin]
socket: Framed<TSubstream, UviBytes<io::Cursor<Vec<u8>>>>,
}
impl NotificationsIn {
pub fn new(protocol_name: impl Into<Cow<'static, str>>, max_notification_size: u64) -> Self {
NotificationsIn {
protocol_name: protocol_name.into(),
max_notification_size,
}
}
}
impl UpgradeInfo for NotificationsIn {
type Info = Cow<'static, [u8]>;
type InfoIter = iter::Once<Self::Info>;
fn protocol_info(&self) -> Self::InfoIter {
let bytes: Cow<'static, [u8]> = match &self.protocol_name {
Cow::Borrowed(s) => Cow::Borrowed(s.as_bytes()),
Cow::Owned(s) => Cow::Owned(s.as_bytes().to_vec())
};
iter::once(bytes)
}
}
impl<TSubstream> InboundUpgrade<TSubstream> for NotificationsIn
where TSubstream: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
type Output = (Vec<u8>, NotificationsInSubstream<TSubstream>);
type Future = Pin<Box<dyn Future<Output = Result<Self::Output, Self::Error>> + Send>>;
type Error = NotificationsHandshakeError;
fn upgrade_inbound(
self,
mut socket: TSubstream,
_: Self::Info,
) -> Self::Future {
Box::pin(async move {
let initial_message_len = unsigned_varint::aio::read_usize(&mut socket).await?;
if initial_message_len > MAX_HANDSHAKE_SIZE {
return Err(NotificationsHandshakeError::TooLarge {
requested: initial_message_len,
max: MAX_HANDSHAKE_SIZE,
});
}
let mut initial_message = vec![0u8; initial_message_len];
if !initial_message.is_empty() {
socket.read_exact(&mut initial_message).await?;
}
let mut codec = UviBytes::default();
codec.set_max_len(usize::try_from(self.max_notification_size).unwrap_or(usize::max_value()));
let substream = NotificationsInSubstream {
socket: Framed::new(socket, codec),
handshake: NotificationsInSubstreamHandshake::NotSent,
};
Ok((initial_message, substream))
})
}
}
impl<TSubstream> NotificationsInSubstream<TSubstream>
where TSubstream: AsyncRead + AsyncWrite + Unpin,
{
pub fn send_handshake(&mut self, message: impl Into<Vec<u8>>) {
if !matches!(self.handshake, NotificationsInSubstreamHandshake::NotSent) {
error!(target: "sub-libp2p", "Tried to send handshake twice");
return;
}
self.handshake = NotificationsInSubstreamHandshake::PendingSend(message.into());
}
pub fn poll_process(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<Infallible, io::Error>> {
let mut this = self.project();
loop {
match mem::replace(this.handshake, NotificationsInSubstreamHandshake::Sent) {
NotificationsInSubstreamHandshake::PendingSend(msg) =>
match Sink::poll_ready(this.socket.as_mut(), cx) {
Poll::Ready(_) => {
*this.handshake = NotificationsInSubstreamHandshake::Flush;
match Sink::start_send(this.socket.as_mut(), io::Cursor::new(msg)) {
Ok(()) => {},
Err(err) => return Poll::Ready(Err(err)),
}
},
Poll::Pending => {
*this.handshake = NotificationsInSubstreamHandshake::PendingSend(msg);
return Poll::Pending
}
},
NotificationsInSubstreamHandshake::Flush =>
match Sink::poll_flush(this.socket.as_mut(), cx)? {
Poll::Ready(()) =>
*this.handshake = NotificationsInSubstreamHandshake::Sent,
Poll::Pending => {
*this.handshake = NotificationsInSubstreamHandshake::Flush;
return Poll::Pending
}
},
st @ NotificationsInSubstreamHandshake::NotSent |
st @ NotificationsInSubstreamHandshake::Sent |
st @ NotificationsInSubstreamHandshake::ClosingInResponseToRemote |
st @ NotificationsInSubstreamHandshake::BothSidesClosed => {
*this.handshake = st;
return Poll::Pending;
}
}
}
}
}
impl<TSubstream> Stream for NotificationsInSubstream<TSubstream>
where TSubstream: AsyncRead + AsyncWrite + Unpin,
{
type Item = Result<BytesMut, io::Error>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
let mut this = self.project();
loop {
match mem::replace(this.handshake, NotificationsInSubstreamHandshake::Sent) {
NotificationsInSubstreamHandshake::NotSent => {
*this.handshake = NotificationsInSubstreamHandshake::NotSent;
return Poll::Pending
},
NotificationsInSubstreamHandshake::PendingSend(msg) =>
match Sink::poll_ready(this.socket.as_mut(), cx) {
Poll::Ready(_) => {
*this.handshake = NotificationsInSubstreamHandshake::Flush;
match Sink::start_send(this.socket.as_mut(), io::Cursor::new(msg)) {
Ok(()) => {},
Err(err) => return Poll::Ready(Some(Err(err))),
}
},
Poll::Pending => {
*this.handshake = NotificationsInSubstreamHandshake::PendingSend(msg);
return Poll::Pending
}
},
NotificationsInSubstreamHandshake::Flush =>
match Sink::poll_flush(this.socket.as_mut(), cx)? {
Poll::Ready(()) =>
*this.handshake = NotificationsInSubstreamHandshake::Sent,
Poll::Pending => {
*this.handshake = NotificationsInSubstreamHandshake::Flush;
return Poll::Pending
}
},
NotificationsInSubstreamHandshake::Sent => {
match Stream::poll_next(this.socket.as_mut(), cx) {
Poll::Ready(None) => *this.handshake =
NotificationsInSubstreamHandshake::ClosingInResponseToRemote,
Poll::Ready(Some(msg)) => {
*this.handshake = NotificationsInSubstreamHandshake::Sent;
return Poll::Ready(Some(msg))
},
Poll::Pending => {
*this.handshake = NotificationsInSubstreamHandshake::Sent;
return Poll::Pending
},
}
},
NotificationsInSubstreamHandshake::ClosingInResponseToRemote =>
match Sink::poll_close(this.socket.as_mut(), cx)? {
Poll::Ready(()) =>
*this.handshake = NotificationsInSubstreamHandshake::BothSidesClosed,
Poll::Pending => {
*this.handshake = NotificationsInSubstreamHandshake::ClosingInResponseToRemote;
return Poll::Pending
}
},
NotificationsInSubstreamHandshake::BothSidesClosed =>
return Poll::Ready(None),
}
}
}
}
impl NotificationsOut {
pub fn new(
protocol_name: impl Into<Cow<'static, str>>,
initial_message: impl Into<Vec<u8>>,
max_notification_size: u64,
) -> Self {
let initial_message = initial_message.into();
if initial_message.len() > MAX_HANDSHAKE_SIZE {
error!(target: "sub-libp2p", "Outbound networking handshake is above allowed protocol limit");
}
NotificationsOut {
protocol_name: protocol_name.into(),
initial_message,
max_notification_size,
}
}
}
impl UpgradeInfo for NotificationsOut {
type Info = Cow<'static, [u8]>;
type InfoIter = iter::Once<Self::Info>;
fn protocol_info(&self) -> Self::InfoIter {
let bytes: Cow<'static, [u8]> = match &self.protocol_name {
Cow::Borrowed(s) => Cow::Borrowed(s.as_bytes()),
Cow::Owned(s) => Cow::Owned(s.as_bytes().to_vec())
};
iter::once(bytes)
}
}
impl<TSubstream> OutboundUpgrade<TSubstream> for NotificationsOut
where TSubstream: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
type Output = (Vec<u8>, NotificationsOutSubstream<TSubstream>);
type Future = Pin<Box<dyn Future<Output = Result<Self::Output, Self::Error>> + Send>>;
type Error = NotificationsHandshakeError;
fn upgrade_outbound(
self,
mut socket: TSubstream,
_: Self::Info,
) -> Self::Future {
Box::pin(async move {
upgrade::write_with_len_prefix(&mut socket, &self.initial_message).await?;
let handshake_len = unsigned_varint::aio::read_usize(&mut socket).await?;
if handshake_len > MAX_HANDSHAKE_SIZE {
return Err(NotificationsHandshakeError::TooLarge {
requested: handshake_len,
max: MAX_HANDSHAKE_SIZE,
});
}
let mut handshake = vec![0u8; handshake_len];
if !handshake.is_empty() {
socket.read_exact(&mut handshake).await?;
}
let mut codec = UviBytes::default();
codec.set_max_len(usize::try_from(self.max_notification_size).unwrap_or(usize::max_value()));
Ok((handshake, NotificationsOutSubstream {
socket: Framed::new(socket, codec),
}))
})
}
}
impl<TSubstream> Sink<Vec<u8>> for NotificationsOutSubstream<TSubstream>
where TSubstream: AsyncRead + AsyncWrite + Unpin,
{
type Error = NotificationsOutError;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
let mut this = self.project();
Sink::poll_ready(this.socket.as_mut(), cx)
.map_err(NotificationsOutError::Io)
}
fn start_send(self: Pin<&mut Self>, item: Vec<u8>) -> Result<(), Self::Error> {
let mut this = self.project();
Sink::start_send(this.socket.as_mut(), io::Cursor::new(item))
.map_err(NotificationsOutError::Io)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
let mut this = self.project();
Sink::poll_flush(this.socket.as_mut(), cx)
.map_err(NotificationsOutError::Io)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
let mut this = self.project();
Sink::poll_close(this.socket.as_mut(), cx)
.map_err(NotificationsOutError::Io)
}
}
#[derive(Debug, derive_more::From, derive_more::Display)]
pub enum NotificationsHandshakeError {
Io(io::Error),
#[display(fmt = "Initial message or handshake was too large: {}", requested)]
TooLarge {
requested: usize,
max: usize,
},
VarintDecode(unsigned_varint::decode::Error),
}
impl From<unsigned_varint::io::ReadError> for NotificationsHandshakeError {
fn from(err: unsigned_varint::io::ReadError) -> Self {
match err {
unsigned_varint::io::ReadError::Io(err) => NotificationsHandshakeError::Io(err),
unsigned_varint::io::ReadError::Decode(err) => NotificationsHandshakeError::VarintDecode(err),
_ => {
log::warn!("Unrecognized varint decoding error");
NotificationsHandshakeError::Io(From::from(io::ErrorKind::InvalidData))
}
}
}
}
#[derive(Debug, derive_more::From, derive_more::Display)]
pub enum NotificationsOutError {
Io(io::Error),
}
#[cfg(test)]
mod tests {
use super::{NotificationsIn, NotificationsOut};
use async_std::net::{TcpListener, TcpStream};
use futures::{prelude::*, channel::oneshot};
use libp2p::core::upgrade;
use std::borrow::Cow;
#[test]
fn basic_works() {
const PROTO_NAME: Cow<'static, str> = Cow::Borrowed("/test/proto/1");
let (listener_addr_tx, listener_addr_rx) = oneshot::channel();
let client = async_std::task::spawn(async move {
let socket = TcpStream::connect(listener_addr_rx.await.unwrap()).await.unwrap();
let (handshake, mut substream) = upgrade::apply_outbound(
socket,
NotificationsOut::new(PROTO_NAME, &b"initial message"[..], 1024 * 1024),
upgrade::Version::V1
).await.unwrap();
assert_eq!(handshake, b"hello world");
substream.send(b"test message".to_vec()).await.unwrap();
});
async_std::task::block_on(async move {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
listener_addr_tx.send(listener.local_addr().unwrap()).unwrap();
let (socket, _) = listener.accept().await.unwrap();
let (initial_message, mut substream) = upgrade::apply_inbound(
socket,
NotificationsIn::new(PROTO_NAME, 1024 * 1024)
).await.unwrap();
assert_eq!(initial_message, b"initial message");
substream.send_handshake(&b"hello world"[..]);
let msg = substream.next().await.unwrap().unwrap();
assert_eq!(msg.as_ref(), b"test message");
});
async_std::task::block_on(client);
}
#[test]
fn empty_handshake() {
const PROTO_NAME: Cow<'static, str> = Cow::Borrowed("/test/proto/1");
let (listener_addr_tx, listener_addr_rx) = oneshot::channel();
let client = async_std::task::spawn(async move {
let socket = TcpStream::connect(listener_addr_rx.await.unwrap()).await.unwrap();
let (handshake, mut substream) = upgrade::apply_outbound(
socket,
NotificationsOut::new(PROTO_NAME, vec![], 1024 * 1024),
upgrade::Version::V1
).await.unwrap();
assert!(handshake.is_empty());
substream.send(Default::default()).await.unwrap();
});
async_std::task::block_on(async move {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
listener_addr_tx.send(listener.local_addr().unwrap()).unwrap();
let (socket, _) = listener.accept().await.unwrap();
let (initial_message, mut substream) = upgrade::apply_inbound(
socket,
NotificationsIn::new(PROTO_NAME, 1024 * 1024)
).await.unwrap();
assert!(initial_message.is_empty());
substream.send_handshake(vec![]);
let msg = substream.next().await.unwrap().unwrap();
assert!(msg.as_ref().is_empty());
});
async_std::task::block_on(client);
}
#[test]
fn refused() {
const PROTO_NAME: Cow<'static, str> = Cow::Borrowed("/test/proto/1");
let (listener_addr_tx, listener_addr_rx) = oneshot::channel();
let client = async_std::task::spawn(async move {
let socket = TcpStream::connect(listener_addr_rx.await.unwrap()).await.unwrap();
let outcome = upgrade::apply_outbound(
socket,
NotificationsOut::new(PROTO_NAME, &b"hello"[..], 1024 * 1024),
upgrade::Version::V1
).await;
assert!(outcome.is_err());
});
async_std::task::block_on(async move {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
listener_addr_tx.send(listener.local_addr().unwrap()).unwrap();
let (socket, _) = listener.accept().await.unwrap();
let (initial_msg, substream) = upgrade::apply_inbound(
socket,
NotificationsIn::new(PROTO_NAME, 1024 * 1024)
).await.unwrap();
assert_eq!(initial_msg, b"hello");
drop(substream);
});
async_std::task::block_on(client);
}
#[test]
fn large_initial_message_refused() {
const PROTO_NAME: Cow<'static, str> = Cow::Borrowed("/test/proto/1");
let (listener_addr_tx, listener_addr_rx) = oneshot::channel();
let client = async_std::task::spawn(async move {
let socket = TcpStream::connect(listener_addr_rx.await.unwrap()).await.unwrap();
let ret = upgrade::apply_outbound(
socket,
NotificationsOut::new(PROTO_NAME, (0..32768).map(|_| 0).collect::<Vec<_>>(), 1024 * 1024),
upgrade::Version::V1
).await;
assert!(ret.is_err());
});
async_std::task::block_on(async move {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
listener_addr_tx.send(listener.local_addr().unwrap()).unwrap();
let (socket, _) = listener.accept().await.unwrap();
let ret = upgrade::apply_inbound(
socket,
NotificationsIn::new(PROTO_NAME, 1024 * 1024)
).await;
assert!(ret.is_err());
});
async_std::task::block_on(client);
}
#[test]
fn large_handshake_refused() {
const PROTO_NAME: Cow<'static, str> = Cow::Borrowed("/test/proto/1");
let (listener_addr_tx, listener_addr_rx) = oneshot::channel();
let client = async_std::task::spawn(async move {
let socket = TcpStream::connect(listener_addr_rx.await.unwrap()).await.unwrap();
let ret = upgrade::apply_outbound(
socket,
NotificationsOut::new(PROTO_NAME, &b"initial message"[..], 1024 * 1024),
upgrade::Version::V1
).await;
assert!(ret.is_err());
});
async_std::task::block_on(async move {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
listener_addr_tx.send(listener.local_addr().unwrap()).unwrap();
let (socket, _) = listener.accept().await.unwrap();
let (initial_message, mut substream) = upgrade::apply_inbound(
socket,
NotificationsIn::new(PROTO_NAME, 1024 * 1024)
).await.unwrap();
assert_eq!(initial_message, b"initial message");
substream.send_handshake((0..32768).map(|_| 0).collect::<Vec<_>>());
let _ = substream.next().await;
});
async_std::task::block_on(client);
}
}