JavaScript is disabled, refresh for a better experience. ambee/giterated

ambee/giterated

Git repository hosting, collaboration, and discovery for the Fediverse.

Fix authenticated endpoints

Amber - ⁨2⁩ years ago

parent: tbd commit: ⁨1400b06

⁨giterated-daemon/src/connection/wrapper.rs⁩ - ⁨7373⁩ bytes
Raw
1 use std::{
2 collections::HashMap,
3 net::SocketAddr,
4 sync::{
5 atomic::{AtomicBool, Ordering},
6 Arc,
7 },
8 };
9
10 use anyhow::Error;
11 use futures_util::{SinkExt, StreamExt};
12 use giterated_models::{
13 messages::error::ConnectionError,
14 model::{authenticated::{Authenticated, AuthenticatedPayload}, instance::Instance},
15 };
16 use rsa::RsaPublicKey;
17 use serde::Serialize;
18 use serde_json::Value;
19 use tokio::{
20 net::TcpStream,
21 sync::{Mutex, RwLock},
22 };
23 use tokio_tungstenite::{tungstenite::Message, WebSocketStream};
24
25 use crate::{
26 authentication::AuthenticationTokenGranter,
27 backend::{RepositoryBackend, UserBackend},
28 connection::forwarded::wrap_forwarded,
29 federation::connections::InstanceConnections,
30 message::NetworkMessage,
31 };
32
33 use super::{
34 authentication::authentication_handle, handshake::handshake_handle,
35 repository::repository_handle, user::user_handle, Connections,
36 };
37
38 pub async fn connection_wrapper(
39 socket: WebSocketStream<TcpStream>,
40 connections: Arc<Mutex<Connections>>,
41 repository_backend: Arc<Mutex<dyn RepositoryBackend + Send>>,
42 user_backend: Arc<Mutex<dyn UserBackend + Send>>,
43 auth_granter: Arc<Mutex<AuthenticationTokenGranter>>,
44 addr: SocketAddr,
45 instance: impl ToOwned<Owned = Instance>,
46 instance_connections: Arc<Mutex<InstanceConnections>>,
47 ) {
48 let connection_state = ConnectionState {
49 socket: Arc::new(Mutex::new(socket)),
50 connections,
51 repository_backend,
52 user_backend,
53 auth_granter,
54 addr,
55 instance: instance.to_owned(),
56 handshaked: Arc::new(AtomicBool::new(false)),
57 cached_keys: Arc::default(),
58 };
59
60 let mut handshaked = false;
61
62 loop {
63 let mut socket = connection_state.socket.lock().await;
64 let message = socket.next().await;
65 drop(socket);
66
67 match message {
68 Some(Ok(message)) => {
69 let payload = match message {
70 Message::Binary(payload) => payload,
71 Message::Ping(_) => {
72 let mut socket = connection_state.socket.lock().await;
73 let _ = socket.send(Message::Pong(vec![])).await;
74 drop(socket);
75 continue;
76 }
77 Message::Close(_) => return,
78 _ => continue,
79 };
80
81 info!(
82 "Received payload: {}",
83 std::str::from_utf8(&payload).unwrap()
84 );
85
86 let message = NetworkMessage(payload.clone());
87
88 if !handshaked {
89 info!("im foo baring");
90 if handshake_handle(&message, &connection_state).await.is_ok() {
91 if connection_state.handshaked.load(Ordering::SeqCst) {
92 handshaked = true;
93 }
94 }
95 } else {
96 let raw = serde_json::from_slice::<AuthenticatedPayload>(&payload).unwrap();
97
98 if let Some(target_instance) = &raw.target_instance {
99 // Forward request
100 info!("Forwarding message to {}", target_instance.url);
101 let mut instance_connections = instance_connections.lock().await;
102 let pool = instance_connections.get_or_open(&target_instance).unwrap();
103 let pool_clone = pool.clone();
104 drop(pool);
105
106 let result = wrap_forwarded(&pool_clone, raw).await;
107
108 let mut socket = connection_state.socket.lock().await;
109 let _ = socket.send(result).await;
110
111 continue;
112 }
113
114 let message_type = &raw.message_type;
115
116 info!("Handling message with type: {}", message_type);
117
118 match authentication_handle(message_type, &message, &connection_state).await {
119 Err(e) => {
120 let _ = connection_state
121 .send_raw(ConnectionError(e.to_string()))
122 .await;
123 }
124 Ok(true) => continue,
125 Ok(false) => {}
126 }
127
128 match repository_handle(message_type, &message, &connection_state).await {
129 Err(e) => {
130 let _ = connection_state
131 .send_raw(ConnectionError(e.to_string()))
132 .await;
133 }
134 Ok(true) => continue,
135 Ok(false) => {}
136 }
137
138 match user_handle(message_type, &message, &connection_state).await {
139 Err(e) => {
140 let _ = connection_state
141 .send_raw(ConnectionError(e.to_string()))
142 .await;
143 }
144 Ok(true) => continue,
145 Ok(false) => {}
146 }
147
148 match authentication_handle(message_type, &message, &connection_state).await {
149 Err(e) => {
150 let _ = connection_state
151 .send_raw(ConnectionError(e.to_string()))
152 .await;
153 }
154 Ok(true) => continue,
155 Ok(false) => {}
156 }
157
158 error!(
159 "Message completely unhandled: {}",
160 std::str::from_utf8(&payload).unwrap()
161 );
162 }
163 }
164 Some(Err(e)) => {
165 error!("Closing connection for {:?} for {}", e, addr);
166 return;
167 }
168 _ => {
169 info!("Unhandled");
170 continue;
171 }
172 }
173 }
174 }
175
176 #[derive(Clone)]
177 pub struct ConnectionState {
178 socket: Arc<Mutex<WebSocketStream<TcpStream>>>,
179 pub connections: Arc<Mutex<Connections>>,
180 pub repository_backend: Arc<Mutex<dyn RepositoryBackend + Send>>,
181 pub user_backend: Arc<Mutex<dyn UserBackend + Send>>,
182 pub auth_granter: Arc<Mutex<AuthenticationTokenGranter>>,
183 pub addr: SocketAddr,
184 pub instance: Instance,
185 pub handshaked: Arc<AtomicBool>,
186 pub cached_keys: Arc<RwLock<HashMap<Instance, RsaPublicKey>>>,
187 }
188
189 impl ConnectionState {
190 pub async fn send<T: Serialize>(&self, message: T) -> Result<(), Error> {
191 let payload = serde_json::to_string(&message)?;
192 info!("Sending payload: {}", &payload);
193 self.socket
194 .lock()
195 .await
196 .send(Message::Binary(payload.into_bytes()))
197 .await?;
198
199 Ok(())
200 }
201
202 pub async fn send_raw<T: Serialize>(&self, message: T) -> Result<(), Error> {
203 let payload = serde_json::to_string(&message)?;
204 info!("Sending payload: {}", &payload);
205 self.socket
206 .lock()
207 .await
208 .send(Message::Binary(payload.into_bytes()))
209 .await?;
210
211 Ok(())
212 }
213 }
214