JavaScript is disabled, refresh for a better experience. ambee/giterated

ambee/giterated

Git repository hosting, collaboration, and discovery for the Fediverse.

Check if requesting user is allowed to see the repository

Type: Fix

emilia - ⁨2⁩ years ago

parent: tbd commit: ⁨1b40c1d

⁨src/model/authenticated.rs⁩ - ⁨10698⁩ bytes
Raw
1 use std::{any::type_name, collections::HashMap, ops::Deref};
2
3 use anyhow::Error;
4 use futures_util::Future;
5 use jsonwebtoken::{decode, Algorithm, DecodingKey, TokenData, Validation};
6 use serde::{de::DeserializeOwned, Deserialize, Serialize};
7 use serde_json::Value;
8
9 use crate::{authentication::UserTokenMetadata, connection::wrapper::ConnectionState};
10
11 use super::{instance::Instance, user::User};
12
13 #[derive(Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
14 pub struct Authenticated<T: Serialize> {
15 // TODO: Can not flatten vec's/enums, might want to just type tag the enum instead. not recommended for anything but languages like JSON anyways
16 // #[serde(flatten)]
17 source: Vec<AuthenticationSource>,
18 message_type: String,
19 #[serde(flatten)]
20 message: T,
21 }
22
23 pub trait AuthenticationSourceProvider: Sized {
24 fn authenticate(self, payload: &Vec<u8>) -> AuthenticationSource;
25 }
26
27 pub trait AuthenticationSourceProviders: Sized {
28 fn authenticate_all(self, payload: &Vec<u8>) -> Vec<AuthenticationSource>;
29 }
30
31 impl<A> AuthenticationSourceProviders for A
32 where
33 A: AuthenticationSourceProvider,
34 {
35 fn authenticate_all(self, payload: &Vec<u8>) -> Vec<AuthenticationSource> {
36 vec![self.authenticate(payload)]
37 }
38 }
39
40 impl<A, B> AuthenticationSourceProviders for (A, B)
41 where
42 A: AuthenticationSourceProvider,
43 B: AuthenticationSourceProvider,
44 {
45 fn authenticate_all(self, payload: &Vec<u8>) -> Vec<AuthenticationSource> {
46 let (first, second) = self;
47
48 vec![first.authenticate(payload), second.authenticate(payload)]
49 }
50 }
51
52 impl<T: Serialize> Authenticated<T> {
53 pub fn new(message: T, auth_sources: impl AuthenticationSourceProvider) -> Self {
54 let message_payload = serde_json::to_vec(&message).unwrap();
55
56 let authentication = auth_sources.authenticate_all(&message_payload);
57
58 Self {
59 source: authentication,
60 message_type: type_name::<T>().to_string(),
61 message,
62 }
63 }
64
65 pub fn new_empty(message: T) -> Self {
66 Self {
67 source: vec![],
68 message_type: type_name::<T>().to_string(),
69 message,
70 }
71 }
72
73 pub fn append_authentication(&mut self, authentication: impl AuthenticationSourceProvider) {
74 let message_payload = serde_json::to_vec(&self.message).unwrap();
75
76 self.source
77 .push(authentication.authenticate(&message_payload));
78 }
79 }
80
81 mod verified {}
82
83 #[derive(Clone, Debug)]
84 pub struct UserAuthenticator {
85 pub user: User,
86 pub token: UserAuthenticationToken,
87 }
88
89 impl AuthenticationSourceProvider for UserAuthenticator {
90 fn authenticate(self, _payload: &Vec<u8>) -> AuthenticationSource {
91 AuthenticationSource::User {
92 user: self.user,
93 token: self.token,
94 }
95 }
96 }
97
98 #[derive(Clone)]
99 pub struct InstanceAuthenticator<'a> {
100 pub instance: Instance,
101 pub private_key: &'a str,
102 }
103
104 impl AuthenticationSourceProvider for InstanceAuthenticator<'_> {
105 fn authenticate(self, _payload: &Vec<u8>) -> AuthenticationSource {
106 AuthenticationSource::Instance {
107 instance: self.instance,
108 // TODO: Actually parse signature from private key
109 signature: InstanceSignature(self.private_key.as_bytes().to_vec()),
110 }
111 }
112 }
113
114 #[repr(transparent)]
115 #[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
116 pub struct UserAuthenticationToken(String);
117
118 impl From<String> for UserAuthenticationToken {
119 fn from(value: String) -> Self {
120 Self(value)
121 }
122 }
123
124 impl ToString for UserAuthenticationToken {
125 fn to_string(&self) -> String {
126 self.0.clone()
127 }
128 }
129
130 impl AsRef<str> for UserAuthenticationToken {
131 fn as_ref(&self) -> &str {
132 &self.0
133 }
134 }
135
136 #[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
137 pub struct InstanceSignature(Vec<u8>);
138
139 #[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
140 pub enum AuthenticationSource {
141 User {
142 user: User,
143 token: UserAuthenticationToken,
144 },
145 Instance {
146 instance: Instance,
147 signature: InstanceSignature,
148 },
149 }
150
151 pub struct NetworkMessage(pub Vec<u8>);
152
153 impl Deref for NetworkMessage {
154 type Target = [u8];
155
156 fn deref(&self) -> &Self::Target {
157 &self.0
158 }
159 }
160
161 pub struct AuthenticatedUser(pub User);
162
163 #[derive(Debug, thiserror::Error)]
164 pub enum UserAuthenticationError {
165 #[error("user authentication missing")]
166 Missing,
167 // #[error("{0}")]
168 // InstanceAuthentication(#[from] Error),
169 #[error("user token was invalid")]
170 InvalidToken,
171 #[error("an error has occured")]
172 Other(#[from] Error),
173 }
174
175 pub struct AuthenticatedInstance(Instance);
176
177 impl AuthenticatedInstance {
178 pub fn inner(&self) -> &Instance {
179 &self.0
180 }
181 }
182
183 #[async_trait::async_trait]
184 pub trait FromMessage<S: Send + Sync>: Sized + Send + Sync {
185 async fn from_message(message: &NetworkMessage, state: &S) -> Result<Self, Error>;
186 }
187
188 #[async_trait::async_trait]
189 impl FromMessage<ConnectionState> for AuthenticatedUser {
190 async fn from_message(
191 network_message: &NetworkMessage,
192 state: &ConnectionState,
193 ) -> Result<Self, Error> {
194 let message: Authenticated<HashMap<String, Value>> =
195 serde_json::from_slice(&network_message).map_err(|e| Error::from(e))?;
196
197 let (auth_user, auth_token) = message
198 .source
199 .iter()
200 .filter_map(|auth| {
201 if let AuthenticationSource::User { user, token } = auth {
202 Some((user, token))
203 } else {
204 None
205 }
206 })
207 .next()
208 .ok_or_else(|| UserAuthenticationError::Missing)?;
209
210 let authenticated_instance =
211 AuthenticatedInstance::from_message(network_message, state).await?;
212
213 let public_key_raw = public_key(&auth_user.instance).await?;
214 let verification_key = DecodingKey::from_rsa_pem(public_key_raw.as_bytes()).unwrap();
215
216 let data: TokenData<UserTokenMetadata> = decode(
217 auth_token.as_ref(),
218 &verification_key,
219 &Validation::new(Algorithm::RS256),
220 )
221 .unwrap();
222
223 if data.claims.user != *auth_user
224 || data.claims.generated_for != *authenticated_instance.inner()
225 {
226 Err(Error::from(UserAuthenticationError::InvalidToken))
227 } else {
228 Ok(AuthenticatedUser(data.claims.user))
229 }
230 }
231 }
232
233 #[async_trait::async_trait]
234 impl FromMessage<ConnectionState> for AuthenticatedInstance {
235 async fn from_message(
236 network_message: &NetworkMessage,
237 _state: &ConnectionState,
238 ) -> Result<Self, Error> {
239 let message: Authenticated<HashMap<String, Value>> =
240 serde_json::from_slice(&network_message).map_err(|e| Error::from(e))?;
241
242 let (instance, signature) = message
243 .source
244 .iter()
245 .filter_map(|auth| {
246 if let AuthenticationSource::Instance {
247 instance,
248 signature,
249 } = auth
250 {
251 Some((instance, signature))
252 } else {
253 None
254 }
255 })
256 .next()
257 // TODO: Instance authentication error
258 .ok_or_else(|| UserAuthenticationError::Missing)?;
259
260 // TODO: Actually validate
261
262 Ok(Self(instance.clone()))
263 }
264 }
265
266 #[async_trait::async_trait]
267 impl<S, T> FromMessage<S> for Option<T>
268 where
269 T: FromMessage<S>,
270 S: Send + Sync + 'static,
271 {
272 async fn from_message(message: &NetworkMessage, state: &S) -> Result<Self, Error> {
273 Ok(T::from_message(message, state).await.ok())
274 }
275 }
276
277 #[async_trait::async_trait]
278 pub trait MessageHandler<T, S, R> {
279 async fn handle_message(self, message: &NetworkMessage, state: &S) -> Result<R, Error>;
280 }
281 #[async_trait::async_trait]
282 impl<F, R, S, T1, T, E> MessageHandler<(T1,), S, R> for T
283 where
284 T: FnOnce(T1) -> F + Clone + Send + 'static,
285 F: Future<Output = Result<R, E>> + Send,
286 T1: FromMessage<S> + Send,
287 S: Send + Sync,
288 E: std::error::Error + Send + Sync + 'static,
289 {
290 async fn handle_message(self, message: &NetworkMessage, state: &S) -> Result<R, Error> {
291 let value = T1::from_message(message, state).await?;
292 self(value).await.map_err(|e| Error::from(e))
293 }
294 }
295
296 #[async_trait::async_trait]
297 impl<F, R, S, T1, T2, T, E> MessageHandler<(T1, T2), S, R> for T
298 where
299 T: FnOnce(T1, T2) -> F + Clone + Send + 'static,
300 F: Future<Output = Result<R, E>> + Send,
301 T1: FromMessage<S> + Send,
302 T2: FromMessage<S> + Send,
303 S: Send + Sync,
304 E: std::error::Error + Send + Sync + 'static,
305 {
306 async fn handle_message(self, message: &NetworkMessage, state: &S) -> Result<R, Error> {
307 let value = T1::from_message(message, state).await?;
308 let value_2 = T2::from_message(message, state).await?;
309 self(value, value_2).await.map_err(|e| Error::from(e))
310 }
311 }
312
313 #[async_trait::async_trait]
314 impl<F, R, S, T1, T2, T3, T, E> MessageHandler<(T1, T2, T3), S, R> for T
315 where
316 T: FnOnce(T1, T2, T3) -> F + Clone + Send + 'static,
317 F: Future<Output = Result<R, E>> + Send,
318 T1: FromMessage<S> + Send,
319 T2: FromMessage<S> + Send,
320 T3: FromMessage<S> + Send,
321 S: Send + Sync,
322 E: std::error::Error + Send + Sync + 'static,
323 {
324 async fn handle_message(self, message: &NetworkMessage, state: &S) -> Result<R, Error> {
325 let value = T1::from_message(message, state).await?;
326 let value_2 = T2::from_message(message, state).await?;
327 let value_3 = T3::from_message(message, state).await?;
328
329 self(value, value_2, value_3)
330 .await
331 .map_err(|e| Error::from(e))
332 }
333 }
334
335 pub struct State<T>(pub T);
336
337 #[async_trait::async_trait]
338 impl<T> FromMessage<T> for State<T>
339 where
340 T: Clone + Send + Sync,
341 {
342 async fn from_message(_: &NetworkMessage, state: &T) -> Result<Self, Error> {
343 Ok(Self(state.clone()))
344 }
345 }
346
347 // Temp
348 #[async_trait::async_trait]
349 impl<T, S> FromMessage<S> for Message<T>
350 where
351 T: DeserializeOwned + Send + Sync + Serialize,
352 S: Clone + Send + Sync,
353 {
354 async fn from_message(message: &NetworkMessage, _: &S) -> Result<Self, Error> {
355 Ok(Message(serde_json::from_slice(&message)?))
356 }
357 }
358
359 pub struct Message<T: Serialize + DeserializeOwned>(pub T);
360
361 async fn public_key(instance: &Instance) -> Result<String, Error> {
362 let key = reqwest::get(format!("https://{}/.giterated/pubkey.pem", instance.url))
363 .await?
364 .text()
365 .await?;
366
367 Ok(key)
368 }
369