JavaScript is disabled, refresh for a better experience. ambee/giterated

ambee/giterated

Git repository hosting, collaboration, and discovery for the Fediverse.

Add settings

Amber - ⁨2⁩ years ago

parent: tbd commit: ⁨0448edb

⁨src/connection/wrapper.rs⁩ - ⁨5518⁩ bytes
Raw
1 use std::{
2 collections::HashMap,
3 net::SocketAddr,
4 sync::{
5 atomic::{AtomicBool, Ordering},
6 Arc,
7 },
8 };
9
10 use anyhow::Error;
11 use futures_util::{SinkExt, StreamExt};
12 use rsa::RsaPublicKey;
13 use serde::Serialize;
14 use serde_json::Value;
15 use tokio::{
16 net::TcpStream,
17 sync::{Mutex, RwLock},
18 };
19 use tokio_tungstenite::{tungstenite::Message, WebSocketStream};
20
21 use crate::{
22 authentication::AuthenticationTokenGranter,
23 backend::{RepositoryBackend, UserBackend},
24 messages::error::ConnectionError,
25 model::{authenticated::NetworkMessage, instance::Instance},
26 };
27
28 use super::{
29 authentication::authentication_handle, handshake::handshake_handle,
30 repository::repository_handle, user::user_handle, Connections,
31 };
32
33 pub async fn connection_wrapper(
34 socket: WebSocketStream<TcpStream>,
35 connections: Arc<Mutex<Connections>>,
36 repository_backend: Arc<Mutex<dyn RepositoryBackend + Send>>,
37 user_backend: Arc<Mutex<dyn UserBackend + Send>>,
38 auth_granter: Arc<Mutex<AuthenticationTokenGranter>>,
39 addr: SocketAddr,
40 instance: impl ToOwned<Owned = Instance>,
41 ) {
42 let connection_state = ConnectionState {
43 socket: Arc::new(Mutex::new(socket)),
44 connections,
45 repository_backend,
46 user_backend,
47 auth_granter,
48 addr,
49 instance: instance.to_owned(),
50 handshaked: Arc::new(AtomicBool::new(false)),
51 cached_keys: Arc::default(),
52 };
53
54 let mut handshaked = false;
55
56 loop {
57 let mut socket = connection_state.socket.lock().await;
58 let message = socket.next().await;
59 drop(socket);
60
61 match message {
62 Some(Ok(message)) => {
63 let payload = match message {
64 Message::Binary(payload) => payload,
65 Message::Ping(_) => {
66 let mut socket = connection_state.socket.lock().await;
67 let _ = socket.send(Message::Pong(vec![])).await;
68 drop(socket);
69 continue;
70 }
71 Message::Close(_) => return,
72 _ => continue,
73 };
74
75 let message = NetworkMessage(payload.clone());
76
77 if !handshaked {
78 if handshake_handle(&message, &connection_state).await.is_ok() {
79 if connection_state.handshaked.load(Ordering::SeqCst) {
80 handshaked = true;
81 }
82 }
83 } else {
84 let raw = serde_json::from_slice::<Value>(&payload).unwrap();
85 let message_type = raw.get("message_type").unwrap().as_str().unwrap();
86
87 match authentication_handle(message_type, &message, &connection_state).await {
88 Err(e) => {
89 let _ = connection_state.send(ConnectionError(e.to_string())).await;
90 }
91 Ok(true) => continue,
92 Ok(false) => {}
93 }
94
95 match repository_handle(message_type, &message, &connection_state).await {
96 Err(e) => {
97 let _ = connection_state.send(ConnectionError(e.to_string())).await;
98 }
99 Ok(true) => continue,
100 Ok(false) => {}
101 }
102
103 match user_handle(message_type, &message, &connection_state).await {
104 Err(e) => {
105 let _ = connection_state.send(ConnectionError(e.to_string())).await;
106 }
107 Ok(true) => continue,
108 Ok(false) => {}
109 }
110
111 match authentication_handle(message_type, &message, &connection_state).await {
112 Err(e) => {
113 let _ = connection_state.send(ConnectionError(e.to_string())).await;
114 }
115 Ok(true) => continue,
116 Ok(false) => {}
117 }
118
119 error!(
120 "Message completely unhandled: {}",
121 std::str::from_utf8(&payload).unwrap()
122 );
123 }
124 }
125 Some(Err(e)) => {
126 error!("Closing connection for {:?} for {}", e, addr);
127 return;
128 }
129 _ => {
130 info!("Unhandled");
131 continue;
132 }
133 }
134 }
135 }
136
137 #[derive(Clone)]
138 pub struct ConnectionState {
139 socket: Arc<Mutex<WebSocketStream<TcpStream>>>,
140 pub connections: Arc<Mutex<Connections>>,
141 pub repository_backend: Arc<Mutex<dyn RepositoryBackend + Send>>,
142 pub user_backend: Arc<Mutex<dyn UserBackend + Send>>,
143 pub auth_granter: Arc<Mutex<AuthenticationTokenGranter>>,
144 pub addr: SocketAddr,
145 pub instance: Instance,
146 pub handshaked: Arc<AtomicBool>,
147 pub cached_keys: Arc<RwLock<HashMap<Instance, RsaPublicKey>>>,
148 }
149
150 impl ConnectionState {
151 pub async fn send<T: Serialize>(&self, message: T) -> Result<(), Error> {
152 let payload = serde_json::to_string(&message)?;
153 info!("Sending payload: {}", &payload);
154 self.socket
155 .lock()
156 .await
157 .send(Message::Binary(payload.into_bytes()))
158 .await?;
159
160 Ok(())
161 }
162 }
163