JavaScript is disabled, refresh for a better experience. ambee/giterated

ambee/giterated

Git repository hosting, collaboration, and discovery for the Fediverse.

Expose errors

Amber - ⁨2⁩ years ago

parent: tbd commit: ⁨6b2125c

⁨src/connection/wrapper.rs⁩ - ⁨5344⁩ bytes
Raw
1 use std::{
2 net::SocketAddr,
3 sync::{
4 atomic::{AtomicBool, Ordering},
5 Arc,
6 },
7 };
8
9 use anyhow::Error;
10 use futures_util::{SinkExt, StreamExt};
11 use serde::Serialize;
12 use serde_json::Value;
13 use tokio::{net::TcpStream, sync::Mutex};
14 use tokio_tungstenite::{tungstenite::Message, WebSocketStream};
15
16 use crate::{
17 authentication::AuthenticationTokenGranter,
18 backend::{RepositoryBackend, UserBackend},
19 messages::error::ConnectionError,
20 model::{authenticated::NetworkMessage, instance::Instance},
21 };
22
23 use super::{
24 authentication::authentication_handle, handshake::handshake_handle,
25 repository::repository_handle, user::user_handle, Connections,
26 };
27
28 pub async fn connection_wrapper(
29 socket: WebSocketStream<TcpStream>,
30 connections: Arc<Mutex<Connections>>,
31 repository_backend: Arc<Mutex<dyn RepositoryBackend + Send>>,
32 user_backend: Arc<Mutex<dyn UserBackend + Send>>,
33 auth_granter: Arc<Mutex<AuthenticationTokenGranter>>,
34 addr: SocketAddr,
35 instance: impl ToOwned<Owned = Instance>,
36 ) {
37 let connection_state = ConnectionState {
38 socket: Arc::new(Mutex::new(socket)),
39 connections,
40 repository_backend,
41 user_backend,
42 auth_granter,
43 addr,
44 instance: instance.to_owned(),
45 handshaked: Arc::new(AtomicBool::new(false)),
46 };
47
48 let mut handshaked = false;
49
50 loop {
51 let mut socket = connection_state.socket.lock().await;
52 let message = socket.next().await;
53 drop(socket);
54
55 match message {
56 Some(Ok(message)) => {
57 let payload = match message {
58 Message::Binary(payload) => payload,
59 Message::Ping(_) => {
60 let mut socket = connection_state.socket.lock().await;
61 let _ = socket.send(Message::Pong(vec![])).await;
62 drop(socket);
63 continue;
64 }
65 Message::Close(_) => return,
66 _ => continue,
67 };
68
69 let message = NetworkMessage(payload.clone());
70
71 if !handshaked {
72 if handshake_handle(&message, &connection_state).await.is_ok() {
73 if connection_state.handshaked.load(Ordering::SeqCst) {
74 handshaked = true;
75 }
76 }
77 } else {
78 let raw = serde_json::from_slice::<Value>(&payload).unwrap();
79 let message_type = raw.get("message_type").unwrap().as_str().unwrap();
80
81 match authentication_handle(message_type, &message, &connection_state).await {
82 Err(e) => {
83 let _ = connection_state.send(ConnectionError(e.to_string())).await;
84 }
85 Ok(true) => continue,
86 Ok(false) => {}
87 }
88
89 match repository_handle(message_type, &message, &connection_state).await {
90 Err(e) => {
91 let _ = connection_state.send(ConnectionError(e.to_string())).await;
92 }
93 Ok(true) => continue,
94 Ok(false) => {}
95 }
96
97 match user_handle(message_type, &message, &connection_state).await {
98 Err(e) => {
99 let _ = connection_state.send(ConnectionError(e.to_string())).await;
100 }
101 Ok(true) => continue,
102 Ok(false) => {}
103 }
104
105 match authentication_handle(message_type, &message, &connection_state).await {
106 Err(e) => {
107 let _ = connection_state.send(ConnectionError(e.to_string())).await;
108 }
109 Ok(true) => continue,
110 Ok(false) => {}
111 }
112
113 error!(
114 "Message completely unhandled: {}",
115 std::str::from_utf8(&payload).unwrap()
116 );
117 }
118 }
119 Some(Err(e)) => {
120 error!("Closing connection for {:?} for {}", e, addr);
121 return;
122 }
123 _ => {
124 info!("Unhandled");
125 continue;
126 }
127 }
128 }
129 }
130
131 #[derive(Clone)]
132 pub struct ConnectionState {
133 socket: Arc<Mutex<WebSocketStream<TcpStream>>>,
134 pub connections: Arc<Mutex<Connections>>,
135 pub repository_backend: Arc<Mutex<dyn RepositoryBackend + Send>>,
136 pub user_backend: Arc<Mutex<dyn UserBackend + Send>>,
137 pub auth_granter: Arc<Mutex<AuthenticationTokenGranter>>,
138 pub addr: SocketAddr,
139 pub instance: Instance,
140 pub handshaked: Arc<AtomicBool>,
141 }
142
143 impl ConnectionState {
144 pub async fn send<T: Serialize>(&self, message: T) -> Result<(), Error> {
145 let payload = serde_json::to_string(&message)?;
146 info!("Sending payload: {}", &payload);
147 self.socket
148 .lock()
149 .await
150 .send(Message::Binary(payload.into_bytes()))
151 .await?;
152
153 Ok(())
154 }
155 }
156