Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions crates/libtortillas/src/engine/actor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ impl Actor for EngineActor {
signal = mailbox_rx.recv() => signal,
peer_stream = self.tcp_socket.accept() => match peer_stream {
Ok((stream, _)) => {
let peer_stream = PeerStream::Tcp(stream);
let peer_stream = PeerStream::tcp(stream);

let Some(actor_ref) = actor_ref.upgrade() else {
error!("Failed to upgrade weak actor reference");
Expand All @@ -218,7 +218,7 @@ impl Actor for EngineActor {
},
peer_stream = self.utp_socket.accept() => match peer_stream {
Ok(stream) => {
let peer_stream = PeerStream::Utp(stream);
let peer_stream = PeerStream::utp(stream);

let Some(actor_ref) = actor_ref.upgrade() else {
error!("Failed to upgrade weak actor reference");
Expand Down
39 changes: 31 additions & 8 deletions crates/libtortillas/src/peer/actor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -402,10 +402,16 @@ impl Actor for PeerActor {
// ourselves.
}

Ok(tokio::select! {
signal = mailbox_rx.recv() => signal,
msg = self.stream.recv() => self.check_message_signal(actor_ref, msg)
})
loop {
tokio::select! {
signal = mailbox_rx.recv() => return Ok(signal),
msg = self.stream.recv() => {
if let Some(signal) = self.check_message_signal(actor_ref.clone(), msg) {
return Ok(Some(signal));
}
}
}
}
}
}

Expand Down Expand Up @@ -646,17 +652,34 @@ pub(crate) mod commands {
}

#[message(derive(Clone, Debug))]
pub(crate) fn cancel_piece(&mut self, index: usize, begin: usize, length: usize) {
#[instrument(skip(self), fields(peer_addr = %self.stream, peer_id = %self.peer.id.unwrap()))]
pub(crate) async fn cancel_piece(&mut self, index: usize, begin: usize, length: usize) {
if !self
.pending_block_requests
.contains(&(index, begin, length))
{
return; // Silently ignore if we don't have the request
}
// TODO: Refactor PeerStream to allow for cancelling requests
// This can't be done yet because it would require a refactor of PeerStream, for
// now we'll just ignore the request.

self.pending_block_requests.remove(&(index, begin, length));

if let Err(err) = self
.stream
.send(PeerMessages::Cancel(
index as u32,
begin as u32,
length as u32,
))
.await
{
warn!(
?err,
piece_index = index,
begin,
length,
"Failed to send cancel request"
);
}
}

#[message(derive(Clone, Debug))]
Expand Down
150 changes: 125 additions & 25 deletions crates/libtortillas/src/protocol/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use std::{

use anyhow::Result;
use async_trait::async_trait;
use bytes::BytesMut;
use bytes::{Buf, BytesMut};
use librqbit_utp::{UtpSocketUdp, UtpStream, UtpStreamReadHalf, UtpStreamWriteHalf};
use tokio::{
io::{self, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf},
Expand All @@ -30,8 +30,14 @@ use crate::{
/// possible to simply make a blanket function as it implements both [AsyncRead]
/// and [AsyncWrite]
pub enum PeerStream {
Tcp(TcpStream),
Utp(UtpStream),
Tcp {
stream: TcpStream,
read_buffer: BytesMut,
},
Utp {
stream: UtpStream,
read_buffer: BytesMut,
},
}

#[async_trait]
Expand Down Expand Up @@ -125,6 +131,20 @@ pub trait PeerRecv: AsyncRead + Unpin {
}

impl PeerStream {
pub fn tcp(stream: TcpStream) -> Self {
Self::Tcp {
stream,
read_buffer: BytesMut::new(),
}
}

pub fn utp(stream: UtpStream) -> Self {
Self::Utp {
stream,
read_buffer: BytesMut::new(),
}
}

/// Connect to a peer with the given peer_addr (ip & port in the form of a
/// [SocketAddr])
///
Expand All @@ -143,16 +163,16 @@ impl PeerStream {
tokio::select! {
stream = utp_socket.connect(peer_addr) => {
trace!(protocol = "uTP", "Connected to peer");
Ok(PeerStream::Utp(stream?))
Ok(PeerStream::utp(stream?))
},
stream = TcpStream::connect(peer_addr) => {
trace!(protocol = "TCP", "Connected to peer");
Ok(PeerStream::Tcp(stream?))
Ok(PeerStream::tcp(stream?))
}
}
} else {
trace!(protocol = "TCP", "Connecting to peer");
Ok(PeerStream::Tcp(TcpStream::connect(peer_addr).await?))
Ok(PeerStream::tcp(TcpStream::connect(peer_addr).await?))
}
}

Expand Down Expand Up @@ -196,19 +216,38 @@ impl PeerStream {
/// Returns the addr of the connected peer
pub fn remote_addr(&self) -> Result<SocketAddr> {
match self {
PeerStream::Tcp(s) => Ok(s.peer_addr()?),
PeerStream::Utp(s) => Ok(s.remote_addr()),
PeerStream::Tcp { stream, .. } => Ok(stream.peer_addr()?),
PeerStream::Utp { stream, .. } => Ok(stream.remote_addr()),
}
}

/// Splits the PeerStream into separate reader and writer halves
/// Splits the PeerStream into separate reader and writer halves.
///
/// Panics if `read_buffer` contains bytes buffered by `PeerRecv::recv()`.
/// Callers must split before using buffered reads, or ensure
/// `recv_handshake_message()` and other direct reads did not leave data for
/// `PeerRecv::recv()` to process.
pub fn split(self) -> (PeerReader, PeerWriter) {
match self {
PeerStream::Tcp(stream) => {
PeerStream::Tcp {
stream,
read_buffer,
} => {
assert!(
read_buffer.is_empty(),
"PeerStream::split would discard buffered read data"
);
let (reader, writer) = stream.into_split();
(PeerReader::Tcp(reader), PeerWriter::Tcp(writer))
}
PeerStream::Utp(stream) => {
PeerStream::Utp {
stream,
read_buffer,
} => {
assert!(
read_buffer.is_empty(),
"PeerStream::split would discard buffered read data"
);
let (reader, writer) = stream.split();
(PeerReader::Utp(reader), PeerWriter::Utp(writer))
}
Expand All @@ -217,8 +256,8 @@ impl PeerStream {

pub fn protocol(&self) -> String {
match self {
PeerStream::Tcp(_) => "TCP".to_string(),
PeerStream::Utp(_) => "uTP".to_string(),
PeerStream::Tcp { .. } => "TCP".to_string(),
PeerStream::Utp { .. } => "uTP".to_string(),
}
}
}
Expand All @@ -233,15 +272,76 @@ impl Display for PeerStream {
}

impl PeerSend for PeerStream {}
impl PeerRecv for PeerStream {}
#[async_trait]
impl PeerRecv for PeerStream {
async fn recv(&mut self) -> Result<PeerMessages, PeerActorError> {
loop {
match self {
PeerStream::Tcp {
stream,
read_buffer,
} => {
if let Some(message) = buffered_message(read_buffer) {
return message;
}
if stream.read_buf(read_buffer).await? == 0 {
return Err(PeerActorError::ReceiveFailed(io::Error::new(
io::ErrorKind::UnexpectedEof,
"peer closed connection",
)));
}
}
PeerStream::Utp {
stream,
read_buffer,
} => {
if let Some(message) = buffered_message(read_buffer) {
return message;
}
if stream.read_buf(read_buffer).await? == 0 {
return Err(PeerActorError::ReceiveFailed(io::Error::new(
io::ErrorKind::UnexpectedEof,
"peer closed connection",
)));
}
}
}
}
}
}

fn buffered_message(read_buffer: &mut BytesMut) -> Option<Result<PeerMessages, PeerActorError>> {
if read_buffer.len() < 4 {
return None;
}

let length = u32::from_be_bytes(
read_buffer[..4]
.try_into()
.expect("slice has exactly 4 bytes"),
) as usize;
let frame_len = 4 + length;

if length == 0 {
read_buffer.advance(4);
return Some(Ok(PeerMessages::KeepAlive));
}

if read_buffer.len() < frame_len {
return None;
}

let frame = read_buffer.split_to(frame_len).freeze();
Some(PeerMessages::from_bytes(frame))
}

impl AsyncRead for PeerStream {
fn poll_read(
mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
match &mut *self {
PeerStream::Tcp(s) => Pin::new(s).poll_read(cx, buf),
PeerStream::Utp(s) => Pin::new(s).poll_read(cx, buf),
PeerStream::Tcp { stream, .. } => Pin::new(stream).poll_read(cx, buf),
PeerStream::Utp { stream, .. } => Pin::new(stream).poll_read(cx, buf),
}
}
}
Expand All @@ -251,22 +351,22 @@ impl AsyncWrite for PeerStream {
mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8],
) -> Poll<Result<usize, io::Error>> {
match &mut *self {
PeerStream::Tcp(s) => Pin::new(s).poll_write(cx, buf),
PeerStream::Utp(s) => Pin::new(s).poll_write(cx, buf),
PeerStream::Tcp { stream, .. } => Pin::new(stream).poll_write(cx, buf),
PeerStream::Utp { stream, .. } => Pin::new(stream).poll_write(cx, buf),
}
}

fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
match &mut *self {
PeerStream::Tcp(s) => Pin::new(s).poll_flush(cx),
PeerStream::Utp(s) => Pin::new(s).poll_flush(cx),
PeerStream::Tcp { stream, .. } => Pin::new(stream).poll_flush(cx),
PeerStream::Utp { stream, .. } => Pin::new(stream).poll_flush(cx),
}
}

fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
match &mut *self {
PeerStream::Tcp(s) => Pin::new(s).poll_shutdown(cx),
PeerStream::Utp(s) => Pin::new(s).poll_shutdown(cx),
PeerStream::Tcp { stream, .. } => Pin::new(stream).poll_shutdown(cx),
PeerStream::Utp { stream, .. } => Pin::new(stream).poll_shutdown(cx),
}
}
}
Expand Down Expand Up @@ -392,7 +492,7 @@ mod tests {
// Spawn client that sends handshake
let client_info_hash = info_hash.clone();
let client = tokio::spawn(async move {
let mut stream = PeerStream::Tcp(TcpStream::connect(addr).await.unwrap());
let mut stream = PeerStream::tcp(TcpStream::connect(addr).await.unwrap());

stream
.send_handshake(client_id, client_info_hash)
Expand All @@ -405,7 +505,7 @@ mod tests {
.await
.expect("client should connect before timeout")
.unwrap();
let mut peer_stream = PeerStream::Tcp(stream);
let mut peer_stream = PeerStream::tcp(stream);

let (incoming_id, _) = timeout(Duration::from_secs(1), peer_stream.recv_handshake())
.await
Expand Down Expand Up @@ -435,7 +535,7 @@ mod tests {
.await
.expect("client should connect before timeout")
.unwrap();
let mut peer_stream = PeerStream::Tcp(stream);
let mut peer_stream = PeerStream::tcp(stream);

let received_handshake =
timeout(Duration::from_secs(1), peer_stream.recv_handshake_message())
Expand Down
2 changes: 1 addition & 1 deletion crates/libtortillas/src/torrent/swarm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ impl TorrentActor {
}
}
Err(err) => {
trace!(error = %err, "Failed to connect to peer; exiting");
trace!(error = %err, peer_addr = %peer.socket_addr(), "Failed to connect to peer; exiting");
return;
}
}
Expand Down
Loading