diff --git a/src/rusteriaServer/src/hysteria.rs b/src/rusteriaServer/src/hysteria.rs index e846f45..2da7dbf 100644 --- a/src/rusteriaServer/src/hysteria.rs +++ b/src/rusteriaServer/src/hysteria.rs @@ -9,6 +9,7 @@ use lib_rusteria::proto::HysteriaTcpRequest; use log::{debug, error, info, trace, warn}; use quiche::{Connection, Shutdown}; use std::collections::BTreeMap; +use std::ptr::read; use tokio::select; use tokio::sync::mpsc; use tokio_quiche::buf_factory::BufFactory; @@ -182,7 +183,7 @@ impl HysDriver { for frame in responses.response.iter() { sending_results.push(match frame { OutboundFrame::Headers(h) => { - + if self .h3conn_as_mut() .send_response(qconn, stream_id, h.as_ref(),responses.auth_res) @@ -236,6 +237,20 @@ impl HysDriver { // should send one of the HysEvent to the receiver fn handle_quic_request(&mut self, qconn: &mut Connection) -> QuicResult<()> { for stream_id in qconn.readable() { + if let Some(context) = self.quic_context_map.get_mut(&stream_id) { + if !context.read_fin && context.fin { + //discard all remaining info in the channel + if let Err(e) = qconn.stream_shutdown(stream_id, Shutdown::Read, 0) { + trace!( + "{} stream {} shutdown error: {}", + qconn.trace_id(), + stream_id, + e + ); + } + continue; + } + } let mut offset = 0; while qconn.stream_readable(stream_id) { //Fin signal will be handled at the end @@ -277,7 +292,7 @@ impl HysDriver { } } else { let (tx, rx) = mpsc::channel(MAX_BUF_SIZE); - let mut event: Option = None; + let mut events: Vec = Vec::new(); let mut read_data = std::mem::replace(&mut self.buffer, BufFactory::get_max_buf()); read_data.truncate(offset); //determine if this is a new proxy request or payload of an existing request @@ -285,7 +300,7 @@ impl HysDriver { // if self.quic_context_map.get_mut(&stream_id).is_none() { if self.quic_context_map.get_mut(&stream_id).is_none() { if let Some(req) = HysteriaTcpRequest::from_bytes(&**read_data) { - let _ = event.insert(HysEvent::QuicEvent( + let _ = events.push(HysEvent::QuicEvent( stream_id, ProxyEvent::Request(req.url, tx), )); @@ -297,40 +312,43 @@ impl HysDriver { read_fin: false, }, ); - let _ = read_data.split_at(offset); + let (_ ,temp) = read_data.split_at(req.offset); + read_data = BufFactory::buf_from_slice(temp); } else { error!( "client is sending invalid initial proxy request on stream {}", stream_id ); } + } + if !read_data.is_empty(){ + //have to clone + let inbound_bytes = Bytes::copy_from_slice(&read_data); + let _ = events.push(HysEvent::QuicEvent( + stream_id, + ProxyEvent::Payload(inbound_bytes), + )); } - else { - //have to clone - let inbound_bytes = Bytes::copy_from_slice(&read_data); - let _ = event.insert(HysEvent::QuicEvent( - stream_id, - ProxyEvent::Payload(inbound_bytes), - )); + + if(!events.is_empty()){ + self.waiting_streams + .push(WaitForStream::QuicStream(WaitForQuicStream { + stream_id, + chan: Some(rx), + })); } - trace!( + for event in events { + trace!( "{} sending event to the handler: {:?}", qconn.trace_id(), event ); - if event.is_some() { //TODO: investigate why client ignores the fin signal - if let Err(e) = self.event_sender.send(event.unwrap()) { + if let Err(e) = self.event_sender.send(event) { //TODO receiver closed error!("Failed to send event to the handler: {}", e); } } - - self.waiting_streams - .push(WaitForStream::QuicStream(WaitForQuicStream { - stream_id, - chan: Some(rx), - })); } } Ok(()) @@ -443,9 +461,9 @@ impl ApplicationOverQuic for HysDriver { //TODO: proper logging trace!("new unverified stream id: {}", stream_id); } - } + } if let Some(context) = self.quic_context_map.get_mut(&stream_id) { - // if !context.queued_bytes.is_empty() { + if !context.queued_bytes.is_empty() { debug!( "writing len {} bytes quic traffic to the client", context.queued_bytes.len() @@ -454,22 +472,10 @@ impl ApplicationOverQuic for HysDriver { "bytes content in ascii: {}", String::from_utf8_lossy(&context.queued_bytes) ); - if !context.read_fin && context.fin { - //discard all remaining info in the channel - if let Err(e) = qconn.stream_shutdown(stream_id, Shutdown::Read, 0) { - trace!( - "{} stream {} shutdown error: {}", - qconn.trace_id(), - stream_id, - e - ); - } - } let sent = qconn.stream_send(stream_id, &*context.queued_bytes, false)?; context.queued_bytes = context.queued_bytes.split_off(sent); context.queued_bytes.reserve(65535); - // } - if context.fin && context.queued_bytes.is_empty() { + }else if context.fin { if let Err(e) = qconn.stream_shutdown(stream_id, Shutdown::Write, 0) { error!( "{} stream {} shutdown error: {}", @@ -489,7 +495,6 @@ impl ApplicationOverQuic for HysDriver { self.quic_context_map.remove(&stream_id); info!("{} stream {} fin", qconn.trace_id(), stream_id); } - // } } } Ok(()) @@ -502,7 +507,7 @@ impl ApplicationOverQuic for HysDriver { connection_result: &QuicResult<()>, ){ self.process_writes(qconn).unwrap(); - + warn!("{}: quic connection closed", qconn.trace_id()); } }