diff --git a/Cargo.lock b/Cargo.lock index 1107fa0..5fc3dff 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -629,7 +629,7 @@ dependencies = [ [[package]] name = "giterated-daemon" -version = "0.0.5" +version = "0.0.6" dependencies = [ "aes-gcm", "anyhow", diff --git a/Cargo.toml b/Cargo.toml index 838bcb5..2522895 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "giterated-daemon" -version = "0.0.5" +version = "0.0.6" edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html diff --git a/src/connection/wrapper.rs b/src/connection/wrapper.rs index 0e44212..a78830a 100644 --- a/src/connection/wrapper.rs +++ b/src/connection/wrapper.rs @@ -1,10 +1,13 @@ use std::{ net::SocketAddr, - sync::{atomic::AtomicBool, Arc}, + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, + }, }; use anyhow::Error; -use futures_util::SinkExt; +use futures_util::{SinkExt, StreamExt}; use serde::Serialize; use tokio::{net::TcpStream, sync::Mutex}; use tokio_tungstenite::{tungstenite::Message, WebSocketStream}; @@ -14,10 +17,13 @@ use crate::{ backend::{DiscoveryBackend, RepositoryBackend, UserBackend}, connection::ConnectionError, listener::Listeners, - model::instance::Instance, + model::{authenticated::NetworkMessage, instance::Instance}, }; -use super::{connection_worker, Connections}; +use super::{ + authentication::authentication_handle, connection_worker, handshake::handshake_handle, + repository::repository_handle, user::user_handle, Connections, +}; pub async fn connection_wrapper( mut socket: WebSocketStream, @@ -28,26 +34,68 @@ pub async fn connection_wrapper( auth_granter: Arc>, discovery_backend: Arc>, addr: SocketAddr, + instance: impl ToOwned, ) { + let mut connection_state = ConnectionState { + socket: Arc::new(Mutex::new(socket)), + listeners, + connections, + repository_backend, + user_backend, + auth_granter, + discovery_backend, + addr, + instance: instance.to_owned(), + handshaked: Arc::new(AtomicBool::new(false)), + }; + let mut handshaked = false; + loop { - if let Err(e) = connection_worker( - &mut socket, - &mut handshaked, - &listeners, - &connections, - &repository_backend, - &user_backend, - &auth_granter, - &discovery_backend, - &addr, - ) - .await - { - error!("Error handling message: {:?}", e); + let mut socket = connection_state.socket.lock().await; + let message = socket.next().await; + drop(socket); - if let ConnectionError::Shutdown = &e { - info!("Closing connection {}", addr); + match message { + Some(Ok(message)) => { + let payload = match message { + Message::Binary(payload) => payload, + Message::Ping(_) => { + let mut socket = connection_state.socket.lock().await; + socket.send(Message::Pong(vec![])).await; + drop(socket); + continue; + } + Message::Close(_) => return, + _ => continue, + }; + + let message = NetworkMessage(payload); + + if !handshaked { + if handshake_handle(&message, &connection_state).await.is_ok() { + if connection_state.handshaked.load(Ordering::SeqCst) { + handshaked = true; + } + } + } else { + if authentication_handle(&message, &connection_state) + .await + .is_ok() + { + continue; + } else if repository_handle(&message, &connection_state).await.is_ok() { + continue; + } else if user_handle(&message, &connection_state).await.is_ok() { + continue; + } else { + error!("Message completely unhandled"); + continue; + } + } + } + _ => { + error!("Closing connection for {}", addr); return; } } diff --git a/src/lib.rs b/src/lib.rs index 74a8a25..bc3ddf5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -10,8 +10,6 @@ pub mod listener; pub mod messages; pub mod model; -pub(crate) use std::error::Error as StdError; - #[macro_use] extern crate tracing; @@ -20,7 +18,7 @@ pub fn version() -> Version { } pub fn validate_version(other: &Version) -> bool { - let version_req = VersionReq::from_str("=0.0.5").unwrap(); + let version_req = VersionReq::from_str("=0.0.6").unwrap(); version_req.matches(other) } diff --git a/src/main.rs b/src/main.rs index 648fe67..4f84048 100644 --- a/src/main.rs +++ b/src/main.rs @@ -118,6 +118,7 @@ async fn main() -> Result<(), Error> { token_granter.clone(), discovery_backend.clone(), address, + Instance::from_str("giterated.dev").unwrap(), )), };