use std::{ collections::HashMap, net::SocketAddr, sync::{ atomic::{AtomicBool, Ordering}, Arc, }, }; use anyhow::Error; use futures_util::{SinkExt, StreamExt}; use giterated_models::{ messages::error::ConnectionError, model::{ authenticated::{Authenticated, AuthenticatedPayload}, instance::Instance, }, }; use rsa::RsaPublicKey; use serde::Serialize; use serde_json::Value; use tokio::{ net::TcpStream, sync::{Mutex, RwLock}, }; use tokio_tungstenite::{tungstenite::Message, WebSocketStream}; use crate::{ authentication::AuthenticationTokenGranter, backend::{RepositoryBackend, UserBackend}, connection::forwarded::wrap_forwarded, federation::connections::InstanceConnections, keys::PublicKeyCache, message::NetworkMessage, }; use super::{ authentication::authentication_handle, handshake::handshake_handle, repository::repository_handle, user::user_handle, Connections, }; pub async fn connection_wrapper( socket: WebSocketStream, connections: Arc>, repository_backend: Arc>, user_backend: Arc>, auth_granter: Arc>, addr: SocketAddr, instance: impl ToOwned, instance_connections: Arc>, ) { let connection_state = ConnectionState { socket: Arc::new(Mutex::new(socket)), connections, repository_backend, user_backend, auth_granter, addr, instance: instance.to_owned(), handshaked: Arc::new(AtomicBool::new(false)), key_cache: Arc::default(), }; let mut handshaked = false; loop { let mut socket = connection_state.socket.lock().await; let message = socket.next().await; drop(socket); match message { Some(Ok(message)) => { let payload = match message { Message::Binary(payload) => payload, Message::Ping(_) => { let mut socket = connection_state.socket.lock().await; let _ = socket.send(Message::Pong(vec![])).await; drop(socket); continue; } Message::Close(_) => return, _ => continue, }; let message = NetworkMessage(payload.clone()); if !handshaked { info!("im foo baring"); if handshake_handle(&message, &connection_state).await.is_ok() { if connection_state.handshaked.load(Ordering::SeqCst) { handshaked = true; } } } else { let raw = serde_json::from_slice::(&payload).unwrap(); if let Some(target_instance) = &raw.target_instance { // Forward request info!("Forwarding message to {}", target_instance.url); let mut instance_connections = instance_connections.lock().await; let pool = instance_connections.get_or_open(&target_instance).unwrap(); let pool_clone = pool.clone(); drop(pool); let result = wrap_forwarded(&pool_clone, raw).await; let mut socket = connection_state.socket.lock().await; let _ = socket.send(result).await; continue; } let message_type = &raw.message_type; info!("Handling message with type: {}", message_type); match authentication_handle(message_type, &message, &connection_state).await { Err(e) => { let _ = connection_state .send_raw(ConnectionError(e.to_string())) .await; } Ok(true) => continue, Ok(false) => {} } match repository_handle(message_type, &message, &connection_state).await { Err(e) => { let _ = connection_state .send_raw(ConnectionError(e.to_string())) .await; } Ok(true) => continue, Ok(false) => {} } match user_handle(message_type, &message, &connection_state).await { Err(e) => { let _ = connection_state .send_raw(ConnectionError(e.to_string())) .await; } Ok(true) => continue, Ok(false) => {} } match authentication_handle(message_type, &message, &connection_state).await { Err(e) => { let _ = connection_state .send_raw(ConnectionError(e.to_string())) .await; } Ok(true) => continue, Ok(false) => {} } error!( "Message completely unhandled: {}", std::str::from_utf8(&payload).unwrap() ); } } Some(Err(e)) => { error!("Closing connection for {:?} for {}", e, addr); return; } _ => { info!("Unhandled"); continue; } } } } #[derive(Clone)] pub struct ConnectionState { socket: Arc>>, pub connections: Arc>, pub repository_backend: Arc>, pub user_backend: Arc>, pub auth_granter: Arc>, pub addr: SocketAddr, pub instance: Instance, pub handshaked: Arc, pub key_cache: Arc>, } impl ConnectionState { pub async fn send(&self, message: T) -> Result<(), Error> { let payload = serde_json::to_string(&message)?; info!("Sending payload: {}", &payload); self.socket .lock() .await .send(Message::Binary(payload.into_bytes())) .await?; Ok(()) } pub async fn send_raw(&self, message: T) -> Result<(), Error> { let payload = serde_json::to_string(&message)?; info!("Sending payload: {}", &payload); self.socket .lock() .await .send(Message::Binary(payload.into_bytes())) .await?; Ok(()) } pub async fn public_key(&self, instance: &Instance) -> Result { let mut keys = self.key_cache.lock().await; keys.get(instance).await } }