JavaScript is disabled, refresh for a better experience. ambee/giterated

ambee/giterated

Git repository hosting, collaboration, and discovery for the Fediverse.

Add more aggressive key caching

Amber - ⁨2⁩ years ago

parent: tbd commit: ⁨5bc92ad

⁨giterated-daemon/src/message.rs⁩ - ⁨7974⁩ bytes
Raw
1 use std::{collections::HashMap, fmt::Debug, ops::Deref};
2
3 use anyhow::Error;
4 use futures_util::Future;
5 use giterated_models::model::{
6 authenticated::{Authenticated, AuthenticatedPayload, AuthenticationSource, UserTokenMetadata},
7 instance::Instance,
8 user::User,
9 };
10 use jsonwebtoken::{decode, Algorithm, DecodingKey, TokenData, Validation};
11 use rsa::{
12 pkcs1::{DecodeRsaPrivateKey, DecodeRsaPublicKey},
13 pss::{Signature, VerifyingKey},
14 sha2::Sha256,
15 signature::Verifier,
16 RsaPublicKey,
17 };
18 use serde::{de::DeserializeOwned, Serialize};
19 use serde_json::Value;
20
21 use crate::connection::wrapper::ConnectionState;
22
23 pub struct NetworkMessage(pub Vec<u8>);
24
25 impl Deref for NetworkMessage {
26 type Target = [u8];
27
28 fn deref(&self) -> &Self::Target {
29 &self.0
30 }
31 }
32
33 pub struct AuthenticatedUser(pub User);
34
35 #[derive(Debug, thiserror::Error)]
36 pub enum UserAuthenticationError {
37 #[error("user authentication missing")]
38 Missing,
39 // #[error("{0}")]
40 // InstanceAuthentication(#[from] Error),
41 #[error("user token was invalid")]
42 InvalidToken,
43 #[error("an error has occured")]
44 Other(#[from] Error),
45 }
46
47 pub struct AuthenticatedInstance(Instance);
48
49 impl AuthenticatedInstance {
50 pub fn inner(&self) -> &Instance {
51 &self.0
52 }
53 }
54
55 #[async_trait::async_trait]
56 pub trait FromMessage<S: Send + Sync>: Sized + Send + Sync {
57 async fn from_message(message: &NetworkMessage, state: &S) -> Result<Self, Error>;
58 }
59
60 #[async_trait::async_trait]
61 impl FromMessage<ConnectionState> for AuthenticatedUser {
62 async fn from_message(
63 network_message: &NetworkMessage,
64 state: &ConnectionState,
65 ) -> Result<Self, Error> {
66 let message: AuthenticatedPayload =
67 serde_json::from_slice(&network_message).map_err(|e| Error::from(e))?;
68
69 let (auth_user, auth_token) = message
70 .source
71 .iter()
72 .filter_map(|auth| {
73 if let AuthenticationSource::User { user, token } = auth {
74 Some((user, token))
75 } else {
76 None
77 }
78 })
79 .next()
80 .ok_or_else(|| UserAuthenticationError::Missing)?;
81
82 let authenticated_instance =
83 AuthenticatedInstance::from_message(network_message, state).await?;
84
85 let public_key_raw = state.public_key(&auth_user.instance).await?;
86 let verification_key = DecodingKey::from_rsa_pem(public_key_raw.as_bytes()).unwrap();
87
88 let data: TokenData<UserTokenMetadata> = decode(
89 auth_token.as_ref(),
90 &verification_key,
91 &Validation::new(Algorithm::RS256),
92 )
93 .unwrap();
94
95 if data.claims.user != *auth_user
96 || data.claims.generated_for != *authenticated_instance.inner()
97 {
98 Err(Error::from(UserAuthenticationError::InvalidToken))
99 } else {
100 Ok(AuthenticatedUser(data.claims.user))
101 }
102 }
103 }
104
105 #[async_trait::async_trait]
106 impl FromMessage<ConnectionState> for AuthenticatedInstance {
107 async fn from_message(
108 network_message: &NetworkMessage,
109 state: &ConnectionState,
110 ) -> Result<Self, Error> {
111 let message: AuthenticatedPayload =
112 serde_json::from_slice(&network_message).map_err(|e| Error::from(e))?;
113
114 let (instance, signature) = message
115 .source
116 .iter()
117 .filter_map(|auth: &AuthenticationSource| {
118 if let AuthenticationSource::Instance {
119 instance,
120 signature,
121 } = auth
122 {
123 Some((instance, signature))
124 } else {
125 None
126 }
127 })
128 .next()
129 // TODO: Instance authentication error
130 .ok_or_else(|| UserAuthenticationError::Missing)?;
131
132 let public_key = RsaPublicKey::from_pkcs1_pem(&state.public_key(instance).await?).unwrap();
133
134 let verifying_key: VerifyingKey<Sha256> = VerifyingKey::new(public_key);
135
136 verifying_key.verify(
137 &message.payload,
138 &Signature::try_from(signature.as_ref()).unwrap(),
139 )?;
140
141 Ok(AuthenticatedInstance(instance.clone()))
142 }
143 }
144
145 #[async_trait::async_trait]
146 impl<S, T> FromMessage<S> for Option<T>
147 where
148 T: FromMessage<S>,
149 S: Send + Sync + 'static,
150 {
151 async fn from_message(message: &NetworkMessage, state: &S) -> Result<Self, Error> {
152 Ok(T::from_message(message, state).await.ok())
153 }
154 }
155
156 #[async_trait::async_trait]
157 pub trait MessageHandler<T, S, R> {
158 async fn handle_message(self, message: &NetworkMessage, state: &S) -> Result<R, Error>;
159 }
160 #[async_trait::async_trait]
161 impl<F, R, S, T1, T, E> MessageHandler<(T1,), S, R> for T
162 where
163 T: FnOnce(T1) -> F + Clone + Send + 'static,
164 F: Future<Output = Result<R, E>> + Send,
165 T1: FromMessage<S> + Send,
166 S: Send + Sync,
167 E: std::error::Error + Send + Sync + 'static,
168 {
169 async fn handle_message(self, message: &NetworkMessage, state: &S) -> Result<R, Error> {
170 let value = T1::from_message(message, state).await?;
171 self(value).await.map_err(|e| Error::from(e))
172 }
173 }
174
175 #[async_trait::async_trait]
176 impl<F, R, S, T1, T2, T, E> MessageHandler<(T1, T2), S, R> for T
177 where
178 T: FnOnce(T1, T2) -> F + Clone + Send + 'static,
179 F: Future<Output = Result<R, E>> + Send,
180 T1: FromMessage<S> + Send,
181 T2: FromMessage<S> + Send,
182 S: Send + Sync,
183 E: std::error::Error + Send + Sync + 'static,
184 {
185 async fn handle_message(self, message: &NetworkMessage, state: &S) -> Result<R, Error> {
186 let value = T1::from_message(message, state).await?;
187 let value_2 = T2::from_message(message, state).await?;
188 self(value, value_2).await.map_err(|e| Error::from(e))
189 }
190 }
191
192 #[async_trait::async_trait]
193 impl<F, R, S, T1, T2, T3, T, E> MessageHandler<(T1, T2, T3), S, R> for T
194 where
195 T: FnOnce(T1, T2, T3) -> F + Clone + Send + 'static,
196 F: Future<Output = Result<R, E>> + Send,
197 T1: FromMessage<S> + Send,
198 T2: FromMessage<S> + Send,
199 T3: FromMessage<S> + Send,
200 S: Send + Sync,
201 E: std::error::Error + Send + Sync + 'static,
202 {
203 async fn handle_message(self, message: &NetworkMessage, state: &S) -> Result<R, Error> {
204 let value = T1::from_message(message, state).await?;
205 let value_2 = T2::from_message(message, state).await?;
206 let value_3 = T3::from_message(message, state).await?;
207
208 self(value, value_2, value_3)
209 .await
210 .map_err(|e| Error::from(e))
211 }
212 }
213
214 pub struct State<T>(pub T);
215
216 #[async_trait::async_trait]
217 impl<T> FromMessage<T> for State<T>
218 where
219 T: Clone + Send + Sync,
220 {
221 async fn from_message(_: &NetworkMessage, state: &T) -> Result<Self, Error> {
222 Ok(Self(state.clone()))
223 }
224 }
225
226 // Temp
227 #[async_trait::async_trait]
228 impl<T, S> FromMessage<S> for Message<T>
229 where
230 T: DeserializeOwned + Send + Sync + Serialize + Debug,
231 S: Clone + Send + Sync,
232 {
233 async fn from_message(message: &NetworkMessage, _: &S) -> Result<Self, Error> {
234 let payload: AuthenticatedPayload = serde_json::from_slice(&message)?;
235 let payload = bincode::deserialize(&payload.payload)?;
236
237 info!("Deserialized payload: {:#?}", payload);
238
239 Ok(Message(payload))
240 }
241 }
242
243 pub struct Message<T: Serialize + DeserializeOwned>(pub T);
244
245 /// Handshake-specific message type.
246 ///
247 /// Uses basic serde_json-based deserialization to maintain the highest
248 /// level of compatibility across versions.
249 pub struct HandshakeMessage<T: Serialize + DeserializeOwned>(pub T);
250
251 #[async_trait::async_trait]
252 impl<T, S> FromMessage<S> for HandshakeMessage<T>
253 where
254 T: DeserializeOwned + Send + Sync + Serialize,
255 S: Clone + Send + Sync,
256 {
257 async fn from_message(message: &NetworkMessage, _: &S) -> Result<Self, Error> {
258 Ok(HandshakeMessage(serde_json::from_slice(&message.0)?))
259 }
260 }
261