JavaScript is disabled, refresh for a better experience. ambee/giterated

ambee/giterated

Git repository hosting, collaboration, and discovery for the Fediverse.

Fixes

Amber - ⁨2⁩ years ago

parent: tbd commit: ⁨25c3410

⁨giterated-daemon/src/message.rs⁩ - ⁨8097⁩ bytes
Raw
1 use std::{collections::HashMap, ops::Deref};
2
3 use anyhow::Error;
4 use futures_util::Future;
5 use giterated_models::model::{
6 authenticated::{Authenticated, AuthenticationSource, UserTokenMetadata},
7 instance::Instance,
8 user::User,
9 };
10 use jsonwebtoken::{decode, Algorithm, DecodingKey, TokenData, Validation};
11 use rsa::{
12 pkcs1::{DecodeRsaPrivateKey, DecodeRsaPublicKey},
13 pss::{Signature, VerifyingKey},
14 sha2::Sha256,
15 signature::Verifier,
16 RsaPublicKey,
17 };
18 use serde::{de::DeserializeOwned, Serialize};
19 use serde_json::Value;
20
21 use crate::connection::wrapper::ConnectionState;
22
23 pub struct NetworkMessage(pub Vec<u8>);
24
25 impl Deref for NetworkMessage {
26 type Target = [u8];
27
28 fn deref(&self) -> &Self::Target {
29 &self.0
30 }
31 }
32
33 pub struct AuthenticatedUser(pub User);
34
35 #[derive(Debug, thiserror::Error)]
36 pub enum UserAuthenticationError {
37 #[error("user authentication missing")]
38 Missing,
39 // #[error("{0}")]
40 // InstanceAuthentication(#[from] Error),
41 #[error("user token was invalid")]
42 InvalidToken,
43 #[error("an error has occured")]
44 Other(#[from] Error),
45 }
46
47 pub struct AuthenticatedInstance(Instance);
48
49 impl AuthenticatedInstance {
50 pub fn inner(&self) -> &Instance {
51 &self.0
52 }
53 }
54
55 #[async_trait::async_trait]
56 pub trait FromMessage<S: Send + Sync>: Sized + Send + Sync {
57 async fn from_message(message: &NetworkMessage, state: &S) -> Result<Self, Error>;
58 }
59
60 #[async_trait::async_trait]
61 impl FromMessage<ConnectionState> for AuthenticatedUser {
62 async fn from_message(
63 network_message: &NetworkMessage,
64 state: &ConnectionState,
65 ) -> Result<Self, Error> {
66 let message: Authenticated<HashMap<String, Value>> =
67 serde_json::from_slice(&network_message).map_err(|e| Error::from(e))?;
68
69 let (auth_user, auth_token) = message
70 .source
71 .iter()
72 .filter_map(|auth| {
73 if let AuthenticationSource::User { user, token } = auth {
74 Some((user, token))
75 } else {
76 None
77 }
78 })
79 .next()
80 .ok_or_else(|| UserAuthenticationError::Missing)?;
81
82 let authenticated_instance =
83 AuthenticatedInstance::from_message(network_message, state).await?;
84
85 let public_key_raw = public_key(&auth_user.instance).await?;
86 let verification_key = DecodingKey::from_rsa_pem(public_key_raw.as_bytes()).unwrap();
87
88 let data: TokenData<UserTokenMetadata> = decode(
89 auth_token.as_ref(),
90 &verification_key,
91 &Validation::new(Algorithm::RS256),
92 )
93 .unwrap();
94
95 if data.claims.user != *auth_user
96 || data.claims.generated_for != *authenticated_instance.inner()
97 {
98 Err(Error::from(UserAuthenticationError::InvalidToken))
99 } else {
100 Ok(AuthenticatedUser(data.claims.user))
101 }
102 }
103 }
104
105 #[async_trait::async_trait]
106 impl FromMessage<ConnectionState> for AuthenticatedInstance {
107 async fn from_message(
108 network_message: &NetworkMessage,
109 state: &ConnectionState,
110 ) -> Result<Self, Error> {
111 let message: Authenticated<Value> =
112 serde_json::from_slice(&network_message).map_err(|e| Error::from(e))?;
113
114 let (instance, signature) = message
115 .source
116 .iter()
117 .filter_map(|auth: &AuthenticationSource| {
118 if let AuthenticationSource::Instance {
119 instance,
120 signature,
121 } = auth
122 {
123 Some((instance, signature))
124 } else {
125 None
126 }
127 })
128 .next()
129 // TODO: Instance authentication error
130 .ok_or_else(|| UserAuthenticationError::Missing)?;
131
132 let public_key = {
133 let cached_keys = state.cached_keys.read().await;
134
135 if let Some(key) = cached_keys.get(&instance) {
136 key.clone()
137 } else {
138 drop(cached_keys);
139 let mut cached_keys = state.cached_keys.write().await;
140 let key = public_key(instance).await?;
141 let public_key = RsaPublicKey::from_pkcs1_pem(&key).unwrap();
142 cached_keys.insert(instance.clone(), public_key.clone());
143 public_key
144 }
145 };
146
147 let verifying_key: VerifyingKey<Sha256> = VerifyingKey::new(public_key);
148
149 let message_json = serde_json::to_vec(&message.message).unwrap();
150
151 info!(
152 "Verification against: {}",
153 std::str::from_utf8(&message_json).unwrap()
154 );
155
156 verifying_key.verify(
157 &message_json,
158 &Signature::try_from(signature.as_ref()).unwrap(),
159 )?;
160
161 Ok(AuthenticatedInstance(instance.clone()))
162 }
163 }
164
165 #[async_trait::async_trait]
166 impl<S, T> FromMessage<S> for Option<T>
167 where
168 T: FromMessage<S>,
169 S: Send + Sync + 'static,
170 {
171 async fn from_message(message: &NetworkMessage, state: &S) -> Result<Self, Error> {
172 Ok(T::from_message(message, state).await.ok())
173 }
174 }
175
176 #[async_trait::async_trait]
177 pub trait MessageHandler<T, S, R> {
178 async fn handle_message(self, message: &NetworkMessage, state: &S) -> Result<R, Error>;
179 }
180 #[async_trait::async_trait]
181 impl<F, R, S, T1, T, E> MessageHandler<(T1,), S, R> for T
182 where
183 T: FnOnce(T1) -> F + Clone + Send + 'static,
184 F: Future<Output = Result<R, E>> + Send,
185 T1: FromMessage<S> + Send,
186 S: Send + Sync,
187 E: std::error::Error + Send + Sync + 'static,
188 {
189 async fn handle_message(self, message: &NetworkMessage, state: &S) -> Result<R, Error> {
190 let value = T1::from_message(message, state).await?;
191 self(value).await.map_err(|e| Error::from(e))
192 }
193 }
194
195 #[async_trait::async_trait]
196 impl<F, R, S, T1, T2, T, E> MessageHandler<(T1, T2), S, R> for T
197 where
198 T: FnOnce(T1, T2) -> F + Clone + Send + 'static,
199 F: Future<Output = Result<R, E>> + Send,
200 T1: FromMessage<S> + Send,
201 T2: FromMessage<S> + Send,
202 S: Send + Sync,
203 E: std::error::Error + Send + Sync + 'static,
204 {
205 async fn handle_message(self, message: &NetworkMessage, state: &S) -> Result<R, Error> {
206 let value = T1::from_message(message, state).await?;
207 let value_2 = T2::from_message(message, state).await?;
208 self(value, value_2).await.map_err(|e| Error::from(e))
209 }
210 }
211
212 #[async_trait::async_trait]
213 impl<F, R, S, T1, T2, T3, T, E> MessageHandler<(T1, T2, T3), S, R> for T
214 where
215 T: FnOnce(T1, T2, T3) -> F + Clone + Send + 'static,
216 F: Future<Output = Result<R, E>> + Send,
217 T1: FromMessage<S> + Send,
218 T2: FromMessage<S> + Send,
219 T3: FromMessage<S> + Send,
220 S: Send + Sync,
221 E: std::error::Error + Send + Sync + 'static,
222 {
223 async fn handle_message(self, message: &NetworkMessage, state: &S) -> Result<R, Error> {
224 let value = T1::from_message(message, state).await?;
225 let value_2 = T2::from_message(message, state).await?;
226 let value_3 = T3::from_message(message, state).await?;
227
228 self(value, value_2, value_3)
229 .await
230 .map_err(|e| Error::from(e))
231 }
232 }
233
234 pub struct State<T>(pub T);
235
236 #[async_trait::async_trait]
237 impl<T> FromMessage<T> for State<T>
238 where
239 T: Clone + Send + Sync,
240 {
241 async fn from_message(_: &NetworkMessage, state: &T) -> Result<Self, Error> {
242 Ok(Self(state.clone()))
243 }
244 }
245
246 // Temp
247 #[async_trait::async_trait]
248 impl<T, S> FromMessage<S> for Message<T>
249 where
250 T: DeserializeOwned + Send + Sync + Serialize,
251 S: Clone + Send + Sync,
252 {
253 async fn from_message(message: &NetworkMessage, _: &S) -> Result<Self, Error> {
254 Ok(Message(serde_json::from_slice(&message)?))
255 }
256 }
257
258 pub struct Message<T: Serialize + DeserializeOwned>(pub T);
259
260 async fn public_key(instance: &Instance) -> Result<String, Error> {
261 let key = reqwest::get(format!("https://{}/.giterated/pubkey.pem", instance.url))
262 .await?
263 .text()
264 .await?;
265
266 Ok(key)
267 }
268