JavaScript is disabled, refresh for a better experience. ambee/giterated

ambee/giterated

Git repository hosting, collaboration, and discovery for the Fediverse.

Major post-refactor cleanup

Amber - ⁨2⁩ years ago

parent: tbd commit: ⁨f90d7fb

⁨src/model/authenticated.rs⁩ - ⁨9581⁩ bytes
Raw
1 use std::{any::type_name, collections::HashMap, ops::Deref};
2
3 use anyhow::Error;
4 use futures_util::Future;
5 use jsonwebtoken::{decode, Algorithm, DecodingKey, TokenData, Validation};
6 use serde::{de::DeserializeOwned, Deserialize, Serialize};
7 use serde_json::Value;
8
9 use crate::{authentication::UserTokenMetadata, connection::wrapper::ConnectionState};
10
11 use super::{instance::Instance, user::User};
12
13 #[derive(Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
14 pub struct Authenticated<T: Serialize> {
15 #[serde(flatten)]
16 source: Vec<AuthenticationSource>,
17 message_type: String,
18 #[serde(flatten)]
19 message: T,
20 }
21
22 pub trait AuthenticationSourceProvider: Sized {
23 fn authenticate(self, payload: &Vec<u8>) -> AuthenticationSource;
24 }
25
26 pub trait AuthenticationSourceProviders: Sized {
27 fn authenticate_all(self, payload: &Vec<u8>) -> Vec<AuthenticationSource>;
28 }
29
30 impl<A> AuthenticationSourceProviders for A
31 where
32 A: AuthenticationSourceProvider,
33 {
34 fn authenticate_all(self, payload: &Vec<u8>) -> Vec<AuthenticationSource> {
35 vec![self.authenticate(payload)]
36 }
37 }
38
39 impl<A, B> AuthenticationSourceProviders for (A, B)
40 where
41 A: AuthenticationSourceProvider,
42 B: AuthenticationSourceProvider,
43 {
44 fn authenticate_all(self, payload: &Vec<u8>) -> Vec<AuthenticationSource> {
45 let (first, second) = self;
46
47 vec![first.authenticate(payload), second.authenticate(payload)]
48 }
49 }
50
51 impl<T: Serialize> Authenticated<T> {
52 pub fn new(message: T, auth_sources: impl AuthenticationSourceProvider) -> Self {
53 let message_payload = serde_json::to_vec(&message).unwrap();
54
55 let authentication = auth_sources.authenticate_all(&message_payload);
56
57 Self {
58 source: authentication,
59 message_type: type_name::<T>().to_string(),
60 message,
61 }
62 }
63
64 pub fn new_empty(message: T) -> Self {
65 Self {
66 source: vec![],
67 message_type: type_name::<T>().to_string(),
68 message,
69 }
70 }
71
72 pub fn append_authentication(&mut self, authentication: impl AuthenticationSourceProvider) {
73 let message_payload = serde_json::to_vec(&self.message).unwrap();
74
75 self.source
76 .push(authentication.authenticate(&message_payload));
77 }
78 }
79
80 mod verified {}
81
82 #[derive(Clone, Debug)]
83 pub struct UserAuthenticator {
84 pub user: User,
85 pub token: UserAuthenticationToken,
86 }
87
88 impl AuthenticationSourceProvider for UserAuthenticator {
89 fn authenticate(self, _payload: &Vec<u8>) -> AuthenticationSource {
90 AuthenticationSource::User {
91 user: self.user,
92 token: self.token,
93 }
94 }
95 }
96
97 #[derive(Clone)]
98 pub struct InstanceAuthenticator<'a> {
99 pub instance: Instance,
100 pub private_key: &'a str,
101 }
102
103 impl AuthenticationSourceProvider for InstanceAuthenticator<'_> {
104 fn authenticate(self, _payload: &Vec<u8>) -> AuthenticationSource {
105 todo!()
106 }
107 }
108
109 #[repr(transparent)]
110 #[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
111 pub struct UserAuthenticationToken(String);
112
113 impl From<String> for UserAuthenticationToken {
114 fn from(value: String) -> Self {
115 Self(value)
116 }
117 }
118
119 impl ToString for UserAuthenticationToken {
120 fn to_string(&self) -> String {
121 self.0.clone()
122 }
123 }
124
125 impl AsRef<str> for UserAuthenticationToken {
126 fn as_ref(&self) -> &str {
127 &self.0
128 }
129 }
130
131 #[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
132 pub struct InstanceSignature(Vec<u8>);
133
134 #[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
135 pub enum AuthenticationSource {
136 User {
137 user: User,
138 token: UserAuthenticationToken,
139 },
140 Instance {
141 instance: Instance,
142 signature: InstanceSignature,
143 },
144 }
145
146 pub struct NetworkMessage(pub Vec<u8>);
147
148 impl Deref for NetworkMessage {
149 type Target = [u8];
150
151 fn deref(&self) -> &Self::Target {
152 &self.0
153 }
154 }
155
156 pub struct AuthenticatedUser(pub User);
157
158 #[derive(Debug, thiserror::Error)]
159 pub enum UserAuthenticationError {
160 #[error("user authentication missing")]
161 Missing,
162 // #[error("{0}")]
163 // InstanceAuthentication(#[from] Error),
164 #[error("user token was invalid")]
165 InvalidToken,
166 #[error("an error has occured")]
167 Other(#[from] Error),
168 }
169
170 pub struct AuthenticatedInstance(Instance);
171
172 impl AuthenticatedInstance {
173 pub fn inner(&self) -> &Instance {
174 &self.0
175 }
176 }
177
178 #[async_trait::async_trait]
179 pub trait FromMessage<S: Send + Sync>: Sized + Send + Sync {
180 async fn from_message(message: &NetworkMessage, state: &S) -> Result<Self, Error>;
181 }
182
183 #[async_trait::async_trait]
184 impl FromMessage<ConnectionState> for AuthenticatedUser {
185 async fn from_message(
186 network_message: &NetworkMessage,
187 state: &ConnectionState,
188 ) -> Result<Self, Error> {
189 let message: Authenticated<HashMap<String, Value>> =
190 serde_json::from_slice(&network_message).map_err(|e| Error::from(e))?;
191
192 let (auth_user, auth_token) = message
193 .source
194 .iter()
195 .filter_map(|auth| {
196 if let AuthenticationSource::User { user, token } = auth {
197 Some((user, token))
198 } else {
199 None
200 }
201 })
202 .next()
203 .ok_or_else(|| UserAuthenticationError::Missing)?;
204
205 let authenticated_instance =
206 AuthenticatedInstance::from_message(network_message, state).await?;
207
208 let public_key_raw = public_key(&auth_user.instance).await?;
209 let verification_key = DecodingKey::from_rsa_pem(public_key_raw.as_bytes()).unwrap();
210
211 let data: TokenData<UserTokenMetadata> = decode(
212 auth_token.as_ref(),
213 &verification_key,
214 &Validation::new(Algorithm::RS256),
215 )
216 .unwrap();
217
218 if data.claims.user != *auth_user
219 || data.claims.generated_for != *authenticated_instance.inner()
220 {
221 Err(Error::from(UserAuthenticationError::InvalidToken))
222 } else {
223 Ok(AuthenticatedUser(data.claims.user))
224 }
225 }
226 }
227
228 #[async_trait::async_trait]
229 impl FromMessage<ConnectionState> for AuthenticatedInstance {
230 async fn from_message(
231 _message: &NetworkMessage,
232 _state: &ConnectionState,
233 ) -> Result<Self, Error> {
234 todo!()
235 }
236 }
237
238 #[async_trait::async_trait]
239 impl<S, T> FromMessage<S> for Option<T>
240 where
241 T: FromMessage<S>,
242 S: Send + Sync + 'static,
243 {
244 async fn from_message(message: &NetworkMessage, state: &S) -> Result<Self, Error> {
245 Ok(T::from_message(message, state).await.ok())
246 }
247 }
248
249 #[async_trait::async_trait]
250 pub trait MessageHandler<T, S, R> {
251 async fn handle_message(self, message: &NetworkMessage, state: &S) -> Result<R, Error>;
252 }
253 #[async_trait::async_trait]
254 impl<F, R, S, T1, T, E> MessageHandler<(T1,), S, R> for T
255 where
256 T: FnOnce(T1) -> F + Clone + Send + 'static,
257 F: Future<Output = Result<R, E>> + Send,
258 T1: FromMessage<S> + Send,
259 S: Send + Sync,
260 E: std::error::Error + Send + Sync + 'static,
261 {
262 async fn handle_message(self, message: &NetworkMessage, state: &S) -> Result<R, Error> {
263 let value = T1::from_message(message, state).await?;
264 self(value).await.map_err(|e| Error::from(e))
265 }
266 }
267
268 #[async_trait::async_trait]
269 impl<F, R, S, T1, T2, T, E> MessageHandler<(T1, T2), S, R> for T
270 where
271 T: FnOnce(T1, T2) -> F + Clone + Send + 'static,
272 F: Future<Output = Result<R, E>> + Send,
273 T1: FromMessage<S> + Send,
274 T2: FromMessage<S> + Send,
275 S: Send + Sync,
276 E: std::error::Error + Send + Sync + 'static,
277 {
278 async fn handle_message(self, message: &NetworkMessage, state: &S) -> Result<R, Error> {
279 let value = T1::from_message(message, state).await?;
280 let value_2 = T2::from_message(message, state).await?;
281 self(value, value_2).await.map_err(|e| Error::from(e))
282 }
283 }
284
285 #[async_trait::async_trait]
286 impl<F, R, S, T1, T2, T3, T, E> MessageHandler<(T1, T2, T3), S, R> for T
287 where
288 T: FnOnce(T1, T2, T3) -> F + Clone + Send + 'static,
289 F: Future<Output = Result<R, E>> + Send,
290 T1: FromMessage<S> + Send,
291 T2: FromMessage<S> + Send,
292 T3: FromMessage<S> + Send,
293 S: Send + Sync,
294 E: std::error::Error + Send + Sync + 'static,
295 {
296 async fn handle_message(self, message: &NetworkMessage, state: &S) -> Result<R, Error> {
297 let value = T1::from_message(message, state).await?;
298 let value_2 = T2::from_message(message, state).await?;
299 let value_3 = T3::from_message(message, state).await?;
300
301 self(value, value_2, value_3)
302 .await
303 .map_err(|e| Error::from(e))
304 }
305 }
306
307 pub struct State<T>(pub T);
308
309 #[async_trait::async_trait]
310 impl<T> FromMessage<T> for State<T>
311 where
312 T: Clone + Send + Sync,
313 {
314 async fn from_message(_: &NetworkMessage, state: &T) -> Result<Self, Error> {
315 Ok(Self(state.clone()))
316 }
317 }
318
319 // Temp
320 #[async_trait::async_trait]
321 impl<T, S> FromMessage<S> for Message<T>
322 where
323 T: DeserializeOwned + Send + Sync + Serialize,
324 S: Clone + Send + Sync,
325 {
326 async fn from_message(message: &NetworkMessage, _: &S) -> Result<Self, Error> {
327 Ok(Message(serde_json::from_slice(&message)?))
328 }
329 }
330
331 pub struct Message<T: Serialize + DeserializeOwned>(pub T);
332
333 async fn public_key(instance: &Instance) -> Result<String, Error> {
334 let key = reqwest::get(format!("https://{}/.giterated/pubkey.pem", instance.url))
335 .await?
336 .text()
337 .await?;
338
339 Ok(key)
340 }
341