1 |
use std::{
|
2 |
collections::HashMap,
|
3 |
net::SocketAddr,
|
4 |
sync::{
|
5 |
atomic::{AtomicBool, Ordering},
|
6 |
Arc,
|
7 |
},
|
8 |
};
|
9 |
|
10 |
use anyhow::Error;
|
11 |
use futures_util::{SinkExt, StreamExt};
|
12 |
use giterated_models::{
|
13 |
messages::error::ConnectionError,
|
14 |
model::{
|
15 |
authenticated::{Authenticated, AuthenticatedPayload},
|
16 |
instance::Instance,
|
17 |
},
|
18 |
};
|
19 |
use rsa::RsaPublicKey;
|
20 |
use serde::Serialize;
|
21 |
use serde_json::Value;
|
22 |
use tokio::{
|
23 |
net::TcpStream,
|
24 |
sync::{Mutex, RwLock},
|
25 |
};
|
26 |
use tokio_tungstenite::{tungstenite::Message, WebSocketStream};
|
27 |
|
28 |
use crate::{
|
29 |
authentication::AuthenticationTokenGranter,
|
30 |
backend::{RepositoryBackend, UserBackend},
|
31 |
connection::forwarded::wrap_forwarded,
|
32 |
federation::connections::InstanceConnections,
|
33 |
message::NetworkMessage,
|
34 |
};
|
35 |
|
36 |
use super::{
|
37 |
authentication::authentication_handle, handshake::handshake_handle,
|
38 |
repository::repository_handle, user::user_handle, Connections,
|
39 |
};
|
40 |
|
41 |
pub async fn connection_wrapper(
|
42 |
socket: WebSocketStream<TcpStream>,
|
43 |
connections: Arc<Mutex<Connections>>,
|
44 |
repository_backend: Arc<Mutex<dyn RepositoryBackend + Send>>,
|
45 |
user_backend: Arc<Mutex<dyn UserBackend + Send>>,
|
46 |
auth_granter: Arc<Mutex<AuthenticationTokenGranter>>,
|
47 |
addr: SocketAddr,
|
48 |
instance: impl ToOwned<Owned = Instance>,
|
49 |
instance_connections: Arc<Mutex<InstanceConnections>>,
|
50 |
) {
|
51 |
let connection_state = ConnectionState {
|
52 |
socket: Arc::new(Mutex::new(socket)),
|
53 |
connections,
|
54 |
repository_backend,
|
55 |
user_backend,
|
56 |
auth_granter,
|
57 |
addr,
|
58 |
instance: instance.to_owned(),
|
59 |
handshaked: Arc::new(AtomicBool::new(false)),
|
60 |
cached_keys: Arc::default(),
|
61 |
};
|
62 |
|
63 |
let mut handshaked = false;
|
64 |
|
65 |
loop {
|
66 |
let mut socket = connection_state.socket.lock().await;
|
67 |
let message = socket.next().await;
|
68 |
drop(socket);
|
69 |
|
70 |
match message {
|
71 |
Some(Ok(message)) => {
|
72 |
let payload = match message {
|
73 |
Message::Binary(payload) => payload,
|
74 |
Message::Ping(_) => {
|
75 |
let mut socket = connection_state.socket.lock().await;
|
76 |
let _ = socket.send(Message::Pong(vec![])).await;
|
77 |
drop(socket);
|
78 |
continue;
|
79 |
}
|
80 |
Message::Close(_) => return,
|
81 |
_ => continue,
|
82 |
};
|
83 |
|
84 |
info!(
|
85 |
"Received payload: {}",
|
86 |
std::str::from_utf8(&payload).unwrap()
|
87 |
);
|
88 |
|
89 |
let message = NetworkMessage(payload.clone());
|
90 |
|
91 |
if !handshaked {
|
92 |
info!("im foo baring");
|
93 |
if handshake_handle(&message, &connection_state).await.is_ok() {
|
94 |
if connection_state.handshaked.load(Ordering::SeqCst) {
|
95 |
handshaked = true;
|
96 |
}
|
97 |
}
|
98 |
} else {
|
99 |
let raw = serde_json::from_slice::<AuthenticatedPayload>(&payload).unwrap();
|
100 |
|
101 |
if let Some(target_instance) = &raw.target_instance {
|
102 |
|
103 |
info!("Forwarding message to {}", target_instance.url);
|
104 |
let mut instance_connections = instance_connections.lock().await;
|
105 |
let pool = instance_connections.get_or_open(&target_instance).unwrap();
|
106 |
let pool_clone = pool.clone();
|
107 |
drop(pool);
|
108 |
|
109 |
let result = wrap_forwarded(&pool_clone, raw).await;
|
110 |
|
111 |
let mut socket = connection_state.socket.lock().await;
|
112 |
let _ = socket.send(result).await;
|
113 |
|
114 |
continue;
|
115 |
}
|
116 |
|
117 |
let message_type = &raw.message_type;
|
118 |
|
119 |
info!("Handling message with type: {}", message_type);
|
120 |
|
121 |
match authentication_handle(message_type, &message, &connection_state).await {
|
122 |
Err(e) => {
|
123 |
let _ = connection_state
|
124 |
.send_raw(ConnectionError(e.to_string()))
|
125 |
.await;
|
126 |
}
|
127 |
Ok(true) => continue,
|
128 |
Ok(false) => {}
|
129 |
}
|
130 |
|
131 |
match repository_handle(message_type, &message, &connection_state).await {
|
132 |
Err(e) => {
|
133 |
let _ = connection_state
|
134 |
.send_raw(ConnectionError(e.to_string()))
|
135 |
.await;
|
136 |
}
|
137 |
Ok(true) => continue,
|
138 |
Ok(false) => {}
|
139 |
}
|
140 |
|
141 |
match user_handle(message_type, &message, &connection_state).await {
|
142 |
Err(e) => {
|
143 |
let _ = connection_state
|
144 |
.send_raw(ConnectionError(e.to_string()))
|
145 |
.await;
|
146 |
}
|
147 |
Ok(true) => continue,
|
148 |
Ok(false) => {}
|
149 |
}
|
150 |
|
151 |
match authentication_handle(message_type, &message, &connection_state).await {
|
152 |
Err(e) => {
|
153 |
let _ = connection_state
|
154 |
.send_raw(ConnectionError(e.to_string()))
|
155 |
.await;
|
156 |
}
|
157 |
Ok(true) => continue,
|
158 |
Ok(false) => {}
|
159 |
}
|
160 |
|
161 |
error!(
|
162 |
"Message completely unhandled: {}",
|
163 |
std::str::from_utf8(&payload).unwrap()
|
164 |
);
|
165 |
}
|
166 |
}
|
167 |
Some(Err(e)) => {
|
168 |
error!("Closing connection for {:?} for {}", e, addr);
|
169 |
return;
|
170 |
}
|
171 |
_ => {
|
172 |
info!("Unhandled");
|
173 |
continue;
|
174 |
}
|
175 |
}
|
176 |
}
|
177 |
}
|
178 |
|
179 |
#[derive(Clone)]
|
180 |
pub struct ConnectionState {
|
181 |
socket: Arc<Mutex<WebSocketStream<TcpStream>>>,
|
182 |
pub connections: Arc<Mutex<Connections>>,
|
183 |
pub repository_backend: Arc<Mutex<dyn RepositoryBackend + Send>>,
|
184 |
pub user_backend: Arc<Mutex<dyn UserBackend + Send>>,
|
185 |
pub auth_granter: Arc<Mutex<AuthenticationTokenGranter>>,
|
186 |
pub addr: SocketAddr,
|
187 |
pub instance: Instance,
|
188 |
pub handshaked: Arc<AtomicBool>,
|
189 |
pub cached_keys: Arc<RwLock<HashMap<Instance, RsaPublicKey>>>,
|
190 |
}
|
191 |
|
192 |
impl ConnectionState {
|
193 |
pub async fn send<T: Serialize>(&self, message: T) -> Result<(), Error> {
|
194 |
let payload = serde_json::to_string(&message)?;
|
195 |
info!("Sending payload: {}", &payload);
|
196 |
self.socket
|
197 |
.lock()
|
198 |
.await
|
199 |
.send(Message::Binary(payload.into_bytes()))
|
200 |
.await?;
|
201 |
|
202 |
Ok(())
|
203 |
}
|
204 |
|
205 |
pub async fn send_raw<T: Serialize>(&self, message: T) -> Result<(), Error> {
|
206 |
let payload = serde_json::to_string(&message)?;
|
207 |
info!("Sending payload: {}", &payload);
|
208 |
self.socket
|
209 |
.lock()
|
210 |
.await
|
211 |
.send(Message::Binary(payload.into_bytes()))
|
212 |
.await?;
|
213 |
|
214 |
Ok(())
|
215 |
}
|
216 |
}
|
217 |
|