JavaScript is disabled, refresh for a better experience. ambee/giterated

ambee/giterated

Git repository hosting, collaboration, and discovery for the Fediverse.

Fixed imports!

Amber - ⁨2⁩ years ago

parent: tbd commit: ⁨ef0e853

⁨giterated-daemon/src/message.rs⁩ - ⁨8116⁩ bytes
Raw
1 use anyhow::Error;
2 use futures_util::Future;
3
4 use giterated_models::instance::Instance;
5
6 use giterated_models::user::User;
7
8 use giterated_models::authenticated::{
9 AuthenticatedPayload, AuthenticationSource, UserTokenMetadata,
10 };
11 use jsonwebtoken::{decode, Algorithm, DecodingKey, TokenData, Validation};
12 use rsa::{
13 pkcs1::DecodeRsaPublicKey,
14 pss::{Signature, VerifyingKey},
15 sha2::Sha256,
16 signature::Verifier,
17 RsaPublicKey,
18 };
19 use serde::{de::DeserializeOwned, Serialize};
20 use std::{fmt::Debug, ops::Deref};
21
22 use crate::connection::wrapper::ConnectionState;
23
24 pub struct NetworkMessage(pub Vec<u8>);
25
26 impl Deref for NetworkMessage {
27 type Target = [u8];
28
29 fn deref(&self) -> &Self::Target {
30 &self.0
31 }
32 }
33
34 pub struct AuthenticatedUser(pub User);
35
36 #[derive(Debug, thiserror::Error)]
37 pub enum UserAuthenticationError {
38 #[error("user authentication missing")]
39 Missing,
40 // #[error("{0}")]
41 // InstanceAuthentication(#[from] Error),
42 #[error("user token was invalid")]
43 InvalidToken,
44 #[error("an error has occured")]
45 Other(#[from] Error),
46 }
47
48 pub struct AuthenticatedInstance(Instance);
49
50 impl AuthenticatedInstance {
51 pub fn inner(&self) -> &Instance {
52 &self.0
53 }
54 }
55
56 #[async_trait::async_trait]
57 pub trait FromMessage<S: Send + Sync>: Sized + Send + Sync {
58 async fn from_message(message: &NetworkMessage, state: &S) -> Result<Self, Error>;
59 }
60
61 #[async_trait::async_trait]
62 impl FromMessage<ConnectionState> for AuthenticatedUser {
63 async fn from_message(
64 network_message: &NetworkMessage,
65 state: &ConnectionState,
66 ) -> Result<Self, Error> {
67 let message: AuthenticatedPayload =
68 serde_json::from_slice(&network_message).map_err(|e| Error::from(e))?;
69
70 let (auth_user, auth_token) = message
71 .source
72 .iter()
73 .filter_map(|auth| {
74 if let AuthenticationSource::User { user, token } = auth {
75 Some((user, token))
76 } else {
77 None
78 }
79 })
80 .next()
81 .ok_or_else(|| UserAuthenticationError::Missing)?;
82
83 let authenticated_instance =
84 AuthenticatedInstance::from_message(network_message, state).await?;
85
86 let public_key_raw = state.public_key(&auth_user.instance).await?;
87 let verification_key = DecodingKey::from_rsa_pem(public_key_raw.as_bytes()).unwrap();
88
89 let data: TokenData<UserTokenMetadata> = decode(
90 auth_token.as_ref(),
91 &verification_key,
92 &Validation::new(Algorithm::RS256),
93 )
94 .unwrap();
95
96 if data.claims.user != *auth_user
97 || data.claims.generated_for != *authenticated_instance.inner()
98 {
99 Err(Error::from(UserAuthenticationError::InvalidToken))
100 } else {
101 Ok(AuthenticatedUser(data.claims.user))
102 }
103 }
104 }
105
106 #[async_trait::async_trait]
107 impl FromMessage<ConnectionState> for AuthenticatedInstance {
108 async fn from_message(
109 network_message: &NetworkMessage,
110 state: &ConnectionState,
111 ) -> Result<Self, Error> {
112 let message: AuthenticatedPayload =
113 serde_json::from_slice(&network_message).map_err(|e| Error::from(e))?;
114
115 info!("Authenticated payload: {:?}", message);
116
117 let (instance, signature) = message
118 .source
119 .iter()
120 .filter_map(|auth: &AuthenticationSource| {
121 if let AuthenticationSource::Instance {
122 instance,
123 signature,
124 } = auth
125 {
126 Some((instance, signature))
127 } else {
128 None
129 }
130 })
131 .next()
132 // TODO: Instance authentication error
133 .ok_or_else(|| UserAuthenticationError::Missing)?;
134
135 info!("Instance: {}", instance.clone().to_string());
136
137 info!("Instance public key: {}", state.public_key(instance).await?);
138
139 let public_key = RsaPublicKey::from_pkcs1_pem(&state.public_key(instance).await?).unwrap();
140
141 let verifying_key: VerifyingKey<Sha256> = VerifyingKey::new(public_key);
142
143 verifying_key.verify(
144 &message.payload,
145 &Signature::try_from(signature.as_ref()).unwrap(),
146 )?;
147
148 Ok(AuthenticatedInstance(instance.clone()))
149 }
150 }
151
152 #[async_trait::async_trait]
153 impl<S, T> FromMessage<S> for Option<T>
154 where
155 T: FromMessage<S>,
156 S: Send + Sync + 'static,
157 {
158 async fn from_message(message: &NetworkMessage, state: &S) -> Result<Self, Error> {
159 Ok(T::from_message(message, state).await.ok())
160 }
161 }
162
163 #[async_trait::async_trait]
164 pub trait MessageHandler<T, S, R> {
165 async fn handle_message(self, message: &NetworkMessage, state: &S) -> Result<R, Error>;
166 }
167 #[async_trait::async_trait]
168 impl<F, R, S, T1, T, E> MessageHandler<(T1,), S, R> for T
169 where
170 T: FnOnce(T1) -> F + Clone + Send + 'static,
171 F: Future<Output = Result<R, E>> + Send,
172 T1: FromMessage<S> + Send,
173 S: Send + Sync,
174 E: std::error::Error + Send + Sync + 'static,
175 {
176 async fn handle_message(self, message: &NetworkMessage, state: &S) -> Result<R, Error> {
177 let value = T1::from_message(message, state).await?;
178 self(value).await.map_err(|e| Error::from(e))
179 }
180 }
181
182 #[async_trait::async_trait]
183 impl<F, R, S, T1, T2, T, E> MessageHandler<(T1, T2), S, R> for T
184 where
185 T: FnOnce(T1, T2) -> F + Clone + Send + 'static,
186 F: Future<Output = Result<R, E>> + Send,
187 T1: FromMessage<S> + Send,
188 T2: FromMessage<S> + Send,
189 S: Send + Sync,
190 E: std::error::Error + Send + Sync + 'static,
191 {
192 async fn handle_message(self, message: &NetworkMessage, state: &S) -> Result<R, Error> {
193 let value = T1::from_message(message, state).await?;
194 let value_2 = T2::from_message(message, state).await?;
195 self(value, value_2).await.map_err(|e| Error::from(e))
196 }
197 }
198
199 #[async_trait::async_trait]
200 impl<F, R, S, T1, T2, T3, T, E> MessageHandler<(T1, T2, T3), S, R> for T
201 where
202 T: FnOnce(T1, T2, T3) -> F + Clone + Send + 'static,
203 F: Future<Output = Result<R, E>> + Send,
204 T1: FromMessage<S> + Send,
205 T2: FromMessage<S> + Send,
206 T3: FromMessage<S> + Send,
207 S: Send + Sync,
208 E: std::error::Error + Send + Sync + 'static,
209 {
210 async fn handle_message(self, message: &NetworkMessage, state: &S) -> Result<R, Error> {
211 let value = T1::from_message(message, state).await?;
212 let value_2 = T2::from_message(message, state).await?;
213 let value_3 = T3::from_message(message, state).await?;
214
215 self(value, value_2, value_3)
216 .await
217 .map_err(|e| Error::from(e))
218 }
219 }
220
221 pub struct State<T>(pub T);
222
223 #[async_trait::async_trait]
224 impl<T> FromMessage<T> for State<T>
225 where
226 T: Clone + Send + Sync,
227 {
228 async fn from_message(_: &NetworkMessage, state: &T) -> Result<Self, Error> {
229 Ok(Self(state.clone()))
230 }
231 }
232
233 // Temp
234 #[async_trait::async_trait]
235 impl<T, S> FromMessage<S> for Message<T>
236 where
237 T: DeserializeOwned + Send + Sync + Serialize + Debug,
238 S: Clone + Send + Sync,
239 {
240 async fn from_message(message: &NetworkMessage, _: &S) -> Result<Self, Error> {
241 let payload: AuthenticatedPayload = serde_json::from_slice(&message)?;
242 let payload = bincode::deserialize(&payload.payload)?;
243
244 info!("Deserialized payload: {:#?}", payload);
245
246 Ok(Message(payload))
247 }
248 }
249
250 pub struct Message<T: Serialize + DeserializeOwned>(pub T);
251
252 /// Handshake-specific message type.
253 ///
254 /// Uses basic serde_json-based deserialization to maintain the highest
255 /// level of compatibility across versions.
256 pub struct HandshakeMessage<T: Serialize + DeserializeOwned>(pub T);
257
258 #[async_trait::async_trait]
259 impl<T, S> FromMessage<S> for HandshakeMessage<T>
260 where
261 T: DeserializeOwned + Send + Sync + Serialize,
262 S: Clone + Send + Sync,
263 {
264 async fn from_message(message: &NetworkMessage, _: &S) -> Result<Self, Error> {
265 Ok(HandshakeMessage(serde_json::from_slice(&message.0)?))
266 }
267 }
268