use std::{ net::SocketAddr, sync::{atomic::AtomicBool, Arc}, }; use anyhow::Error; use futures_util::{SinkExt, StreamExt}; use giterated_models::{error::OperationError, instance::Instance}; use giterated_models::object_backend::ObjectBackend; use giterated_models::{ authenticated::AuthenticatedPayload, message::GiteratedMessage, object::AnyObject, operation::AnyOperation, }; use giterated_stack::{handler::GiteratedBackend, StackOperationState}; use serde::Serialize; use tokio::{net::TcpStream, sync::Mutex}; use tokio_tungstenite::{tungstenite::Message, WebSocketStream}; use toml::Table; use crate::{ authentication::AuthenticationTokenGranter, backend::{MetadataBackend, RepositoryBackend, UserBackend}, database_backend::DatabaseBackend, federation::connections::InstanceConnections, keys::PublicKeyCache, }; use super::Connections; pub async fn connection_wrapper( socket: WebSocketStream, connections: Arc>, repository_backend: Arc>, user_backend: Arc>, auth_granter: Arc>, settings_backend: Arc>, addr: SocketAddr, instance: impl ToOwned, instance_connections: Arc>, config: Table, backend: GiteratedBackend, operation_state: StackOperationState, ) { let connection_state = ConnectionState { socket: Arc::new(Mutex::new(socket)), connections, repository_backend, user_backend, auth_granter, settings_backend, addr, instance: instance.to_owned(), handshaked: Arc::new(AtomicBool::new(false)), key_cache: Arc::default(), instance_connections: instance_connections.clone(), config, }; let _handshaked = false; loop { let mut socket = connection_state.socket.lock().await; let message = socket.next().await; drop(socket); match message { Some(Ok(message)) => { let payload = match message { Message::Binary(payload) => payload, Message::Ping(_) => { let mut socket = connection_state.socket.lock().await; let _ = socket.send(Message::Pong(vec![])).await; drop(socket); continue; } Message::Close(_) => return, _ => continue, }; let message: AuthenticatedPayload = bincode::deserialize(&payload).unwrap(); let message: GiteratedMessage = message.into_message(); let result = backend .object_operation( message.object, &message.operation, message.payload, &operation_state, ) .await; // Map result to Vec on both let result = match result { Ok(result) => Ok(serde_json::to_vec(&result).unwrap()), Err(err) => Err(match err { OperationError::Operation(err) => { OperationError::Operation(serde_json::to_vec(&err).unwrap()) } OperationError::Internal(err) => OperationError::Internal(err), OperationError::Unhandled => OperationError::Unhandled, }), }; let mut socket = connection_state.socket.lock().await; let _ = socket .send(Message::Binary(bincode::serialize(&result).unwrap())) .await; drop(socket); } _ => { return; } } } // loop { // let mut socket = connection_state.socket.lock().await; // let message = socket.next().await; // drop(socket); // match message { // Some(Ok(message)) => { // let payload = match message { // Message::Binary(payload) => payload, // Message::Ping(_) => { // let mut socket = connection_state.socket.lock().await; // let _ = socket.send(Message::Pong(vec![])).await; // drop(socket); // continue; // } // Message::Close(_) => return, // _ => continue, // }; // let message = NetworkMessage(payload.clone()); // if !handshaked { // if handshake_handle(&message, &connection_state).await.is_ok() { // if connection_state.handshaked.load(Ordering::SeqCst) { // handshaked = true; // } // } // } else { // let raw = serde_json::from_slice::(&payload).unwrap(); // if let Some(target_instance) = &raw.target_instance { // if connection_state.instance != *target_instance { // // Forward request // info!("Forwarding message to {}", target_instance.url); // let mut instance_connections = instance_connections.lock().await; // let pool = instance_connections.get_or_open(&target_instance).unwrap(); // let pool_clone = pool.clone(); // drop(pool); // let result = wrap_forwarded(&pool_clone, raw).await; // let mut socket = connection_state.socket.lock().await; // let _ = socket.send(result).await; // continue; // } // } // let message_type = &raw.message_type; // match authentication_handle(message_type, &message, &connection_state).await { // Err(e) => { // let _ = connection_state // .send_raw(ConnectionError(e.to_string())) // .await; // } // Ok(true) => continue, // Ok(false) => {} // } // match repository_handle(message_type, &message, &connection_state).await { // Err(e) => { // let _ = connection_state // .send_raw(ConnectionError(e.to_string())) // .await; // } // Ok(true) => continue, // Ok(false) => {} // } // match user_handle(message_type, &message, &connection_state).await { // Err(e) => { // let _ = connection_state // .send_raw(ConnectionError(e.to_string())) // .await; // } // Ok(true) => continue, // Ok(false) => {} // } // match authentication_handle(message_type, &message, &connection_state).await { // Err(e) => { // let _ = connection_state // .send_raw(ConnectionError(e.to_string())) // .await; // } // Ok(true) => continue, // Ok(false) => {} // } // error!( // "Message completely unhandled: {}", // std::str::from_utf8(&payload).unwrap() // ); // } // } // Some(Err(e)) => { // error!("Closing connection for {:?} for {}", e, addr); // return; // } // _ => { // continue; // } // } // } } #[derive(Clone)] pub struct ConnectionState { socket: Arc>>, pub connections: Arc>, pub repository_backend: Arc>, pub user_backend: Arc>, pub auth_granter: Arc>, pub settings_backend: Arc>, pub addr: SocketAddr, pub instance: Instance, pub handshaked: Arc, pub key_cache: Arc>, pub instance_connections: Arc>, pub config: Table, } impl ConnectionState { pub async fn send(&self, message: T) -> Result<(), Error> { let payload = serde_json::to_string(&message)?; self.socket .lock() .await .send(Message::Binary(payload.into_bytes())) .await?; Ok(()) } pub async fn send_raw(&self, message: T) -> Result<(), Error> { let payload = serde_json::to_string(&message)?; self.socket .lock() .await .send(Message::Binary(payload.into_bytes())) .await?; Ok(()) } pub async fn public_key(&self, instance: &Instance) -> Result { let mut keys = self.key_cache.lock().await; keys.get(instance).await } }