diff --git a/giterated-daemon/src/authentication.rs b/giterated-daemon/src/authentication.rs index 518c0ec..0476306 100644 --- a/giterated-daemon/src/authentication.rs +++ b/giterated-daemon/src/authentication.rs @@ -8,10 +8,12 @@ use giterated_models::{ }, }; use jsonwebtoken::{decode, encode, Algorithm, DecodingKey, EncodingKey, TokenData, Validation}; -use std::time::SystemTime; -use tokio::{fs::File, io::AsyncReadExt}; +use std::{sync::Arc, time::SystemTime}; +use tokio::{fs::File, io::AsyncReadExt, sync::Mutex}; use toml::Table; +use crate::keys::PublicKeyCache; + pub struct AuthenticationTokenGranter { pub config: Table, pub instance: Instance, @@ -110,9 +112,12 @@ impl AuthenticationTokenGranter { pub async fn extension_request( &mut self, issued_for: &Instance, + key_cache: &Arc>, token: UserAuthenticationToken, ) -> Result { - let server_public_key = public_key(&self.instance).await.unwrap(); + let mut key_cache = key_cache.lock().await; + let server_public_key = key_cache.get(issued_for).await?; + drop(key_cache); let verification_key = DecodingKey::from_rsa_pem(server_public_key.as_bytes()).unwrap(); @@ -167,12 +172,3 @@ impl AuthenticationTokenGranter { }) } } - -async fn public_key(instance: &Instance) -> Result { - let key = reqwest::get(format!("https://{}/.giterated/pubkey.pem", instance.url)) - .await? - .text() - .await?; - - Ok(key) -} diff --git a/giterated-daemon/src/connection/authentication.rs b/giterated-daemon/src/connection/authentication.rs index 48f29ea..7f0f676 100644 --- a/giterated-daemon/src/connection/authentication.rs +++ b/giterated-daemon/src/connection/authentication.rs @@ -96,7 +96,7 @@ async fn token_extension_request( let mut token_granter = connection_state.auth_granter.lock().await; let response = token_granter - .extension_request(&issued_for, request.token) + .extension_request(&issued_for, &connection_state.key_cache, request.token) .await .map_err(|e| AuthenticationConnectionError::TokenIssuance(e))?; diff --git a/giterated-daemon/src/connection/wrapper.rs b/giterated-daemon/src/connection/wrapper.rs index b0c3f74..ff086f8 100644 --- a/giterated-daemon/src/connection/wrapper.rs +++ b/giterated-daemon/src/connection/wrapper.rs @@ -30,6 +30,7 @@ use crate::{ backend::{RepositoryBackend, UserBackend}, connection::forwarded::wrap_forwarded, federation::connections::InstanceConnections, + keys::PublicKeyCache, message::NetworkMessage, }; @@ -57,7 +58,7 @@ pub async fn connection_wrapper( addr, instance: instance.to_owned(), handshaked: Arc::new(AtomicBool::new(false)), - cached_keys: Arc::default(), + key_cache: Arc::default(), }; let mut handshaked = false; @@ -181,7 +182,7 @@ pub struct ConnectionState { pub addr: SocketAddr, pub instance: Instance, pub handshaked: Arc, - pub cached_keys: Arc>>, + pub key_cache: Arc>, } impl ConnectionState { @@ -208,4 +209,9 @@ impl ConnectionState { Ok(()) } + + pub async fn public_key(&self, instance: &Instance) -> Result { + let mut keys = self.key_cache.lock().await; + keys.get(instance).await + } } diff --git a/giterated-daemon/src/keys.rs b/giterated-daemon/src/keys.rs new file mode 100644 index 0000000..e30e06b --- /dev/null +++ b/giterated-daemon/src/keys.rs @@ -0,0 +1,26 @@ +use std::collections::HashMap; + +use anyhow::Error; +use giterated_models::model::instance::Instance; + +#[derive(Default)] +pub struct PublicKeyCache { + pub keys: HashMap, +} + +impl PublicKeyCache { + pub async fn get(&mut self, instance: &Instance) -> Result { + if let Some(key) = self.keys.get(instance) { + return Ok(key.clone()); + } else { + let key = reqwest::get(format!("https://{}/.giterated/pubkey.pem", instance.url)) + .await? + .text() + .await?; + + self.keys.insert(instance.clone(), key); + + Ok(self.keys.get(instance).unwrap().clone()) + } + } +} diff --git a/giterated-daemon/src/lib.rs b/giterated-daemon/src/lib.rs index 05494fb..592deb1 100644 --- a/giterated-daemon/src/lib.rs +++ b/giterated-daemon/src/lib.rs @@ -6,6 +6,7 @@ pub mod authentication; pub mod backend; pub mod connection; pub mod federation; +pub mod keys; pub mod message; #[macro_use] diff --git a/giterated-daemon/src/message.rs b/giterated-daemon/src/message.rs index c9ba2f9..e3641b7 100644 --- a/giterated-daemon/src/message.rs +++ b/giterated-daemon/src/message.rs @@ -82,7 +82,7 @@ impl FromMessage for AuthenticatedUser { let authenticated_instance = AuthenticatedInstance::from_message(network_message, state).await?; - let public_key_raw = public_key(&auth_user.instance).await?; + let public_key_raw = state.public_key(&auth_user.instance).await?; let verification_key = DecodingKey::from_rsa_pem(public_key_raw.as_bytes()).unwrap(); let data: TokenData = decode( @@ -129,20 +129,7 @@ impl FromMessage for AuthenticatedInstance { // TODO: Instance authentication error .ok_or_else(|| UserAuthenticationError::Missing)?; - let public_key = { - let cached_keys = state.cached_keys.read().await; - - if let Some(key) = cached_keys.get(&instance) { - key.clone() - } else { - drop(cached_keys); - let mut cached_keys = state.cached_keys.write().await; - let key = public_key(&instance).await?; - let public_key = RsaPublicKey::from_pkcs1_pem(&key).unwrap(); - cached_keys.insert(instance.clone(), public_key.clone()); - public_key - } - }; + let public_key = RsaPublicKey::from_pkcs1_pem(&state.public_key(instance).await?).unwrap(); let verifying_key: VerifyingKey = VerifyingKey::new(public_key); @@ -255,15 +242,6 @@ where pub struct Message(pub T); -async fn public_key(instance: &Instance) -> Result { - let key = reqwest::get(format!("https://{}/.giterated/pubkey.pem", instance.url)) - .await? - .text() - .await?; - - Ok(key) -} - /// Handshake-specific message type. /// /// Uses basic serde_json-based deserialization to maintain the highest