use std::{ net::SocketAddr, ops::Deref, sync::{atomic::AtomicBool, Arc}, }; use anyhow::Error; use futures_util::{SinkExt, StreamExt}; use giterated_models::{ authenticated::{AuthenticationSource, UserTokenMetadata}, error::OperationError, instance::Instance, }; use giterated_models::authenticated::AuthenticatedPayload; use giterated_stack::{ AuthenticatedInstance, AuthenticatedUser, GiteratedStack, StackOperationState, }; use jsonwebtoken::{DecodingKey, TokenData, Validation}; use rsa::{ pkcs1::DecodeRsaPublicKey, pss::{Signature, VerifyingKey}, sha2::Sha256, signature::Verifier, RsaPublicKey, }; use serde::Serialize; use tokio::{net::TcpStream, sync::Mutex}; use tokio_tungstenite::{tungstenite::Message, WebSocketStream}; use toml::Table; use crate::{ authentication::AuthenticationTokenGranter, backend::{MetadataBackend, RepositoryBackend, UserBackend}, federation::connections::InstanceConnections, keys::PublicKeyCache, }; use super::Connections; pub async fn connection_wrapper( socket: WebSocketStream, connections: Arc>, repository_backend: Arc>, user_backend: Arc>, auth_granter: Arc>, settings_backend: Arc>, addr: SocketAddr, instance: impl ToOwned, instance_connections: Arc>, config: Table, runtime: Arc, mut operation_state: StackOperationState, ) { let connection_state = ConnectionState { socket: Arc::new(Mutex::new(socket)), connections, repository_backend, user_backend, auth_granter, settings_backend, addr, instance: instance.to_owned(), handshaked: Arc::new(AtomicBool::new(false)), key_cache: Arc::default(), instance_connections: instance_connections.clone(), config, }; let _handshaked = false; let mut key_cache = PublicKeyCache::default(); loop { let mut socket = connection_state.socket.lock().await; let message = socket.next().await; drop(socket); match message { Some(Ok(message)) => { let payload = match message { Message::Binary(payload) => payload, Message::Ping(_) => { let mut socket = connection_state.socket.lock().await; let _ = socket.send(Message::Pong(vec![])).await; drop(socket); continue; } Message::Close(_) => return, _ => continue, }; let message: AuthenticatedPayload = bincode::deserialize(&payload).unwrap(); // Get authentication let instance = { let mut verified_instance: Option = None; for source in &message.source { if let AuthenticationSource::Instance { instance, signature, } = source { let public_key = key_cache.get(&instance).await.unwrap(); let public_key = RsaPublicKey::from_pkcs1_pem(&public_key).unwrap(); let verifying_key = VerifyingKey::::new(public_key); if verifying_key .verify( &message.payload, &Signature::try_from(signature.as_ref()).unwrap(), ) .is_ok() { verified_instance = Some(AuthenticatedInstance::new(instance.clone())); break; } } } verified_instance }; let user = { let mut verified_user = None; if let Some(verified_instance) = &instance { for source in &message.source { if let AuthenticationSource::User { user, token } = source { // Get token let public_key = key_cache.get(&verified_instance).await.unwrap(); let token: TokenData = jsonwebtoken::decode( token.as_ref(), &DecodingKey::from_rsa_pem(public_key.as_bytes()).unwrap(), &Validation::new(jsonwebtoken::Algorithm::RS256), ) .unwrap(); if token.claims.generated_for != *verified_instance.deref() { // Nope! break; } if token.claims.user != *user { // Nope! break; } verified_user = Some(AuthenticatedUser::new(user.clone())); break; } } } verified_user }; operation_state.user = user; operation_state.instance = instance; let result = runtime .handle_network_message(message, &operation_state) .await; // Asking for exploits here operation_state.user = None; operation_state.instance = None; if let Err(OperationError::Internal(internal_error)) = &result { error!("An internal error has occured: {}", internal_error); } let mut socket = connection_state.socket.lock().await; let _ = socket .send(Message::Binary(bincode::serialize(&result).unwrap())) .await; drop(socket); } _ => { return; } } } } #[derive(Clone)] pub struct ConnectionState { socket: Arc>>, pub connections: Arc>, pub repository_backend: Arc>, pub user_backend: Arc>, pub auth_granter: Arc>, pub settings_backend: Arc>, pub addr: SocketAddr, pub instance: Instance, pub handshaked: Arc, pub key_cache: Arc>, pub instance_connections: Arc>, pub config: Table, } impl ConnectionState { pub async fn send(&self, message: T) -> Result<(), Error> { let payload = serde_json::to_string(&message)?; self.socket .lock() .await .send(Message::Binary(payload.into_bytes())) .await?; Ok(()) } pub async fn send_raw(&self, message: T) -> Result<(), Error> { let payload = serde_json::to_string(&message)?; self.socket .lock() .await .send(Message::Binary(payload.into_bytes())) .await?; Ok(()) } pub async fn public_key(&self, instance: &Instance) -> Result { let mut keys = self.key_cache.lock().await; keys.get(instance).await } }