JavaScript is disabled, refresh for a better experience. ambee/giterated

ambee/giterated

Git repository hosting, collaboration, and discovery for the Fediverse.

Add settings

Amber - ⁨2⁩ years ago

parent: tbd commit: ⁨0448edb

⁨src/model/authenticated.rs⁩ - ⁨12127⁩ bytes
Raw
1 use std::{any::type_name, collections::HashMap, ops::Deref};
2
3 use anyhow::Error;
4 use futures_util::Future;
5 use jsonwebtoken::{decode, Algorithm, DecodingKey, TokenData, Validation};
6 use rsa::{
7 pkcs1::{DecodeRsaPrivateKey, DecodeRsaPublicKey},
8 pss::{Signature, SigningKey, VerifyingKey},
9 sha2::Sha256,
10 signature::{RandomizedSigner, SignatureEncoding, Verifier},
11 RsaPrivateKey, RsaPublicKey,
12 };
13 use serde::{de::DeserializeOwned, Deserialize, Serialize};
14 use serde_json::Value;
15
16 use crate::{authentication::UserTokenMetadata, connection::wrapper::ConnectionState};
17
18 use super::{instance::Instance, user::User};
19
20 #[derive(Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
21 pub struct Authenticated<T: Serialize> {
22 // TODO: Can not flatten vec's/enums, might want to just type tag the enum instead. not recommended for anything but languages like JSON anyways
23 // #[serde(flatten)]
24 source: Vec<AuthenticationSource>,
25 message_type: String,
26 #[serde(flatten)]
27 message: T,
28 }
29
30 pub trait AuthenticationSourceProvider: Sized {
31 fn authenticate(self, payload: &Vec<u8>) -> AuthenticationSource;
32 }
33
34 pub trait AuthenticationSourceProviders: Sized {
35 fn authenticate_all(self, payload: &Vec<u8>) -> Vec<AuthenticationSource>;
36 }
37
38 impl<A> AuthenticationSourceProviders for A
39 where
40 A: AuthenticationSourceProvider,
41 {
42 fn authenticate_all(self, payload: &Vec<u8>) -> Vec<AuthenticationSource> {
43 vec![self.authenticate(payload)]
44 }
45 }
46
47 impl<A, B> AuthenticationSourceProviders for (A, B)
48 where
49 A: AuthenticationSourceProvider,
50 B: AuthenticationSourceProvider,
51 {
52 fn authenticate_all(self, payload: &Vec<u8>) -> Vec<AuthenticationSource> {
53 let (first, second) = self;
54
55 vec![first.authenticate(payload), second.authenticate(payload)]
56 }
57 }
58
59 impl<T: Serialize> Authenticated<T> {
60 pub fn new(message: T, auth_sources: impl AuthenticationSourceProvider) -> Self {
61 let message_payload = serde_json::to_vec(&message).unwrap();
62
63 let authentication = auth_sources.authenticate_all(&message_payload);
64
65 Self {
66 source: authentication,
67 message_type: type_name::<T>().to_string(),
68 message,
69 }
70 }
71
72 pub fn new_empty(message: T) -> Self {
73 Self {
74 source: vec![],
75 message_type: type_name::<T>().to_string(),
76 message,
77 }
78 }
79
80 pub fn append_authentication(&mut self, authentication: impl AuthenticationSourceProvider) {
81 let message_payload = serde_json::to_vec(&self.message).unwrap();
82
83 self.source
84 .push(authentication.authenticate(&message_payload));
85 }
86 }
87
88 mod verified {}
89
90 #[derive(Clone, Debug)]
91 pub struct UserAuthenticator {
92 pub user: User,
93 pub token: UserAuthenticationToken,
94 }
95
96 impl AuthenticationSourceProvider for UserAuthenticator {
97 fn authenticate(self, _payload: &Vec<u8>) -> AuthenticationSource {
98 AuthenticationSource::User {
99 user: self.user,
100 token: self.token,
101 }
102 }
103 }
104
105 #[derive(Clone)]
106 pub struct InstanceAuthenticator<'a> {
107 pub instance: Instance,
108 pub private_key: &'a str,
109 }
110
111 impl AuthenticationSourceProvider for InstanceAuthenticator<'_> {
112 fn authenticate(self, payload: &Vec<u8>) -> AuthenticationSource {
113 let mut rng = rand::thread_rng();
114
115 let private_key = RsaPrivateKey::from_pkcs1_pem(self.private_key).unwrap();
116 let signing_key = SigningKey::<Sha256>::new(private_key);
117 let signature = signing_key.sign_with_rng(&mut rng, &payload);
118
119 AuthenticationSource::Instance {
120 instance: self.instance,
121 // TODO: Actually parse signature from private key
122 signature: InstanceSignature(signature.to_bytes().into_vec()),
123 }
124 }
125 }
126
127 #[repr(transparent)]
128 #[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
129 pub struct UserAuthenticationToken(String);
130
131 impl From<String> for UserAuthenticationToken {
132 fn from(value: String) -> Self {
133 Self(value)
134 }
135 }
136
137 impl ToString for UserAuthenticationToken {
138 fn to_string(&self) -> String {
139 self.0.clone()
140 }
141 }
142
143 impl AsRef<str> for UserAuthenticationToken {
144 fn as_ref(&self) -> &str {
145 &self.0
146 }
147 }
148
149 #[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
150 pub struct InstanceSignature(Vec<u8>);
151
152 impl AsRef<[u8]> for InstanceSignature {
153 fn as_ref(&self) -> &[u8] {
154 &self.0
155 }
156 }
157
158 #[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
159 pub enum AuthenticationSource {
160 User {
161 user: User,
162 token: UserAuthenticationToken,
163 },
164 Instance {
165 instance: Instance,
166 signature: InstanceSignature,
167 },
168 }
169
170 pub struct NetworkMessage(pub Vec<u8>);
171
172 impl Deref for NetworkMessage {
173 type Target = [u8];
174
175 fn deref(&self) -> &Self::Target {
176 &self.0
177 }
178 }
179
180 pub struct AuthenticatedUser(pub User);
181
182 #[derive(Debug, thiserror::Error)]
183 pub enum UserAuthenticationError {
184 #[error("user authentication missing")]
185 Missing,
186 // #[error("{0}")]
187 // InstanceAuthentication(#[from] Error),
188 #[error("user token was invalid")]
189 InvalidToken,
190 #[error("an error has occured")]
191 Other(#[from] Error),
192 }
193
194 pub struct AuthenticatedInstance(Instance);
195
196 impl AuthenticatedInstance {
197 pub fn inner(&self) -> &Instance {
198 &self.0
199 }
200 }
201
202 #[async_trait::async_trait]
203 pub trait FromMessage<S: Send + Sync>: Sized + Send + Sync {
204 async fn from_message(message: &NetworkMessage, state: &S) -> Result<Self, Error>;
205 }
206
207 #[async_trait::async_trait]
208 impl FromMessage<ConnectionState> for AuthenticatedUser {
209 async fn from_message(
210 network_message: &NetworkMessage,
211 state: &ConnectionState,
212 ) -> Result<Self, Error> {
213 let message: Authenticated<HashMap<String, Value>> =
214 serde_json::from_slice(&network_message).map_err(|e| Error::from(e))?;
215
216 let (auth_user, auth_token) = message
217 .source
218 .iter()
219 .filter_map(|auth| {
220 if let AuthenticationSource::User { user, token } = auth {
221 Some((user, token))
222 } else {
223 None
224 }
225 })
226 .next()
227 .ok_or_else(|| UserAuthenticationError::Missing)?;
228
229 let authenticated_instance =
230 AuthenticatedInstance::from_message(network_message, state).await?;
231
232 let public_key_raw = public_key(&auth_user.instance).await?;
233 let verification_key = DecodingKey::from_rsa_pem(public_key_raw.as_bytes()).unwrap();
234
235 let data: TokenData<UserTokenMetadata> = decode(
236 auth_token.as_ref(),
237 &verification_key,
238 &Validation::new(Algorithm::RS256),
239 )
240 .unwrap();
241
242 if data.claims.user != *auth_user
243 || data.claims.generated_for != *authenticated_instance.inner()
244 {
245 Err(Error::from(UserAuthenticationError::InvalidToken))
246 } else {
247 Ok(AuthenticatedUser(data.claims.user))
248 }
249 }
250 }
251
252 #[async_trait::async_trait]
253 impl FromMessage<ConnectionState> for AuthenticatedInstance {
254 async fn from_message(
255 network_message: &NetworkMessage,
256 state: &ConnectionState,
257 ) -> Result<Self, Error> {
258 let message: Authenticated<Value> =
259 serde_json::from_slice(&network_message).map_err(|e| Error::from(e))?;
260
261 let (instance, signature) = message
262 .source
263 .iter()
264 .filter_map(|auth: &AuthenticationSource| {
265 if let AuthenticationSource::Instance {
266 instance,
267 signature,
268 } = auth
269 {
270 Some((instance, signature))
271 } else {
272 None
273 }
274 })
275 .next()
276 // TODO: Instance authentication error
277 .ok_or_else(|| UserAuthenticationError::Missing)?;
278
279 let public_key = {
280 let cached_keys = state.cached_keys.read().await;
281
282 if let Some(key) = cached_keys.get(&instance) {
283 key.clone()
284 } else {
285 drop(cached_keys);
286 let mut cached_keys = state.cached_keys.write().await;
287 let key = public_key(instance).await?;
288 let public_key = RsaPublicKey::from_pkcs1_pem(&key).unwrap();
289 cached_keys.insert(instance.clone(), public_key.clone());
290 public_key
291 }
292 };
293
294 let verifying_key: VerifyingKey<Sha256> = VerifyingKey::new(public_key);
295
296 let message_json = serde_json::to_vec(&message.message).unwrap();
297
298 verifying_key.verify(
299 &message_json,
300 &Signature::try_from(signature.as_ref()).unwrap(),
301 )?;
302
303 Ok(AuthenticatedInstance(instance.clone()))
304 }
305 }
306
307 #[async_trait::async_trait]
308 impl<S, T> FromMessage<S> for Option<T>
309 where
310 T: FromMessage<S>,
311 S: Send + Sync + 'static,
312 {
313 async fn from_message(message: &NetworkMessage, state: &S) -> Result<Self, Error> {
314 Ok(T::from_message(message, state).await.ok())
315 }
316 }
317
318 #[async_trait::async_trait]
319 pub trait MessageHandler<T, S, R> {
320 async fn handle_message(self, message: &NetworkMessage, state: &S) -> Result<R, Error>;
321 }
322 #[async_trait::async_trait]
323 impl<F, R, S, T1, T, E> MessageHandler<(T1,), S, R> for T
324 where
325 T: FnOnce(T1) -> F + Clone + Send + 'static,
326 F: Future<Output = Result<R, E>> + Send,
327 T1: FromMessage<S> + Send,
328 S: Send + Sync,
329 E: std::error::Error + Send + Sync + 'static,
330 {
331 async fn handle_message(self, message: &NetworkMessage, state: &S) -> Result<R, Error> {
332 let value = T1::from_message(message, state).await?;
333 self(value).await.map_err(|e| Error::from(e))
334 }
335 }
336
337 #[async_trait::async_trait]
338 impl<F, R, S, T1, T2, T, E> MessageHandler<(T1, T2), S, R> for T
339 where
340 T: FnOnce(T1, T2) -> F + Clone + Send + 'static,
341 F: Future<Output = Result<R, E>> + Send,
342 T1: FromMessage<S> + Send,
343 T2: FromMessage<S> + Send,
344 S: Send + Sync,
345 E: std::error::Error + Send + Sync + 'static,
346 {
347 async fn handle_message(self, message: &NetworkMessage, state: &S) -> Result<R, Error> {
348 let value = T1::from_message(message, state).await?;
349 let value_2 = T2::from_message(message, state).await?;
350 self(value, value_2).await.map_err(|e| Error::from(e))
351 }
352 }
353
354 #[async_trait::async_trait]
355 impl<F, R, S, T1, T2, T3, T, E> MessageHandler<(T1, T2, T3), S, R> for T
356 where
357 T: FnOnce(T1, T2, T3) -> F + Clone + Send + 'static,
358 F: Future<Output = Result<R, E>> + Send,
359 T1: FromMessage<S> + Send,
360 T2: FromMessage<S> + Send,
361 T3: FromMessage<S> + Send,
362 S: Send + Sync,
363 E: std::error::Error + Send + Sync + 'static,
364 {
365 async fn handle_message(self, message: &NetworkMessage, state: &S) -> Result<R, Error> {
366 let value = T1::from_message(message, state).await?;
367 let value_2 = T2::from_message(message, state).await?;
368 let value_3 = T3::from_message(message, state).await?;
369
370 self(value, value_2, value_3)
371 .await
372 .map_err(|e| Error::from(e))
373 }
374 }
375
376 pub struct State<T>(pub T);
377
378 #[async_trait::async_trait]
379 impl<T> FromMessage<T> for State<T>
380 where
381 T: Clone + Send + Sync,
382 {
383 async fn from_message(_: &NetworkMessage, state: &T) -> Result<Self, Error> {
384 Ok(Self(state.clone()))
385 }
386 }
387
388 // Temp
389 #[async_trait::async_trait]
390 impl<T, S> FromMessage<S> for Message<T>
391 where
392 T: DeserializeOwned + Send + Sync + Serialize,
393 S: Clone + Send + Sync,
394 {
395 async fn from_message(message: &NetworkMessage, _: &S) -> Result<Self, Error> {
396 Ok(Message(serde_json::from_slice(&message)?))
397 }
398 }
399
400 pub struct Message<T: Serialize + DeserializeOwned>(pub T);
401
402 async fn public_key(instance: &Instance) -> Result<String, Error> {
403 let key = reqwest::get(format!("https://{}/.giterated/pubkey.pem", instance.url))
404 .await?
405 .text()
406 .await?;
407
408 Ok(key)
409 }
410