JavaScript is disabled, refresh for a better experience. ambee/giterated

ambee/giterated

Git repository hosting, collaboration, and discovery for the Fediverse.

Add authentication back into the operation states

Amber - ⁨2⁩ years ago

parent: tbd commit: ⁨97a26fd

⁨giterated-daemon/src/connection/wrapper.rs⁩ - ⁨13507⁩ bytes
Raw
1 use std::{
2 net::SocketAddr,
3 ops::Deref,
4 sync::{atomic::AtomicBool, Arc},
5 };
6
7 use anyhow::Error;
8 use futures_util::{SinkExt, StreamExt};
9
10 use giterated_models::{
11 authenticated::{AuthenticationSource, UserTokenMetadata},
12 error::OperationError,
13 instance::Instance,
14 };
15
16 use giterated_models::object_backend::ObjectBackend;
17
18 use giterated_models::{
19 authenticated::AuthenticatedPayload, message::GiteratedMessage, object::AnyObject,
20 operation::AnyOperation,
21 };
22 use giterated_stack::{
23 handler::GiteratedBackend, AuthenticatedInstance, AuthenticatedUser, StackOperationState,
24 };
25 use jsonwebtoken::{DecodingKey, TokenData, Validation};
26 use rsa::{
27 pkcs1::DecodeRsaPublicKey,
28 pss::{Signature, VerifyingKey},
29 sha2::Sha256,
30 signature::Verifier,
31 RsaPublicKey,
32 };
33 use serde::Serialize;
34
35 use tokio::{net::TcpStream, sync::Mutex};
36 use tokio_tungstenite::{tungstenite::Message, WebSocketStream};
37 use toml::Table;
38
39 use crate::{
40 authentication::AuthenticationTokenGranter,
41 backend::{MetadataBackend, RepositoryBackend, UserBackend},
42 database_backend::DatabaseBackend,
43 federation::connections::InstanceConnections,
44 keys::PublicKeyCache,
45 };
46
47 use super::Connections;
48
49 pub async fn connection_wrapper(
50 socket: WebSocketStream<TcpStream>,
51 connections: Arc<Mutex<Connections>>,
52 repository_backend: Arc<Mutex<dyn RepositoryBackend + Send>>,
53 user_backend: Arc<Mutex<dyn UserBackend + Send>>,
54 auth_granter: Arc<Mutex<AuthenticationTokenGranter>>,
55 settings_backend: Arc<Mutex<dyn MetadataBackend + Send>>,
56 addr: SocketAddr,
57 instance: impl ToOwned<Owned = Instance>,
58 instance_connections: Arc<Mutex<InstanceConnections>>,
59 config: Table,
60 backend: GiteratedBackend<DatabaseBackend>,
61 mut operation_state: StackOperationState,
62 ) {
63 let connection_state = ConnectionState {
64 socket: Arc::new(Mutex::new(socket)),
65 connections,
66 repository_backend,
67 user_backend,
68 auth_granter,
69 settings_backend,
70 addr,
71 instance: instance.to_owned(),
72 handshaked: Arc::new(AtomicBool::new(false)),
73 key_cache: Arc::default(),
74 instance_connections: instance_connections.clone(),
75 config,
76 };
77
78 let _handshaked = false;
79 let mut key_cache = PublicKeyCache::default();
80
81 loop {
82 let mut socket = connection_state.socket.lock().await;
83 let message = socket.next().await;
84 drop(socket);
85
86 match message {
87 Some(Ok(message)) => {
88 let payload = match message {
89 Message::Binary(payload) => payload,
90 Message::Ping(_) => {
91 let mut socket = connection_state.socket.lock().await;
92 let _ = socket.send(Message::Pong(vec![])).await;
93 drop(socket);
94 continue;
95 }
96 Message::Close(_) => return,
97 _ => continue,
98 };
99
100 let message: AuthenticatedPayload = bincode::deserialize(&payload).unwrap();
101
102 // Get authentication
103 let instance = {
104 let mut verified_instance: Option<AuthenticatedInstance> = None;
105 for source in &message.source {
106 if let AuthenticationSource::Instance {
107 instance,
108 signature,
109 } = source
110 {
111 let public_key = key_cache.get(&instance).await.unwrap();
112 let public_key = RsaPublicKey::from_pkcs1_pem(&public_key).unwrap();
113 let verifying_key = VerifyingKey::<Sha256>::new(public_key);
114
115 if verifying_key
116 .verify(
117 &message.payload,
118 &Signature::try_from(signature.as_ref()).unwrap(),
119 )
120 .is_ok()
121 {
122 verified_instance =
123 Some(AuthenticatedInstance::new(instance.clone()));
124
125 break;
126 }
127 }
128 }
129
130 verified_instance
131 };
132
133 let user = {
134 let mut verified_user = None;
135 if let Some(verified_instance) = &instance {
136 for source in &message.source {
137 if let AuthenticationSource::User { user, token } = source {
138 // Get token
139 let public_key = key_cache.get(&verified_instance).await.unwrap();
140
141 let token: TokenData<UserTokenMetadata> = jsonwebtoken::decode(
142 token.as_ref(),
143 &DecodingKey::from_rsa_pem(public_key.as_bytes()).unwrap(),
144 &Validation::new(jsonwebtoken::Algorithm::RS256),
145 )
146 .unwrap();
147
148 if token.claims.generated_for != *verified_instance.deref() {
149 // Nope!
150 break;
151 }
152
153 if token.claims.user != *user {
154 // Nope!
155 break;
156 }
157
158 verified_user = Some(AuthenticatedUser::new(user.clone()));
159 break;
160 }
161 }
162 }
163
164 verified_user
165 };
166
167 let message: GiteratedMessage<AnyObject, AnyOperation> = message.into_message();
168
169 operation_state.user = user;
170 operation_state.instance = instance;
171
172 let result = backend
173 .object_operation(
174 message.object,
175 &message.operation,
176 message.payload,
177 &operation_state,
178 )
179 .await;
180
181 // Asking for exploits here
182 operation_state.user = None;
183 operation_state.instance = None;
184
185 // Map result to Vec<u8> on both
186 let result = match result {
187 Ok(result) => Ok(serde_json::to_vec(&result).unwrap()),
188 Err(err) => Err(match err {
189 OperationError::Operation(err) => {
190 OperationError::Operation(serde_json::to_vec(&err).unwrap())
191 }
192 OperationError::Internal(err) => OperationError::Internal(err),
193 OperationError::Unhandled => OperationError::Unhandled,
194 }),
195 };
196
197 let mut socket = connection_state.socket.lock().await;
198 let _ = socket
199 .send(Message::Binary(bincode::serialize(&result).unwrap()))
200 .await;
201
202 drop(socket);
203 }
204 _ => {
205 return;
206 }
207 }
208 }
209
210 // loop {
211 // let mut socket = connection_state.socket.lock().await;
212 // let message = socket.next().await;
213 // drop(socket);
214
215 // match message {
216 // Some(Ok(message)) => {
217 // let payload = match message {
218 // Message::Binary(payload) => payload,
219 // Message::Ping(_) => {
220 // let mut socket = connection_state.socket.lock().await;
221 // let _ = socket.send(Message::Pong(vec![])).await;
222 // drop(socket);
223 // continue;
224 // }
225 // Message::Close(_) => return,
226 // _ => continue,
227 // };
228
229 // let message = NetworkMessage(payload.clone());
230
231 // if !handshaked {
232 // if handshake_handle(&message, &connection_state).await.is_ok() {
233 // if connection_state.handshaked.load(Ordering::SeqCst) {
234 // handshaked = true;
235 // }
236 // }
237 // } else {
238 // let raw = serde_json::from_slice::<AuthenticatedPayload>(&payload).unwrap();
239
240 // if let Some(target_instance) = &raw.target_instance {
241 // if connection_state.instance != *target_instance {
242 // // Forward request
243 // info!("Forwarding message to {}", target_instance.url);
244 // let mut instance_connections = instance_connections.lock().await;
245 // let pool = instance_connections.get_or_open(&target_instance).unwrap();
246 // let pool_clone = pool.clone();
247 // drop(pool);
248
249 // let result = wrap_forwarded(&pool_clone, raw).await;
250
251 // let mut socket = connection_state.socket.lock().await;
252 // let _ = socket.send(result).await;
253
254 // continue;
255 // }
256 // }
257
258 // let message_type = &raw.message_type;
259
260 // match authentication_handle(message_type, &message, &connection_state).await {
261 // Err(e) => {
262 // let _ = connection_state
263 // .send_raw(ConnectionError(e.to_string()))
264 // .await;
265 // }
266 // Ok(true) => continue,
267 // Ok(false) => {}
268 // }
269
270 // match repository_handle(message_type, &message, &connection_state).await {
271 // Err(e) => {
272 // let _ = connection_state
273 // .send_raw(ConnectionError(e.to_string()))
274 // .await;
275 // }
276 // Ok(true) => continue,
277 // Ok(false) => {}
278 // }
279
280 // match user_handle(message_type, &message, &connection_state).await {
281 // Err(e) => {
282 // let _ = connection_state
283 // .send_raw(ConnectionError(e.to_string()))
284 // .await;
285 // }
286 // Ok(true) => continue,
287 // Ok(false) => {}
288 // }
289
290 // match authentication_handle(message_type, &message, &connection_state).await {
291 // Err(e) => {
292 // let _ = connection_state
293 // .send_raw(ConnectionError(e.to_string()))
294 // .await;
295 // }
296 // Ok(true) => continue,
297 // Ok(false) => {}
298 // }
299
300 // error!(
301 // "Message completely unhandled: {}",
302 // std::str::from_utf8(&payload).unwrap()
303 // );
304 // }
305 // }
306 // Some(Err(e)) => {
307 // error!("Closing connection for {:?} for {}", e, addr);
308 // return;
309 // }
310 // _ => {
311 // continue;
312 // }
313 // }
314 // }
315 }
316
317 #[derive(Clone)]
318 pub struct ConnectionState {
319 socket: Arc<Mutex<WebSocketStream<TcpStream>>>,
320 pub connections: Arc<Mutex<Connections>>,
321 pub repository_backend: Arc<Mutex<dyn RepositoryBackend + Send>>,
322 pub user_backend: Arc<Mutex<dyn UserBackend + Send>>,
323 pub auth_granter: Arc<Mutex<AuthenticationTokenGranter>>,
324 pub settings_backend: Arc<Mutex<dyn MetadataBackend + Send>>,
325 pub addr: SocketAddr,
326 pub instance: Instance,
327 pub handshaked: Arc<AtomicBool>,
328 pub key_cache: Arc<Mutex<PublicKeyCache>>,
329 pub instance_connections: Arc<Mutex<InstanceConnections>>,
330 pub config: Table,
331 }
332
333 impl ConnectionState {
334 pub async fn send<T: Serialize>(&self, message: T) -> Result<(), Error> {
335 let payload = serde_json::to_string(&message)?;
336 self.socket
337 .lock()
338 .await
339 .send(Message::Binary(payload.into_bytes()))
340 .await?;
341
342 Ok(())
343 }
344
345 pub async fn send_raw<T: Serialize>(&self, message: T) -> Result<(), Error> {
346 let payload = serde_json::to_string(&message)?;
347 self.socket
348 .lock()
349 .await
350 .send(Message::Binary(payload.into_bytes()))
351 .await?;
352
353 Ok(())
354 }
355
356 pub async fn public_key(&self, instance: &Instance) -> Result<String, Error> {
357 let mut keys = self.key_cache.lock().await;
358 keys.get(instance).await
359 }
360 }
361