JavaScript is disabled, refresh for a better experience. ambee/giterated

ambee/giterated

Git repository hosting, collaboration, and discovery for the Fediverse.

Push to keep up to date with api

Amber - ⁨2⁩ years ago

parent: tbd commit: ⁨1034e0a

⁨src/model/authenticated.rs⁩ - ⁨9889⁩ bytes
Raw
1 use std::{any::type_name, ops::Deref, pin::Pin, str::FromStr};
2
3 use anyhow::Error;
4 use futures_util::{future::BoxFuture, Future, FutureExt};
5 use jsonwebtoken::{decode, Algorithm, DecodingKey, TokenData, Validation};
6 use rsa::{pkcs1::DecodeRsaPublicKey, RsaPublicKey};
7 use serde::{de::DeserializeOwned, Deserialize, Serialize};
8
9 use crate::{
10 authentication::UserTokenMetadata, connection::wrapper::ConnectionState, messages::MessageKind,
11 };
12
13 use super::{instance::Instance, user::User};
14
15 #[derive(Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
16 pub struct Authenticated<T: Serialize> {
17 #[serde(flatten)]
18 source: Vec<AuthenticationSource>,
19 message_type: String,
20 #[serde(flatten)]
21 message: T,
22 }
23
24 pub trait AuthenticationSourceProvider: Sized {
25 fn authenticate(self, payload: &Vec<u8>) -> AuthenticationSource;
26 }
27
28 pub trait AuthenticationSourceProviders: Sized {
29 fn authenticate_all(self, payload: &Vec<u8>) -> Vec<AuthenticationSource>;
30 }
31
32 impl<A> AuthenticationSourceProviders for A
33 where
34 A: AuthenticationSourceProvider,
35 {
36 fn authenticate_all(self, payload: &Vec<u8>) -> Vec<AuthenticationSource> {
37 vec![self.authenticate(payload)]
38 }
39 }
40
41 impl<A, B> AuthenticationSourceProviders for (A, B)
42 where
43 A: AuthenticationSourceProvider,
44 B: AuthenticationSourceProvider,
45 {
46 fn authenticate_all(self, payload: &Vec<u8>) -> Vec<AuthenticationSource> {
47 let (first, second) = self;
48
49 vec![first.authenticate(payload), second.authenticate(payload)]
50 }
51 }
52
53 impl<T: Serialize> Authenticated<T> {
54 pub fn new(message: T, auth_sources: impl AuthenticationSourceProvider) -> Self {
55 let message_payload = serde_json::to_vec(&message).unwrap();
56
57 let authentication = auth_sources.authenticate_all(&message_payload);
58
59 Self {
60 source: authentication,
61 message_type: type_name::<T>().to_string(),
62 message,
63 }
64 }
65
66 pub fn new_empty(message: T) -> Self {
67 Self {
68 source: vec![],
69 message_type: type_name::<T>().to_string(),
70 message,
71 }
72 }
73
74 pub fn append_authentication(&mut self, authentication: impl AuthenticationSourceProvider) {
75 let message_payload = serde_json::to_vec(&self.message).unwrap();
76
77 self.source
78 .push(authentication.authenticate(&message_payload));
79 }
80 }
81
82 mod verified {}
83
84 #[derive(Clone, Debug)]
85 pub struct UserAuthenticator {
86 pub user: User,
87 pub token: UserAuthenticationToken,
88 }
89
90 impl AuthenticationSourceProvider for UserAuthenticator {
91 fn authenticate(self, payload: &Vec<u8>) -> AuthenticationSource {
92 AuthenticationSource::User {
93 user: self.user,
94 token: self.token,
95 }
96 }
97 }
98
99 #[derive(Clone)]
100 pub struct InstanceAuthenticator<'a> {
101 pub instance: Instance,
102 pub private_key: &'a str,
103 }
104
105 impl AuthenticationSourceProvider for InstanceAuthenticator<'_> {
106 fn authenticate(self, payload: &Vec<u8>) -> AuthenticationSource {
107 todo!()
108 }
109 }
110
111 #[repr(transparent)]
112 #[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
113 pub struct UserAuthenticationToken(String);
114
115 impl From<String> for UserAuthenticationToken {
116 fn from(value: String) -> Self {
117 Self(value)
118 }
119 }
120
121 impl ToString for UserAuthenticationToken {
122 fn to_string(&self) -> String {
123 self.0.clone()
124 }
125 }
126
127 impl AsRef<str> for UserAuthenticationToken {
128 fn as_ref(&self) -> &str {
129 &self.0
130 }
131 }
132
133 #[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
134 pub struct InstanceSignature(Vec<u8>);
135
136 #[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
137 pub enum AuthenticationSource {
138 User {
139 user: User,
140 token: UserAuthenticationToken,
141 },
142 Instance {
143 instance: Instance,
144 signature: InstanceSignature,
145 },
146 }
147
148 pub struct NetworkMessage(pub Vec<u8>);
149
150 impl Deref for NetworkMessage {
151 type Target = [u8];
152
153 fn deref(&self) -> &Self::Target {
154 &self.0
155 }
156 }
157
158 pub struct AuthenticatedUser(pub User);
159
160 #[derive(Debug, thiserror::Error)]
161 pub enum UserAuthenticationError {
162 #[error("user authentication missing")]
163 Missing,
164 // #[error("{0}")]
165 // InstanceAuthentication(#[from] Error),
166 #[error("user token was invalid")]
167 InvalidToken,
168 #[error("an error has occured")]
169 Other(#[from] Error),
170 }
171
172 pub struct AuthenticatedInstance(Instance);
173
174 impl AuthenticatedInstance {
175 pub fn inner(&self) -> &Instance {
176 &self.0
177 }
178 }
179
180 #[async_trait::async_trait]
181 pub trait FromMessage<S: Send + Sync>: Sized + Send + Sync {
182 async fn from_message(message: &NetworkMessage, state: &S) -> Result<Self, Error>;
183 }
184
185 #[async_trait::async_trait]
186 impl FromMessage<ConnectionState> for AuthenticatedUser {
187 async fn from_message(
188 network_message: &NetworkMessage,
189 state: &ConnectionState,
190 ) -> Result<Self, Error> {
191 let message: Authenticated<MessageKind> =
192 serde_json::from_slice(&network_message).map_err(|e| Error::from(e))?;
193
194 let (auth_user, auth_token) = message
195 .source
196 .iter()
197 .filter_map(|auth| {
198 if let AuthenticationSource::User { user, token } = auth {
199 Some((user, token))
200 } else {
201 None
202 }
203 })
204 .next()
205 .ok_or_else(|| UserAuthenticationError::Missing)?;
206
207 let authenticated_instance =
208 AuthenticatedInstance::from_message(network_message, state).await?;
209
210 let public_key_raw = public_key(&auth_user.instance).await?;
211 let verification_key = DecodingKey::from_rsa_pem(public_key_raw.as_bytes()).unwrap();
212
213 let data: TokenData<UserTokenMetadata> = decode(
214 auth_token.as_ref(),
215 &verification_key,
216 &Validation::new(Algorithm::RS256),
217 )
218 .unwrap();
219
220 if data.claims.user != *auth_user
221 || data.claims.generated_for != *authenticated_instance.inner()
222 {
223 Err(Error::from(UserAuthenticationError::InvalidToken))
224 } else {
225 Ok(AuthenticatedUser(data.claims.user))
226 }
227 }
228 }
229
230 #[async_trait::async_trait]
231 impl FromMessage<ConnectionState> for AuthenticatedInstance {
232 async fn from_message(
233 message: &NetworkMessage,
234 state: &ConnectionState,
235 ) -> Result<Self, Error> {
236 todo!()
237 }
238 }
239
240 #[async_trait::async_trait]
241 impl FromMessage<ConnectionState> for MessageKind {
242 async fn from_message(
243 message: &NetworkMessage,
244 state: &ConnectionState,
245 ) -> Result<Self, Error> {
246 todo!()
247 }
248 }
249
250 #[async_trait::async_trait]
251 impl<S, T> FromMessage<S> for Option<T>
252 where
253 T: FromMessage<S>,
254 S: Send + Sync + 'static,
255 {
256 async fn from_message(message: &NetworkMessage, state: &S) -> Result<Self, Error> {
257 Ok(T::from_message(message, state).await.ok())
258 }
259 }
260
261 #[async_trait::async_trait]
262 pub trait MessageHandler<T, S, R> {
263 async fn handle_message(self, message: &NetworkMessage, state: &S) -> Result<R, Error>;
264 }
265 #[async_trait::async_trait]
266 impl<F, R, S, T1, T, E> MessageHandler<(T1,), S, R> for T
267 where
268 T: FnOnce(T1) -> F + Clone + Send + 'static,
269 F: Future<Output = Result<R, E>> + Send,
270 T1: FromMessage<S> + Send,
271 S: Send + Sync,
272 E: std::error::Error + Send + Sync + 'static,
273 {
274 async fn handle_message(self, message: &NetworkMessage, state: &S) -> Result<R, Error> {
275 let value = T1::from_message(message, state).await?;
276 self(value).await.map_err(|e| Error::from(e))
277 }
278 }
279
280 #[async_trait::async_trait]
281 impl<F, R, S, T1, T2, T, E> MessageHandler<(T1, T2), S, R> for T
282 where
283 T: FnOnce(T1, T2) -> F + Clone + Send + 'static,
284 F: Future<Output = Result<R, E>> + Send,
285 T1: FromMessage<S> + Send,
286 T2: FromMessage<S> + Send,
287 S: Send + Sync,
288 E: std::error::Error + Send + Sync + 'static,
289 {
290 async fn handle_message(self, message: &NetworkMessage, state: &S) -> Result<R, Error> {
291 let value = T1::from_message(message, state).await?;
292 let value_2 = T2::from_message(message, state).await?;
293 self(value, value_2).await.map_err(|e| Error::from(e))
294 }
295 }
296
297 #[async_trait::async_trait]
298 impl<F, R, S, T1, T2, T3, T, E> MessageHandler<(T1, T2, T3), S, R> for T
299 where
300 T: FnOnce(T1, T2, T3) -> F + Clone + Send + 'static,
301 F: Future<Output = Result<R, E>> + Send,
302 T1: FromMessage<S> + Send,
303 T2: FromMessage<S> + Send,
304 T3: FromMessage<S> + Send,
305 S: Send + Sync,
306 E: std::error::Error + Send + Sync + 'static,
307 {
308 async fn handle_message(self, message: &NetworkMessage, state: &S) -> Result<R, Error> {
309 let value = T1::from_message(message, state).await?;
310 let value_2 = T2::from_message(message, state).await?;
311 let value_3 = T3::from_message(message, state).await?;
312
313 self(value, value_2, value_3)
314 .await
315 .map_err(|e| Error::from(e))
316 }
317 }
318
319 pub struct State<T>(pub T);
320
321 #[async_trait::async_trait]
322 impl<T> FromMessage<T> for State<T>
323 where
324 T: Clone + Send + Sync,
325 {
326 async fn from_message(_: &NetworkMessage, state: &T) -> Result<Self, Error> {
327 Ok(Self(state.clone()))
328 }
329 }
330
331 // Temp
332 #[async_trait::async_trait]
333 impl<T, S> FromMessage<S> for Message<T>
334 where
335 T: DeserializeOwned + Send + Sync + Serialize,
336 S: Clone + Send + Sync,
337 {
338 async fn from_message(message: &NetworkMessage, _: &S) -> Result<Self, Error> {
339 Ok(Message(serde_json::from_slice(&message)?))
340 }
341 }
342
343 pub struct Message<T: Serialize + DeserializeOwned>(pub T);
344
345 async fn public_key(instance: &Instance) -> Result<String, Error> {
346 let key = reqwest::get(format!("https://{}/.giterated/pubkey.pem", instance.url))
347 .await?
348 .text()
349 .await?;
350
351 Ok(key)
352 }
353