JavaScript is disabled, refresh for a better experience. ambee/giterated

ambee/giterated

Git repository hosting, collaboration, and discovery for the Fediverse.

Finish unified stack refactor.

Adds support for operation state, which will be used to pass authentication information around. Added generic backend that uses a channel to communicate with a typed backend.

Amber - ⁨2⁩ years ago

parent: tbd commit: ⁨d15581c

⁨giterated-daemon/src/connection/wrapper.rs⁩ - ⁨10036⁩ bytes
Raw
1 use std::{
2 net::SocketAddr,
3 sync::{atomic::AtomicBool, Arc},
4 };
5
6 use anyhow::Error;
7 use futures_util::{SinkExt, StreamExt};
8
9 use giterated_models::{error::OperationError, instance::Instance};
10
11 use giterated_models::object_backend::ObjectBackend;
12
13 use giterated_models::{
14 authenticated::AuthenticatedPayload, message::GiteratedMessage, object::AnyObject,
15 operation::AnyOperation,
16 };
17 use giterated_stack::{handler::GiteratedBackend, StackOperationState};
18 use serde::Serialize;
19
20 use tokio::{net::TcpStream, sync::Mutex};
21 use tokio_tungstenite::{tungstenite::Message, WebSocketStream};
22 use toml::Table;
23
24 use crate::{
25 authentication::AuthenticationTokenGranter,
26 backend::{MetadataBackend, RepositoryBackend, UserBackend},
27 database_backend::DatabaseBackend,
28 federation::connections::InstanceConnections,
29 keys::PublicKeyCache,
30 };
31
32 use super::Connections;
33
34 pub async fn connection_wrapper(
35 socket: WebSocketStream<TcpStream>,
36 connections: Arc<Mutex<Connections>>,
37 repository_backend: Arc<Mutex<dyn RepositoryBackend + Send>>,
38 user_backend: Arc<Mutex<dyn UserBackend + Send>>,
39 auth_granter: Arc<Mutex<AuthenticationTokenGranter>>,
40 settings_backend: Arc<Mutex<dyn MetadataBackend + Send>>,
41 addr: SocketAddr,
42 instance: impl ToOwned<Owned = Instance>,
43 instance_connections: Arc<Mutex<InstanceConnections>>,
44 config: Table,
45 backend: GiteratedBackend<DatabaseBackend>,
46 operation_state: StackOperationState,
47 ) {
48 let connection_state = ConnectionState {
49 socket: Arc::new(Mutex::new(socket)),
50 connections,
51 repository_backend,
52 user_backend,
53 auth_granter,
54 settings_backend,
55 addr,
56 instance: instance.to_owned(),
57 handshaked: Arc::new(AtomicBool::new(false)),
58 key_cache: Arc::default(),
59 instance_connections: instance_connections.clone(),
60 config,
61 };
62
63 let _handshaked = false;
64
65 loop {
66 let mut socket = connection_state.socket.lock().await;
67 let message = socket.next().await;
68 drop(socket);
69
70 match message {
71 Some(Ok(message)) => {
72 let payload = match message {
73 Message::Binary(payload) => payload,
74 Message::Ping(_) => {
75 let mut socket = connection_state.socket.lock().await;
76 let _ = socket.send(Message::Pong(vec![])).await;
77 drop(socket);
78 continue;
79 }
80 Message::Close(_) => return,
81 _ => continue,
82 };
83
84 let message: AuthenticatedPayload = bincode::deserialize(&payload).unwrap();
85
86 let message: GiteratedMessage<AnyObject, AnyOperation> = message.into_message();
87
88 let result = backend
89 .object_operation(
90 message.object,
91 &message.operation,
92 message.payload,
93 &operation_state,
94 )
95 .await;
96
97 // Map result to Vec<u8> on both
98 let result = match result {
99 Ok(result) => Ok(serde_json::to_vec(&result).unwrap()),
100 Err(err) => Err(match err {
101 OperationError::Operation(err) => {
102 OperationError::Operation(serde_json::to_vec(&err).unwrap())
103 }
104 OperationError::Internal(err) => OperationError::Internal(err),
105 OperationError::Unhandled => OperationError::Unhandled,
106 }),
107 };
108
109 let mut socket = connection_state.socket.lock().await;
110 let _ = socket
111 .send(Message::Binary(bincode::serialize(&result).unwrap()))
112 .await;
113
114 drop(socket);
115 }
116 _ => {
117 return;
118 }
119 }
120 }
121
122 // loop {
123 // let mut socket = connection_state.socket.lock().await;
124 // let message = socket.next().await;
125 // drop(socket);
126
127 // match message {
128 // Some(Ok(message)) => {
129 // let payload = match message {
130 // Message::Binary(payload) => payload,
131 // Message::Ping(_) => {
132 // let mut socket = connection_state.socket.lock().await;
133 // let _ = socket.send(Message::Pong(vec![])).await;
134 // drop(socket);
135 // continue;
136 // }
137 // Message::Close(_) => return,
138 // _ => continue,
139 // };
140
141 // let message = NetworkMessage(payload.clone());
142
143 // if !handshaked {
144 // if handshake_handle(&message, &connection_state).await.is_ok() {
145 // if connection_state.handshaked.load(Ordering::SeqCst) {
146 // handshaked = true;
147 // }
148 // }
149 // } else {
150 // let raw = serde_json::from_slice::<AuthenticatedPayload>(&payload).unwrap();
151
152 // if let Some(target_instance) = &raw.target_instance {
153 // if connection_state.instance != *target_instance {
154 // // Forward request
155 // info!("Forwarding message to {}", target_instance.url);
156 // let mut instance_connections = instance_connections.lock().await;
157 // let pool = instance_connections.get_or_open(&target_instance).unwrap();
158 // let pool_clone = pool.clone();
159 // drop(pool);
160
161 // let result = wrap_forwarded(&pool_clone, raw).await;
162
163 // let mut socket = connection_state.socket.lock().await;
164 // let _ = socket.send(result).await;
165
166 // continue;
167 // }
168 // }
169
170 // let message_type = &raw.message_type;
171
172 // match authentication_handle(message_type, &message, &connection_state).await {
173 // Err(e) => {
174 // let _ = connection_state
175 // .send_raw(ConnectionError(e.to_string()))
176 // .await;
177 // }
178 // Ok(true) => continue,
179 // Ok(false) => {}
180 // }
181
182 // match repository_handle(message_type, &message, &connection_state).await {
183 // Err(e) => {
184 // let _ = connection_state
185 // .send_raw(ConnectionError(e.to_string()))
186 // .await;
187 // }
188 // Ok(true) => continue,
189 // Ok(false) => {}
190 // }
191
192 // match user_handle(message_type, &message, &connection_state).await {
193 // Err(e) => {
194 // let _ = connection_state
195 // .send_raw(ConnectionError(e.to_string()))
196 // .await;
197 // }
198 // Ok(true) => continue,
199 // Ok(false) => {}
200 // }
201
202 // match authentication_handle(message_type, &message, &connection_state).await {
203 // Err(e) => {
204 // let _ = connection_state
205 // .send_raw(ConnectionError(e.to_string()))
206 // .await;
207 // }
208 // Ok(true) => continue,
209 // Ok(false) => {}
210 // }
211
212 // error!(
213 // "Message completely unhandled: {}",
214 // std::str::from_utf8(&payload).unwrap()
215 // );
216 // }
217 // }
218 // Some(Err(e)) => {
219 // error!("Closing connection for {:?} for {}", e, addr);
220 // return;
221 // }
222 // _ => {
223 // continue;
224 // }
225 // }
226 // }
227 }
228
229 #[derive(Clone)]
230 pub struct ConnectionState {
231 socket: Arc<Mutex<WebSocketStream<TcpStream>>>,
232 pub connections: Arc<Mutex<Connections>>,
233 pub repository_backend: Arc<Mutex<dyn RepositoryBackend + Send>>,
234 pub user_backend: Arc<Mutex<dyn UserBackend + Send>>,
235 pub auth_granter: Arc<Mutex<AuthenticationTokenGranter>>,
236 pub settings_backend: Arc<Mutex<dyn MetadataBackend + Send>>,
237 pub addr: SocketAddr,
238 pub instance: Instance,
239 pub handshaked: Arc<AtomicBool>,
240 pub key_cache: Arc<Mutex<PublicKeyCache>>,
241 pub instance_connections: Arc<Mutex<InstanceConnections>>,
242 pub config: Table,
243 }
244
245 impl ConnectionState {
246 pub async fn send<T: Serialize>(&self, message: T) -> Result<(), Error> {
247 let payload = serde_json::to_string(&message)?;
248 self.socket
249 .lock()
250 .await
251 .send(Message::Binary(payload.into_bytes()))
252 .await?;
253
254 Ok(())
255 }
256
257 pub async fn send_raw<T: Serialize>(&self, message: T) -> Result<(), Error> {
258 let payload = serde_json::to_string(&message)?;
259 self.socket
260 .lock()
261 .await
262 .send(Message::Binary(payload.into_bytes()))
263 .await?;
264
265 Ok(())
266 }
267
268 pub async fn public_key(&self, instance: &Instance) -> Result<String, Error> {
269 let mut keys = self.key_cache.lock().await;
270 keys.get(instance).await
271 }
272 }
273