JavaScript is disabled, refresh for a better experience. ambee/giterated

ambee/giterated

Git repository hosting, collaboration, and discovery for the Fediverse.

Major post-refactor cleanup

Amber - ⁨2⁩ years ago

parent: tbd commit: ⁨f90d7fb

⁨src/connection/wrapper.rs⁩ - ⁨3851⁩ bytes
Raw
1 use std::{
2 net::SocketAddr,
3 sync::{
4 atomic::{AtomicBool, Ordering},
5 Arc,
6 },
7 };
8
9 use anyhow::Error;
10 use futures_util::{SinkExt, StreamExt};
11 use serde::Serialize;
12 use tokio::{net::TcpStream, sync::Mutex};
13 use tokio_tungstenite::{tungstenite::Message, WebSocketStream};
14
15 use crate::{
16 authentication::AuthenticationTokenGranter,
17 backend::{RepositoryBackend, UserBackend},
18 model::{authenticated::NetworkMessage, instance::Instance},
19 };
20
21 use super::{
22 authentication::authentication_handle, handshake::handshake_handle,
23 repository::repository_handle, user::user_handle, Connections,
24 };
25
26 pub async fn connection_wrapper(
27 socket: WebSocketStream<TcpStream>,
28 connections: Arc<Mutex<Connections>>,
29 repository_backend: Arc<Mutex<dyn RepositoryBackend + Send>>,
30 user_backend: Arc<Mutex<dyn UserBackend + Send>>,
31 auth_granter: Arc<Mutex<AuthenticationTokenGranter>>,
32 addr: SocketAddr,
33 instance: impl ToOwned<Owned = Instance>,
34 ) {
35 let connection_state = ConnectionState {
36 socket: Arc::new(Mutex::new(socket)),
37 connections,
38 repository_backend,
39 user_backend,
40 auth_granter,
41 addr,
42 instance: instance.to_owned(),
43 handshaked: Arc::new(AtomicBool::new(false)),
44 };
45
46 let mut handshaked = false;
47
48 loop {
49 let mut socket = connection_state.socket.lock().await;
50 let message = socket.next().await;
51 drop(socket);
52
53 match message {
54 Some(Ok(message)) => {
55 let payload = match message {
56 Message::Binary(payload) => payload,
57 Message::Ping(_) => {
58 let mut socket = connection_state.socket.lock().await;
59 let _ = socket.send(Message::Pong(vec![])).await;
60 drop(socket);
61 continue;
62 }
63 Message::Close(_) => return,
64 _ => continue,
65 };
66
67 let message = NetworkMessage(payload);
68
69 if !handshaked {
70 if handshake_handle(&message, &connection_state).await.is_ok() {
71 if connection_state.handshaked.load(Ordering::SeqCst) {
72 handshaked = true;
73 }
74 }
75 } else {
76 if authentication_handle(&message, &connection_state)
77 .await
78 .is_ok()
79 {
80 continue;
81 } else if repository_handle(&message, &connection_state).await.is_ok() {
82 continue;
83 } else if user_handle(&message, &connection_state).await.is_ok() {
84 continue;
85 } else {
86 error!("Message completely unhandled");
87 continue;
88 }
89 }
90 }
91 _ => {
92 error!("Closing connection for {}", addr);
93 return;
94 }
95 }
96 }
97 }
98
99 #[derive(Clone)]
100 pub struct ConnectionState {
101 socket: Arc<Mutex<WebSocketStream<TcpStream>>>,
102 pub connections: Arc<Mutex<Connections>>,
103 pub repository_backend: Arc<Mutex<dyn RepositoryBackend + Send>>,
104 pub user_backend: Arc<Mutex<dyn UserBackend + Send>>,
105 pub auth_granter: Arc<Mutex<AuthenticationTokenGranter>>,
106 pub addr: SocketAddr,
107 pub instance: Instance,
108 pub handshaked: Arc<AtomicBool>,
109 }
110
111 impl ConnectionState {
112 pub async fn send<T: Serialize>(&self, message: T) -> Result<(), Error> {
113 self.socket
114 .lock()
115 .await
116 .send(Message::Binary(serde_json::to_vec(&message)?))
117 .await?;
118
119 Ok(())
120 }
121 }
122