Add more aggressive key caching
parent: tbd commit: 5bc92ad
Showing 6 changed files with 46 insertions and 39 deletions
giterated-daemon/src/authentication.rs
@@ -8,10 +8,12 @@ use giterated_models::{ | ||
8 | 8 | }, |
9 | 9 | }; |
10 | 10 | use jsonwebtoken::{decode, encode, Algorithm, DecodingKey, EncodingKey, TokenData, Validation}; |
11 | use std::time::SystemTime; | |
12 | use tokio::{fs::File, io::AsyncReadExt}; | |
11 | use std::{sync::Arc, time::SystemTime}; | |
12 | use tokio::{fs::File, io::AsyncReadExt, sync::Mutex}; | |
13 | 13 | use toml::Table; |
14 | 14 | |
15 | use crate::keys::PublicKeyCache; | |
16 | ||
15 | 17 | pub struct AuthenticationTokenGranter { |
16 | 18 | pub config: Table, |
17 | 19 | pub instance: Instance, |
@@ -110,9 +112,12 @@ impl AuthenticationTokenGranter { | ||
110 | 112 | pub async fn extension_request( |
111 | 113 | &mut self, |
112 | 114 | issued_for: &Instance, |
115 | key_cache: &Arc<Mutex<PublicKeyCache>>, | |
113 | 116 | token: UserAuthenticationToken, |
114 | 117 | ) -> Result<TokenExtensionResponse, Error> { |
115 | let server_public_key = public_key(&self.instance).await.unwrap(); | |
118 | let mut key_cache = key_cache.lock().await; | |
119 | let server_public_key = key_cache.get(issued_for).await?; | |
120 | drop(key_cache); | |
116 | 121 | |
117 | 122 | let verification_key = DecodingKey::from_rsa_pem(server_public_key.as_bytes()).unwrap(); |
118 | 123 | |
@@ -167,12 +172,3 @@ impl AuthenticationTokenGranter { | ||
167 | 172 | }) |
168 | 173 | } |
169 | 174 | } |
170 | ||
171 | async fn public_key(instance: &Instance) -> Result<String, Error> { | |
172 | let key = reqwest::get(format!("https://{}/.giterated/pubkey.pem", instance.url)) | |
173 | .await? | |
174 | .text() | |
175 | .await?; | |
176 | ||
177 | Ok(key) | |
178 | } |
giterated-daemon/src/connection/authentication.rs
@@ -96,7 +96,7 @@ async fn token_extension_request( | ||
96 | 96 | let mut token_granter = connection_state.auth_granter.lock().await; |
97 | 97 | |
98 | 98 | let response = token_granter |
99 | .extension_request(&issued_for, request.token) | |
99 | .extension_request(&issued_for, &connection_state.key_cache, request.token) | |
100 | 100 | .await |
101 | 101 | .map_err(|e| AuthenticationConnectionError::TokenIssuance(e))?; |
102 | 102 |
giterated-daemon/src/connection/wrapper.rs
@@ -30,6 +30,7 @@ use crate::{ | ||
30 | 30 | backend::{RepositoryBackend, UserBackend}, |
31 | 31 | connection::forwarded::wrap_forwarded, |
32 | 32 | federation::connections::InstanceConnections, |
33 | keys::PublicKeyCache, | |
33 | 34 | message::NetworkMessage, |
34 | 35 | }; |
35 | 36 | |
@@ -57,7 +58,7 @@ pub async fn connection_wrapper( | ||
57 | 58 | addr, |
58 | 59 | instance: instance.to_owned(), |
59 | 60 | handshaked: Arc::new(AtomicBool::new(false)), |
60 | cached_keys: Arc::default(), | |
61 | key_cache: Arc::default(), | |
61 | 62 | }; |
62 | 63 | |
63 | 64 | let mut handshaked = false; |
@@ -181,7 +182,7 @@ pub struct ConnectionState { | ||
181 | 182 | pub addr: SocketAddr, |
182 | 183 | pub instance: Instance, |
183 | 184 | pub handshaked: Arc<AtomicBool>, |
184 | pub cached_keys: Arc<RwLock<HashMap<Instance, RsaPublicKey>>>, | |
185 | pub key_cache: Arc<Mutex<PublicKeyCache>>, | |
185 | 186 | } |
186 | 187 | |
187 | 188 | impl ConnectionState { |
@@ -208,4 +209,9 @@ impl ConnectionState { | ||
208 | 209 | |
209 | 210 | Ok(()) |
210 | 211 | } |
212 | ||
213 | pub async fn public_key(&self, instance: &Instance) -> Result<String, Error> { | |
214 | let mut keys = self.key_cache.lock().await; | |
215 | keys.get(instance).await | |
216 | } | |
211 | 217 | } |
giterated-daemon/src/keys.rs
@@ -0,0 +1,26 @@ | ||
1 | use std::collections::HashMap; | |
2 | ||
3 | use anyhow::Error; | |
4 | use giterated_models::model::instance::Instance; | |
5 | ||
6 | #[derive(Default)] | |
7 | pub struct PublicKeyCache { | |
8 | pub keys: HashMap<Instance, String>, | |
9 | } | |
10 | ||
11 | impl PublicKeyCache { | |
12 | pub async fn get(&mut self, instance: &Instance) -> Result<String, Error> { | |
13 | if let Some(key) = self.keys.get(instance) { | |
14 | return Ok(key.clone()); | |
15 | } else { | |
16 | let key = reqwest::get(format!("https://{}/.giterated/pubkey.pem", instance.url)) | |
17 | .await? | |
18 | .text() | |
19 | .await?; | |
20 | ||
21 | self.keys.insert(instance.clone(), key); | |
22 | ||
23 | Ok(self.keys.get(instance).unwrap().clone()) | |
24 | } | |
25 | } | |
26 | } |
giterated-daemon/src/lib.rs
@@ -6,6 +6,7 @@ pub mod authentication; | ||
6 | 6 | pub mod backend; |
7 | 7 | pub mod connection; |
8 | 8 | pub mod federation; |
9 | pub mod keys; | |
9 | 10 | pub mod message; |
10 | 11 | |
11 | 12 | #[macro_use] |
giterated-daemon/src/message.rs
@@ -82,7 +82,7 @@ impl FromMessage<ConnectionState> for AuthenticatedUser { | ||
82 | 82 | let authenticated_instance = |
83 | 83 | AuthenticatedInstance::from_message(network_message, state).await?; |
84 | 84 | |
85 | let public_key_raw = public_key(&auth_user.instance).await?; | |
85 | let public_key_raw = state.public_key(&auth_user.instance).await?; | |
86 | 86 | let verification_key = DecodingKey::from_rsa_pem(public_key_raw.as_bytes()).unwrap(); |
87 | 87 | |
88 | 88 | let data: TokenData<UserTokenMetadata> = decode( |
@@ -129,20 +129,7 @@ impl FromMessage<ConnectionState> for AuthenticatedInstance { | ||
129 | 129 | // TODO: Instance authentication error |
130 | 130 | .ok_or_else(|| UserAuthenticationError::Missing)?; |
131 | 131 | |
132 | let public_key = { | |
133 | let cached_keys = state.cached_keys.read().await; | |
134 | ||
135 | if let Some(key) = cached_keys.get(&instance) { | |
136 | key.clone() | |
137 | } else { | |
138 | drop(cached_keys); | |
139 | let mut cached_keys = state.cached_keys.write().await; | |
140 | let key = public_key(&instance).await?; | |
141 | let public_key = RsaPublicKey::from_pkcs1_pem(&key).unwrap(); | |
142 | cached_keys.insert(instance.clone(), public_key.clone()); | |
143 | public_key | |
144 | } | |
145 | }; | |
132 | let public_key = RsaPublicKey::from_pkcs1_pem(&state.public_key(instance).await?).unwrap(); | |
146 | 133 | |
147 | 134 | let verifying_key: VerifyingKey<Sha256> = VerifyingKey::new(public_key); |
148 | 135 | |
@@ -255,15 +242,6 @@ where | ||
255 | 242 | |
256 | 243 | pub struct Message<T: Serialize + DeserializeOwned>(pub T); |
257 | 244 | |
258 | async fn public_key(instance: &Instance) -> Result<String, Error> { | |
259 | let key = reqwest::get(format!("https://{}/.giterated/pubkey.pem", instance.url)) | |
260 | .await? | |
261 | .text() | |
262 | .await?; | |
263 | ||
264 | Ok(key) | |
265 | } | |
266 | ||
267 | 245 | /// Handshake-specific message type. |
268 | 246 | /// |
269 | 247 | /// Uses basic serde_json-based deserialization to maintain the highest |