JavaScript is disabled, refresh for a better experience. ambee/giterated

ambee/giterated

Git repository hosting, collaboration, and discovery for the Fediverse.

Automatically populate the target instance field from the request using MessageTarget trait

Amber - ⁨2⁩ years ago

parent: tbd commit: ⁨37da513

⁨giterated-daemon/src/message.rs⁩ - ⁨8087⁩ bytes
Raw
1 use std::{fmt::Debug, ops::Deref};
2
3 use anyhow::Error;
4 use futures_util::Future;
5 use giterated_models::model::{
6 authenticated::{AuthenticatedPayload, AuthenticationSource, UserTokenMetadata},
7 instance::Instance,
8 user::User,
9 };
10 use jsonwebtoken::{decode, Algorithm, DecodingKey, TokenData, Validation};
11 use rsa::{
12 pkcs1::DecodeRsaPublicKey,
13 pss::{Signature, VerifyingKey},
14 sha2::Sha256,
15 signature::Verifier,
16 RsaPublicKey,
17 };
18 use serde::{de::DeserializeOwned, Serialize};
19
20 use crate::connection::wrapper::ConnectionState;
21
22 pub struct NetworkMessage(pub Vec<u8>);
23
24 impl Deref for NetworkMessage {
25 type Target = [u8];
26
27 fn deref(&self) -> &Self::Target {
28 &self.0
29 }
30 }
31
32 pub struct AuthenticatedUser(pub User);
33
34 #[derive(Debug, thiserror::Error)]
35 pub enum UserAuthenticationError {
36 #[error("user authentication missing")]
37 Missing,
38 // #[error("{0}")]
39 // InstanceAuthentication(#[from] Error),
40 #[error("user token was invalid")]
41 InvalidToken,
42 #[error("an error has occured")]
43 Other(#[from] Error),
44 }
45
46 pub struct AuthenticatedInstance(Instance);
47
48 impl AuthenticatedInstance {
49 pub fn inner(&self) -> &Instance {
50 &self.0
51 }
52 }
53
54 #[async_trait::async_trait]
55 pub trait FromMessage<S: Send + Sync>: Sized + Send + Sync {
56 async fn from_message(message: &NetworkMessage, state: &S) -> Result<Self, Error>;
57 }
58
59 #[async_trait::async_trait]
60 impl FromMessage<ConnectionState> for AuthenticatedUser {
61 async fn from_message(
62 network_message: &NetworkMessage,
63 state: &ConnectionState,
64 ) -> Result<Self, Error> {
65 let message: AuthenticatedPayload =
66 serde_json::from_slice(&network_message).map_err(|e| Error::from(e))?;
67
68 let (auth_user, auth_token) = message
69 .source
70 .iter()
71 .filter_map(|auth| {
72 if let AuthenticationSource::User { user, token } = auth {
73 Some((user, token))
74 } else {
75 None
76 }
77 })
78 .next()
79 .ok_or_else(|| UserAuthenticationError::Missing)?;
80
81 let authenticated_instance =
82 AuthenticatedInstance::from_message(network_message, state).await?;
83
84 let public_key_raw = state.public_key(&auth_user.instance).await?;
85 let verification_key = DecodingKey::from_rsa_pem(public_key_raw.as_bytes()).unwrap();
86
87 let data: TokenData<UserTokenMetadata> = decode(
88 auth_token.as_ref(),
89 &verification_key,
90 &Validation::new(Algorithm::RS256),
91 )
92 .unwrap();
93
94 if data.claims.user != *auth_user
95 || data.claims.generated_for != *authenticated_instance.inner()
96 {
97 Err(Error::from(UserAuthenticationError::InvalidToken))
98 } else {
99 Ok(AuthenticatedUser(data.claims.user))
100 }
101 }
102 }
103
104 #[async_trait::async_trait]
105 impl FromMessage<ConnectionState> for AuthenticatedInstance {
106 async fn from_message(
107 network_message: &NetworkMessage,
108 state: &ConnectionState,
109 ) -> Result<Self, Error> {
110 let message: AuthenticatedPayload =
111 serde_json::from_slice(&network_message).map_err(|e| Error::from(e))?;
112
113 info!("Authenticated payload: {:?}", message);
114
115 let (instance, signature) = message
116 .source
117 .iter()
118 .filter_map(|auth: &AuthenticationSource| {
119 if let AuthenticationSource::Instance {
120 instance,
121 signature,
122 } = auth
123 {
124 Some((instance, signature))
125 } else {
126 None
127 }
128 })
129 .next()
130 // TODO: Instance authentication error
131 .ok_or_else(|| UserAuthenticationError::Missing)?;
132
133 info!("Instance: {}", instance.clone().to_string());
134
135 info!("Instance public key: {}", state.public_key(instance).await?);
136
137 let public_key = RsaPublicKey::from_pkcs1_pem(&state.public_key(instance).await?).unwrap();
138
139 let verifying_key: VerifyingKey<Sha256> = VerifyingKey::new(public_key);
140
141 verifying_key.verify(
142 &message.payload,
143 &Signature::try_from(signature.as_ref()).unwrap(),
144 )?;
145
146 Ok(AuthenticatedInstance(instance.clone()))
147 }
148 }
149
150 #[async_trait::async_trait]
151 impl<S, T> FromMessage<S> for Option<T>
152 where
153 T: FromMessage<S>,
154 S: Send + Sync + 'static,
155 {
156 async fn from_message(message: &NetworkMessage, state: &S) -> Result<Self, Error> {
157 Ok(T::from_message(message, state).await.ok())
158 }
159 }
160
161 #[async_trait::async_trait]
162 pub trait MessageHandler<T, S, R> {
163 async fn handle_message(self, message: &NetworkMessage, state: &S) -> Result<R, Error>;
164 }
165 #[async_trait::async_trait]
166 impl<F, R, S, T1, T, E> MessageHandler<(T1,), S, R> for T
167 where
168 T: FnOnce(T1) -> F + Clone + Send + 'static,
169 F: Future<Output = Result<R, E>> + Send,
170 T1: FromMessage<S> + Send,
171 S: Send + Sync,
172 E: std::error::Error + Send + Sync + 'static,
173 {
174 async fn handle_message(self, message: &NetworkMessage, state: &S) -> Result<R, Error> {
175 let value = T1::from_message(message, state).await?;
176 self(value).await.map_err(|e| Error::from(e))
177 }
178 }
179
180 #[async_trait::async_trait]
181 impl<F, R, S, T1, T2, T, E> MessageHandler<(T1, T2), S, R> for T
182 where
183 T: FnOnce(T1, T2) -> F + Clone + Send + 'static,
184 F: Future<Output = Result<R, E>> + Send,
185 T1: FromMessage<S> + Send,
186 T2: FromMessage<S> + Send,
187 S: Send + Sync,
188 E: std::error::Error + Send + Sync + 'static,
189 {
190 async fn handle_message(self, message: &NetworkMessage, state: &S) -> Result<R, Error> {
191 let value = T1::from_message(message, state).await?;
192 let value_2 = T2::from_message(message, state).await?;
193 self(value, value_2).await.map_err(|e| Error::from(e))
194 }
195 }
196
197 #[async_trait::async_trait]
198 impl<F, R, S, T1, T2, T3, T, E> MessageHandler<(T1, T2, T3), S, R> for T
199 where
200 T: FnOnce(T1, T2, T3) -> F + Clone + Send + 'static,
201 F: Future<Output = Result<R, E>> + Send,
202 T1: FromMessage<S> + Send,
203 T2: FromMessage<S> + Send,
204 T3: FromMessage<S> + Send,
205 S: Send + Sync,
206 E: std::error::Error + Send + Sync + 'static,
207 {
208 async fn handle_message(self, message: &NetworkMessage, state: &S) -> Result<R, Error> {
209 let value = T1::from_message(message, state).await?;
210 let value_2 = T2::from_message(message, state).await?;
211 let value_3 = T3::from_message(message, state).await?;
212
213 self(value, value_2, value_3)
214 .await
215 .map_err(|e| Error::from(e))
216 }
217 }
218
219 pub struct State<T>(pub T);
220
221 #[async_trait::async_trait]
222 impl<T> FromMessage<T> for State<T>
223 where
224 T: Clone + Send + Sync,
225 {
226 async fn from_message(_: &NetworkMessage, state: &T) -> Result<Self, Error> {
227 Ok(Self(state.clone()))
228 }
229 }
230
231 // Temp
232 #[async_trait::async_trait]
233 impl<T, S> FromMessage<S> for Message<T>
234 where
235 T: DeserializeOwned + Send + Sync + Serialize + Debug,
236 S: Clone + Send + Sync,
237 {
238 async fn from_message(message: &NetworkMessage, _: &S) -> Result<Self, Error> {
239 let payload: AuthenticatedPayload = serde_json::from_slice(&message)?;
240 let payload = bincode::deserialize(&payload.payload)?;
241
242 info!("Deserialized payload: {:#?}", payload);
243
244 Ok(Message(payload))
245 }
246 }
247
248 pub struct Message<T: Serialize + DeserializeOwned>(pub T);
249
250 /// Handshake-specific message type.
251 ///
252 /// Uses basic serde_json-based deserialization to maintain the highest
253 /// level of compatibility across versions.
254 pub struct HandshakeMessage<T: Serialize + DeserializeOwned>(pub T);
255
256 #[async_trait::async_trait]
257 impl<T, S> FromMessage<S> for HandshakeMessage<T>
258 where
259 T: DeserializeOwned + Send + Sync + Serialize,
260 S: Clone + Send + Sync,
261 {
262 async fn from_message(message: &NetworkMessage, _: &S) -> Result<Self, Error> {
263 Ok(HandshakeMessage(serde_json::from_slice(&message.0)?))
264 }
265 }
266