use std::{collections::HashMap, net::SocketAddr, sync::Arc}; use futures_util::{stream::StreamExt, SinkExt}; use tokio::{ net::TcpStream, sync::{ broadcast::{Receiver, Sender}, Mutex, }, task::JoinHandle, }; use tokio_tungstenite::{tungstenite::Message, WebSocketStream}; use crate::{ authentication::AuthenticationTokenGranter, backend::{IssuesBackend, RepositoryBackend}, handshake::{HandshakeFinalize, HandshakeMessage, HandshakeResponse}, listener::Listeners, messages::{ authentication::{AuthenticationMessage, AuthenticationRequest, TokenExtensionResponse}, repository::{ RepositoryMessage, RepositoryMessageKind, RepositoryRequest, RepositoryResponse, }, MessageKind, }, model::{ instance::{Instance, InstanceMeta}, repository::Repository, user::User, }, }; pub struct RawConnection { pub task: JoinHandle<()>, } pub struct InstanceConnection { pub instance: InstanceMeta, pub sender: Sender, pub task: JoinHandle<()>, } /// Represents a connection which hasn't finished the handshake. pub struct UnestablishedConnection { pub socket: WebSocketStream, } #[derive(Default)] pub struct Connections { pub connections: Vec, pub instance_connections: HashMap, } pub async fn connection_worker( mut socket: WebSocketStream, listeners: Arc>, connections: Arc>, backend: Arc>, auth_granter: Arc>, addr: SocketAddr, ) { let mut handshaked = false; let this_instance = Instance { url: String::from("giterated.dev"), }; while let Some(message) = socket.next().await { let message = match message { Ok(message) => message, Err(err) => { error!("Error reading message: {:?}", err); continue; } }; let payload = match message { Message::Text(text) => text.into_bytes(), Message::Binary(bytes) => bytes, Message::Ping(_) => continue, Message::Pong(_) => continue, Message::Close(_) => { info!("Closing connection with {}.", addr); return; } _ => unreachable!(), }; let message = match serde_json::from_slice::(&payload) { Ok(message) => message, Err(err) => { error!("Error deserializing message from {}: {:?}", addr, err); continue; } }; info!("Read payload: {}", std::str::from_utf8(&payload).unwrap()); if let MessageKind::Handshake(handshake) = message { match handshake { HandshakeMessage::Initiate(_) => { // Send HandshakeMessage::Response let message = HandshakeResponse { identity: Instance { url: String::from("foo.com"), }, version: String::from("0.1.0"), }; socket .send(Message::Binary( serde_json::to_vec(&MessageKind::Handshake( HandshakeMessage::Response(message), )) .unwrap(), )) .await .unwrap(); continue; } HandshakeMessage::Response(_) => { // Send HandshakeMessage::Finalize let message = HandshakeFinalize { success: true }; socket .send(Message::Binary( serde_json::to_vec(&MessageKind::Handshake( HandshakeMessage::Finalize(message), )) .unwrap(), )) .await .unwrap(); continue; } HandshakeMessage::Finalize(_) => { handshaked = true; // Send HandshakeMessage::Finalize let message = HandshakeFinalize { success: true }; socket .send(Message::Binary( serde_json::to_vec(&MessageKind::Handshake( HandshakeMessage::Finalize(message), )) .unwrap(), )) .await .unwrap(); continue; } } } if !handshaked { continue; } if let MessageKind::Repository(repository) = &message { if repository.target.instance != this_instance { info!("Forwarding command to {}", repository.target.instance.url); // We need to send this command to a different instance let mut listener = send_and_get_listener(message, &listeners, &connections).await; // Wait for response while let Ok(message) = listener.recv().await { if let MessageKind::Repository(RepositoryMessage { command: RepositoryMessageKind::Response(_), .. }) = message { socket .send(Message::Binary(serde_json::to_vec(&message).unwrap())) .await .unwrap(); } } continue; } else { // This message is targeting this instance match &repository.command { RepositoryMessageKind::Request(request) => match request { RepositoryRequest::CreateRepository(request) => { let mut backend = backend.lock().await; let response = backend.create_repository(request).await; let response = match response { Ok(response) => response, Err(err) => { error!("Error handling request: {:?}", err); continue; } }; drop(backend); socket .send(Message::Binary( serde_json::to_vec(&MessageKind::Repository( RepositoryMessage { target: repository.target.clone(), command: RepositoryMessageKind::Response( RepositoryResponse::CreateRepository(response), ), }, )) .unwrap(), )) .await .unwrap(); continue; } RepositoryRequest::RepositoryFileInspect(request) => { let mut backend = backend.lock().await; let response = backend.repository_file_inspect(request); let response = match response { Ok(response) => response, Err(err) => { error!("Error handling request: {:?}", err); continue; } }; drop(backend); socket .send(Message::Binary( serde_json::to_vec(&MessageKind::Repository( RepositoryMessage { target: repository.target.clone(), command: RepositoryMessageKind::Response( RepositoryResponse::RepositoryFileInspection( response, ), ), }, )) .unwrap(), )) .await .unwrap(); continue; } RepositoryRequest::RepositoryInfo(request) => { let mut backend = backend.lock().await; let response = backend.repository_info(request).await; let response = match response { Ok(response) => response, Err(err) => { error!("Error handling request: {:?}", err); continue; } }; drop(backend); socket .send(Message::Binary( serde_json::to_vec(&MessageKind::Repository( RepositoryMessage { target: repository.target.clone(), command: RepositoryMessageKind::Response( RepositoryResponse::RepositoryInfo(response), ), }, )) .unwrap(), )) .await .unwrap(); continue; } RepositoryRequest::IssuesCount(request) => { let mut backend = backend.lock().await; let response = backend.issues_count(request); let response = match response { Ok(response) => response, Err(err) => { error!("Error handling request: {:?}", err); continue; } }; drop(backend); socket .send(Message::Binary( serde_json::to_vec(&MessageKind::Repository( RepositoryMessage { target: repository.target.clone(), command: RepositoryMessageKind::Response( RepositoryResponse::IssuesCount(response), ), }, )) .unwrap(), )) .await .unwrap(); continue; } RepositoryRequest::IssueLabels(request) => { let mut backend = backend.lock().await; let response = backend.issue_labels(request); let response = match response { Ok(response) => response, Err(err) => { error!("Error handling request: {:?}", err); continue; } }; drop(backend); socket .send(Message::Binary( serde_json::to_vec(&MessageKind::Repository( RepositoryMessage { target: repository.target.clone(), command: RepositoryMessageKind::Response( RepositoryResponse::IssueLabels(response), ), }, )) .unwrap(), )) .await .unwrap(); continue; } RepositoryRequest::Issues(request) => { let mut backend = backend.lock().await; let response = backend.issues(request); let response = match response { Ok(response) => response, Err(err) => { error!("Error handling request: {:?}", err); continue; } }; drop(backend); socket .send(Message::Binary( serde_json::to_vec(&MessageKind::Repository( RepositoryMessage { target: repository.target.clone(), command: RepositoryMessageKind::Response( RepositoryResponse::Issues(response), ), }, )) .unwrap(), )) .await .unwrap(); continue; } }, RepositoryMessageKind::Response(_response) => { unreachable!() } } } } if let MessageKind::Authentication(authentication) = &message { match authentication { AuthenticationMessage::Request(request) => match request { AuthenticationRequest::AuthenticationToken(token) => { let mut granter = auth_granter.lock().await; let response = granter.token_request(token.clone()).await.unwrap(); drop(granter); socket .send(Message::Binary( serde_json::to_vec(&MessageKind::Authentication( AuthenticationMessage::Response(crate::messages::authentication::AuthenticationResponse::AuthenticationToken(response)) )) .unwrap(), )) .await .unwrap(); continue; } AuthenticationRequest::TokenExtension(request) => { let mut granter = auth_granter.lock().await; let response = granter .extension_request(request.clone()) .await .unwrap_or(TokenExtensionResponse { new_token: None }); drop(granter); socket .send(Message::Binary( serde_json::to_vec(&MessageKind::Authentication( AuthenticationMessage::Response(crate::messages::authentication::AuthenticationResponse::TokenExtension(response)) )) .unwrap(), )) .await .unwrap(); continue; } }, AuthenticationMessage::Response(_) => unreachable!(), } } } info!("Connection closed"); } async fn send_and_get_listener( message: MessageKind, listeners: &Arc>, connections: &Arc>, ) -> Receiver { let (instance, user, repository): (Option, Option, Option) = match &message { MessageKind::Handshake(_) => { todo!() } MessageKind::Repository(repository) => (None, None, Some(repository.target.clone())), MessageKind::Authentication(_) => todo!(), }; let target = match (&instance, &user, &repository) { (Some(instance), _, _) => instance.clone(), (_, Some(user), _) => user.instance.clone(), (_, _, Some(repository)) => repository.instance.clone(), _ => unreachable!(), }; let mut listeners = listeners.lock().await; let listener = listeners.add(instance, user, repository); drop(listeners); let connections = connections.lock().await; if let Some(connection) = connections.instance_connections.get(&target) { connection.sender.send(message); } else { error!("Unable to message {}, this is a bug.", target.url); panic!(); } drop(connections); listener }