JavaScript is disabled, refresh for a better experience. ambee/giterated

ambee/giterated

Git repository hosting, collaboration, and discovery for the Fediverse.

Fixes!

Amber - ⁨2⁩ years ago

parent: tbd commit: ⁨398622e

⁨src/connection/wrapper.rs⁩ - ⁨4577⁩ bytes
Raw
1 use std::{
2 net::SocketAddr,
3 sync::{
4 atomic::{AtomicBool, Ordering},
5 Arc,
6 },
7 };
8
9 use anyhow::Error;
10 use futures_util::{SinkExt, StreamExt};
11 use serde::Serialize;
12 use serde_json::Value;
13 use tokio::{net::TcpStream, sync::Mutex};
14 use tokio_tungstenite::{tungstenite::Message, WebSocketStream};
15
16 use crate::{
17 authentication::AuthenticationTokenGranter,
18 backend::{RepositoryBackend, UserBackend},
19 model::{authenticated::NetworkMessage, instance::Instance},
20 };
21
22 use super::{
23 authentication::authentication_handle, handshake::handshake_handle,
24 repository::repository_handle, user::user_handle, Connections,
25 };
26
27 pub async fn connection_wrapper(
28 socket: WebSocketStream<TcpStream>,
29 connections: Arc<Mutex<Connections>>,
30 repository_backend: Arc<Mutex<dyn RepositoryBackend + Send>>,
31 user_backend: Arc<Mutex<dyn UserBackend + Send>>,
32 auth_granter: Arc<Mutex<AuthenticationTokenGranter>>,
33 addr: SocketAddr,
34 instance: impl ToOwned<Owned = Instance>,
35 ) {
36 let connection_state = ConnectionState {
37 socket: Arc::new(Mutex::new(socket)),
38 connections,
39 repository_backend,
40 user_backend,
41 auth_granter,
42 addr,
43 instance: instance.to_owned(),
44 handshaked: Arc::new(AtomicBool::new(false)),
45 };
46
47 let mut handshaked = false;
48
49 loop {
50 let mut socket = connection_state.socket.lock().await;
51 let message = socket.next().await;
52 drop(socket);
53
54 match message {
55 Some(Ok(message)) => {
56 let payload = match message {
57 Message::Binary(payload) => payload,
58 Message::Ping(_) => {
59 let mut socket = connection_state.socket.lock().await;
60 let _ = socket.send(Message::Pong(vec![])).await;
61 drop(socket);
62 continue;
63 }
64 Message::Close(_) => return,
65 _ => continue,
66 };
67
68 let message = NetworkMessage(payload.clone());
69
70 if !handshaked {
71 if handshake_handle(&message, &connection_state).await.is_ok() {
72 if connection_state.handshaked.load(Ordering::SeqCst) {
73 handshaked = true;
74 }
75 }
76 } else {
77 let raw = serde_json::from_slice::<Value>(&payload).unwrap();
78 let message_type = raw.get("message_type").unwrap().as_str().unwrap();
79
80 if authentication_handle(message_type, &message, &connection_state)
81 .await
82 .is_ok()
83 {
84 continue;
85 } else if repository_handle(message_type, &message, &connection_state)
86 .await
87 .is_ok()
88 {
89 continue;
90 } else if user_handle(message_type, &message, &connection_state)
91 .await
92 .is_ok()
93 {
94 continue;
95 } else {
96 error!(
97 "Message completely unhandled: {}",
98 std::str::from_utf8(&payload).unwrap()
99 );
100 continue;
101 }
102 }
103 }
104 Some(Err(e)) => {
105 error!("Closing connection for {:?} for {}", e, addr);
106 return;
107 }
108 _ => {
109 info!("Unhandled");
110 continue;
111 }
112 }
113 }
114 }
115
116 #[derive(Clone)]
117 pub struct ConnectionState {
118 socket: Arc<Mutex<WebSocketStream<TcpStream>>>,
119 pub connections: Arc<Mutex<Connections>>,
120 pub repository_backend: Arc<Mutex<dyn RepositoryBackend + Send>>,
121 pub user_backend: Arc<Mutex<dyn UserBackend + Send>>,
122 pub auth_granter: Arc<Mutex<AuthenticationTokenGranter>>,
123 pub addr: SocketAddr,
124 pub instance: Instance,
125 pub handshaked: Arc<AtomicBool>,
126 }
127
128 impl ConnectionState {
129 pub async fn send<T: Serialize>(&self, message: T) -> Result<(), Error> {
130 let payload = serde_json::to_string(&message)?;
131 info!("Sending payload: {}", &payload);
132 self.socket
133 .lock()
134 .await
135 .send(Message::Binary(payload.into_bytes()))
136 .await?;
137
138 Ok(())
139 }
140 }
141