JavaScript is disabled, refresh for a better experience. ambee/giterated

ambee/giterated

Git repository hosting, collaboration, and discovery for the Fediverse.

Add repository settings

Amber - ⁨2⁩ years ago

parent: tbd commit: ⁨f8eaf38

⁨giterated-daemon/src/message.rs⁩ - ⁨7914⁩ bytes
Raw
1 use std::{fmt::Debug, ops::Deref};
2
3 use anyhow::Error;
4 use futures_util::Future;
5 use giterated_models::model::{
6 authenticated::{AuthenticatedPayload, AuthenticationSource, UserTokenMetadata},
7 instance::Instance,
8 user::User,
9 };
10 use jsonwebtoken::{decode, Algorithm, DecodingKey, TokenData, Validation};
11 use rsa::{
12 pkcs1::{DecodeRsaPrivateKey, DecodeRsaPublicKey},
13 pss::{Signature, VerifyingKey},
14 sha2::Sha256,
15 signature::Verifier,
16 RsaPublicKey,
17 };
18 use serde::{de::DeserializeOwned, Serialize};
19
20 use crate::connection::wrapper::ConnectionState;
21
22 pub struct NetworkMessage(pub Vec<u8>);
23
24 impl Deref for NetworkMessage {
25 type Target = [u8];
26
27 fn deref(&self) -> &Self::Target {
28 &self.0
29 }
30 }
31
32 pub struct AuthenticatedUser(pub User);
33
34 #[derive(Debug, thiserror::Error)]
35 pub enum UserAuthenticationError {
36 #[error("user authentication missing")]
37 Missing,
38 // #[error("{0}")]
39 // InstanceAuthentication(#[from] Error),
40 #[error("user token was invalid")]
41 InvalidToken,
42 #[error("an error has occured")]
43 Other(#[from] Error),
44 }
45
46 pub struct AuthenticatedInstance(Instance);
47
48 impl AuthenticatedInstance {
49 pub fn inner(&self) -> &Instance {
50 &self.0
51 }
52 }
53
54 #[async_trait::async_trait]
55 pub trait FromMessage<S: Send + Sync>: Sized + Send + Sync {
56 async fn from_message(message: &NetworkMessage, state: &S) -> Result<Self, Error>;
57 }
58
59 #[async_trait::async_trait]
60 impl FromMessage<ConnectionState> for AuthenticatedUser {
61 async fn from_message(
62 network_message: &NetworkMessage,
63 state: &ConnectionState,
64 ) -> Result<Self, Error> {
65 let message: AuthenticatedPayload =
66 serde_json::from_slice(&network_message).map_err(|e| Error::from(e))?;
67
68 let (auth_user, auth_token) = message
69 .source
70 .iter()
71 .filter_map(|auth| {
72 if let AuthenticationSource::User { user, token } = auth {
73 Some((user, token))
74 } else {
75 None
76 }
77 })
78 .next()
79 .ok_or_else(|| UserAuthenticationError::Missing)?;
80
81 let authenticated_instance =
82 AuthenticatedInstance::from_message(network_message, state).await?;
83
84 let public_key_raw = state.public_key(&auth_user.instance).await?;
85 let verification_key = DecodingKey::from_rsa_pem(public_key_raw.as_bytes()).unwrap();
86
87 let data: TokenData<UserTokenMetadata> = decode(
88 auth_token.as_ref(),
89 &verification_key,
90 &Validation::new(Algorithm::RS256),
91 )
92 .unwrap();
93
94 if data.claims.user != *auth_user
95 || data.claims.generated_for != *authenticated_instance.inner()
96 {
97 Err(Error::from(UserAuthenticationError::InvalidToken))
98 } else {
99 Ok(AuthenticatedUser(data.claims.user))
100 }
101 }
102 }
103
104 #[async_trait::async_trait]
105 impl FromMessage<ConnectionState> for AuthenticatedInstance {
106 async fn from_message(
107 network_message: &NetworkMessage,
108 state: &ConnectionState,
109 ) -> Result<Self, Error> {
110 let message: AuthenticatedPayload =
111 serde_json::from_slice(&network_message).map_err(|e| Error::from(e))?;
112
113 let (instance, signature) = message
114 .source
115 .iter()
116 .filter_map(|auth: &AuthenticationSource| {
117 if let AuthenticationSource::Instance {
118 instance,
119 signature,
120 } = auth
121 {
122 Some((instance, signature))
123 } else {
124 None
125 }
126 })
127 .next()
128 // TODO: Instance authentication error
129 .ok_or_else(|| UserAuthenticationError::Missing)?;
130
131 let public_key = RsaPublicKey::from_pkcs1_pem(&state.public_key(instance).await?).unwrap();
132
133 let verifying_key: VerifyingKey<Sha256> = VerifyingKey::new(public_key);
134
135 verifying_key.verify(
136 &message.payload,
137 &Signature::try_from(signature.as_ref()).unwrap(),
138 )?;
139
140 Ok(AuthenticatedInstance(instance.clone()))
141 }
142 }
143
144 #[async_trait::async_trait]
145 impl<S, T> FromMessage<S> for Option<T>
146 where
147 T: FromMessage<S>,
148 S: Send + Sync + 'static,
149 {
150 async fn from_message(message: &NetworkMessage, state: &S) -> Result<Self, Error> {
151 Ok(T::from_message(message, state).await.ok())
152 }
153 }
154
155 #[async_trait::async_trait]
156 pub trait MessageHandler<T, S, R> {
157 async fn handle_message(self, message: &NetworkMessage, state: &S) -> Result<R, Error>;
158 }
159 #[async_trait::async_trait]
160 impl<F, R, S, T1, T, E> MessageHandler<(T1,), S, R> for T
161 where
162 T: FnOnce(T1) -> F + Clone + Send + 'static,
163 F: Future<Output = Result<R, E>> + Send,
164 T1: FromMessage<S> + Send,
165 S: Send + Sync,
166 E: std::error::Error + Send + Sync + 'static,
167 {
168 async fn handle_message(self, message: &NetworkMessage, state: &S) -> Result<R, Error> {
169 let value = T1::from_message(message, state).await?;
170 self(value).await.map_err(|e| Error::from(e))
171 }
172 }
173
174 #[async_trait::async_trait]
175 impl<F, R, S, T1, T2, T, E> MessageHandler<(T1, T2), S, R> for T
176 where
177 T: FnOnce(T1, T2) -> F + Clone + Send + 'static,
178 F: Future<Output = Result<R, E>> + Send,
179 T1: FromMessage<S> + Send,
180 T2: FromMessage<S> + Send,
181 S: Send + Sync,
182 E: std::error::Error + Send + Sync + 'static,
183 {
184 async fn handle_message(self, message: &NetworkMessage, state: &S) -> Result<R, Error> {
185 let value = T1::from_message(message, state).await?;
186 let value_2 = T2::from_message(message, state).await?;
187 self(value, value_2).await.map_err(|e| Error::from(e))
188 }
189 }
190
191 #[async_trait::async_trait]
192 impl<F, R, S, T1, T2, T3, T, E> MessageHandler<(T1, T2, T3), S, R> for T
193 where
194 T: FnOnce(T1, T2, T3) -> F + Clone + Send + 'static,
195 F: Future<Output = Result<R, E>> + Send,
196 T1: FromMessage<S> + Send,
197 T2: FromMessage<S> + Send,
198 T3: FromMessage<S> + Send,
199 S: Send + Sync,
200 E: std::error::Error + Send + Sync + 'static,
201 {
202 async fn handle_message(self, message: &NetworkMessage, state: &S) -> Result<R, Error> {
203 let value = T1::from_message(message, state).await?;
204 let value_2 = T2::from_message(message, state).await?;
205 let value_3 = T3::from_message(message, state).await?;
206
207 self(value, value_2, value_3)
208 .await
209 .map_err(|e| Error::from(e))
210 }
211 }
212
213 pub struct State<T>(pub T);
214
215 #[async_trait::async_trait]
216 impl<T> FromMessage<T> for State<T>
217 where
218 T: Clone + Send + Sync,
219 {
220 async fn from_message(_: &NetworkMessage, state: &T) -> Result<Self, Error> {
221 Ok(Self(state.clone()))
222 }
223 }
224
225 // Temp
226 #[async_trait::async_trait]
227 impl<T, S> FromMessage<S> for Message<T>
228 where
229 T: DeserializeOwned + Send + Sync + Serialize + Debug,
230 S: Clone + Send + Sync,
231 {
232 async fn from_message(message: &NetworkMessage, _: &S) -> Result<Self, Error> {
233 let payload: AuthenticatedPayload = serde_json::from_slice(&message)?;
234 let payload = bincode::deserialize(&payload.payload)?;
235
236 info!("Deserialized payload: {:#?}", payload);
237
238 Ok(Message(payload))
239 }
240 }
241
242 pub struct Message<T: Serialize + DeserializeOwned>(pub T);
243
244 /// Handshake-specific message type.
245 ///
246 /// Uses basic serde_json-based deserialization to maintain the highest
247 /// level of compatibility across versions.
248 pub struct HandshakeMessage<T: Serialize + DeserializeOwned>(pub T);
249
250 #[async_trait::async_trait]
251 impl<T, S> FromMessage<S> for HandshakeMessage<T>
252 where
253 T: DeserializeOwned + Send + Sync + Serialize,
254 S: Clone + Send + Sync,
255 {
256 async fn from_message(message: &NetworkMessage, _: &S) -> Result<Self, Error> {
257 Ok(HandshakeMessage(serde_json::from_slice(&message.0)?))
258 }
259 }
260