use std::{any::type_name, collections::HashMap, ops::Deref}; use anyhow::Error; use futures_util::Future; use jsonwebtoken::{decode, Algorithm, DecodingKey, TokenData, Validation}; use rsa::{ pkcs1::{DecodeRsaPrivateKey, DecodeRsaPublicKey}, pss::{Signature, SigningKey, VerifyingKey}, sha2::Sha256, signature::{RandomizedSigner, SignatureEncoding, Verifier}, RsaPrivateKey, RsaPublicKey, }; use serde::{de::DeserializeOwned, Deserialize, Serialize}; use serde_json::Value; use crate::{authentication::UserTokenMetadata, connection::wrapper::ConnectionState}; use super::{instance::Instance, user::User}; #[derive(Debug, Hash, PartialEq, Eq, Serialize, Deserialize)] pub struct Authenticated { // TODO: Can not flatten vec's/enums, might want to just type tag the enum instead. not recommended for anything but languages like JSON anyways // #[serde(flatten)] source: Vec, message_type: String, #[serde(flatten)] message: T, } pub trait AuthenticationSourceProvider: Sized { fn authenticate(self, payload: &Vec) -> AuthenticationSource; } pub trait AuthenticationSourceProviders: Sized { fn authenticate_all(self, payload: &Vec) -> Vec; } impl AuthenticationSourceProviders for A where A: AuthenticationSourceProvider, { fn authenticate_all(self, payload: &Vec) -> Vec { vec![self.authenticate(payload)] } } impl AuthenticationSourceProviders for (A, B) where A: AuthenticationSourceProvider, B: AuthenticationSourceProvider, { fn authenticate_all(self, payload: &Vec) -> Vec { let (first, second) = self; vec![first.authenticate(payload), second.authenticate(payload)] } } impl Authenticated { pub fn new(message: T, auth_sources: impl AuthenticationSourceProvider) -> Self { let message_payload = serde_json::to_vec(&message).unwrap(); let authentication = auth_sources.authenticate_all(&message_payload); Self { source: authentication, message_type: type_name::().to_string(), message, } } pub fn new_empty(message: T) -> Self { Self { source: vec![], message_type: type_name::().to_string(), message, } } pub fn append_authentication(&mut self, authentication: impl AuthenticationSourceProvider) { let message_payload = serde_json::to_vec(&self.message).unwrap(); self.source .push(authentication.authenticate(&message_payload)); } } mod verified {} #[derive(Clone, Debug)] pub struct UserAuthenticator { pub user: User, pub token: UserAuthenticationToken, } impl AuthenticationSourceProvider for UserAuthenticator { fn authenticate(self, _payload: &Vec) -> AuthenticationSource { AuthenticationSource::User { user: self.user, token: self.token, } } } #[derive(Clone)] pub struct InstanceAuthenticator<'a> { pub instance: Instance, pub private_key: &'a str, } impl AuthenticationSourceProvider for InstanceAuthenticator<'_> { fn authenticate(self, payload: &Vec) -> AuthenticationSource { let mut rng = rand::thread_rng(); let private_key = RsaPrivateKey::from_pkcs1_pem(self.private_key).unwrap(); let signing_key = SigningKey::::new(private_key); let signature = signing_key.sign_with_rng(&mut rng, &payload); AuthenticationSource::Instance { instance: self.instance, // TODO: Actually parse signature from private key signature: InstanceSignature(signature.to_bytes().into_vec()), } } } #[repr(transparent)] #[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)] pub struct UserAuthenticationToken(String); impl From for UserAuthenticationToken { fn from(value: String) -> Self { Self(value) } } impl ToString for UserAuthenticationToken { fn to_string(&self) -> String { self.0.clone() } } impl AsRef for UserAuthenticationToken { fn as_ref(&self) -> &str { &self.0 } } #[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)] pub struct InstanceSignature(Vec); impl AsRef<[u8]> for InstanceSignature { fn as_ref(&self) -> &[u8] { &self.0 } } #[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)] pub enum AuthenticationSource { User { user: User, token: UserAuthenticationToken, }, Instance { instance: Instance, signature: InstanceSignature, }, } pub struct NetworkMessage(pub Vec); impl Deref for NetworkMessage { type Target = [u8]; fn deref(&self) -> &Self::Target { &self.0 } } pub struct AuthenticatedUser(pub User); #[derive(Debug, thiserror::Error)] pub enum UserAuthenticationError { #[error("user authentication missing")] Missing, // #[error("{0}")] // InstanceAuthentication(#[from] Error), #[error("user token was invalid")] InvalidToken, #[error("an error has occured")] Other(#[from] Error), } pub struct AuthenticatedInstance(Instance); impl AuthenticatedInstance { pub fn inner(&self) -> &Instance { &self.0 } } #[async_trait::async_trait] pub trait FromMessage: Sized + Send + Sync { async fn from_message(message: &NetworkMessage, state: &S) -> Result; } #[async_trait::async_trait] impl FromMessage for AuthenticatedUser { async fn from_message( network_message: &NetworkMessage, state: &ConnectionState, ) -> Result { let message: Authenticated> = serde_json::from_slice(&network_message).map_err(|e| Error::from(e))?; let (auth_user, auth_token) = message .source .iter() .filter_map(|auth| { if let AuthenticationSource::User { user, token } = auth { Some((user, token)) } else { None } }) .next() .ok_or_else(|| UserAuthenticationError::Missing)?; let authenticated_instance = AuthenticatedInstance::from_message(network_message, state).await?; let public_key_raw = public_key(&auth_user.instance).await?; let verification_key = DecodingKey::from_rsa_pem(public_key_raw.as_bytes()).unwrap(); let data: TokenData = decode( auth_token.as_ref(), &verification_key, &Validation::new(Algorithm::RS256), ) .unwrap(); if data.claims.user != *auth_user || data.claims.generated_for != *authenticated_instance.inner() { Err(Error::from(UserAuthenticationError::InvalidToken)) } else { Ok(AuthenticatedUser(data.claims.user)) } } } #[async_trait::async_trait] impl FromMessage for AuthenticatedInstance { async fn from_message( network_message: &NetworkMessage, state: &ConnectionState, ) -> Result { let message: Authenticated = serde_json::from_slice(&network_message).map_err(|e| Error::from(e))?; let (instance, signature) = message .source .iter() .filter_map(|auth: &AuthenticationSource| { if let AuthenticationSource::Instance { instance, signature, } = auth { Some((instance, signature)) } else { None } }) .next() // TODO: Instance authentication error .ok_or_else(|| UserAuthenticationError::Missing)?; let public_key = { let cached_keys = state.cached_keys.read().await; if let Some(key) = cached_keys.get(&instance) { key.clone() } else { drop(cached_keys); let mut cached_keys = state.cached_keys.write().await; let key = public_key(instance).await?; let public_key = RsaPublicKey::from_pkcs1_pem(&key).unwrap(); cached_keys.insert(instance.clone(), public_key.clone()); public_key } }; let verifying_key: VerifyingKey = VerifyingKey::new(public_key); let message_json = serde_json::to_vec(&message.message).unwrap(); verifying_key.verify( &message_json, &Signature::try_from(signature.as_ref()).unwrap(), )?; Ok(AuthenticatedInstance(instance.clone())) } } #[async_trait::async_trait] impl FromMessage for Option where T: FromMessage, S: Send + Sync + 'static, { async fn from_message(message: &NetworkMessage, state: &S) -> Result { Ok(T::from_message(message, state).await.ok()) } } #[async_trait::async_trait] pub trait MessageHandler { async fn handle_message(self, message: &NetworkMessage, state: &S) -> Result; } #[async_trait::async_trait] impl MessageHandler<(T1,), S, R> for T where T: FnOnce(T1) -> F + Clone + Send + 'static, F: Future> + Send, T1: FromMessage + Send, S: Send + Sync, E: std::error::Error + Send + Sync + 'static, { async fn handle_message(self, message: &NetworkMessage, state: &S) -> Result { let value = T1::from_message(message, state).await?; self(value).await.map_err(|e| Error::from(e)) } } #[async_trait::async_trait] impl MessageHandler<(T1, T2), S, R> for T where T: FnOnce(T1, T2) -> F + Clone + Send + 'static, F: Future> + Send, T1: FromMessage + Send, T2: FromMessage + Send, S: Send + Sync, E: std::error::Error + Send + Sync + 'static, { async fn handle_message(self, message: &NetworkMessage, state: &S) -> Result { let value = T1::from_message(message, state).await?; let value_2 = T2::from_message(message, state).await?; self(value, value_2).await.map_err(|e| Error::from(e)) } } #[async_trait::async_trait] impl MessageHandler<(T1, T2, T3), S, R> for T where T: FnOnce(T1, T2, T3) -> F + Clone + Send + 'static, F: Future> + Send, T1: FromMessage + Send, T2: FromMessage + Send, T3: FromMessage + Send, S: Send + Sync, E: std::error::Error + Send + Sync + 'static, { async fn handle_message(self, message: &NetworkMessage, state: &S) -> Result { let value = T1::from_message(message, state).await?; let value_2 = T2::from_message(message, state).await?; let value_3 = T3::from_message(message, state).await?; self(value, value_2, value_3) .await .map_err(|e| Error::from(e)) } } pub struct State(pub T); #[async_trait::async_trait] impl FromMessage for State where T: Clone + Send + Sync, { async fn from_message(_: &NetworkMessage, state: &T) -> Result { Ok(Self(state.clone())) } } // Temp #[async_trait::async_trait] impl FromMessage for Message where T: DeserializeOwned + Send + Sync + Serialize, S: Clone + Send + Sync, { async fn from_message(message: &NetworkMessage, _: &S) -> Result { Ok(Message(serde_json::from_slice(&message)?)) } } pub struct Message(pub T); async fn public_key(instance: &Instance) -> Result { let key = reqwest::get(format!("https://{}/.giterated/pubkey.pem", instance.url)) .await? .text() .await?; Ok(key) }