JavaScript is disabled, refresh for a better experience. ambee/giterated

ambee/giterated

Git repository hosting, collaboration, and discovery for the Fediverse.

Implement Debug on all messages

Type: Fix

Amber - ⁨2⁩ years ago

parent: tbd commit: ⁨249c88e

⁨giterated-daemon/src/message.rs⁩ - ⁨8544⁩ bytes
Raw
1 use std::{collections::HashMap, ops::Deref};
2
3 use anyhow::Error;
4 use futures_util::Future;
5 use giterated_models::model::{
6 authenticated::{Authenticated, AuthenticatedPayload, AuthenticationSource, UserTokenMetadata},
7 instance::Instance,
8 user::User,
9 };
10 use jsonwebtoken::{decode, Algorithm, DecodingKey, TokenData, Validation};
11 use rsa::{
12 pkcs1::{DecodeRsaPrivateKey, DecodeRsaPublicKey},
13 pss::{Signature, VerifyingKey},
14 sha2::Sha256,
15 signature::Verifier,
16 RsaPublicKey,
17 };
18 use serde::{de::DeserializeOwned, Serialize};
19 use serde_json::Value;
20
21 use crate::connection::wrapper::ConnectionState;
22
23 pub struct NetworkMessage(pub Vec<u8>);
24
25 impl Deref for NetworkMessage {
26 type Target = [u8];
27
28 fn deref(&self) -> &Self::Target {
29 &self.0
30 }
31 }
32
33 pub struct AuthenticatedUser(pub User);
34
35 #[derive(Debug, thiserror::Error)]
36 pub enum UserAuthenticationError {
37 #[error("user authentication missing")]
38 Missing,
39 // #[error("{0}")]
40 // InstanceAuthentication(#[from] Error),
41 #[error("user token was invalid")]
42 InvalidToken,
43 #[error("an error has occured")]
44 Other(#[from] Error),
45 }
46
47 pub struct AuthenticatedInstance(Instance);
48
49 impl AuthenticatedInstance {
50 pub fn inner(&self) -> &Instance {
51 &self.0
52 }
53 }
54
55 #[async_trait::async_trait]
56 pub trait FromMessage<S: Send + Sync>: Sized + Send + Sync {
57 async fn from_message(message: &NetworkMessage, state: &S) -> Result<Self, Error>;
58 }
59
60 #[async_trait::async_trait]
61 impl FromMessage<ConnectionState> for AuthenticatedUser {
62 async fn from_message(
63 network_message: &NetworkMessage,
64 state: &ConnectionState,
65 ) -> Result<Self, Error> {
66 let message: AuthenticatedPayload =
67 serde_json::from_slice(&network_message).map_err(|e| Error::from(e))?;
68
69 let (auth_user, auth_token) = message
70 .source
71 .iter()
72 .filter_map(|auth| {
73 if let AuthenticationSource::User { user, token } = auth {
74 Some((user, token))
75 } else {
76 None
77 }
78 })
79 .next()
80 .ok_or_else(|| UserAuthenticationError::Missing)?;
81
82 let authenticated_instance =
83 AuthenticatedInstance::from_message(network_message, state).await?;
84
85 let public_key_raw = public_key(&auth_user.instance).await?;
86 let verification_key = DecodingKey::from_rsa_pem(public_key_raw.as_bytes()).unwrap();
87
88 let data: TokenData<UserTokenMetadata> = decode(
89 auth_token.as_ref(),
90 &verification_key,
91 &Validation::new(Algorithm::RS256),
92 )
93 .unwrap();
94
95 if data.claims.user != *auth_user
96 || data.claims.generated_for != *authenticated_instance.inner()
97 {
98 Err(Error::from(UserAuthenticationError::InvalidToken))
99 } else {
100 Ok(AuthenticatedUser(data.claims.user))
101 }
102 }
103 }
104
105 #[async_trait::async_trait]
106 impl FromMessage<ConnectionState> for AuthenticatedInstance {
107 async fn from_message(
108 network_message: &NetworkMessage,
109 state: &ConnectionState,
110 ) -> Result<Self, Error> {
111 let message: AuthenticatedPayload =
112 serde_json::from_slice(&network_message).map_err(|e| Error::from(e))?;
113
114 let (instance, signature) = message
115 .source
116 .iter()
117 .filter_map(|auth: &AuthenticationSource| {
118 if let AuthenticationSource::Instance {
119 instance,
120 signature,
121 } = auth
122 {
123 Some((instance, signature))
124 } else {
125 None
126 }
127 })
128 .next()
129 // TODO: Instance authentication error
130 .ok_or_else(|| UserAuthenticationError::Missing)?;
131
132 let public_key = {
133 let cached_keys = state.cached_keys.read().await;
134
135 if let Some(key) = cached_keys.get(&instance) {
136 key.clone()
137 } else {
138 drop(cached_keys);
139 let mut cached_keys = state.cached_keys.write().await;
140 let key = public_key(&instance).await?;
141 let public_key = RsaPublicKey::from_pkcs1_pem(&key).unwrap();
142 cached_keys.insert(instance.clone(), public_key.clone());
143 public_key
144 }
145 };
146
147 let verifying_key: VerifyingKey<Sha256> = VerifyingKey::new(public_key);
148
149 verifying_key.verify(
150 &message.payload,
151 &Signature::try_from(signature.as_ref()).unwrap(),
152 )?;
153
154 Ok(AuthenticatedInstance(instance.clone()))
155 }
156 }
157
158 #[async_trait::async_trait]
159 impl<S, T> FromMessage<S> for Option<T>
160 where
161 T: FromMessage<S>,
162 S: Send + Sync + 'static,
163 {
164 async fn from_message(message: &NetworkMessage, state: &S) -> Result<Self, Error> {
165 Ok(T::from_message(message, state).await.ok())
166 }
167 }
168
169 #[async_trait::async_trait]
170 pub trait MessageHandler<T, S, R> {
171 async fn handle_message(self, message: &NetworkMessage, state: &S) -> Result<R, Error>;
172 }
173 #[async_trait::async_trait]
174 impl<F, R, S, T1, T, E> MessageHandler<(T1,), S, R> for T
175 where
176 T: FnOnce(T1) -> F + Clone + Send + 'static,
177 F: Future<Output = Result<R, E>> + Send,
178 T1: FromMessage<S> + Send,
179 S: Send + Sync,
180 E: std::error::Error + Send + Sync + 'static,
181 {
182 async fn handle_message(self, message: &NetworkMessage, state: &S) -> Result<R, Error> {
183 let value = T1::from_message(message, state).await?;
184 self(value).await.map_err(|e| Error::from(e))
185 }
186 }
187
188 #[async_trait::async_trait]
189 impl<F, R, S, T1, T2, T, E> MessageHandler<(T1, T2), S, R> for T
190 where
191 T: FnOnce(T1, T2) -> F + Clone + Send + 'static,
192 F: Future<Output = Result<R, E>> + Send,
193 T1: FromMessage<S> + Send,
194 T2: FromMessage<S> + Send,
195 S: Send + Sync,
196 E: std::error::Error + Send + Sync + 'static,
197 {
198 async fn handle_message(self, message: &NetworkMessage, state: &S) -> Result<R, Error> {
199 let value = T1::from_message(message, state).await?;
200 let value_2 = T2::from_message(message, state).await?;
201 self(value, value_2).await.map_err(|e| Error::from(e))
202 }
203 }
204
205 #[async_trait::async_trait]
206 impl<F, R, S, T1, T2, T3, T, E> MessageHandler<(T1, T2, T3), S, R> for T
207 where
208 T: FnOnce(T1, T2, T3) -> F + Clone + Send + 'static,
209 F: Future<Output = Result<R, E>> + Send,
210 T1: FromMessage<S> + Send,
211 T2: FromMessage<S> + Send,
212 T3: FromMessage<S> + Send,
213 S: Send + Sync,
214 E: std::error::Error + Send + Sync + 'static,
215 {
216 async fn handle_message(self, message: &NetworkMessage, state: &S) -> Result<R, Error> {
217 let value = T1::from_message(message, state).await?;
218 let value_2 = T2::from_message(message, state).await?;
219 let value_3 = T3::from_message(message, state).await?;
220
221 self(value, value_2, value_3)
222 .await
223 .map_err(|e| Error::from(e))
224 }
225 }
226
227 pub struct State<T>(pub T);
228
229 #[async_trait::async_trait]
230 impl<T> FromMessage<T> for State<T>
231 where
232 T: Clone + Send + Sync,
233 {
234 async fn from_message(_: &NetworkMessage, state: &T) -> Result<Self, Error> {
235 Ok(Self(state.clone()))
236 }
237 }
238
239 // Temp
240 #[async_trait::async_trait]
241 impl<T, S> FromMessage<S> for Message<T>
242 where
243 T: DeserializeOwned + Send + Sync + Serialize,
244 S: Clone + Send + Sync,
245 {
246 async fn from_message(message: &NetworkMessage, _: &S) -> Result<Self, Error> {
247 let payload: AuthenticatedPayload = serde_json::from_slice(&message)?;
248 Ok(Message(bincode::deserialize(&payload.payload)?))
249 }
250 }
251
252 pub struct Message<T: Serialize + DeserializeOwned>(pub T);
253
254 async fn public_key(instance: &Instance) -> Result<String, Error> {
255 let key = reqwest::get(format!("https://{}/.giterated/pubkey.pem", instance.url))
256 .await?
257 .text()
258 .await?;
259
260 Ok(key)
261 }
262
263 /// Handshake-specific message type.
264 ///
265 /// Uses basic serde_json-based deserialization to maintain the highest
266 /// level of compatibility across versions.
267 pub struct HandshakeMessage<T: Serialize + DeserializeOwned>(pub T);
268
269 #[async_trait::async_trait]
270 impl<T, S> FromMessage<S> for HandshakeMessage<T>
271 where
272 T: DeserializeOwned + Send + Sync + Serialize,
273 S: Clone + Send + Sync,
274 {
275 async fn from_message(message: &NetworkMessage, _: &S) -> Result<Self, Error> {
276 Ok(HandshakeMessage(serde_json::from_slice(&message.0)?))
277 }
278 }
279