diff --git a/crates/libtortillas/src/engine/actor.rs b/crates/libtortillas/src/engine/actor.rs index dadab926..7c4b533c 100644 --- a/crates/libtortillas/src/engine/actor.rs +++ b/crates/libtortillas/src/engine/actor.rs @@ -195,7 +195,7 @@ impl Actor for EngineActor { signal = mailbox_rx.recv() => signal, peer_stream = self.tcp_socket.accept() => match peer_stream { Ok((stream, _)) => { - let peer_stream = PeerStream::Tcp(stream); + let peer_stream = PeerStream::tcp(stream); let Some(actor_ref) = actor_ref.upgrade() else { error!("Failed to upgrade weak actor reference"); @@ -218,7 +218,7 @@ impl Actor for EngineActor { }, peer_stream = self.utp_socket.accept() => match peer_stream { Ok(stream) => { - let peer_stream = PeerStream::Utp(stream); + let peer_stream = PeerStream::utp(stream); let Some(actor_ref) = actor_ref.upgrade() else { error!("Failed to upgrade weak actor reference"); diff --git a/crates/libtortillas/src/peer/actor.rs b/crates/libtortillas/src/peer/actor.rs index 7ef92769..d543df11 100644 --- a/crates/libtortillas/src/peer/actor.rs +++ b/crates/libtortillas/src/peer/actor.rs @@ -402,10 +402,16 @@ impl Actor for PeerActor { // ourselves. } - Ok(tokio::select! { - signal = mailbox_rx.recv() => signal, - msg = self.stream.recv() => self.check_message_signal(actor_ref, msg) - }) + loop { + tokio::select! { + signal = mailbox_rx.recv() => return Ok(signal), + msg = self.stream.recv() => { + if let Some(signal) = self.check_message_signal(actor_ref.clone(), msg) { + return Ok(Some(signal)); + } + } + } + } } } @@ -646,17 +652,34 @@ pub(crate) mod commands { } #[message(derive(Clone, Debug))] - pub(crate) fn cancel_piece(&mut self, index: usize, begin: usize, length: usize) { + #[instrument(skip(self), fields(peer_addr = %self.stream, peer_id = %self.peer.id.unwrap()))] + pub(crate) async fn cancel_piece(&mut self, index: usize, begin: usize, length: usize) { if !self .pending_block_requests .contains(&(index, begin, length)) { return; // Silently ignore if we don't have the request } - // TODO: Refactor PeerStream to allow for cancelling requests - // This can't be done yet because it would require a refactor of PeerStream, for - // now we'll just ignore the request. + self.pending_block_requests.remove(&(index, begin, length)); + + if let Err(err) = self + .stream + .send(PeerMessages::Cancel( + index as u32, + begin as u32, + length as u32, + )) + .await + { + warn!( + ?err, + piece_index = index, + begin, + length, + "Failed to send cancel request" + ); + } } #[message(derive(Clone, Debug))] diff --git a/crates/libtortillas/src/protocol/stream.rs b/crates/libtortillas/src/protocol/stream.rs index 15fdbd30..0ac32220 100644 --- a/crates/libtortillas/src/protocol/stream.rs +++ b/crates/libtortillas/src/protocol/stream.rs @@ -9,7 +9,7 @@ use std::{ use anyhow::Result; use async_trait::async_trait; -use bytes::BytesMut; +use bytes::{Buf, BytesMut}; use librqbit_utp::{UtpSocketUdp, UtpStream, UtpStreamReadHalf, UtpStreamWriteHalf}; use tokio::{ io::{self, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf}, @@ -30,8 +30,14 @@ use crate::{ /// possible to simply make a blanket function as it implements both [AsyncRead] /// and [AsyncWrite] pub enum PeerStream { - Tcp(TcpStream), - Utp(UtpStream), + Tcp { + stream: TcpStream, + read_buffer: BytesMut, + }, + Utp { + stream: UtpStream, + read_buffer: BytesMut, + }, } #[async_trait] @@ -125,6 +131,20 @@ pub trait PeerRecv: AsyncRead + Unpin { } impl PeerStream { + pub fn tcp(stream: TcpStream) -> Self { + Self::Tcp { + stream, + read_buffer: BytesMut::new(), + } + } + + pub fn utp(stream: UtpStream) -> Self { + Self::Utp { + stream, + read_buffer: BytesMut::new(), + } + } + /// Connect to a peer with the given peer_addr (ip & port in the form of a /// [SocketAddr]) /// @@ -143,16 +163,16 @@ impl PeerStream { tokio::select! { stream = utp_socket.connect(peer_addr) => { trace!(protocol = "uTP", "Connected to peer"); - Ok(PeerStream::Utp(stream?)) + Ok(PeerStream::utp(stream?)) }, stream = TcpStream::connect(peer_addr) => { trace!(protocol = "TCP", "Connected to peer"); - Ok(PeerStream::Tcp(stream?)) + Ok(PeerStream::tcp(stream?)) } } } else { trace!(protocol = "TCP", "Connecting to peer"); - Ok(PeerStream::Tcp(TcpStream::connect(peer_addr).await?)) + Ok(PeerStream::tcp(TcpStream::connect(peer_addr).await?)) } } @@ -196,19 +216,38 @@ impl PeerStream { /// Returns the addr of the connected peer pub fn remote_addr(&self) -> Result { match self { - PeerStream::Tcp(s) => Ok(s.peer_addr()?), - PeerStream::Utp(s) => Ok(s.remote_addr()), + PeerStream::Tcp { stream, .. } => Ok(stream.peer_addr()?), + PeerStream::Utp { stream, .. } => Ok(stream.remote_addr()), } } - /// Splits the PeerStream into separate reader and writer halves + /// Splits the PeerStream into separate reader and writer halves. + /// + /// Panics if `read_buffer` contains bytes buffered by `PeerRecv::recv()`. + /// Callers must split before using buffered reads, or ensure + /// `recv_handshake_message()` and other direct reads did not leave data for + /// `PeerRecv::recv()` to process. pub fn split(self) -> (PeerReader, PeerWriter) { match self { - PeerStream::Tcp(stream) => { + PeerStream::Tcp { + stream, + read_buffer, + } => { + assert!( + read_buffer.is_empty(), + "PeerStream::split would discard buffered read data" + ); let (reader, writer) = stream.into_split(); (PeerReader::Tcp(reader), PeerWriter::Tcp(writer)) } - PeerStream::Utp(stream) => { + PeerStream::Utp { + stream, + read_buffer, + } => { + assert!( + read_buffer.is_empty(), + "PeerStream::split would discard buffered read data" + ); let (reader, writer) = stream.split(); (PeerReader::Utp(reader), PeerWriter::Utp(writer)) } @@ -217,8 +256,8 @@ impl PeerStream { pub fn protocol(&self) -> String { match self { - PeerStream::Tcp(_) => "TCP".to_string(), - PeerStream::Utp(_) => "uTP".to_string(), + PeerStream::Tcp { .. } => "TCP".to_string(), + PeerStream::Utp { .. } => "uTP".to_string(), } } } @@ -233,15 +272,76 @@ impl Display for PeerStream { } impl PeerSend for PeerStream {} -impl PeerRecv for PeerStream {} +#[async_trait] +impl PeerRecv for PeerStream { + async fn recv(&mut self) -> Result { + loop { + match self { + PeerStream::Tcp { + stream, + read_buffer, + } => { + if let Some(message) = buffered_message(read_buffer) { + return message; + } + if stream.read_buf(read_buffer).await? == 0 { + return Err(PeerActorError::ReceiveFailed(io::Error::new( + io::ErrorKind::UnexpectedEof, + "peer closed connection", + ))); + } + } + PeerStream::Utp { + stream, + read_buffer, + } => { + if let Some(message) = buffered_message(read_buffer) { + return message; + } + if stream.read_buf(read_buffer).await? == 0 { + return Err(PeerActorError::ReceiveFailed(io::Error::new( + io::ErrorKind::UnexpectedEof, + "peer closed connection", + ))); + } + } + } + } + } +} + +fn buffered_message(read_buffer: &mut BytesMut) -> Option> { + if read_buffer.len() < 4 { + return None; + } + + let length = u32::from_be_bytes( + read_buffer[..4] + .try_into() + .expect("slice has exactly 4 bytes"), + ) as usize; + let frame_len = 4 + length; + + if length == 0 { + read_buffer.advance(4); + return Some(Ok(PeerMessages::KeepAlive)); + } + + if read_buffer.len() < frame_len { + return None; + } + + let frame = read_buffer.split_to(frame_len).freeze(); + Some(PeerMessages::from_bytes(frame)) +} impl AsyncRead for PeerStream { fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll> { match &mut *self { - PeerStream::Tcp(s) => Pin::new(s).poll_read(cx, buf), - PeerStream::Utp(s) => Pin::new(s).poll_read(cx, buf), + PeerStream::Tcp { stream, .. } => Pin::new(stream).poll_read(cx, buf), + PeerStream::Utp { stream, .. } => Pin::new(stream).poll_read(cx, buf), } } } @@ -251,22 +351,22 @@ impl AsyncWrite for PeerStream { mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { match &mut *self { - PeerStream::Tcp(s) => Pin::new(s).poll_write(cx, buf), - PeerStream::Utp(s) => Pin::new(s).poll_write(cx, buf), + PeerStream::Tcp { stream, .. } => Pin::new(stream).poll_write(cx, buf), + PeerStream::Utp { stream, .. } => Pin::new(stream).poll_write(cx, buf), } } fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match &mut *self { - PeerStream::Tcp(s) => Pin::new(s).poll_flush(cx), - PeerStream::Utp(s) => Pin::new(s).poll_flush(cx), + PeerStream::Tcp { stream, .. } => Pin::new(stream).poll_flush(cx), + PeerStream::Utp { stream, .. } => Pin::new(stream).poll_flush(cx), } } fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match &mut *self { - PeerStream::Tcp(s) => Pin::new(s).poll_shutdown(cx), - PeerStream::Utp(s) => Pin::new(s).poll_shutdown(cx), + PeerStream::Tcp { stream, .. } => Pin::new(stream).poll_shutdown(cx), + PeerStream::Utp { stream, .. } => Pin::new(stream).poll_shutdown(cx), } } } @@ -392,7 +492,7 @@ mod tests { // Spawn client that sends handshake let client_info_hash = info_hash.clone(); let client = tokio::spawn(async move { - let mut stream = PeerStream::Tcp(TcpStream::connect(addr).await.unwrap()); + let mut stream = PeerStream::tcp(TcpStream::connect(addr).await.unwrap()); stream .send_handshake(client_id, client_info_hash) @@ -405,7 +505,7 @@ mod tests { .await .expect("client should connect before timeout") .unwrap(); - let mut peer_stream = PeerStream::Tcp(stream); + let mut peer_stream = PeerStream::tcp(stream); let (incoming_id, _) = timeout(Duration::from_secs(1), peer_stream.recv_handshake()) .await @@ -435,7 +535,7 @@ mod tests { .await .expect("client should connect before timeout") .unwrap(); - let mut peer_stream = PeerStream::Tcp(stream); + let mut peer_stream = PeerStream::tcp(stream); let received_handshake = timeout(Duration::from_secs(1), peer_stream.recv_handshake_message()) diff --git a/crates/libtortillas/src/torrent/swarm.rs b/crates/libtortillas/src/torrent/swarm.rs index 6dae3148..9e85c758 100644 --- a/crates/libtortillas/src/torrent/swarm.rs +++ b/crates/libtortillas/src/torrent/swarm.rs @@ -65,7 +65,7 @@ impl TorrentActor { } } Err(err) => { - trace!(error = %err, "Failed to connect to peer; exiting"); + trace!(error = %err, peer_addr = %peer.socket_addr(), "Failed to connect to peer; exiting"); return; } }