JavaScript is disabled, refresh for a better experience. ambee/giterated

ambee/giterated

Git repository hosting, collaboration, and discovery for the Fediverse.

Remove unneeded logs

Amber - ⁨2⁩ years ago

parent: tbd commit: ⁨cfba404

⁨giterated-daemon/src/message.rs⁩ - ⁨7864⁩ bytes
Raw
1 use anyhow::Error;
2 use futures_util::Future;
3
4 use giterated_models::instance::Instance;
5
6 use giterated_models::user::User;
7
8 use giterated_models::authenticated::{
9 AuthenticatedPayload, AuthenticationSource, UserTokenMetadata,
10 };
11 use jsonwebtoken::{decode, Algorithm, DecodingKey, TokenData, Validation};
12 use rsa::{
13 pkcs1::DecodeRsaPublicKey,
14 pss::{Signature, VerifyingKey},
15 sha2::Sha256,
16 signature::Verifier,
17 RsaPublicKey,
18 };
19 use serde::{de::DeserializeOwned, Serialize};
20 use std::{fmt::Debug, ops::Deref};
21
22 use crate::connection::wrapper::ConnectionState;
23
24 pub struct NetworkMessage(pub Vec<u8>);
25
26 impl Deref for NetworkMessage {
27 type Target = [u8];
28
29 fn deref(&self) -> &Self::Target {
30 &self.0
31 }
32 }
33
34 pub struct AuthenticatedUser(pub User);
35
36 #[derive(Debug, thiserror::Error)]
37 pub enum UserAuthenticationError {
38 #[error("user authentication missing")]
39 Missing,
40 // #[error("{0}")]
41 // InstanceAuthentication(#[from] Error),
42 #[error("user token was invalid")]
43 InvalidToken,
44 #[error("an error has occured")]
45 Other(#[from] Error),
46 }
47
48 pub struct AuthenticatedInstance(Instance);
49
50 impl AuthenticatedInstance {
51 pub fn inner(&self) -> &Instance {
52 &self.0
53 }
54 }
55
56 #[async_trait::async_trait]
57 pub trait FromMessage<S: Send + Sync>: Sized + Send + Sync {
58 async fn from_message(message: &NetworkMessage, state: &S) -> Result<Self, Error>;
59 }
60
61 #[async_trait::async_trait]
62 impl FromMessage<ConnectionState> for AuthenticatedUser {
63 async fn from_message(
64 network_message: &NetworkMessage,
65 state: &ConnectionState,
66 ) -> Result<Self, Error> {
67 let message: AuthenticatedPayload =
68 serde_json::from_slice(&network_message).map_err(|e| Error::from(e))?;
69
70 let (auth_user, auth_token) = message
71 .source
72 .iter()
73 .filter_map(|auth| {
74 if let AuthenticationSource::User { user, token } = auth {
75 Some((user, token))
76 } else {
77 None
78 }
79 })
80 .next()
81 .ok_or_else(|| UserAuthenticationError::Missing)?;
82
83 let authenticated_instance =
84 AuthenticatedInstance::from_message(network_message, state).await?;
85
86 let public_key_raw = state.public_key(&auth_user.instance).await?;
87 let verification_key = DecodingKey::from_rsa_pem(public_key_raw.as_bytes()).unwrap();
88
89 let data: TokenData<UserTokenMetadata> = decode(
90 auth_token.as_ref(),
91 &verification_key,
92 &Validation::new(Algorithm::RS256),
93 )
94 .unwrap();
95
96 if data.claims.user != *auth_user
97 || data.claims.generated_for != *authenticated_instance.inner()
98 {
99 Err(Error::from(UserAuthenticationError::InvalidToken))
100 } else {
101 Ok(AuthenticatedUser(data.claims.user))
102 }
103 }
104 }
105
106 #[async_trait::async_trait]
107 impl FromMessage<ConnectionState> for AuthenticatedInstance {
108 async fn from_message(
109 network_message: &NetworkMessage,
110 state: &ConnectionState,
111 ) -> Result<Self, Error> {
112 let message: AuthenticatedPayload =
113 serde_json::from_slice(&network_message).map_err(|e| Error::from(e))?;
114
115 let (instance, signature) = message
116 .source
117 .iter()
118 .filter_map(|auth: &AuthenticationSource| {
119 if let AuthenticationSource::Instance {
120 instance,
121 signature,
122 } = auth
123 {
124 Some((instance, signature))
125 } else {
126 None
127 }
128 })
129 .next()
130 // TODO: Instance authentication error
131 .ok_or_else(|| UserAuthenticationError::Missing)?;
132
133 let public_key = RsaPublicKey::from_pkcs1_pem(&state.public_key(instance).await?).unwrap();
134
135 let verifying_key: VerifyingKey<Sha256> = VerifyingKey::new(public_key);
136
137 verifying_key.verify(
138 &message.payload,
139 &Signature::try_from(signature.as_ref()).unwrap(),
140 )?;
141
142 Ok(AuthenticatedInstance(instance.clone()))
143 }
144 }
145
146 #[async_trait::async_trait]
147 impl<S, T> FromMessage<S> for Option<T>
148 where
149 T: FromMessage<S>,
150 S: Send + Sync + 'static,
151 {
152 async fn from_message(message: &NetworkMessage, state: &S) -> Result<Self, Error> {
153 Ok(T::from_message(message, state).await.ok())
154 }
155 }
156
157 #[async_trait::async_trait]
158 pub trait MessageHandler<T, S, R> {
159 async fn handle_message(self, message: &NetworkMessage, state: &S) -> Result<R, Error>;
160 }
161 #[async_trait::async_trait]
162 impl<F, R, S, T1, T, E> MessageHandler<(T1,), S, R> for T
163 where
164 T: FnOnce(T1) -> F + Clone + Send + 'static,
165 F: Future<Output = Result<R, E>> + Send,
166 T1: FromMessage<S> + Send,
167 S: Send + Sync,
168 E: std::error::Error + Send + Sync + 'static,
169 {
170 async fn handle_message(self, message: &NetworkMessage, state: &S) -> Result<R, Error> {
171 let value = T1::from_message(message, state).await?;
172 self(value).await.map_err(|e| Error::from(e))
173 }
174 }
175
176 #[async_trait::async_trait]
177 impl<F, R, S, T1, T2, T, E> MessageHandler<(T1, T2), S, R> for T
178 where
179 T: FnOnce(T1, T2) -> F + Clone + Send + 'static,
180 F: Future<Output = Result<R, E>> + Send,
181 T1: FromMessage<S> + Send,
182 T2: FromMessage<S> + Send,
183 S: Send + Sync,
184 E: std::error::Error + Send + Sync + 'static,
185 {
186 async fn handle_message(self, message: &NetworkMessage, state: &S) -> Result<R, Error> {
187 let value = T1::from_message(message, state).await?;
188 let value_2 = T2::from_message(message, state).await?;
189 self(value, value_2).await.map_err(|e| Error::from(e))
190 }
191 }
192
193 #[async_trait::async_trait]
194 impl<F, R, S, T1, T2, T3, T, E> MessageHandler<(T1, T2, T3), S, R> for T
195 where
196 T: FnOnce(T1, T2, T3) -> F + Clone + Send + 'static,
197 F: Future<Output = Result<R, E>> + Send,
198 T1: FromMessage<S> + Send,
199 T2: FromMessage<S> + Send,
200 T3: FromMessage<S> + Send,
201 S: Send + Sync,
202 E: std::error::Error + Send + Sync + 'static,
203 {
204 async fn handle_message(self, message: &NetworkMessage, state: &S) -> Result<R, Error> {
205 let value = T1::from_message(message, state).await?;
206 let value_2 = T2::from_message(message, state).await?;
207 let value_3 = T3::from_message(message, state).await?;
208
209 self(value, value_2, value_3)
210 .await
211 .map_err(|e| Error::from(e))
212 }
213 }
214
215 pub struct State<T>(pub T);
216
217 #[async_trait::async_trait]
218 impl<T> FromMessage<T> for State<T>
219 where
220 T: Clone + Send + Sync,
221 {
222 async fn from_message(_: &NetworkMessage, state: &T) -> Result<Self, Error> {
223 Ok(Self(state.clone()))
224 }
225 }
226
227 // Temp
228 #[async_trait::async_trait]
229 impl<T, S> FromMessage<S> for Message<T>
230 where
231 T: DeserializeOwned + Send + Sync + Serialize + Debug,
232 S: Clone + Send + Sync,
233 {
234 async fn from_message(message: &NetworkMessage, _: &S) -> Result<Self, Error> {
235 let payload: AuthenticatedPayload = serde_json::from_slice(&message)?;
236 let payload = bincode::deserialize(&payload.payload)?;
237
238 Ok(Message(payload))
239 }
240 }
241
242 pub struct Message<T: Serialize + DeserializeOwned>(pub T);
243
244 /// Handshake-specific message type.
245 ///
246 /// Uses basic serde_json-based deserialization to maintain the highest
247 /// level of compatibility across versions.
248 pub struct HandshakeMessage<T: Serialize + DeserializeOwned>(pub T);
249
250 #[async_trait::async_trait]
251 impl<T, S> FromMessage<S> for HandshakeMessage<T>
252 where
253 T: DeserializeOwned + Send + Sync + Serialize,
254 S: Clone + Send + Sync,
255 {
256 async fn from_message(message: &NetworkMessage, _: &S) -> Result<Self, Error> {
257 Ok(HandshakeMessage(serde_json::from_slice(&message.0)?))
258 }
259 }
260