use std::{ net::SocketAddr, sync::{ atomic::{AtomicBool, Ordering}, Arc, }, }; use anyhow::Error; use futures_util::{SinkExt, StreamExt}; use serde::Serialize; use tokio::{net::TcpStream, sync::Mutex}; use tokio_tungstenite::{tungstenite::Message, WebSocketStream}; use crate::{ authentication::AuthenticationTokenGranter, backend::{DiscoveryBackend, RepositoryBackend, UserBackend}, connection::ConnectionError, listener::Listeners, model::{authenticated::NetworkMessage, instance::Instance}, }; use super::{ authentication::authentication_handle, connection_worker, handshake::handshake_handle, repository::repository_handle, user::user_handle, Connections, }; pub async fn connection_wrapper( mut socket: WebSocketStream, listeners: Arc>, connections: Arc>, repository_backend: Arc>, user_backend: Arc>, auth_granter: Arc>, discovery_backend: Arc>, addr: SocketAddr, instance: impl ToOwned, ) { let mut connection_state = ConnectionState { socket: Arc::new(Mutex::new(socket)), listeners, connections, repository_backend, user_backend, auth_granter, discovery_backend, addr, instance: instance.to_owned(), handshaked: Arc::new(AtomicBool::new(false)), }; let mut handshaked = false; loop { let mut socket = connection_state.socket.lock().await; let message = socket.next().await; drop(socket); match message { Some(Ok(message)) => { let payload = match message { Message::Binary(payload) => payload, Message::Ping(_) => { let mut socket = connection_state.socket.lock().await; socket.send(Message::Pong(vec![])).await; drop(socket); continue; } Message::Close(_) => return, _ => continue, }; let message = NetworkMessage(payload); if !handshaked { if handshake_handle(&message, &connection_state).await.is_ok() { if connection_state.handshaked.load(Ordering::SeqCst) { handshaked = true; } } } else { if authentication_handle(&message, &connection_state) .await .is_ok() { continue; } else if repository_handle(&message, &connection_state).await.is_ok() { continue; } else if user_handle(&message, &connection_state).await.is_ok() { continue; } else { error!("Message completely unhandled"); continue; } } } _ => { error!("Closing connection for {}", addr); return; } } } } #[derive(Clone)] pub struct ConnectionState { socket: Arc>>, pub listeners: Arc>, pub connections: Arc>, pub repository_backend: Arc>, pub user_backend: Arc>, pub auth_granter: Arc>, pub discovery_backend: Arc>, pub addr: SocketAddr, pub instance: Instance, pub handshaked: Arc, } impl ConnectionState { pub async fn send(&self, message: T) -> Result<(), Error> { self.socket .lock() .await .send(Message::Binary(serde_json::to_vec(&message)?)) .await?; Ok(()) } }