JavaScript is disabled, refresh for a better experience. ambee/giterated

ambee/giterated

Git repository hosting, collaboration, and discovery for the Fediverse.

Fix handling stack

Amber - ⁨2⁩ years ago

parent: tbd commit: ⁨c53b026

⁨giterated-daemon/src/connection/wrapper.rs⁩ - ⁨10070⁩ bytes
Raw
1 use std::{
2 net::SocketAddr,
3 sync::{atomic::AtomicBool, Arc},
4 };
5
6 use anyhow::Error;
7 use futures_util::{SinkExt, StreamExt};
8
9 use giterated_models::{error::OperationError, instance::Instance};
10
11 use giterated_models::object_backend::ObjectBackend;
12
13 use giterated_models::{
14 authenticated::AuthenticatedPayload, message::GiteratedMessage, object::AnyObject,
15 operation::AnyOperation,
16 };
17 use serde::Serialize;
18
19 use tokio::{net::TcpStream, sync::Mutex};
20 use tokio_tungstenite::{tungstenite::Message, WebSocketStream};
21 use toml::Table;
22
23 use crate::{
24 authentication::AuthenticationTokenGranter,
25 backend::{MetadataBackend, RepositoryBackend, UserBackend},
26 database_backend::DatabaseBackend,
27 federation::connections::InstanceConnections,
28 keys::PublicKeyCache,
29 };
30
31 use super::Connections;
32
33 pub async fn connection_wrapper(
34 socket: WebSocketStream<TcpStream>,
35 connections: Arc<Mutex<Connections>>,
36 repository_backend: Arc<Mutex<dyn RepositoryBackend + Send>>,
37 user_backend: Arc<Mutex<dyn UserBackend + Send>>,
38 auth_granter: Arc<Mutex<AuthenticationTokenGranter>>,
39 settings_backend: Arc<Mutex<dyn MetadataBackend + Send>>,
40 addr: SocketAddr,
41 instance: impl ToOwned<Owned = Instance>,
42 instance_connections: Arc<Mutex<InstanceConnections>>,
43 config: Table,
44 backend: DatabaseBackend,
45 ) {
46 let connection_state = ConnectionState {
47 socket: Arc::new(Mutex::new(socket)),
48 connections,
49 repository_backend,
50 user_backend,
51 auth_granter,
52 settings_backend,
53 addr,
54 instance: instance.to_owned(),
55 handshaked: Arc::new(AtomicBool::new(false)),
56 key_cache: Arc::default(),
57 instance_connections: instance_connections.clone(),
58 config,
59 };
60
61 let _handshaked = false;
62
63 loop {
64 let mut socket = connection_state.socket.lock().await;
65 let message = socket.next().await;
66 drop(socket);
67
68 match message {
69 Some(Ok(message)) => {
70 let payload = match message {
71 Message::Binary(payload) => payload,
72 Message::Ping(_) => {
73 let mut socket = connection_state.socket.lock().await;
74 let _ = socket.send(Message::Pong(vec![])).await;
75 drop(socket);
76 continue;
77 }
78 Message::Close(_) => return,
79 _ => continue,
80 };
81
82 let message: AuthenticatedPayload = bincode::deserialize(&payload).unwrap();
83
84 let message: GiteratedMessage<AnyObject, AnyOperation> = message.into_message();
85
86 let result = backend
87 .object_operation(message.object, &message.operation, message.payload)
88 .await;
89
90 // Map result to Vec<u8> on both
91 let result = match result {
92 Ok(result) => Ok(serde_json::to_vec(&result).unwrap()),
93 Err(err) => Err(match err {
94 OperationError::Operation(err) => {
95 OperationError::Operation(serde_json::to_vec(&err).unwrap())
96 }
97 OperationError::Internal(err) => OperationError::Internal(err),
98 OperationError::Unhandled => OperationError::Unhandled,
99 }),
100 };
101
102 let mut socket = connection_state.socket.lock().await;
103 let _ = socket
104 .send(Message::Binary(bincode::serialize(&result).unwrap()))
105 .await;
106
107 drop(socket);
108 }
109 _ => {
110 return;
111 }
112 }
113 }
114
115 // loop {
116 // let mut socket = connection_state.socket.lock().await;
117 // let message = socket.next().await;
118 // drop(socket);
119
120 // match message {
121 // Some(Ok(message)) => {
122 // let payload = match message {
123 // Message::Binary(payload) => payload,
124 // Message::Ping(_) => {
125 // let mut socket = connection_state.socket.lock().await;
126 // let _ = socket.send(Message::Pong(vec![])).await;
127 // drop(socket);
128 // continue;
129 // }
130 // Message::Close(_) => return,
131 // _ => continue,
132 // };
133 // info!("one payload");
134
135 // let message = NetworkMessage(payload.clone());
136
137 // if !handshaked {
138 // info!("im foo baring");
139 // if handshake_handle(&message, &connection_state).await.is_ok() {
140 // if connection_state.handshaked.load(Ordering::SeqCst) {
141 // handshaked = true;
142 // }
143 // }
144 // } else {
145 // let raw = serde_json::from_slice::<AuthenticatedPayload>(&payload).unwrap();
146
147 // if let Some(target_instance) = &raw.target_instance {
148 // if connection_state.instance != *target_instance {
149 // // Forward request
150 // info!("Forwarding message to {}", target_instance.url);
151 // let mut instance_connections = instance_connections.lock().await;
152 // let pool = instance_connections.get_or_open(&target_instance).unwrap();
153 // let pool_clone = pool.clone();
154 // drop(pool);
155
156 // let result = wrap_forwarded(&pool_clone, raw).await;
157
158 // let mut socket = connection_state.socket.lock().await;
159 // let _ = socket.send(result).await;
160
161 // continue;
162 // }
163 // }
164
165 // let message_type = &raw.message_type;
166
167 // info!("Handling message with type: {}", message_type);
168
169 // match authentication_handle(message_type, &message, &connection_state).await {
170 // Err(e) => {
171 // let _ = connection_state
172 // .send_raw(ConnectionError(e.to_string()))
173 // .await;
174 // }
175 // Ok(true) => continue,
176 // Ok(false) => {}
177 // }
178
179 // match repository_handle(message_type, &message, &connection_state).await {
180 // Err(e) => {
181 // let _ = connection_state
182 // .send_raw(ConnectionError(e.to_string()))
183 // .await;
184 // }
185 // Ok(true) => continue,
186 // Ok(false) => {}
187 // }
188
189 // match user_handle(message_type, &message, &connection_state).await {
190 // Err(e) => {
191 // let _ = connection_state
192 // .send_raw(ConnectionError(e.to_string()))
193 // .await;
194 // }
195 // Ok(true) => continue,
196 // Ok(false) => {}
197 // }
198
199 // match authentication_handle(message_type, &message, &connection_state).await {
200 // Err(e) => {
201 // let _ = connection_state
202 // .send_raw(ConnectionError(e.to_string()))
203 // .await;
204 // }
205 // Ok(true) => continue,
206 // Ok(false) => {}
207 // }
208
209 // error!(
210 // "Message completely unhandled: {}",
211 // std::str::from_utf8(&payload).unwrap()
212 // );
213 // }
214 // }
215 // Some(Err(e)) => {
216 // error!("Closing connection for {:?} for {}", e, addr);
217 // return;
218 // }
219 // _ => {
220 // info!("Unhandled");
221 // continue;
222 // }
223 // }
224 // }
225 }
226
227 #[derive(Clone)]
228 pub struct ConnectionState {
229 socket: Arc<Mutex<WebSocketStream<TcpStream>>>,
230 pub connections: Arc<Mutex<Connections>>,
231 pub repository_backend: Arc<Mutex<dyn RepositoryBackend + Send>>,
232 pub user_backend: Arc<Mutex<dyn UserBackend + Send>>,
233 pub auth_granter: Arc<Mutex<AuthenticationTokenGranter>>,
234 pub settings_backend: Arc<Mutex<dyn MetadataBackend + Send>>,
235 pub addr: SocketAddr,
236 pub instance: Instance,
237 pub handshaked: Arc<AtomicBool>,
238 pub key_cache: Arc<Mutex<PublicKeyCache>>,
239 pub instance_connections: Arc<Mutex<InstanceConnections>>,
240 pub config: Table,
241 }
242
243 impl ConnectionState {
244 pub async fn send<T: Serialize>(&self, message: T) -> Result<(), Error> {
245 let payload = serde_json::to_string(&message)?;
246 info!("Sending payload: {}", &payload);
247 self.socket
248 .lock()
249 .await
250 .send(Message::Binary(payload.into_bytes()))
251 .await?;
252
253 Ok(())
254 }
255
256 pub async fn send_raw<T: Serialize>(&self, message: T) -> Result<(), Error> {
257 let payload = serde_json::to_string(&message)?;
258 info!("Sending payload: {}", &payload);
259 self.socket
260 .lock()
261 .await
262 .send(Message::Binary(payload.into_bytes()))
263 .await?;
264
265 Ok(())
266 }
267
268 pub async fn public_key(&self, instance: &Instance) -> Result<String, Error> {
269 let mut keys = self.key_cache.lock().await;
270 keys.get(instance).await
271 }
272 }
273