JavaScript is disabled, refresh for a better experience. ambee/giterated

ambee/giterated

Git repository hosting, collaboration, and discovery for the Fediverse.

Add token extension

Amber - ⁨2⁩ years ago

parent: tbd commit: ⁨86d028f

⁨src/connection.rs⁩ - ⁨18756⁩ bytes
Raw
1 use std::{collections::HashMap, net::SocketAddr, sync::Arc};
2
3 use futures_util::{stream::StreamExt, SinkExt};
4 use tokio::{
5 net::TcpStream,
6 sync::{
7 broadcast::{Receiver, Sender},
8 Mutex,
9 },
10 task::JoinHandle,
11 };
12 use tokio_tungstenite::{tungstenite::Message, WebSocketStream};
13
14 use crate::{
15 authentication::AuthenticationTokenGranter,
16 backend::{IssuesBackend, RepositoryBackend},
17 handshake::{HandshakeFinalize, HandshakeMessage, HandshakeResponse},
18 listener::Listeners,
19 messages::{
20 authentication::{AuthenticationMessage, AuthenticationRequest, TokenExtensionResponse},
21 repository::{
22 RepositoryMessage, RepositoryMessageKind, RepositoryRequest, RepositoryResponse,
23 },
24 MessageKind,
25 },
26 model::{
27 instance::{Instance, InstanceMeta},
28 repository::Repository,
29 user::User,
30 },
31 };
32
33 pub struct RawConnection {
34 pub task: JoinHandle<()>,
35 }
36
37 pub struct InstanceConnection {
38 pub instance: InstanceMeta,
39 pub sender: Sender<MessageKind>,
40 pub task: JoinHandle<()>,
41 }
42
43 /// Represents a connection which hasn't finished the handshake.
44 pub struct UnestablishedConnection {
45 pub socket: WebSocketStream<TcpStream>,
46 }
47
48 #[derive(Default)]
49 pub struct Connections {
50 pub connections: Vec<RawConnection>,
51 pub instance_connections: HashMap<Instance, InstanceConnection>,
52 }
53
54 pub async fn connection_worker(
55 mut socket: WebSocketStream<TcpStream>,
56 listeners: Arc<Mutex<Listeners>>,
57 connections: Arc<Mutex<Connections>>,
58 backend: Arc<Mutex<dyn RepositoryBackend + Send>>,
59 auth_granter: Arc<Mutex<AuthenticationTokenGranter>>,
60 addr: SocketAddr,
61 ) {
62 let mut handshaked = false;
63 let this_instance = Instance {
64 url: String::from("giterated.dev"),
65 };
66
67 while let Some(message) = socket.next().await {
68 let message = match message {
69 Ok(message) => message,
70 Err(err) => {
71 error!("Error reading message: {:?}", err);
72 continue;
73 }
74 };
75
76 let payload = match message {
77 Message::Text(text) => text.into_bytes(),
78 Message::Binary(bytes) => bytes,
79 Message::Ping(_) => continue,
80 Message::Pong(_) => continue,
81 Message::Close(_) => {
82 info!("Closing connection with {}.", addr);
83
84 return;
85 }
86 _ => unreachable!(),
87 };
88
89 let message = match serde_json::from_slice::<MessageKind>(&payload) {
90 Ok(message) => message,
91 Err(err) => {
92 error!("Error deserializing message from {}: {:?}", addr, err);
93 continue;
94 }
95 };
96
97 info!("Read payload: {}", std::str::from_utf8(&payload).unwrap());
98
99 if let MessageKind::Handshake(handshake) = message {
100 match handshake {
101 HandshakeMessage::Initiate(_) => {
102 // Send HandshakeMessage::Response
103 let message = HandshakeResponse {
104 identity: Instance {
105 url: String::from("foo.com"),
106 },
107 version: String::from("0.1.0"),
108 };
109
110 socket
111 .send(Message::Binary(
112 serde_json::to_vec(&MessageKind::Handshake(
113 HandshakeMessage::Response(message),
114 ))
115 .unwrap(),
116 ))
117 .await
118 .unwrap();
119
120 continue;
121 }
122 HandshakeMessage::Response(_) => {
123 // Send HandshakeMessage::Finalize
124 let message = HandshakeFinalize { success: true };
125
126 socket
127 .send(Message::Binary(
128 serde_json::to_vec(&MessageKind::Handshake(
129 HandshakeMessage::Finalize(message),
130 ))
131 .unwrap(),
132 ))
133 .await
134 .unwrap();
135
136 continue;
137 }
138 HandshakeMessage::Finalize(_) => {
139 handshaked = true;
140
141 // Send HandshakeMessage::Finalize
142 let message = HandshakeFinalize { success: true };
143
144 socket
145 .send(Message::Binary(
146 serde_json::to_vec(&MessageKind::Handshake(
147 HandshakeMessage::Finalize(message),
148 ))
149 .unwrap(),
150 ))
151 .await
152 .unwrap();
153
154 continue;
155 }
156 }
157 }
158
159 if !handshaked {
160 continue;
161 }
162
163 if let MessageKind::Repository(repository) = &message {
164 if repository.target.instance != this_instance {
165 info!("Forwarding command to {}", repository.target.instance.url);
166 // We need to send this command to a different instance
167
168 let mut listener = send_and_get_listener(message, &listeners, &connections).await;
169
170 // Wait for response
171 while let Ok(message) = listener.recv().await {
172 if let MessageKind::Repository(RepositoryMessage {
173 command: RepositoryMessageKind::Response(_),
174 ..
175 }) = message
176 {
177 socket
178 .send(Message::Binary(serde_json::to_vec(&message).unwrap()))
179 .await
180 .unwrap();
181 }
182 }
183 continue;
184 } else {
185 // This message is targeting this instance
186 match &repository.command {
187 RepositoryMessageKind::Request(request) => match request {
188 RepositoryRequest::CreateRepository(request) => {
189 let mut backend = backend.lock().await;
190 let response = backend.create_repository(request).await;
191
192 let response = match response {
193 Ok(response) => response,
194 Err(err) => {
195 error!("Error handling request: {:?}", err);
196 continue;
197 }
198 };
199 drop(backend);
200
201 socket
202 .send(Message::Binary(
203 serde_json::to_vec(&MessageKind::Repository(
204 RepositoryMessage {
205 target: repository.target.clone(),
206 command: RepositoryMessageKind::Response(
207 RepositoryResponse::CreateRepository(response),
208 ),
209 },
210 ))
211 .unwrap(),
212 ))
213 .await
214 .unwrap();
215
216 continue;
217 }
218 RepositoryRequest::RepositoryFileInspect(request) => {
219 let mut backend = backend.lock().await;
220 let response = backend.repository_file_inspect(request);
221
222 let response = match response {
223 Ok(response) => response,
224 Err(err) => {
225 error!("Error handling request: {:?}", err);
226 continue;
227 }
228 };
229 drop(backend);
230
231 socket
232 .send(Message::Binary(
233 serde_json::to_vec(&MessageKind::Repository(
234 RepositoryMessage {
235 target: repository.target.clone(),
236 command: RepositoryMessageKind::Response(
237 RepositoryResponse::RepositoryFileInspection(
238 response,
239 ),
240 ),
241 },
242 ))
243 .unwrap(),
244 ))
245 .await
246 .unwrap();
247 continue;
248 }
249 RepositoryRequest::RepositoryInfo(request) => {
250 let mut backend = backend.lock().await;
251 let response = backend.repository_info(request).await;
252
253 let response = match response {
254 Ok(response) => response,
255 Err(err) => {
256 error!("Error handling request: {:?}", err);
257 continue;
258 }
259 };
260 drop(backend);
261
262 socket
263 .send(Message::Binary(
264 serde_json::to_vec(&MessageKind::Repository(
265 RepositoryMessage {
266 target: repository.target.clone(),
267 command: RepositoryMessageKind::Response(
268 RepositoryResponse::RepositoryInfo(response),
269 ),
270 },
271 ))
272 .unwrap(),
273 ))
274 .await
275 .unwrap();
276 continue;
277 }
278 RepositoryRequest::IssuesCount(request) => {
279 let mut backend = backend.lock().await;
280 let response = backend.issues_count(request);
281
282 let response = match response {
283 Ok(response) => response,
284 Err(err) => {
285 error!("Error handling request: {:?}", err);
286 continue;
287 }
288 };
289 drop(backend);
290
291 socket
292 .send(Message::Binary(
293 serde_json::to_vec(&MessageKind::Repository(
294 RepositoryMessage {
295 target: repository.target.clone(),
296 command: RepositoryMessageKind::Response(
297 RepositoryResponse::IssuesCount(response),
298 ),
299 },
300 ))
301 .unwrap(),
302 ))
303 .await
304 .unwrap();
305 continue;
306 }
307 RepositoryRequest::IssueLabels(request) => {
308 let mut backend = backend.lock().await;
309 let response = backend.issue_labels(request);
310
311 let response = match response {
312 Ok(response) => response,
313 Err(err) => {
314 error!("Error handling request: {:?}", err);
315 continue;
316 }
317 };
318 drop(backend);
319 socket
320 .send(Message::Binary(
321 serde_json::to_vec(&MessageKind::Repository(
322 RepositoryMessage {
323 target: repository.target.clone(),
324 command: RepositoryMessageKind::Response(
325 RepositoryResponse::IssueLabels(response),
326 ),
327 },
328 ))
329 .unwrap(),
330 ))
331 .await
332 .unwrap();
333 continue;
334 }
335 RepositoryRequest::Issues(request) => {
336 let mut backend = backend.lock().await;
337 let response = backend.issues(request);
338
339 let response = match response {
340 Ok(response) => response,
341 Err(err) => {
342 error!("Error handling request: {:?}", err);
343 continue;
344 }
345 };
346 drop(backend);
347
348 socket
349 .send(Message::Binary(
350 serde_json::to_vec(&MessageKind::Repository(
351 RepositoryMessage {
352 target: repository.target.clone(),
353 command: RepositoryMessageKind::Response(
354 RepositoryResponse::Issues(response),
355 ),
356 },
357 ))
358 .unwrap(),
359 ))
360 .await
361 .unwrap();
362 continue;
363 }
364 },
365 RepositoryMessageKind::Response(_response) => {
366 unreachable!()
367 }
368 }
369 }
370 }
371
372 if let MessageKind::Authentication(authentication) = &message {
373 match authentication {
374 AuthenticationMessage::Request(request) => match request {
375 AuthenticationRequest::AuthenticationToken(token) => {
376 let mut granter = auth_granter.lock().await;
377
378 let response = granter.token_request(token.clone()).await.unwrap();
379 drop(granter);
380
381 socket
382 .send(Message::Binary(
383 serde_json::to_vec(&MessageKind::Authentication(
384 AuthenticationMessage::Response(crate::messages::authentication::AuthenticationResponse::AuthenticationToken(response))
385 ))
386 .unwrap(),
387 ))
388 .await
389 .unwrap();
390 continue;
391 }
392 AuthenticationRequest::TokenExtension(request) => {
393 let mut granter = auth_granter.lock().await;
394
395 let response = granter
396 .extension_request(request.clone())
397 .await
398 .unwrap_or_else(|_| TokenExtensionResponse { new_token: None });
399 drop(granter);
400
401 socket
402 .send(Message::Binary(
403 serde_json::to_vec(&MessageKind::Authentication(
404 AuthenticationMessage::Response(crate::messages::authentication::AuthenticationResponse::TokenExtension(response))
405 ))
406 .unwrap(),
407 ))
408 .await
409 .unwrap();
410 continue;
411 }
412 },
413 AuthenticationMessage::Response(_) => unreachable!(),
414 }
415 }
416 }
417
418 info!("Connection closed");
419 }
420
421 async fn send_and_get_listener(
422 message: MessageKind,
423 listeners: &Arc<Mutex<Listeners>>,
424 connections: &Arc<Mutex<Connections>>,
425 ) -> Receiver<MessageKind> {
426 let (instance, user, repository): (Option<Instance>, Option<User>, Option<Repository>) =
427 match &message {
428 MessageKind::Handshake(_) => {
429 todo!()
430 }
431 MessageKind::Repository(repository) => (None, None, Some(repository.target.clone())),
432 MessageKind::Authentication(_) => todo!(),
433 };
434
435 let target = match (&instance, &user, &repository) {
436 (Some(instance), _, _) => instance.clone(),
437 (_, Some(user), _) => user.instance.clone(),
438 (_, _, Some(repository)) => repository.instance.clone(),
439 _ => unreachable!(),
440 };
441
442 let mut listeners = listeners.lock().await;
443 let listener = listeners.add(instance, user, repository);
444 drop(listeners);
445
446 let connections = connections.lock().await;
447
448 if let Some(connection) = connections.instance_connections.get(&target) {
449 connection.sender.send(message);
450 } else {
451 error!("Unable to message {}, this is a bug.", target.url);
452
453 panic!();
454 }
455
456 drop(connections);
457
458 listener
459 }
460