JavaScript is disabled, refresh for a better experience. ambee/giterated

ambee/giterated

Git repository hosting, collaboration, and discovery for the Fediverse.

Major connection refactor base

Type: Refactor

Amber - ⁨2⁩ years ago

parent: tbd commit: ⁨8dcc111

⁨src/model/authenticated.rs⁩ - ⁨9403⁩ bytes
Raw
1 use std::{any::type_name, ops::Deref, pin::Pin, str::FromStr};
2
3 use anyhow::Error;
4 use futures_util::{future::BoxFuture, Future, FutureExt};
5 use jsonwebtoken::{decode, Algorithm, DecodingKey, TokenData, Validation};
6 use rsa::{pkcs1::DecodeRsaPublicKey, RsaPublicKey};
7 use serde::{de::DeserializeOwned, Deserialize, Serialize};
8
9 use crate::{
10 authentication::UserTokenMetadata, connection::wrapper::ConnectionState, messages::MessageKind,
11 };
12
13 use super::{instance::Instance, user::User};
14
15 #[derive(Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
16 pub struct Authenticated<T: Serialize> {
17 #[serde(flatten)]
18 source: Vec<AuthenticationSource>,
19 message_type: String,
20 #[serde(flatten)]
21 message: T,
22 }
23
24 pub trait AuthenticationSourceProvider: Sized {
25 fn authenticate(self, payload: &Vec<u8>) -> AuthenticationSource;
26 }
27
28 pub trait AuthenticationSourceProviders: Sized {
29 fn authenticate_all(self, payload: &Vec<u8>) -> Vec<AuthenticationSource>;
30 }
31
32 impl<A> AuthenticationSourceProviders for A
33 where
34 A: AuthenticationSourceProvider,
35 {
36 fn authenticate_all(self, payload: &Vec<u8>) -> Vec<AuthenticationSource> {
37 vec![self.authenticate(payload)]
38 }
39 }
40
41 impl<A, B> AuthenticationSourceProviders for (A, B)
42 where
43 A: AuthenticationSourceProvider,
44 B: AuthenticationSourceProvider,
45 {
46 fn authenticate_all(self, payload: &Vec<u8>) -> Vec<AuthenticationSource> {
47 let (first, second) = self;
48
49 vec![first.authenticate(payload), second.authenticate(payload)]
50 }
51 }
52
53 impl<T: Serialize> Authenticated<T> {
54 pub fn new(message: T, auth_sources: impl AuthenticationSourceProvider) -> Self {
55 let message_payload = serde_json::to_vec(&message).unwrap();
56
57 let authentication = auth_sources.authenticate_all(&message_payload);
58
59 Self {
60 source: authentication,
61 message_type: type_name::<T>().to_string(),
62 message,
63 }
64 }
65 }
66
67 mod verified {}
68
69 pub struct UserAuthenticator {
70 pub user: User,
71 pub token: UserAuthenticationToken,
72 }
73
74 impl AuthenticationSourceProvider for UserAuthenticator {
75 fn authenticate(self, payload: &Vec<u8>) -> AuthenticationSource {
76 AuthenticationSource::User {
77 user: self.user,
78 token: self.token,
79 }
80 }
81 }
82
83 pub struct InstanceAuthenticator<'a> {
84 pub instance: Instance,
85 pub private_key: &'a str,
86 }
87
88 impl AuthenticationSourceProvider for InstanceAuthenticator<'_> {
89 fn authenticate(self, payload: &Vec<u8>) -> AuthenticationSource {
90 todo!()
91 }
92 }
93
94 #[repr(transparent)]
95 #[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
96 pub struct UserAuthenticationToken(String);
97
98 impl From<String> for UserAuthenticationToken {
99 fn from(value: String) -> Self {
100 Self(value)
101 }
102 }
103
104 impl ToString for UserAuthenticationToken {
105 fn to_string(&self) -> String {
106 self.0.clone()
107 }
108 }
109
110 impl AsRef<str> for UserAuthenticationToken {
111 fn as_ref(&self) -> &str {
112 &self.0
113 }
114 }
115
116 #[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
117 pub struct InstanceSignature(Vec<u8>);
118
119 #[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
120 pub enum AuthenticationSource {
121 User {
122 user: User,
123 token: UserAuthenticationToken,
124 },
125 Instance {
126 instance: Instance,
127 signature: InstanceSignature,
128 },
129 }
130
131 pub struct NetworkMessage(pub Vec<u8>);
132
133 impl Deref for NetworkMessage {
134 type Target = [u8];
135
136 fn deref(&self) -> &Self::Target {
137 &self.0
138 }
139 }
140
141 pub struct AuthenticatedUser(pub User);
142
143 #[derive(Debug, thiserror::Error)]
144 pub enum UserAuthenticationError {
145 #[error("user authentication missing")]
146 Missing,
147 // #[error("{0}")]
148 // InstanceAuthentication(#[from] Error),
149 #[error("user token was invalid")]
150 InvalidToken,
151 #[error("an error has occured")]
152 Other(#[from] Error),
153 }
154
155 pub struct AuthenticatedInstance(Instance);
156
157 impl AuthenticatedInstance {
158 pub fn inner(&self) -> &Instance {
159 &self.0
160 }
161 }
162
163 #[async_trait::async_trait]
164 pub trait FromMessage<S: Send + Sync>: Sized + Send + Sync {
165 async fn from_message(message: &NetworkMessage, state: &S) -> Result<Self, Error>;
166 }
167
168 #[async_trait::async_trait]
169 impl FromMessage<ConnectionState> for AuthenticatedUser {
170 async fn from_message(
171 network_message: &NetworkMessage,
172 state: &ConnectionState,
173 ) -> Result<Self, Error> {
174 let message: Authenticated<MessageKind> =
175 serde_json::from_slice(&network_message).map_err(|e| Error::from(e))?;
176
177 let (auth_user, auth_token) = message
178 .source
179 .iter()
180 .filter_map(|auth| {
181 if let AuthenticationSource::User { user, token } = auth {
182 Some((user, token))
183 } else {
184 None
185 }
186 })
187 .next()
188 .ok_or_else(|| UserAuthenticationError::Missing)?;
189
190 let authenticated_instance =
191 AuthenticatedInstance::from_message(network_message, state).await?;
192
193 let public_key_raw = public_key(&auth_user.instance).await?;
194 let verification_key = DecodingKey::from_rsa_pem(public_key_raw.as_bytes()).unwrap();
195
196 let data: TokenData<UserTokenMetadata> = decode(
197 auth_token.as_ref(),
198 &verification_key,
199 &Validation::new(Algorithm::RS256),
200 )
201 .unwrap();
202
203 if data.claims.user != *auth_user
204 || data.claims.generated_for != *authenticated_instance.inner()
205 {
206 Err(Error::from(UserAuthenticationError::InvalidToken))
207 } else {
208 Ok(AuthenticatedUser(data.claims.user))
209 }
210 }
211 }
212
213 #[async_trait::async_trait]
214 impl FromMessage<ConnectionState> for AuthenticatedInstance {
215 async fn from_message(
216 message: &NetworkMessage,
217 state: &ConnectionState,
218 ) -> Result<Self, Error> {
219 todo!()
220 }
221 }
222
223 #[async_trait::async_trait]
224 impl FromMessage<ConnectionState> for MessageKind {
225 async fn from_message(
226 message: &NetworkMessage,
227 state: &ConnectionState,
228 ) -> Result<Self, Error> {
229 todo!()
230 }
231 }
232
233 #[async_trait::async_trait]
234 impl<S, T> FromMessage<S> for Option<T>
235 where
236 T: FromMessage<S>,
237 S: Send + Sync + 'static,
238 {
239 async fn from_message(message: &NetworkMessage, state: &S) -> Result<Self, Error> {
240 Ok(T::from_message(message, state).await.ok())
241 }
242 }
243
244 #[async_trait::async_trait]
245 pub trait MessageHandler<T, S, R> {
246 async fn handle_message(self, message: &NetworkMessage, state: &S) -> Result<R, Error>;
247 }
248 #[async_trait::async_trait]
249 impl<F, R, S, T1, T, E> MessageHandler<(T1,), S, R> for T
250 where
251 T: FnOnce(T1) -> F + Clone + Send + 'static,
252 F: Future<Output = Result<R, E>> + Send,
253 T1: FromMessage<S> + Send,
254 S: Send + Sync,
255 E: std::error::Error + Send + Sync + 'static,
256 {
257 async fn handle_message(self, message: &NetworkMessage, state: &S) -> Result<R, Error> {
258 let value = T1::from_message(message, state).await?;
259 self(value).await.map_err(|e| Error::from(e))
260 }
261 }
262
263 #[async_trait::async_trait]
264 impl<F, R, S, T1, T2, T, E> MessageHandler<(T1, T2), S, R> for T
265 where
266 T: FnOnce(T1, T2) -> F + Clone + Send + 'static,
267 F: Future<Output = Result<R, E>> + Send,
268 T1: FromMessage<S> + Send,
269 T2: FromMessage<S> + Send,
270 S: Send + Sync,
271 E: std::error::Error + Send + Sync + 'static,
272 {
273 async fn handle_message(self, message: &NetworkMessage, state: &S) -> Result<R, Error> {
274 let value = T1::from_message(message, state).await?;
275 let value_2 = T2::from_message(message, state).await?;
276 self(value, value_2).await.map_err(|e| Error::from(e))
277 }
278 }
279
280 #[async_trait::async_trait]
281 impl<F, R, S, T1, T2, T3, T, E> MessageHandler<(T1, T2, T3), S, R> for T
282 where
283 T: FnOnce(T1, T2, T3) -> F + Clone + Send + 'static,
284 F: Future<Output = Result<R, E>> + Send,
285 T1: FromMessage<S> + Send,
286 T2: FromMessage<S> + Send,
287 T3: FromMessage<S> + Send,
288 S: Send + Sync,
289 E: std::error::Error + Send + Sync + 'static,
290 {
291 async fn handle_message(self, message: &NetworkMessage, state: &S) -> Result<R, Error> {
292 let value = T1::from_message(message, state).await?;
293 let value_2 = T2::from_message(message, state).await?;
294 let value_3 = T3::from_message(message, state).await?;
295
296 self(value, value_2, value_3)
297 .await
298 .map_err(|e| Error::from(e))
299 }
300 }
301
302 pub struct State<T>(pub T);
303
304 #[async_trait::async_trait]
305 impl<T> FromMessage<T> for State<T>
306 where
307 T: Clone + Send + Sync,
308 {
309 async fn from_message(_: &NetworkMessage, state: &T) -> Result<Self, Error> {
310 Ok(Self(state.clone()))
311 }
312 }
313
314 // Temp
315 #[async_trait::async_trait]
316 impl<T, S> FromMessage<S> for Message<T>
317 where
318 T: DeserializeOwned + Send + Sync + Serialize,
319 S: Clone + Send + Sync,
320 {
321 async fn from_message(message: &NetworkMessage, _: &S) -> Result<Self, Error> {
322 Ok(Message(serde_json::from_slice(&message)?))
323 }
324 }
325
326 pub struct Message<T: Serialize + DeserializeOwned>(pub T);
327
328 async fn public_key(instance: &Instance) -> Result<String, Error> {
329 let key = reqwest::get(format!("https://{}/.giterated/pubkey.pem", instance.url))
330 .await?
331 .text()
332 .await?;
333
334 Ok(key)
335 }
336