JavaScript is disabled, refresh for a better experience. ambee/giterated

ambee/giterated

Git repository hosting, collaboration, and discovery for the Fediverse.

Fixes for authentication!

Amber - ⁨2⁩ years ago

parent: tbd commit: ⁨969964d

⁨src/model/authenticated.rs⁩ - ⁨11717⁩ bytes
Raw
1 use std::{any::type_name, collections::HashMap, ops::Deref};
2
3 use anyhow::Error;
4 use futures_util::Future;
5 use jsonwebtoken::{decode, Algorithm, DecodingKey, TokenData, Validation};
6 use rsa::{
7 pkcs1::{DecodeRsaPrivateKey, DecodeRsaPublicKey},
8 pss::{Signature, SigningKey, VerifyingKey},
9 sha2::Sha256,
10 signature::{RandomizedSigner, SignatureEncoding, Verifier},
11 RsaPrivateKey, RsaPublicKey,
12 };
13 use serde::{de::DeserializeOwned, Deserialize, Serialize};
14 use serde_json::Value;
15
16 use crate::{authentication::UserTokenMetadata, connection::wrapper::ConnectionState};
17
18 use super::{instance::Instance, user::User};
19
20 #[derive(Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
21 pub struct Authenticated<T: Serialize> {
22 // TODO: Can not flatten vec's/enums, might want to just type tag the enum instead. not recommended for anything but languages like JSON anyways
23 // #[serde(flatten)]
24 source: Vec<AuthenticationSource>,
25 message_type: String,
26 #[serde(flatten)]
27 message: T,
28 }
29
30 pub trait AuthenticationSourceProvider: Sized {
31 fn authenticate(self, payload: &Vec<u8>) -> AuthenticationSource;
32 }
33
34 pub trait AuthenticationSourceProviders: Sized {
35 fn authenticate_all(self, payload: &Vec<u8>) -> Vec<AuthenticationSource>;
36 }
37
38 impl<A> AuthenticationSourceProviders for A
39 where
40 A: AuthenticationSourceProvider,
41 {
42 fn authenticate_all(self, payload: &Vec<u8>) -> Vec<AuthenticationSource> {
43 vec![self.authenticate(payload)]
44 }
45 }
46
47 impl<A, B> AuthenticationSourceProviders for (A, B)
48 where
49 A: AuthenticationSourceProvider,
50 B: AuthenticationSourceProvider,
51 {
52 fn authenticate_all(self, payload: &Vec<u8>) -> Vec<AuthenticationSource> {
53 let (first, second) = self;
54
55 vec![first.authenticate(payload), second.authenticate(payload)]
56 }
57 }
58
59 impl<T: Serialize> Authenticated<T> {
60 pub fn new(message: T, auth_sources: impl AuthenticationSourceProvider) -> Self {
61 let message_payload = serde_json::to_vec(&message).unwrap();
62
63 let authentication = auth_sources.authenticate_all(&message_payload);
64
65 Self {
66 source: authentication,
67 message_type: type_name::<T>().to_string(),
68 message,
69 }
70 }
71
72 pub fn new_empty(message: T) -> Self {
73 Self {
74 source: vec![],
75 message_type: type_name::<T>().to_string(),
76 message,
77 }
78 }
79
80 pub fn append_authentication(&mut self, authentication: impl AuthenticationSourceProvider) {
81 let message_payload = serde_json::to_vec(&self.message).unwrap();
82
83 self.source
84 .push(authentication.authenticate(&message_payload));
85 }
86 }
87
88 mod verified {}
89
90 #[derive(Clone, Debug)]
91 pub struct UserAuthenticator {
92 pub user: User,
93 pub token: UserAuthenticationToken,
94 }
95
96 impl AuthenticationSourceProvider for UserAuthenticator {
97 fn authenticate(self, _payload: &Vec<u8>) -> AuthenticationSource {
98 AuthenticationSource::User {
99 user: self.user,
100 token: self.token,
101 }
102 }
103 }
104
105 #[derive(Clone)]
106 pub struct InstanceAuthenticator<'a> {
107 pub instance: Instance,
108 pub private_key: &'a str,
109 }
110
111 impl AuthenticationSourceProvider for InstanceAuthenticator<'_> {
112 fn authenticate(self, payload: &Vec<u8>) -> AuthenticationSource {
113 let mut rng = rand::thread_rng();
114
115 let private_key = RsaPrivateKey::from_pkcs1_pem(self.private_key).unwrap();
116 let signing_key = SigningKey::<Sha256>::new(private_key);
117 let signature = signing_key.sign_with_rng(&mut rng, &payload);
118
119 AuthenticationSource::Instance {
120 instance: self.instance,
121 // TODO: Actually parse signature from private key
122 signature: InstanceSignature(signature.to_bytes().into_vec()),
123 }
124 }
125 }
126
127 #[repr(transparent)]
128 #[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
129 pub struct UserAuthenticationToken(String);
130
131 impl From<String> for UserAuthenticationToken {
132 fn from(value: String) -> Self {
133 Self(value)
134 }
135 }
136
137 impl ToString for UserAuthenticationToken {
138 fn to_string(&self) -> String {
139 self.0.clone()
140 }
141 }
142
143 impl AsRef<str> for UserAuthenticationToken {
144 fn as_ref(&self) -> &str {
145 &self.0
146 }
147 }
148
149 #[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
150 pub struct InstanceSignature(Vec<u8>);
151
152 impl AsRef<[u8]> for InstanceSignature {
153 fn as_ref(&self) -> &[u8] {
154 &self.0
155 }
156 }
157
158 #[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
159 pub enum AuthenticationSource {
160 User {
161 user: User,
162 token: UserAuthenticationToken,
163 },
164 Instance {
165 instance: Instance,
166 signature: InstanceSignature,
167 },
168 }
169
170 pub struct NetworkMessage(pub Vec<u8>);
171
172 impl Deref for NetworkMessage {
173 type Target = [u8];
174
175 fn deref(&self) -> &Self::Target {
176 &self.0
177 }
178 }
179
180 pub struct AuthenticatedUser(pub User);
181
182 #[derive(Debug, thiserror::Error)]
183 pub enum UserAuthenticationError {
184 #[error("user authentication missing")]
185 Missing,
186 // #[error("{0}")]
187 // InstanceAuthentication(#[from] Error),
188 #[error("user token was invalid")]
189 InvalidToken,
190 #[error("an error has occured")]
191 Other(#[from] Error),
192 }
193
194 pub struct AuthenticatedInstance(Instance);
195
196 impl AuthenticatedInstance {
197 pub fn inner(&self) -> &Instance {
198 &self.0
199 }
200 }
201
202 #[async_trait::async_trait]
203 pub trait FromMessage<S: Send + Sync>: Sized + Send + Sync {
204 async fn from_message(message: &NetworkMessage, state: &S) -> Result<Self, Error>;
205 }
206
207 #[async_trait::async_trait]
208 impl FromMessage<ConnectionState> for AuthenticatedUser {
209 async fn from_message(
210 network_message: &NetworkMessage,
211 state: &ConnectionState,
212 ) -> Result<Self, Error> {
213 let message: Authenticated<HashMap<String, Value>> =
214 serde_json::from_slice(&network_message).map_err(|e| Error::from(e))?;
215
216 let (auth_user, auth_token) = message
217 .source
218 .iter()
219 .filter_map(|auth| {
220 if let AuthenticationSource::User { user, token } = auth {
221 Some((user, token))
222 } else {
223 None
224 }
225 })
226 .next()
227 .ok_or_else(|| UserAuthenticationError::Missing)?;
228
229 let authenticated_instance =
230 AuthenticatedInstance::from_message(network_message, state).await?;
231
232 let public_key_raw = public_key(&auth_user.instance).await?;
233 let verification_key = DecodingKey::from_rsa_pem(public_key_raw.as_bytes()).unwrap();
234
235 let data: TokenData<UserTokenMetadata> = decode(
236 auth_token.as_ref(),
237 &verification_key,
238 &Validation::new(Algorithm::RS256),
239 )
240 .unwrap();
241
242 if data.claims.user != *auth_user
243 || data.claims.generated_for != *authenticated_instance.inner()
244 {
245 Err(Error::from(UserAuthenticationError::InvalidToken))
246 } else {
247 Ok(AuthenticatedUser(data.claims.user))
248 }
249 }
250 }
251
252 #[async_trait::async_trait]
253 impl FromMessage<ConnectionState> for AuthenticatedInstance {
254 async fn from_message(
255 network_message: &NetworkMessage,
256 state: &ConnectionState,
257 ) -> Result<Self, Error> {
258 let message: Authenticated<Value> =
259 serde_json::from_slice(&network_message).map_err(|e| Error::from(e))?;
260
261 let (instance, signature) = message
262 .source
263 .iter()
264 .filter_map(|auth| {
265 if let AuthenticationSource::Instance {
266 instance,
267 signature,
268 } = auth
269 {
270 Some((instance, signature))
271 } else {
272 None
273 }
274 })
275 .next()
276 // TODO: Instance authentication error
277 .ok_or_else(|| UserAuthenticationError::Missing)?;
278
279 let public_key = public_key(instance).await?;
280 let public_key = RsaPublicKey::from_pkcs1_pem(&public_key).unwrap();
281
282 let verifying_key: VerifyingKey<Sha256> = VerifyingKey::new(public_key);
283
284 let message_json = serde_json::to_vec(&message.message).unwrap();
285
286 verifying_key
287 .verify(
288 &message_json,
289 &Signature::try_from(signature.as_ref()).unwrap(),
290 )
291 .unwrap();
292
293 Ok(AuthenticatedInstance(instance.clone()))
294 }
295 }
296
297 #[async_trait::async_trait]
298 impl<S, T> FromMessage<S> for Option<T>
299 where
300 T: FromMessage<S>,
301 S: Send + Sync + 'static,
302 {
303 async fn from_message(message: &NetworkMessage, state: &S) -> Result<Self, Error> {
304 Ok(T::from_message(message, state).await.ok())
305 }
306 }
307
308 #[async_trait::async_trait]
309 pub trait MessageHandler<T, S, R> {
310 async fn handle_message(self, message: &NetworkMessage, state: &S) -> Result<R, Error>;
311 }
312 #[async_trait::async_trait]
313 impl<F, R, S, T1, T, E> MessageHandler<(T1,), S, R> for T
314 where
315 T: FnOnce(T1) -> F + Clone + Send + 'static,
316 F: Future<Output = Result<R, E>> + Send,
317 T1: FromMessage<S> + Send,
318 S: Send + Sync,
319 E: std::error::Error + Send + Sync + 'static,
320 {
321 async fn handle_message(self, message: &NetworkMessage, state: &S) -> Result<R, Error> {
322 let value = T1::from_message(message, state).await?;
323 self(value).await.map_err(|e| Error::from(e))
324 }
325 }
326
327 #[async_trait::async_trait]
328 impl<F, R, S, T1, T2, T, E> MessageHandler<(T1, T2), S, R> for T
329 where
330 T: FnOnce(T1, T2) -> F + Clone + Send + 'static,
331 F: Future<Output = Result<R, E>> + Send,
332 T1: FromMessage<S> + Send,
333 T2: FromMessage<S> + Send,
334 S: Send + Sync,
335 E: std::error::Error + Send + Sync + 'static,
336 {
337 async fn handle_message(self, message: &NetworkMessage, state: &S) -> Result<R, Error> {
338 let value = T1::from_message(message, state).await?;
339 let value_2 = T2::from_message(message, state).await?;
340 self(value, value_2).await.map_err(|e| Error::from(e))
341 }
342 }
343
344 #[async_trait::async_trait]
345 impl<F, R, S, T1, T2, T3, T, E> MessageHandler<(T1, T2, T3), S, R> for T
346 where
347 T: FnOnce(T1, T2, T3) -> F + Clone + Send + 'static,
348 F: Future<Output = Result<R, E>> + Send,
349 T1: FromMessage<S> + Send,
350 T2: FromMessage<S> + Send,
351 T3: FromMessage<S> + Send,
352 S: Send + Sync,
353 E: std::error::Error + Send + Sync + 'static,
354 {
355 async fn handle_message(self, message: &NetworkMessage, state: &S) -> Result<R, Error> {
356 let value = T1::from_message(message, state).await?;
357 let value_2 = T2::from_message(message, state).await?;
358 let value_3 = T3::from_message(message, state).await?;
359
360 self(value, value_2, value_3)
361 .await
362 .map_err(|e| Error::from(e))
363 }
364 }
365
366 pub struct State<T>(pub T);
367
368 #[async_trait::async_trait]
369 impl<T> FromMessage<T> for State<T>
370 where
371 T: Clone + Send + Sync,
372 {
373 async fn from_message(_: &NetworkMessage, state: &T) -> Result<Self, Error> {
374 Ok(Self(state.clone()))
375 }
376 }
377
378 // Temp
379 #[async_trait::async_trait]
380 impl<T, S> FromMessage<S> for Message<T>
381 where
382 T: DeserializeOwned + Send + Sync + Serialize,
383 S: Clone + Send + Sync,
384 {
385 async fn from_message(message: &NetworkMessage, _: &S) -> Result<Self, Error> {
386 Ok(Message(serde_json::from_slice(&message)?))
387 }
388 }
389
390 pub struct Message<T: Serialize + DeserializeOwned>(pub T);
391
392 async fn public_key(instance: &Instance) -> Result<String, Error> {
393 let key = reqwest::get(format!("https://{}/.giterated/pubkey.pem", instance.url))
394 .await?
395 .text()
396 .await?;
397
398 Ok(key)
399 }
400