JavaScript is disabled, refresh for a better experience. ambee/giterated

ambee/giterated

Git repository hosting, collaboration, and discovery for the Fediverse.

User Auth Early

Amber - ⁨2⁩ years ago

parent: tbd commit: ⁨8069fba

⁨src/connection.rs⁩ - ⁨19230⁩ bytes
Raw
1 use std::{collections::HashMap, net::SocketAddr, sync::Arc};
2
3 use futures_util::{stream::StreamExt, SinkExt};
4 use tokio::{
5 net::TcpStream,
6 sync::{
7 broadcast::{Receiver, Sender},
8 Mutex,
9 },
10 task::JoinHandle,
11 };
12 use tokio_tungstenite::{tungstenite::Message, WebSocketStream};
13
14 use crate::{
15 authentication::AuthenticationTokenGranter,
16 backend::{IssuesBackend, RepositoryBackend},
17 handshake::{HandshakeFinalize, HandshakeMessage, HandshakeResponse},
18 listener::Listeners,
19 messages::{
20 authentication::{AuthenticationMessage, AuthenticationRequest, TokenExtensionResponse},
21 repository::{
22 RepositoryMessage, RepositoryMessageKind, RepositoryRequest, RepositoryResponse,
23 },
24 MessageKind,
25 },
26 model::{
27 instance::{Instance, InstanceMeta},
28 repository::Repository,
29 user::User,
30 },
31 };
32
33 pub struct RawConnection {
34 pub task: JoinHandle<()>,
35 }
36
37 pub struct InstanceConnection {
38 pub instance: InstanceMeta,
39 pub sender: Sender<MessageKind>,
40 pub task: JoinHandle<()>,
41 }
42
43 /// Represents a connection which hasn't finished the handshake.
44 pub struct UnestablishedConnection {
45 pub socket: WebSocketStream<TcpStream>,
46 }
47
48 #[derive(Default)]
49 pub struct Connections {
50 pub connections: Vec<RawConnection>,
51 pub instance_connections: HashMap<Instance, InstanceConnection>,
52 }
53
54 pub async fn connection_worker(
55 mut socket: WebSocketStream<TcpStream>,
56 listeners: Arc<Mutex<Listeners>>,
57 connections: Arc<Mutex<Connections>>,
58 backend: Arc<Mutex<dyn RepositoryBackend + Send>>,
59 auth_granter: Arc<Mutex<AuthenticationTokenGranter>>,
60 addr: SocketAddr,
61 ) {
62 let mut handshaked = false;
63 let this_instance = Instance {
64 url: String::from("giterated.dev"),
65 };
66
67 while let Some(message) = socket.next().await {
68 let message = match message {
69 Ok(message) => message,
70 Err(err) => {
71 error!("Error reading message: {:?}", err);
72 continue;
73 }
74 };
75
76 let payload = match message {
77 Message::Text(text) => text.into_bytes(),
78 Message::Binary(bytes) => bytes,
79 Message::Ping(_) => continue,
80 Message::Pong(_) => continue,
81 Message::Close(_) => {
82 info!("Closing connection with {}.", addr);
83
84 return;
85 }
86 _ => unreachable!(),
87 };
88
89 let message = match serde_json::from_slice::<MessageKind>(&payload) {
90 Ok(message) => message,
91 Err(err) => {
92 error!("Error deserializing message from {}: {:?}", addr, err);
93 continue;
94 }
95 };
96
97 // info!("Read payload: {}", std::str::from_utf8(&payload).unwrap());
98
99 if let MessageKind::Handshake(handshake) = message {
100 match handshake {
101 HandshakeMessage::Initiate(_) => {
102 // Send HandshakeMessage::Response
103 let message = HandshakeResponse {
104 identity: Instance {
105 url: String::from("foo.com"),
106 },
107 version: String::from("0.1.0"),
108 };
109
110 socket
111 .send(Message::Binary(
112 serde_json::to_vec(&MessageKind::Handshake(
113 HandshakeMessage::Response(message),
114 ))
115 .unwrap(),
116 ))
117 .await
118 .unwrap();
119
120 continue;
121 }
122 HandshakeMessage::Response(_) => {
123 // Send HandshakeMessage::Finalize
124 let message = HandshakeFinalize { success: true };
125
126 socket
127 .send(Message::Binary(
128 serde_json::to_vec(&MessageKind::Handshake(
129 HandshakeMessage::Finalize(message),
130 ))
131 .unwrap(),
132 ))
133 .await
134 .unwrap();
135
136 continue;
137 }
138 HandshakeMessage::Finalize(_) => {
139 handshaked = true;
140
141 // Send HandshakeMessage::Finalize
142 let message = HandshakeFinalize { success: true };
143
144 socket
145 .send(Message::Binary(
146 serde_json::to_vec(&MessageKind::Handshake(
147 HandshakeMessage::Finalize(message),
148 ))
149 .unwrap(),
150 ))
151 .await
152 .unwrap();
153
154 continue;
155 }
156 }
157 }
158
159 if !handshaked {
160 continue;
161 }
162
163 if let MessageKind::Repository(repository) = &message {
164 if repository.target.instance != this_instance {
165 info!("Forwarding command to {}", repository.target.instance.url);
166 // We need to send this command to a different instance
167
168 let mut listener = send_and_get_listener(message, &listeners, &connections).await;
169
170 // Wait for response
171 while let Ok(message) = listener.recv().await {
172 if let MessageKind::Repository(RepositoryMessage {
173 command: RepositoryMessageKind::Response(_),
174 ..
175 }) = message
176 {
177 socket
178 .send(Message::Binary(serde_json::to_vec(&message).unwrap()))
179 .await
180 .unwrap();
181 }
182 }
183 continue;
184 } else {
185 // This message is targeting this instance
186 match &repository.command {
187 RepositoryMessageKind::Request(request) => match request.clone() {
188 RepositoryRequest::CreateRepository(request) => {
189 let mut backend = backend.lock().await;
190 let request = request.validate().await.unwrap();
191 let response = backend.create_repository(&request).await;
192
193 let response = match response {
194 Ok(response) => response,
195 Err(err) => {
196 error!("Error handling request: {:?}", err);
197 continue;
198 }
199 };
200 drop(backend);
201
202 socket
203 .send(Message::Binary(
204 serde_json::to_vec(&MessageKind::Repository(
205 RepositoryMessage {
206 target: repository.target.clone(),
207 command: RepositoryMessageKind::Response(
208 RepositoryResponse::CreateRepository(response),
209 ),
210 },
211 ))
212 .unwrap(),
213 ))
214 .await
215 .unwrap();
216
217 continue;
218 }
219 RepositoryRequest::RepositoryFileInspect(request) => {
220 let mut backend = backend.lock().await;
221 let request = request.validate().await.unwrap();
222 let response = backend.repository_file_inspect(&request);
223
224 let response = match response {
225 Ok(response) => response,
226 Err(err) => {
227 error!("Error handling request: {:?}", err);
228 continue;
229 }
230 };
231 drop(backend);
232
233 socket
234 .send(Message::Binary(
235 serde_json::to_vec(&MessageKind::Repository(
236 RepositoryMessage {
237 target: repository.target.clone(),
238 command: RepositoryMessageKind::Response(
239 RepositoryResponse::RepositoryFileInspection(
240 response,
241 ),
242 ),
243 },
244 ))
245 .unwrap(),
246 ))
247 .await
248 .unwrap();
249 continue;
250 }
251 RepositoryRequest::RepositoryInfo(request) => {
252 let mut backend = backend.lock().await;
253 let request = request.validate().await.unwrap();
254 let response = backend.repository_info(&request).await;
255
256 let response = match response {
257 Ok(response) => response,
258 Err(err) => {
259 error!("Error handling request: {:?}", err);
260 continue;
261 }
262 };
263 drop(backend);
264
265 socket
266 .send(Message::Binary(
267 serde_json::to_vec(&MessageKind::Repository(
268 RepositoryMessage {
269 target: repository.target.clone(),
270 command: RepositoryMessageKind::Response(
271 RepositoryResponse::RepositoryInfo(response),
272 ),
273 },
274 ))
275 .unwrap(),
276 ))
277 .await
278 .unwrap();
279 continue;
280 }
281 RepositoryRequest::IssuesCount(request) => {
282 let request = &request.validate().await.unwrap();
283
284 let mut backend = backend.lock().await;
285 let response = backend.issues_count(request);
286
287 let response = match response {
288 Ok(response) => response,
289 Err(err) => {
290 error!("Error handling request: {:?}", err);
291 continue;
292 }
293 };
294 drop(backend);
295
296 socket
297 .send(Message::Binary(
298 serde_json::to_vec(&MessageKind::Repository(
299 RepositoryMessage {
300 target: repository.target.clone(),
301 command: RepositoryMessageKind::Response(
302 RepositoryResponse::IssuesCount(response),
303 ),
304 },
305 ))
306 .unwrap(),
307 ))
308 .await
309 .unwrap();
310 continue;
311 }
312 RepositoryRequest::IssueLabels(request) => {
313 let request = &request.validate().await.unwrap();
314
315 let mut backend = backend.lock().await;
316 let response = backend.issue_labels(&request);
317
318 let response = match response {
319 Ok(response) => response,
320 Err(err) => {
321 error!("Error handling request: {:?}", err);
322 continue;
323 }
324 };
325 drop(backend);
326 socket
327 .send(Message::Binary(
328 serde_json::to_vec(&MessageKind::Repository(
329 RepositoryMessage {
330 target: repository.target.clone(),
331 command: RepositoryMessageKind::Response(
332 RepositoryResponse::IssueLabels(response),
333 ),
334 },
335 ))
336 .unwrap(),
337 ))
338 .await
339 .unwrap();
340 continue;
341 }
342 RepositoryRequest::Issues(request) => {
343 let request = request.validate().await.unwrap();
344
345 let mut backend = backend.lock().await;
346 let response = backend.issues(&request);
347
348 let response = match response {
349 Ok(response) => response,
350 Err(err) => {
351 error!("Error handling request: {:?}", err);
352 continue;
353 }
354 };
355 drop(backend);
356
357 socket
358 .send(Message::Binary(
359 serde_json::to_vec(&MessageKind::Repository(
360 RepositoryMessage {
361 target: repository.target.clone(),
362 command: RepositoryMessageKind::Response(
363 RepositoryResponse::Issues(response),
364 ),
365 },
366 ))
367 .unwrap(),
368 ))
369 .await
370 .unwrap();
371 continue;
372 }
373 },
374 RepositoryMessageKind::Response(_response) => {
375 unreachable!()
376 }
377 }
378 }
379 }
380
381 if let MessageKind::Authentication(authentication) = &message {
382 match authentication {
383 AuthenticationMessage::Request(request) => match request {
384 AuthenticationRequest::AuthenticationToken(token) => {
385 let mut granter = auth_granter.lock().await;
386
387 let response = granter.token_request(token.clone()).await.unwrap();
388 drop(granter);
389
390 socket
391 .send(Message::Binary(
392 serde_json::to_vec(&MessageKind::Authentication(
393 AuthenticationMessage::Response(crate::messages::authentication::AuthenticationResponse::AuthenticationToken(response))
394 ))
395 .unwrap(),
396 ))
397 .await
398 .unwrap();
399 continue;
400 }
401 AuthenticationRequest::TokenExtension(request) => {
402 let mut granter = auth_granter.lock().await;
403
404 let response = granter
405 .extension_request(request.clone())
406 .await
407 .unwrap_or(TokenExtensionResponse { new_token: None });
408 drop(granter);
409
410 socket
411 .send(Message::Binary(
412 serde_json::to_vec(&MessageKind::Authentication(
413 AuthenticationMessage::Response(crate::messages::authentication::AuthenticationResponse::TokenExtension(response))
414 ))
415 .unwrap(),
416 ))
417 .await
418 .unwrap();
419 continue;
420 }
421 },
422 AuthenticationMessage::Response(_) => unreachable!(),
423 }
424 }
425 }
426
427 info!("Connection closed");
428 }
429
430 async fn send_and_get_listener(
431 message: MessageKind,
432 listeners: &Arc<Mutex<Listeners>>,
433 connections: &Arc<Mutex<Connections>>,
434 ) -> Receiver<MessageKind> {
435 let (instance, user, repository): (Option<Instance>, Option<User>, Option<Repository>) =
436 match &message {
437 MessageKind::Handshake(_) => {
438 todo!()
439 }
440 MessageKind::Repository(repository) => (None, None, Some(repository.target.clone())),
441 MessageKind::Authentication(_) => todo!(),
442 };
443
444 let target = match (&instance, &user, &repository) {
445 (Some(instance), _, _) => instance.clone(),
446 (_, Some(user), _) => user.instance.clone(),
447 (_, _, Some(repository)) => repository.instance.clone(),
448 _ => unreachable!(),
449 };
450
451 let mut listeners = listeners.lock().await;
452 let listener = listeners.add(instance, user, repository);
453 drop(listeners);
454
455 let connections = connections.lock().await;
456
457 if let Some(connection) = connections.instance_connections.get(&target) {
458 connection.sender.send(message);
459 } else {
460 error!("Unable to message {}, this is a bug.", target.url);
461
462 panic!();
463 }
464
465 drop(connections);
466
467 listener
468 }
469