Showing 19 changed files with 1221 insertions and 595 deletions
Cargo.lock
@@ -652,6 +652,7 @@ dependencies = [ | ||
652 | 652 | "tokio", |
653 | 653 | "tokio-tungstenite", |
654 | 654 | "toml", |
655 | "tower", | |
655 | 656 | "tracing", |
656 | 657 | "tracing-subscriber", |
657 | 658 | ] |
@@ -2230,6 +2231,23 @@ dependencies = [ | ||
2230 | 2231 | ] |
2231 | 2232 | |
2232 | 2233 | [[package]] |
2234 | name = "tower" | |
2235 | version = "0.4.13" | |
2236 | source = "registry+https://github.com/rust-lang/crates.io-index" | |
2237 | checksum = "b8fa9be0de6cf49e536ce1851f987bd21a43b771b09473c3549a6c853db37c1c" | |
2238 | dependencies = [ | |
2239 | "tower-layer", | |
2240 | "tower-service", | |
2241 | "tracing", | |
2242 | ] | |
2243 | ||
2244 | [[package]] | |
2245 | name = "tower-layer" | |
2246 | version = "0.3.2" | |
2247 | source = "registry+https://github.com/rust-lang/crates.io-index" | |
2248 | checksum = "c20c8dbed6283a09604c3e69b4b7eeb54e298b8a600d4d5ecb5ad39de609f1d0" | |
2249 | ||
2250 | [[package]] | |
2233 | 2251 | name = "tower-service" |
2234 | 2252 | version = "0.3.2" |
2235 | 2253 | source = "registry+https://github.com/rust-lang/crates.io-index" |
Cargo.toml
@@ -10,7 +10,7 @@ tokio-tungstenite = "*" | ||
10 | 10 | tokio = { version = "1.32.0", features = [ "full" ] } |
11 | 11 | tracing = "*" |
12 | 12 | futures-util = "*" |
13 | serde = { version = "1", features = [ "derive" ]} | |
13 | serde = { version = "1.0.188", features = [ "derive" ]} | |
14 | 14 | serde_json = "1.0" |
15 | 15 | tracing-subscriber = "0.3" |
16 | 16 | base64 = "0.21.3" |
@@ -22,6 +22,7 @@ reqwest = "*" | ||
22 | 22 | argon2 = "*" |
23 | 23 | aes-gcm = "0.10.2" |
24 | 24 | semver = {version = "*", features = ["serde"]} |
25 | tower = "*" | |
25 | 26 | |
26 | 27 | toml = { version = "0.7" } |
27 | 28 |
src/authentication.rs
@@ -13,7 +13,7 @@ use crate::{ | ||
13 | 13 | }, |
14 | 14 | InstanceAuthenticated, |
15 | 15 | }, |
16 | model::{instance::Instance, user::User}, | |
16 | model::{authenticated::UserAuthenticationToken, instance::Instance, user::User}, | |
17 | 17 | }; |
18 | 18 | |
19 | 19 | #[derive(Debug, Serialize, Deserialize)] |
@@ -74,21 +74,10 @@ impl AuthenticationTokenGranter { | ||
74 | 74 | |
75 | 75 | pub async fn token_request( |
76 | 76 | &mut self, |
77 | raw_request: InstanceAuthenticated<AuthenticationTokenRequest>, | |
77 | issued_for: impl ToOwned<Owned = Instance>, | |
78 | username: String, | |
79 | password: String, | |
78 | 80 | ) -> Result<AuthenticationTokenResponse, Error> { |
79 | let request = raw_request.inner().await; | |
80 | ||
81 | info!("Ensuring token request is from the same instance..."); | |
82 | raw_request | |
83 | .validate(&Instance { | |
84 | url: String::from("giterated.dev"), | |
85 | }) | |
86 | .await | |
87 | .unwrap(); | |
88 | ||
89 | let secret_key = self.config["authentication"]["secret_key"] | |
90 | .as_str() | |
91 | .unwrap(); | |
92 | 81 | let private_key = { |
93 | 82 | let mut file = File::open( |
94 | 83 | self.config["giterated"]["keys"]["private"] |
@@ -104,20 +93,14 @@ impl AuthenticationTokenGranter { | ||
104 | 93 | key |
105 | 94 | }; |
106 | 95 | |
107 | if request.secret_key != secret_key { | |
108 | error!("Incorrect secret key!"); | |
109 | ||
110 | panic!() | |
111 | } | |
112 | ||
113 | 96 | let encoding_key = EncodingKey::from_rsa_pem(&private_key).unwrap(); |
114 | 97 | |
115 | 98 | let claims = UserTokenMetadata { |
116 | 99 | user: User { |
117 | username: request.username.clone(), | |
100 | username, | |
118 | 101 | instance: self.instance.clone(), |
119 | 102 | }, |
120 | generated_for: raw_request.instance.clone(), | |
103 | generated_for: issued_for.to_owned(), | |
121 | 104 | exp: (SystemTime::UNIX_EPOCH.elapsed().unwrap() |
122 | 105 | + std::time::Duration::from_secs(24 * 60 * 60)) |
123 | 106 | .as_secs(), |
@@ -135,53 +118,25 @@ impl AuthenticationTokenGranter { | ||
135 | 118 | |
136 | 119 | pub async fn extension_request( |
137 | 120 | &mut self, |
138 | raw_request: InstanceAuthenticated<TokenExtensionRequest>, | |
121 | issued_for: &Instance, | |
122 | token: UserAuthenticationToken, | |
139 | 123 | ) -> Result<TokenExtensionResponse, Error> { |
140 | let request = raw_request.inner().await; | |
141 | ||
142 | // let server_public_key = { | |
143 | // let mut file = File::open(self.config["keys"]["public"].as_str().unwrap()) | |
144 | // .await | |
145 | // .unwrap(); | |
146 | ||
147 | // let mut key = String::default(); | |
148 | // file.read_to_string(&mut key).await.unwrap(); | |
149 | ||
150 | // key | |
151 | // }; | |
152 | ||
153 | let server_public_key = public_key(&Instance { | |
154 | url: String::from("giterated.dev"), | |
155 | }) | |
156 | .await | |
157 | .unwrap(); | |
124 | let server_public_key = public_key(&self.instance).await.unwrap(); | |
158 | 125 | |
159 | 126 | let verification_key = DecodingKey::from_rsa_pem(server_public_key.as_bytes()).unwrap(); |
160 | 127 | |
161 | 128 | let data: TokenData<UserTokenMetadata> = decode( |
162 | &request.token, | |
129 | token.as_ref(), | |
163 | 130 | &verification_key, |
164 | 131 | &Validation::new(Algorithm::RS256), |
165 | 132 | ) |
166 | 133 | .unwrap(); |
167 | 134 | |
168 | info!("Token Extension Request Token validated"); | |
169 | ||
170 | let secret_key = self.config["authentication"]["secret_key"] | |
171 | .as_str() | |
172 | .unwrap(); | |
173 | ||
174 | if request.secret_key != secret_key { | |
175 | error!("Incorrect secret key!"); | |
176 | ||
135 | if data.claims.generated_for != *issued_for { | |
177 | 136 | panic!() |
178 | 137 | } |
179 | // Validate request | |
180 | raw_request | |
181 | .validate(&data.claims.generated_for) | |
182 | .await | |
183 | .unwrap(); | |
184 | info!("Validated request for key extension"); | |
138 | ||
139 | info!("Token Extension Request Token validated"); | |
185 | 140 | |
186 | 141 | let private_key = { |
187 | 142 | let mut file = File::open( |
@@ -203,7 +158,7 @@ impl AuthenticationTokenGranter { | ||
203 | 158 | let claims = UserTokenMetadata { |
204 | 159 | // TODO: Probably exploitable |
205 | 160 | user: data.claims.user, |
206 | generated_for: data.claims.generated_for, | |
161 | generated_for: issued_for.clone(), | |
207 | 162 | exp: (SystemTime::UNIX_EPOCH.elapsed().unwrap() |
208 | 163 | + std::time::Duration::from_secs(24 * 60 * 60)) |
209 | 164 | .as_secs(), |
src/backend/git.rs
@@ -212,26 +212,9 @@ impl GitBackend { | ||
212 | 212 | impl RepositoryBackend for GitBackend { |
213 | 213 | async fn create_repository( |
214 | 214 | &mut self, |
215 | raw_request: &ValidatedUserAuthenticated<CreateRepositoryRequest>, | |
215 | user: &User, | |
216 | request: &CreateRepositoryRequest, | |
216 | 217 | ) -> Result<CreateRepositoryResponse, Error> { |
217 | let request = raw_request.inner().await; | |
218 | ||
219 | // let public_key = public_key(&Instance { | |
220 | // url: String::from("giterated.dev"), | |
221 | // }) | |
222 | // .await | |
223 | // .unwrap(); | |
224 | // | |
225 | // match raw_request.validate(public_key).await { | |
226 | // Ok(_) => info!("Request was validated"), | |
227 | // Err(err) => { | |
228 | // error!("Failed to validate request: {:?}", err); | |
229 | // panic!(); | |
230 | // } | |
231 | // } | |
232 | // | |
233 | // info!("Request was valid!"); | |
234 | ||
235 | 218 | // Check if repository already exists in the database |
236 | 219 | if let Ok(repository) = self |
237 | 220 | .find_by_owner_user_name(&request.owner, &request.name) |
@@ -297,11 +280,9 @@ impl RepositoryBackend for GitBackend { | ||
297 | 280 | |
298 | 281 | async fn repository_info( |
299 | 282 | &mut self, |
300 | // TODO: Allow non-authenticated??? | |
301 | raw_request: &ValidatedUserAuthenticated<RepositoryInfoRequest>, | |
283 | requester: Option<&User>, | |
284 | request: &RepositoryInfoRequest, | |
302 | 285 | ) -> Result<RepositoryView, Error> { |
303 | let request = raw_request.inner().await; | |
304 | ||
305 | 286 | let repository = match self |
306 | 287 | .find_by_owner_user_name( |
307 | 288 | // &request.owner.instance.url, |
@@ -314,7 +295,17 @@ impl RepositoryBackend for GitBackend { | ||
314 | 295 | Err(err) => return Err(Box::new(err).into()), |
315 | 296 | }; |
316 | 297 | |
317 | if !repository.can_user_view_repository(Some(&raw_request.user)) { | |
298 | if let Some(requester) = requester { | |
299 | if !repository.can_user_view_repository(Some(&requester)) { | |
300 | return Err(Box::new(GitBackendError::RepositoryNotFound { | |
301 | owner_user: request.repository.owner.to_string(), | |
302 | name: request.repository.name.clone(), | |
303 | }) | |
304 | .into()); | |
305 | } | |
306 | } else if matches!(repository.visibility, RepositoryVisibility::Private) { | |
307 | // Unauthenticated users can never view private repositories | |
308 | ||
318 | 309 | return Err(Box::new(GitBackendError::RepositoryNotFound { |
319 | 310 | owner_user: request.repository.owner.to_string(), |
320 | 311 | name: request.repository.name.clone(), |
@@ -446,9 +437,10 @@ impl RepositoryBackend for GitBackend { | ||
446 | 437 | }) |
447 | 438 | } |
448 | 439 | |
449 | fn repository_file_inspect( | |
440 | async fn repository_file_inspect( | |
450 | 441 | &mut self, |
451 | _request: &ValidatedUserAuthenticated<RepositoryFileInspectRequest>, | |
442 | requester: Option<&User>, | |
443 | _request: &RepositoryFileInspectRequest, | |
452 | 444 | ) -> Result<RepositoryFileInspectionResponse, Error> { |
453 | 445 | todo!() |
454 | 446 | } |
src/backend/mod.rs
@@ -35,15 +35,18 @@ use crate::{ | ||
35 | 35 | pub trait RepositoryBackend: IssuesBackend { |
36 | 36 | async fn create_repository( |
37 | 37 | &mut self, |
38 | request: &ValidatedUserAuthenticated<CreateRepositoryRequest>, | |
38 | user: &User, | |
39 | request: &CreateRepositoryRequest, | |
39 | 40 | ) -> Result<CreateRepositoryResponse, Error>; |
40 | 41 | async fn repository_info( |
41 | 42 | &mut self, |
42 | request: &ValidatedUserAuthenticated<RepositoryInfoRequest>, | |
43 | requester: Option<&User>, | |
44 | request: &RepositoryInfoRequest, | |
43 | 45 | ) -> Result<RepositoryView, Error>; |
44 | fn repository_file_inspect( | |
46 | async fn repository_file_inspect( | |
45 | 47 | &mut self, |
46 | request: &ValidatedUserAuthenticated<RepositoryFileInspectRequest>, | |
48 | requester: Option<&User>, | |
49 | request: &RepositoryFileInspectRequest, | |
47 | 50 | ) -> Result<RepositoryFileInspectionResponse, Error>; |
48 | 51 | async fn repositories_for_user(&mut self, user: &User) |
49 | 52 | -> Result<Vec<RepositorySummary>, Error>; |
@@ -90,6 +93,7 @@ pub trait UserBackend: AuthBackend { | ||
90 | 93 | ) -> Result<UserDisplayImageResponse, Error>; |
91 | 94 | |
92 | 95 | async fn bio(&mut self, request: UserBioRequest) -> Result<UserBioResponse, Error>; |
96 | async fn exists(&mut self, user: &User) -> Result<bool, Error>; | |
93 | 97 | } |
94 | 98 | |
95 | 99 | #[async_trait::async_trait] |
src/backend/user.rs
@@ -100,6 +100,17 @@ impl UserBackend for UserAuth { | ||
100 | 100 | |
101 | 101 | Ok(UserBioResponse { bio: db_row.bio }) |
102 | 102 | } |
103 | ||
104 | async fn exists(&mut self, user: &User) -> Result<bool, Error> { | |
105 | Ok(sqlx::query_as!( | |
106 | UserRow, | |
107 | r#"SELECT * FROM users WHERE username = $1"#, | |
108 | user.username | |
109 | ) | |
110 | .fetch_one(&self.pg_pool.clone()) | |
111 | .await | |
112 | .is_err()) | |
113 | } | |
103 | 114 | } |
104 | 115 | |
105 | 116 | #[async_trait::async_trait] |
src/connection.rs
@@ -1,9 +1,15 @@ | ||
1 | use std::{collections::HashMap, net::SocketAddr, str::FromStr, sync::Arc}; | |
1 | pub mod authentication; | |
2 | pub mod handshake; | |
3 | pub mod repository; | |
4 | pub mod user; | |
5 | pub mod wrapper; | |
6 | ||
7 | use std::{any::type_name, collections::HashMap, net::SocketAddr, str::FromStr, sync::Arc}; | |
2 | 8 | |
3 | 9 | use anyhow::Error; |
4 | 10 | use futures_util::{stream::StreamExt, SinkExt}; |
5 | 11 | use semver::Version; |
6 | use serde::Serialize; | |
12 | use serde::{de::DeserializeOwned, Serialize}; | |
7 | 13 | use tokio::{ |
8 | 14 | net::TcpStream, |
9 | 15 | sync::{ |
@@ -31,7 +37,7 @@ use crate::{ | ||
31 | 37 | UserMessage, UserMessageKind, UserMessageRequest, UserMessageResponse, |
32 | 38 | UserRepositoriesResponse, |
33 | 39 | }, |
34 | MessageKind, | |
40 | ErrorMessage, MessageKind, | |
35 | 41 | }, |
36 | 42 | model::{ |
37 | 43 | instance::{Instance, InstanceMeta}, |
@@ -41,6 +47,16 @@ use crate::{ | ||
41 | 47 | validate_version, version, |
42 | 48 | }; |
43 | 49 | |
50 | #[derive(Debug, thiserror::Error)] | |
51 | pub enum ConnectionError { | |
52 | #[error("connection error message {0}")] | |
53 | ErrorMessage(#[from] ErrorMessage), | |
54 | #[error("connection should close")] | |
55 | Shutdown, | |
56 | #[error("internal error {0}")] | |
57 | InternalError(#[from] Error), | |
58 | } | |
59 | ||
44 | 60 | pub struct RawConnection { |
45 | 61 | pub task: JoinHandle<()>, |
46 | 62 | } |
@@ -63,523 +79,153 @@ pub struct Connections { | ||
63 | 79 | } |
64 | 80 | |
65 | 81 | pub async fn connection_worker( |
66 | mut socket: WebSocketStream<TcpStream>, | |
67 | listeners: Arc<Mutex<Listeners>>, | |
68 | connections: Arc<Mutex<Connections>>, | |
69 | backend: Arc<Mutex<dyn RepositoryBackend + Send>>, | |
70 | user_backend: Arc<Mutex<dyn UserBackend + Send>>, | |
71 | auth_granter: Arc<Mutex<AuthenticationTokenGranter>>, | |
72 | discovery_backend: Arc<Mutex<dyn DiscoveryBackend + Send>>, | |
73 | addr: SocketAddr, | |
74 | ) { | |
75 | let mut handshaked = false; | |
82 | mut socket: &mut WebSocketStream<TcpStream>, | |
83 | handshaked: &mut bool, | |
84 | listeners: &Arc<Mutex<Listeners>>, | |
85 | connections: &Arc<Mutex<Connections>>, | |
86 | backend: &Arc<Mutex<dyn RepositoryBackend + Send>>, | |
87 | user_backend: &Arc<Mutex<dyn UserBackend + Send>>, | |
88 | auth_granter: &Arc<Mutex<AuthenticationTokenGranter>>, | |
89 | discovery_backend: &Arc<Mutex<dyn DiscoveryBackend + Send>>, | |
90 | addr: &SocketAddr, | |
91 | ) -> Result<(), ConnectionError> { | |
76 | 92 | let this_instance = Instance { |
77 | 93 | url: String::from("giterated.dev"), |
78 | 94 | }; |
79 | 95 | |
80 | while let Some(message) = socket.next().await { | |
81 | let message = match message { | |
82 | Ok(message) => message, | |
83 | Err(err) => { | |
84 | error!("Error reading message: {:?}", err); | |
85 | continue; | |
86 | } | |
87 | }; | |
96 | let message = socket | |
97 | .next() | |
98 | .await | |
99 | .ok_or_else(|| ConnectionError::Shutdown)? | |
100 | .map_err(|e| Error::from(e))?; | |
101 | ||
102 | let payload = match message { | |
103 | Message::Text(text) => text.into_bytes(), | |
104 | Message::Binary(bytes) => bytes, | |
105 | Message::Ping(_) => return Ok(()), | |
106 | Message::Pong(_) => return Ok(()), | |
107 | Message::Close(_) => { | |
108 | info!("Closing connection with {}.", addr); | |
109 | ||
110 | return Err(ConnectionError::Shutdown); | |
111 | } | |
112 | _ => unreachable!(), | |
113 | }; | |
88 | 114 | |
89 | let payload = match message { | |
90 | Message::Text(text) => text.into_bytes(), | |
91 | Message::Binary(bytes) => bytes, | |
92 | Message::Ping(_) => continue, | |
93 | Message::Pong(_) => continue, | |
94 | Message::Close(_) => { | |
95 | info!("Closing connection with {}.", addr); | |
115 | let message = serde_json::from_slice::<MessageKind>(&payload).map_err(|e| Error::from(e))?; | |
96 | 116 | |
97 | return; | |
117 | if let MessageKind::Handshake(handshake) = message { | |
118 | match handshake { | |
119 | HandshakeMessage::Initiate(request) => { | |
120 | unimplemented!() | |
98 | 121 | } |
99 | _ => unreachable!(), | |
100 | }; | |
101 | ||
102 | let message = match serde_json::from_slice::<MessageKind>(&payload) { | |
103 | Ok(message) => message, | |
104 | Err(err) => { | |
105 | error!("Error deserializing message from {}: {:?}", addr, err); | |
106 | continue; | |
122 | HandshakeMessage::Response(response) => { | |
123 | unimplemented!() | |
107 | 124 | } |
108 | }; | |
109 | ||
110 | // info!("Read payload: {}", std::str::from_utf8(&payload).unwrap()); | |
111 | ||
112 | if let MessageKind::Handshake(handshake) = message { | |
113 | match handshake { | |
114 | HandshakeMessage::Initiate(request) => { | |
115 | // Send HandshakeMessage::Response | |
116 | let message = HandshakeResponse { | |
117 | identity: this_instance.clone(), | |
118 | version: version(), | |
119 | }; | |
120 | ||
121 | let version_check = validate_version(&request.version); | |
122 | ||
123 | let _result = if !version_check { | |
124 | error!( | |
125 | "Version compatibility failure! Our Version: {}, Their Version: {}", | |
126 | Version::from_str(&std::env::var("CARGO_PKG_VERSION").unwrap()) | |
127 | .unwrap(), | |
128 | request.version | |
129 | ); | |
130 | ||
131 | send( | |
132 | &mut socket, | |
133 | MessageKind::Handshake(HandshakeMessage::Finalize(HandshakeFinalize { | |
134 | success: false, | |
135 | })), | |
136 | ) | |
137 | .await | |
138 | } else { | |
139 | send( | |
140 | &mut socket, | |
141 | MessageKind::Handshake(HandshakeMessage::Response(message)), | |
142 | ) | |
143 | .await | |
144 | }; | |
145 | ||
146 | continue; | |
147 | } | |
148 | HandshakeMessage::Response(response) => { | |
149 | // Check version | |
150 | let message = if validate_version(&response.version) { | |
151 | error!( | |
152 | "Version compatibility failure! Our Version: {}, Their Version: {}", | |
153 | version(), | |
154 | response.version | |
155 | ); | |
156 | ||
157 | HandshakeFinalize { success: false } | |
158 | } else { | |
159 | info!("Connected with a compatible version"); | |
160 | ||
161 | HandshakeFinalize { success: true } | |
162 | }; | |
163 | ||
164 | let _result = send( | |
165 | &mut socket, | |
166 | MessageKind::Handshake(HandshakeMessage::Finalize(message)), | |
167 | ) | |
168 | .await; | |
169 | ||
170 | continue; | |
171 | } | |
172 | HandshakeMessage::Finalize(response) => { | |
173 | if !response.success { | |
174 | error!("Error during handshake, aborting connection"); | |
175 | return; | |
176 | } | |
177 | ||
178 | handshaked = true; | |
179 | ||
180 | // Send HandshakeMessage::Finalize | |
181 | let message = HandshakeFinalize { success: true }; | |
182 | ||
183 | let _result = send( | |
184 | &mut socket, | |
185 | MessageKind::Handshake(HandshakeMessage::Finalize(message)), | |
186 | ) | |
187 | .await; | |
188 | ||
189 | continue; | |
190 | } | |
125 | HandshakeMessage::Finalize(response) => { | |
126 | unimplemented!() | |
191 | 127 | } |
192 | 128 | } |
129 | } | |
193 | 130 | |
194 | if !handshaked { | |
195 | continue; | |
196 | } | |
131 | if !*handshaked { | |
132 | return Ok(()); | |
133 | } | |
197 | 134 | |
198 | if let MessageKind::Repository(repository) = &message { | |
199 | if repository.target.instance != this_instance { | |
200 | info!("Forwarding command to {}", repository.target.instance.url); | |
201 | // We need to send this command to a different instance | |
202 | ||
203 | let mut listener = send_and_get_listener(message, &listeners, &connections).await; | |
204 | ||
205 | // Wait for response | |
206 | while let Ok(message) = listener.recv().await { | |
207 | if let MessageKind::Repository(RepositoryMessage { | |
208 | command: RepositoryMessageKind::Response(_), | |
209 | .. | |
210 | }) = message | |
211 | { | |
212 | let _result = send(&mut socket, message).await; | |
213 | } | |
214 | } | |
215 | continue; | |
216 | } else { | |
217 | // This message is targeting this instance | |
218 | match &repository.command { | |
219 | RepositoryMessageKind::Request(request) => match request.clone() { | |
220 | RepositoryRequest::CreateRepository(request) => { | |
221 | let mut backend = backend.lock().await; | |
222 | let request = request.validate().await.unwrap(); | |
223 | let response = backend.create_repository(&request).await; | |
224 | ||
225 | let response = match response { | |
226 | Ok(response) => response, | |
227 | Err(err) => { | |
228 | error!("Error handling request: {:?}", err); | |
229 | continue; | |
230 | } | |
231 | }; | |
232 | drop(backend); | |
233 | ||
234 | let _result = send( | |
235 | &mut socket, | |
236 | MessageKind::Repository(RepositoryMessage { | |
237 | target: repository.target.clone(), | |
238 | command: RepositoryMessageKind::Response( | |
239 | RepositoryResponse::CreateRepository(response), | |
240 | ), | |
241 | }), | |
242 | ) | |
243 | .await; | |
244 | ||
245 | continue; | |
246 | } | |
247 | RepositoryRequest::RepositoryFileInspect(request) => { | |
248 | let mut backend = backend.lock().await; | |
249 | let request = request.validate().await.unwrap(); | |
250 | let response = backend.repository_file_inspect(&request); | |
251 | ||
252 | let response = match response { | |
253 | Ok(response) => response, | |
254 | Err(err) => { | |
255 | error!("Error handling request: {:?}", err); | |
256 | continue; | |
257 | } | |
258 | }; | |
259 | drop(backend); | |
260 | ||
261 | let _result = send( | |
262 | &mut socket, | |
263 | MessageKind::Repository(RepositoryMessage { | |
264 | target: repository.target.clone(), | |
265 | command: RepositoryMessageKind::Response( | |
266 | RepositoryResponse::RepositoryFileInspection(response), | |
267 | ), | |
268 | }), | |
269 | ) | |
270 | .await; | |
271 | ||
272 | continue; | |
273 | } | |
274 | RepositoryRequest::RepositoryInfo(request) => { | |
275 | let mut backend = backend.lock().await; | |
276 | let request = request.validate().await.unwrap(); | |
277 | let response = backend.repository_info(&request).await; | |
278 | ||
279 | let response = match response { | |
280 | Ok(response) => response, | |
281 | Err(err) => { | |
282 | error!("Error handling request: {:?}", err); | |
283 | continue; | |
284 | } | |
285 | }; | |
286 | drop(backend); | |
287 | ||
288 | let _result = send( | |
289 | &mut socket, | |
290 | MessageKind::Repository(RepositoryMessage { | |
291 | target: repository.target.clone(), | |
292 | command: RepositoryMessageKind::Response( | |
293 | RepositoryResponse::RepositoryInfo(response), | |
294 | ), | |
295 | }), | |
296 | ) | |
297 | .await; | |
298 | ||
299 | continue; | |
300 | } | |
301 | RepositoryRequest::IssuesCount(request) => { | |
302 | let request = &request.validate().await.unwrap(); | |
303 | ||
304 | let mut backend = backend.lock().await; | |
305 | let response = backend.issues_count(request); | |
306 | ||
307 | let response = match response { | |
308 | Ok(response) => response, | |
309 | Err(err) => { | |
310 | error!("Error handling request: {:?}", err); | |
311 | continue; | |
312 | } | |
313 | }; | |
314 | drop(backend); | |
315 | ||
316 | let _result = send( | |
317 | &mut socket, | |
318 | MessageKind::Repository(RepositoryMessage { | |
319 | target: repository.target.clone(), | |
320 | command: RepositoryMessageKind::Response( | |
321 | RepositoryResponse::IssuesCount(response), | |
322 | ), | |
323 | }), | |
324 | ) | |
325 | .await; | |
326 | ||
327 | continue; | |
328 | } | |
329 | RepositoryRequest::IssueLabels(request) => { | |
330 | let request = &request.validate().await.unwrap(); | |
331 | ||
332 | let mut backend = backend.lock().await; | |
333 | let response = backend.issue_labels(request); | |
334 | ||
335 | let response = match response { | |
336 | Ok(response) => response, | |
337 | Err(err) => { | |
338 | error!("Error handling request: {:?}", err); | |
339 | continue; | |
340 | } | |
341 | }; | |
342 | drop(backend); | |
343 | ||
344 | let _result = send( | |
345 | &mut socket, | |
346 | MessageKind::Repository(RepositoryMessage { | |
347 | target: repository.target.clone(), | |
348 | command: RepositoryMessageKind::Response( | |
349 | RepositoryResponse::IssueLabels(response), | |
350 | ), | |
351 | }), | |
352 | ) | |
353 | .await; | |
354 | ||
355 | continue; | |
356 | } | |
357 | RepositoryRequest::Issues(request) => { | |
358 | let request = request.validate().await.unwrap(); | |
359 | ||
360 | let mut backend = backend.lock().await; | |
361 | let response = backend.issues(&request); | |
362 | ||
363 | let response = match response { | |
364 | Ok(response) => response, | |
365 | Err(err) => { | |
366 | error!("Error handling request: {:?}", err); | |
367 | continue; | |
368 | } | |
369 | }; | |
370 | drop(backend); | |
371 | ||
372 | let _result = send( | |
373 | &mut socket, | |
374 | MessageKind::Repository(RepositoryMessage { | |
375 | target: repository.target.clone(), | |
376 | command: RepositoryMessageKind::Response( | |
377 | RepositoryResponse::Issues(response), | |
378 | ), | |
379 | }), | |
380 | ) | |
381 | .await; | |
382 | ||
383 | continue; | |
384 | } | |
385 | }, | |
386 | RepositoryMessageKind::Response(_response) => { | |
387 | unreachable!() | |
388 | } | |
135 | if let MessageKind::Repository(repository) = &message { | |
136 | if repository.target.instance != this_instance { | |
137 | info!("Forwarding command to {}", repository.target.instance.url); | |
138 | // We need to send this command to a different instance | |
139 | ||
140 | let mut listener = send_and_get_listener(message, &listeners, &connections).await; | |
141 | ||
142 | // Wait for response | |
143 | while let Ok(message) = listener.recv().await { | |
144 | if let MessageKind::Repository(RepositoryMessage { | |
145 | command: RepositoryMessageKind::Response(_), | |
146 | .. | |
147 | }) = message | |
148 | { | |
149 | let _result = send(&mut socket, message).await; | |
389 | 150 | } |
390 | 151 | } |
391 | } | |
392 | 152 | |
393 | if let MessageKind::Authentication(authentication) = &message { | |
394 | match authentication { | |
395 | AuthenticationMessage::Request(request) => match request { | |
396 | AuthenticationRequest::AuthenticationToken(token) => { | |
397 | let mut granter = auth_granter.lock().await; | |
398 | ||
399 | let response = granter.token_request(token.clone()).await.unwrap(); | |
400 | drop(granter); | |
401 | ||
402 | let _result = send( | |
403 | &mut socket, | |
404 | MessageKind::Authentication(AuthenticationMessage::Response( | |
405 | AuthenticationResponse::AuthenticationToken(response), | |
406 | )), | |
407 | ) | |
408 | .await; | |
409 | ||
410 | continue; | |
153 | return Ok(()); | |
154 | } else { | |
155 | // This message is targeting this instance | |
156 | match &repository.command { | |
157 | RepositoryMessageKind::Request(request) => match request.clone() { | |
158 | RepositoryRequest::CreateRepository(request) => { | |
159 | unimplemented!(); | |
411 | 160 | } |
412 | AuthenticationRequest::TokenExtension(request) => { | |
413 | let mut granter = auth_granter.lock().await; | |
414 | ||
415 | let response = granter | |
416 | .extension_request(request.clone()) | |
417 | .await | |
418 | .unwrap_or(TokenExtensionResponse { new_token: None }); | |
419 | drop(granter); | |
420 | ||
421 | let _result = send( | |
422 | &mut socket, | |
423 | MessageKind::Authentication(AuthenticationMessage::Response( | |
424 | AuthenticationResponse::TokenExtension(response), | |
425 | )), | |
426 | ) | |
427 | .await; | |
428 | ||
429 | continue; | |
161 | RepositoryRequest::RepositoryFileInspect(request) => { | |
162 | unimplemented!() | |
430 | 163 | } |
431 | AuthenticationRequest::RegisterAccount(request) => { | |
432 | let request = request.inner().await.clone(); | |
433 | ||
434 | let mut user_backend = user_backend.lock().await; | |
435 | ||
436 | let response = user_backend.register(request.clone()).await.unwrap(); | |
437 | drop(user_backend); | |
438 | ||
439 | let _result = send( | |
440 | &mut socket, | |
441 | MessageKind::Authentication(AuthenticationMessage::Response( | |
442 | AuthenticationResponse::RegisterAccount(response), | |
443 | )), | |
444 | ) | |
445 | .await; | |
446 | ||
447 | continue; | |
164 | RepositoryRequest::RepositoryInfo(request) => { | |
165 | unimplemented!() | |
166 | } | |
167 | RepositoryRequest::IssuesCount(request) => { | |
168 | unimplemented!() | |
169 | } | |
170 | RepositoryRequest::IssueLabels(request) => { | |
171 | unimplemented!() | |
172 | } | |
173 | RepositoryRequest::Issues(request) => { | |
174 | unimplemented!(); | |
448 | 175 | } |
449 | 176 | }, |
450 | AuthenticationMessage::Response(_) => unreachable!(), | |
177 | RepositoryMessageKind::Response(_response) => { | |
178 | unreachable!() | |
179 | } | |
451 | 180 | } |
452 | 181 | } |
182 | } | |
453 | 183 | |
454 | if let MessageKind::Discovery(message) = &message { | |
455 | let mut backend = discovery_backend.lock().await; | |
456 | backend.try_handle(message).await.unwrap(); | |
457 | ||
458 | continue; | |
184 | if let MessageKind::Authentication(authentication) = &message { | |
185 | match authentication { | |
186 | AuthenticationMessage::Request(request) => match request { | |
187 | AuthenticationRequest::AuthenticationToken(token) => { | |
188 | unimplemented!() | |
189 | } | |
190 | AuthenticationRequest::TokenExtension(request) => { | |
191 | unimplemented!() | |
192 | } | |
193 | AuthenticationRequest::RegisterAccount(request) => { | |
194 | unimplemented!() | |
195 | } | |
196 | }, | |
197 | AuthenticationMessage::Response(_) => unreachable!(), | |
459 | 198 | } |
199 | } | |
460 | 200 | |
461 | if let MessageKind::User(message) = &message { | |
462 | match &message.message { | |
463 | UserMessageKind::Request(request) => match request { | |
464 | UserMessageRequest::DisplayName(request) => { | |
465 | let mut user_backend = user_backend.lock().await; | |
466 | ||
467 | let response = user_backend.display_name(request.clone()).await; | |
468 | ||
469 | let response = match response { | |
470 | Ok(response) => response, | |
471 | Err(err) => { | |
472 | error!("Error handling request: {:?}", err); | |
473 | continue; | |
474 | } | |
475 | }; | |
476 | drop(user_backend); | |
477 | ||
478 | let _result = send( | |
479 | &mut socket, | |
480 | MessageKind::User(UserMessage { | |
481 | instance: message.instance.clone(), | |
482 | message: UserMessageKind::Response( | |
483 | UserMessageResponse::DisplayName(response), | |
484 | ), | |
485 | }), | |
486 | ) | |
487 | .await; | |
488 | ||
489 | continue; | |
490 | } | |
491 | UserMessageRequest::DisplayImage(request) => { | |
492 | let mut user_backend = user_backend.lock().await; | |
493 | ||
494 | let response = user_backend.display_image(request.clone()).await; | |
495 | ||
496 | let response = match response { | |
497 | Ok(response) => response, | |
498 | Err(err) => { | |
499 | error!("Error handling request: {:?}", err); | |
500 | continue; | |
501 | } | |
502 | }; | |
503 | drop(user_backend); | |
504 | ||
505 | let _result = send( | |
506 | &mut socket, | |
507 | MessageKind::User(UserMessage { | |
508 | instance: message.instance.clone(), | |
509 | message: UserMessageKind::Response( | |
510 | UserMessageResponse::DisplayImage(response), | |
511 | ), | |
512 | }), | |
513 | ) | |
514 | .await; | |
515 | ||
516 | continue; | |
517 | } | |
518 | UserMessageRequest::Bio(request) => { | |
519 | let mut user_backend = user_backend.lock().await; | |
520 | ||
521 | let response = user_backend.bio(request.clone()).await; | |
522 | ||
523 | let response = match response { | |
524 | Ok(response) => response, | |
525 | Err(err) => { | |
526 | error!("Error handling request: {:?}", err); | |
527 | continue; | |
528 | } | |
529 | }; | |
530 | drop(user_backend); | |
531 | ||
532 | let _result = send( | |
533 | &mut socket, | |
534 | MessageKind::User(UserMessage { | |
535 | instance: message.instance.clone(), | |
536 | message: UserMessageKind::Response(UserMessageResponse::Bio( | |
537 | response, | |
538 | )), | |
539 | }), | |
540 | ) | |
541 | .await; | |
542 | ||
543 | continue; | |
544 | } | |
545 | UserMessageRequest::Repositories(request) => { | |
546 | let mut repository_backend = backend.lock().await; | |
547 | ||
548 | let repositories = repository_backend | |
549 | .repositories_for_user(&request.user) | |
550 | .await; | |
551 | ||
552 | let repositories = match repositories { | |
553 | Ok(repositories) => repositories, | |
554 | Err(err) => { | |
555 | error!("Error handling request: {:?}", err); | |
556 | continue; | |
557 | } | |
558 | }; | |
559 | drop(repository_backend); | |
560 | ||
561 | let response = UserRepositoriesResponse { repositories }; | |
562 | ||
563 | let _result = send( | |
564 | &mut socket, | |
565 | MessageKind::User(UserMessage { | |
566 | instance: message.instance.clone(), | |
567 | message: UserMessageKind::Response( | |
568 | UserMessageResponse::Repositories(response), | |
569 | ), | |
570 | }), | |
571 | ) | |
572 | .await; | |
573 | ||
574 | continue; | |
575 | } | |
576 | }, | |
577 | UserMessageKind::Response(_) => unreachable!(), | |
578 | } | |
201 | if let MessageKind::Discovery(message) = &message { | |
202 | let mut backend = discovery_backend.lock().await; | |
203 | backend.try_handle(message).await?; | |
204 | ||
205 | return Ok(()); | |
206 | } | |
207 | ||
208 | if let MessageKind::User(message) = &message { | |
209 | match &message.message { | |
210 | UserMessageKind::Request(request) => match request { | |
211 | UserMessageRequest::DisplayName(request) => { | |
212 | unimplemented!() | |
213 | } | |
214 | UserMessageRequest::DisplayImage(request) => { | |
215 | unimplemented!() | |
216 | } | |
217 | UserMessageRequest::Bio(request) => { | |
218 | unimplemented!() | |
219 | } | |
220 | UserMessageRequest::Repositories(request) => { | |
221 | unimplemented!() | |
222 | } | |
223 | }, | |
224 | UserMessageKind::Response(_) => unreachable!(), | |
579 | 225 | } |
580 | 226 | } |
581 | 227 | |
582 | info!("Connection closed"); | |
228 | Ok(()) | |
583 | 229 | } |
584 | 230 | |
585 | 231 | async fn send_and_get_listener( |
@@ -596,6 +242,7 @@ async fn send_and_get_listener( | ||
596 | 242 | MessageKind::Authentication(_) => todo!(), |
597 | 243 | MessageKind::Discovery(_) => todo!(), |
598 | 244 | MessageKind::User(user) => todo!(), |
245 | MessageKind::Error(_) => todo!(), | |
599 | 246 | }; |
600 | 247 | |
601 | 248 | let target = match (&instance, &user, &repository) { |
@@ -631,8 +278,56 @@ async fn send<T: Serialize>( | ||
631 | 278 | message: T, |
632 | 279 | ) -> Result<(), Error> { |
633 | 280 | socket |
634 | .send(Message::Binary(serde_json::to_vec(&message).unwrap())) | |
281 | .send(Message::Binary(serde_json::to_vec(&message)?)) | |
635 | 282 | .await?; |
636 | 283 | |
637 | 284 | Ok(()) |
638 | 285 | } |
286 | ||
287 | #[derive(Debug, thiserror::Error)] | |
288 | #[error("handler did not handle")] | |
289 | pub struct HandlerUnhandled; | |
290 | ||
291 | pub trait MessageHandling<A, M, R> { | |
292 | fn message_type() -> &'static str; | |
293 | } | |
294 | ||
295 | impl<T1, F, M, R> MessageHandling<(T1,), M, R> for F | |
296 | where | |
297 | F: FnOnce(T1) -> R, | |
298 | T1: Serialize + DeserializeOwned, | |
299 | { | |
300 | fn message_type() -> &'static str { | |
301 | type_name::<T1>() | |
302 | } | |
303 | } | |
304 | ||
305 | impl<T1, T2, F, M, R> MessageHandling<(T1, T2), M, R> for F | |
306 | where | |
307 | F: FnOnce(T1, T2) -> R, | |
308 | T1: Serialize + DeserializeOwned, | |
309 | { | |
310 | fn message_type() -> &'static str { | |
311 | type_name::<T1>() | |
312 | } | |
313 | } | |
314 | ||
315 | impl<T1, T2, T3, F, M, R> MessageHandling<(T1, T2, T3), M, R> for F | |
316 | where | |
317 | F: FnOnce(T1, T2, T3) -> R, | |
318 | T1: Serialize + DeserializeOwned, | |
319 | { | |
320 | fn message_type() -> &'static str { | |
321 | type_name::<T1>() | |
322 | } | |
323 | } | |
324 | ||
325 | impl<T1, T2, T3, T4, F, M, R> MessageHandling<(T1, T2, T3, T4), M, R> for F | |
326 | where | |
327 | F: FnOnce(T1, T2, T3, T4) -> R, | |
328 | T1: Serialize + DeserializeOwned, | |
329 | { | |
330 | fn message_type() -> &'static str { | |
331 | type_name::<T1>() | |
332 | } | |
333 | } |
src/connection/authentication.rs
@@ -0,0 +1,160 @@ | ||
1 | use anyhow::Error; | |
2 | use thiserror::Error; | |
3 | ||
4 | use crate::messages::authentication::{AuthenticationMessage, AuthenticationResponse}; | |
5 | use crate::model::authenticated::MessageHandler; | |
6 | use crate::{ | |
7 | messages::{authentication::AuthenticationRequest, MessageKind}, | |
8 | model::authenticated::{AuthenticatedInstance, NetworkMessage, State}, | |
9 | }; | |
10 | ||
11 | use super::wrapper::ConnectionState; | |
12 | use super::HandlerUnhandled; | |
13 | ||
14 | pub async fn authentication_handle( | |
15 | message: &NetworkMessage, | |
16 | state: &ConnectionState, | |
17 | ) -> Result<(), Error> { | |
18 | let message_kind: MessageKind = serde_json::from_slice(&message.0).unwrap(); | |
19 | ||
20 | match message_kind { | |
21 | MessageKind::Authentication(AuthenticationMessage::Request(request)) => match request { | |
22 | AuthenticationRequest::RegisterAccount(_) => { | |
23 | register_account_request | |
24 | .handle_message(message, state) | |
25 | .await | |
26 | } | |
27 | AuthenticationRequest::AuthenticationToken(_) => { | |
28 | authentication_token_request | |
29 | .handle_message(message, state) | |
30 | .await | |
31 | } | |
32 | AuthenticationRequest::TokenExtension(_) => { | |
33 | token_extension_request.handle_message(message, state).await | |
34 | } | |
35 | }, | |
36 | _ => Err(Error::from(HandlerUnhandled)), | |
37 | } | |
38 | } | |
39 | ||
40 | async fn register_account_request( | |
41 | State(connection_state): State<ConnectionState>, | |
42 | request: MessageKind, | |
43 | instance: AuthenticatedInstance, | |
44 | ) -> Result<(), AuthenticationConnectionError> { | |
45 | let request = if let MessageKind::Authentication(AuthenticationMessage::Request( | |
46 | AuthenticationRequest::RegisterAccount(request), | |
47 | )) = request | |
48 | { | |
49 | request | |
50 | } else { | |
51 | return Err(AuthenticationConnectionError::InvalidRequest); | |
52 | }; | |
53 | ||
54 | if *instance.inner() != connection_state.instance { | |
55 | return Err(AuthenticationConnectionError::SameInstance); | |
56 | } | |
57 | ||
58 | let mut user_backend = connection_state.user_backend.lock().await; | |
59 | ||
60 | let response = user_backend | |
61 | .register(request.clone()) | |
62 | .await | |
63 | .map_err(|e| AuthenticationConnectionError::Registration(e))?; | |
64 | drop(user_backend); | |
65 | ||
66 | connection_state | |
67 | .send(MessageKind::Authentication( | |
68 | AuthenticationMessage::Response(AuthenticationResponse::RegisterAccount(response)), | |
69 | )) | |
70 | .await | |
71 | .map_err(|e| AuthenticationConnectionError::Sending(e))?; | |
72 | ||
73 | Ok(()) | |
74 | } | |
75 | ||
76 | async fn authentication_token_request( | |
77 | State(connection_state): State<ConnectionState>, | |
78 | request: MessageKind, | |
79 | instance: AuthenticatedInstance, | |
80 | ) -> Result<(), AuthenticationConnectionError> { | |
81 | let request = if let MessageKind::Authentication(AuthenticationMessage::Request( | |
82 | AuthenticationRequest::AuthenticationToken(request), | |
83 | )) = request | |
84 | { | |
85 | request | |
86 | } else { | |
87 | return Err(AuthenticationConnectionError::InvalidRequest); | |
88 | }; | |
89 | ||
90 | let issued_for = instance.inner().clone(); | |
91 | ||
92 | let mut token_granter = connection_state.auth_granter.lock().await; | |
93 | ||
94 | let response = token_granter | |
95 | .token_request(issued_for, request.username, request.password) | |
96 | .await | |
97 | .map_err(|e| AuthenticationConnectionError::TokenIssuance(e))?; | |
98 | ||
99 | connection_state | |
100 | .send(MessageKind::Authentication( | |
101 | AuthenticationMessage::Response(AuthenticationResponse::AuthenticationToken(response)), | |
102 | )) | |
103 | .await | |
104 | .map_err(|e| AuthenticationConnectionError::Sending(e))?; | |
105 | ||
106 | Ok(()) | |
107 | } | |
108 | ||
109 | async fn token_extension_request( | |
110 | State(connection_state): State<ConnectionState>, | |
111 | request: MessageKind, | |
112 | instance: AuthenticatedInstance, | |
113 | ) -> Result<(), AuthenticationConnectionError> { | |
114 | let request = if let MessageKind::Authentication(AuthenticationMessage::Request( | |
115 | AuthenticationRequest::TokenExtension(request), | |
116 | )) = request | |
117 | { | |
118 | request | |
119 | } else { | |
120 | return Err(AuthenticationConnectionError::InvalidRequest); | |
121 | }; | |
122 | ||
123 | let issued_for = instance.inner().clone(); | |
124 | ||
125 | let mut token_granter = connection_state.auth_granter.lock().await; | |
126 | ||
127 | let response = token_granter | |
128 | .extension_request(&issued_for, request.token) | |
129 | .await | |
130 | .map_err(|e| AuthenticationConnectionError::TokenIssuance(e))?; | |
131 | ||
132 | connection_state | |
133 | .send(MessageKind::Authentication( | |
134 | AuthenticationMessage::Response(AuthenticationResponse::TokenExtension(response)), | |
135 | )) | |
136 | .await | |
137 | .map_err(|e| AuthenticationConnectionError::Sending(e))?; | |
138 | ||
139 | Ok(()) | |
140 | } | |
141 | ||
142 | async fn verify(state: ConnectionState) { | |
143 | register_account_request | |
144 | .handle_message(&NetworkMessage(vec![]), &state) | |
145 | .await; | |
146 | } | |
147 | ||
148 | #[derive(Debug, Error)] | |
149 | pub enum AuthenticationConnectionError { | |
150 | #[error("the request was invalid")] | |
151 | InvalidRequest, | |
152 | #[error("request must be from the same instance")] | |
153 | SameInstance, | |
154 | #[error("issue during registration {0}")] | |
155 | Registration(Error), | |
156 | #[error("sending error")] | |
157 | Sending(Error), | |
158 | #[error("error issuing token")] | |
159 | TokenIssuance(Error), | |
160 | } |
src/connection/handshake.rs
@@ -0,0 +1,128 @@ | ||
1 | use std::{str::FromStr, sync::atomic::Ordering}; | |
2 | ||
3 | use anyhow::Error; | |
4 | use semver::Version; | |
5 | use thiserror::Error; | |
6 | ||
7 | use crate::model::authenticated::MessageHandler; | |
8 | use crate::{ | |
9 | connection::ConnectionError, | |
10 | handshake::{HandshakeFinalize, HandshakeResponse, InitiateHandshake}, | |
11 | model::authenticated::{AuthenticatedInstance, Message, NetworkMessage, State}, | |
12 | validate_version, | |
13 | }; | |
14 | ||
15 | use super::{wrapper::ConnectionState, HandlerUnhandled}; | |
16 | ||
17 | pub async fn handshake_handle( | |
18 | message: &NetworkMessage, | |
19 | state: &ConnectionState, | |
20 | ) -> Result<(), Error> { | |
21 | if initiate_handshake | |
22 | .handle_message(&message, state) | |
23 | .await | |
24 | .is_ok() | |
25 | { | |
26 | Ok(()) | |
27 | } else if handshake_response | |
28 | .handle_message(&message, state) | |
29 | .await | |
30 | .is_ok() | |
31 | { | |
32 | Ok(()) | |
33 | } else if handshake_finalize | |
34 | .handle_message(&message, state) | |
35 | .await | |
36 | .is_ok() | |
37 | { | |
38 | Ok(()) | |
39 | } else { | |
40 | Err(Error::from(HandlerUnhandled)) | |
41 | } | |
42 | } | |
43 | ||
44 | async fn initiate_handshake( | |
45 | Message(initiation): Message<InitiateHandshake>, | |
46 | State(connection_state): State<ConnectionState>, | |
47 | _instance: AuthenticatedInstance, | |
48 | ) -> Result<(), HandshakeError> { | |
49 | if !validate_version(&initiation.version) { | |
50 | error!( | |
51 | "Version compatibility failure! Our Version: {}, Their Version: {}", | |
52 | Version::from_str(&std::env::var("CARGO_PKG_VERSION").unwrap()).unwrap(), | |
53 | initiation.version | |
54 | ); | |
55 | ||
56 | connection_state | |
57 | .send(HandshakeFinalize { success: false }) | |
58 | .await | |
59 | .map_err(|e| HandshakeError::SendError(e))?; | |
60 | ||
61 | Ok(()) | |
62 | } else { | |
63 | connection_state | |
64 | .send(HandshakeFinalize { success: true }) | |
65 | .await | |
66 | .map_err(|e| HandshakeError::SendError(e))?; | |
67 | ||
68 | Ok(()) | |
69 | } | |
70 | } | |
71 | ||
72 | async fn handshake_response( | |
73 | Message(response): Message<HandshakeResponse>, | |
74 | State(connection_state): State<ConnectionState>, | |
75 | _instance: AuthenticatedInstance, | |
76 | ) -> Result<(), HandshakeError> { | |
77 | if !validate_version(&response.version) { | |
78 | error!( | |
79 | "Version compatibility failure! Our Version: {}, Their Version: {}", | |
80 | Version::from_str(&std::env::var("CARGO_PKG_VERSION").unwrap()).unwrap(), | |
81 | response.version | |
82 | ); | |
83 | ||
84 | connection_state | |
85 | .send(HandshakeFinalize { success: false }) | |
86 | .await | |
87 | .map_err(|e| HandshakeError::SendError(e))?; | |
88 | ||
89 | Ok(()) | |
90 | } else { | |
91 | connection_state | |
92 | .send(HandshakeFinalize { success: true }) | |
93 | .await | |
94 | .map_err(|e| HandshakeError::SendError(e))?; | |
95 | ||
96 | Ok(()) | |
97 | } | |
98 | } | |
99 | ||
100 | async fn handshake_finalize( | |
101 | Message(finalize): Message<HandshakeFinalize>, | |
102 | State(connection_state): State<ConnectionState>, | |
103 | _instance: AuthenticatedInstance, | |
104 | ) -> Result<(), HandshakeError> { | |
105 | if !finalize.success { | |
106 | error!("Error during handshake, aborting connection"); | |
107 | return Err(Error::from(ConnectionError::Shutdown).into()); | |
108 | } else { | |
109 | connection_state.handshaked.store(true, Ordering::SeqCst); | |
110 | ||
111 | connection_state | |
112 | .send(HandshakeFinalize { success: true }) | |
113 | .await | |
114 | .map_err(|e| HandshakeError::SendError(e))?; | |
115 | ||
116 | Ok(()) | |
117 | } | |
118 | } | |
119 | ||
120 | #[derive(Debug, thiserror::Error)] | |
121 | pub enum HandshakeError { | |
122 | #[error("version mismatch during handshake, ours: {0}, theirs: {1}")] | |
123 | VersionMismatch(Version, Version), | |
124 | #[error("while sending message: {0}")] | |
125 | SendError(Error), | |
126 | #[error("{0}")] | |
127 | Other(#[from] Error), | |
128 | } |
src/connection/repository.rs
@@ -0,0 +1,130 @@ | ||
1 | use anyhow::Error; | |
2 | ||
3 | use crate::{ | |
4 | messages::repository::{ | |
5 | CreateRepositoryRequest, RepositoryFileInspectRequest, RepositoryInfoRequest, | |
6 | RepositoryIssueLabelsRequest, RepositoryIssuesCountRequest, RepositoryIssuesRequest, | |
7 | RepositoryRequest, | |
8 | }, | |
9 | model::authenticated::{AuthenticatedUser, Message, MessageHandler, NetworkMessage, State}, | |
10 | }; | |
11 | ||
12 | use super::{wrapper::ConnectionState, HandlerUnhandled}; | |
13 | ||
14 | pub async fn repository_handle( | |
15 | message: &NetworkMessage, | |
16 | state: &ConnectionState, | |
17 | ) -> Result<(), Error> { | |
18 | if create_repository | |
19 | .handle_message(&message, state) | |
20 | .await | |
21 | .is_ok() | |
22 | { | |
23 | Ok(()) | |
24 | } else if repository_file_inspect | |
25 | .handle_message(&message, state) | |
26 | .await | |
27 | .is_ok() | |
28 | { | |
29 | Ok(()) | |
30 | } else if repository_info | |
31 | .handle_message(&message, state) | |
32 | .await | |
33 | .is_ok() | |
34 | { | |
35 | Ok(()) | |
36 | } else if issues_count.handle_message(&message, state).await.is_ok() { | |
37 | Ok(()) | |
38 | } else if issue_labels.handle_message(&message, state).await.is_ok() { | |
39 | Ok(()) | |
40 | } else if issues.handle_message(&message, state).await.is_ok() { | |
41 | Ok(()) | |
42 | } else { | |
43 | Err(Error::from(HandlerUnhandled)) | |
44 | } | |
45 | } | |
46 | ||
47 | async fn create_repository( | |
48 | Message(request): Message<CreateRepositoryRequest>, | |
49 | State(connection_state): State<ConnectionState>, | |
50 | AuthenticatedUser(user): AuthenticatedUser, | |
51 | ) -> Result<(), RepositoryError> { | |
52 | let mut repository_backend = connection_state.repository_backend.lock().await; | |
53 | let response = repository_backend | |
54 | .create_repository(&user, &request) | |
55 | .await?; | |
56 | ||
57 | drop(repository_backend); | |
58 | ||
59 | connection_state.send(response).await?; | |
60 | ||
61 | Ok(()) | |
62 | } | |
63 | ||
64 | async fn repository_file_inspect( | |
65 | Message(request): Message<RepositoryFileInspectRequest>, | |
66 | State(connection_state): State<ConnectionState>, | |
67 | user: Option<AuthenticatedUser>, | |
68 | ) -> Result<(), RepositoryError> { | |
69 | let user = user.map(|u| u.0); | |
70 | ||
71 | let mut repository_backend = connection_state.repository_backend.lock().await; | |
72 | let response = repository_backend | |
73 | .repository_file_inspect(user.as_ref(), &request) | |
74 | .await?; | |
75 | ||
76 | drop(repository_backend); | |
77 | ||
78 | connection_state.send(response).await?; | |
79 | ||
80 | Ok(()) | |
81 | } | |
82 | ||
83 | async fn repository_info( | |
84 | Message(request): Message<RepositoryInfoRequest>, | |
85 | State(connection_state): State<ConnectionState>, | |
86 | user: Option<AuthenticatedUser>, | |
87 | ) -> Result<(), RepositoryError> { | |
88 | let user = user.map(|u| u.0); | |
89 | ||
90 | let mut repository_backend = connection_state.repository_backend.lock().await; | |
91 | let response = repository_backend | |
92 | .repository_info(user.as_ref(), &request) | |
93 | .await?; | |
94 | ||
95 | drop(repository_backend); | |
96 | ||
97 | connection_state.send(response).await?; | |
98 | ||
99 | Ok(()) | |
100 | } | |
101 | ||
102 | async fn issues_count( | |
103 | Message(request): Message<RepositoryIssuesCountRequest>, | |
104 | State(connection_state): State<ConnectionState>, | |
105 | user: Option<AuthenticatedUser>, | |
106 | ) -> Result<(), RepositoryError> { | |
107 | unimplemented!(); | |
108 | } | |
109 | ||
110 | async fn issue_labels( | |
111 | Message(request): Message<RepositoryIssueLabelsRequest>, | |
112 | State(connection_state): State<ConnectionState>, | |
113 | user: Option<AuthenticatedUser>, | |
114 | ) -> Result<(), RepositoryError> { | |
115 | unimplemented!(); | |
116 | } | |
117 | ||
118 | async fn issues( | |
119 | Message(request): Message<RepositoryIssuesRequest>, | |
120 | State(connection_state): State<ConnectionState>, | |
121 | user: Option<AuthenticatedUser>, | |
122 | ) -> Result<(), RepositoryError> { | |
123 | unimplemented!(); | |
124 | } | |
125 | ||
126 | #[derive(Debug, thiserror::Error)] | |
127 | pub enum RepositoryError { | |
128 | #[error("{0}")] | |
129 | Other(#[from] Error), | |
130 | } |
src/connection/user.rs
@@ -0,0 +1,104 @@ | ||
1 | use anyhow::Error; | |
2 | ||
3 | use crate::{ | |
4 | messages::user::{ | |
5 | UserBioRequest, UserDisplayImageRequest, UserDisplayNameRequest, UserRepositoriesRequest, | |
6 | UserRepositoriesResponse, | |
7 | }, | |
8 | model::authenticated::{Message, MessageHandler, NetworkMessage, State}, | |
9 | }; | |
10 | ||
11 | use super::{wrapper::ConnectionState, HandlerUnhandled}; | |
12 | ||
13 | pub async fn user_handle(message: &NetworkMessage, state: &ConnectionState) -> Result<(), Error> { | |
14 | if display_name.handle_message(&message, state).await.is_ok() { | |
15 | Ok(()) | |
16 | } else if display_image.handle_message(&message, state).await.is_ok() { | |
17 | Ok(()) | |
18 | } else if bio.handle_message(&message, state).await.is_ok() { | |
19 | Ok(()) | |
20 | } else { | |
21 | Err(Error::from(HandlerUnhandled)) | |
22 | } | |
23 | } | |
24 | ||
25 | async fn display_name( | |
26 | Message(request): Message<UserDisplayNameRequest>, | |
27 | State(connection_state): State<ConnectionState>, | |
28 | ) -> Result<(), UserError> { | |
29 | let mut user_backend = connection_state.user_backend.lock().await; | |
30 | let response = user_backend.display_name(request.clone()).await?; | |
31 | ||
32 | drop(user_backend); | |
33 | ||
34 | connection_state.send(response).await?; | |
35 | ||
36 | Ok(()) | |
37 | } | |
38 | ||
39 | async fn display_image( | |
40 | Message(request): Message<UserDisplayImageRequest>, | |
41 | State(connection_state): State<ConnectionState>, | |
42 | ) -> Result<(), UserError> { | |
43 | let mut user_backend = connection_state.user_backend.lock().await; | |
44 | let response = user_backend.display_image(request.clone()).await?; | |
45 | ||
46 | drop(user_backend); | |
47 | ||
48 | connection_state.send(response).await?; | |
49 | ||
50 | Ok(()) | |
51 | } | |
52 | ||
53 | async fn bio( | |
54 | Message(request): Message<UserBioRequest>, | |
55 | State(connection_state): State<ConnectionState>, | |
56 | ) -> Result<(), UserError> { | |
57 | let mut user_backend = connection_state.user_backend.lock().await; | |
58 | let response = user_backend.bio(request.clone()).await?; | |
59 | ||
60 | drop(user_backend); | |
61 | ||
62 | connection_state.send(response).await?; | |
63 | ||
64 | Ok(()) | |
65 | } | |
66 | ||
67 | async fn repositories( | |
68 | Message(request): Message<UserRepositoriesRequest>, | |
69 | State(connection_state): State<ConnectionState>, | |
70 | ) -> Result<(), UserError> { | |
71 | let mut repository_backend = connection_state.repository_backend.lock().await; | |
72 | ||
73 | let repositories = repository_backend | |
74 | .repositories_for_user(&request.user) | |
75 | .await; | |
76 | ||
77 | let repositories = match repositories { | |
78 | Ok(repositories) => repositories, | |
79 | Err(err) => { | |
80 | error!("Error handling request: {:?}", err); | |
81 | return Ok(()); | |
82 | } | |
83 | }; | |
84 | drop(repository_backend); | |
85 | ||
86 | let mut user_backend = connection_state.user_backend.lock().await; | |
87 | let user_exists = user_backend.exists(&request.user).await; | |
88 | ||
89 | if repositories.is_empty() && !matches!(user_exists, Ok(true)) { | |
90 | panic!() | |
91 | } | |
92 | ||
93 | let response: UserRepositoriesResponse = UserRepositoriesResponse { repositories }; | |
94 | ||
95 | connection_state.send(response).await?; | |
96 | ||
97 | Ok(()) | |
98 | } | |
99 | ||
100 | #[derive(Debug, thiserror::Error)] | |
101 | pub enum UserError { | |
102 | #[error("{0}")] | |
103 | Other(#[from] Error), | |
104 | } |
src/connection/wrapper.rs
@@ -0,0 +1,81 @@ | ||
1 | use std::{ | |
2 | net::SocketAddr, | |
3 | sync::{atomic::AtomicBool, Arc}, | |
4 | }; | |
5 | ||
6 | use anyhow::Error; | |
7 | use futures_util::SinkExt; | |
8 | use serde::Serialize; | |
9 | use tokio::{net::TcpStream, sync::Mutex}; | |
10 | use tokio_tungstenite::{tungstenite::Message, WebSocketStream}; | |
11 | ||
12 | use crate::{ | |
13 | authentication::AuthenticationTokenGranter, | |
14 | backend::{DiscoveryBackend, RepositoryBackend, UserBackend}, | |
15 | connection::ConnectionError, | |
16 | listener::Listeners, | |
17 | model::instance::Instance, | |
18 | }; | |
19 | ||
20 | use super::{connection_worker, Connections}; | |
21 | ||
22 | pub async fn connection_wrapper( | |
23 | mut socket: WebSocketStream<TcpStream>, | |
24 | listeners: Arc<Mutex<Listeners>>, | |
25 | connections: Arc<Mutex<Connections>>, | |
26 | repository_backend: Arc<Mutex<dyn RepositoryBackend + Send>>, | |
27 | user_backend: Arc<Mutex<dyn UserBackend + Send>>, | |
28 | auth_granter: Arc<Mutex<AuthenticationTokenGranter>>, | |
29 | discovery_backend: Arc<Mutex<dyn DiscoveryBackend + Send>>, | |
30 | addr: SocketAddr, | |
31 | ) { | |
32 | let mut handshaked = false; | |
33 | loop { | |
34 | if let Err(e) = connection_worker( | |
35 | &mut socket, | |
36 | &mut handshaked, | |
37 | &listeners, | |
38 | &connections, | |
39 | &repository_backend, | |
40 | &user_backend, | |
41 | &auth_granter, | |
42 | &discovery_backend, | |
43 | &addr, | |
44 | ) | |
45 | .await | |
46 | { | |
47 | error!("Error handling message: {:?}", e); | |
48 | ||
49 | if let ConnectionError::Shutdown = &e { | |
50 | info!("Closing connection {}", addr); | |
51 | return; | |
52 | } | |
53 | } | |
54 | } | |
55 | } | |
56 | ||
57 | #[derive(Clone)] | |
58 | pub struct ConnectionState { | |
59 | socket: Arc<Mutex<WebSocketStream<TcpStream>>>, | |
60 | pub listeners: Arc<Mutex<Listeners>>, | |
61 | pub connections: Arc<Mutex<Connections>>, | |
62 | pub repository_backend: Arc<Mutex<dyn RepositoryBackend + Send>>, | |
63 | pub user_backend: Arc<Mutex<dyn UserBackend + Send>>, | |
64 | pub auth_granter: Arc<Mutex<AuthenticationTokenGranter>>, | |
65 | pub discovery_backend: Arc<Mutex<dyn DiscoveryBackend + Send>>, | |
66 | pub addr: SocketAddr, | |
67 | pub instance: Instance, | |
68 | pub handshaked: Arc<AtomicBool>, | |
69 | } | |
70 | ||
71 | impl ConnectionState { | |
72 | pub async fn send<T: Serialize>(&self, message: T) -> Result<(), Error> { | |
73 | self.socket | |
74 | .lock() | |
75 | .await | |
76 | .send(Message::Binary(serde_json::to_vec(&message)?)) | |
77 | .await?; | |
78 | ||
79 | Ok(()) | |
80 | } | |
81 | } |
src/handshake.rs
@@ -6,7 +6,6 @@ use crate::model::instance::Instance; | ||
6 | 6 | /// Sent by the initiator of a new inter-daemon connection. |
7 | 7 | #[derive(Clone, Serialize, Deserialize)] |
8 | 8 | pub struct InitiateHandshake { |
9 | pub identity: Instance, | |
10 | 9 | pub version: Version, |
11 | 10 | } |
12 | 11 |
src/lib.rs
@@ -10,6 +10,8 @@ pub mod listener; | ||
10 | 10 | pub mod messages; |
11 | 11 | pub mod model; |
12 | 12 | |
13 | pub(crate) use std::error::Error as StdError; | |
14 | ||
13 | 15 | #[macro_use] |
14 | 16 | extern crate tracing; |
15 | 17 |
src/main.rs
@@ -6,7 +6,8 @@ use giterated_daemon::{ | ||
6 | 6 | discovery::GiteratedDiscoveryProtocol, git::GitBackend, user::UserAuth, DiscoveryBackend, |
7 | 7 | RepositoryBackend, UserBackend, |
8 | 8 | }, |
9 | connection, listener, | |
9 | connection::{self, wrapper::connection_wrapper}, | |
10 | listener, | |
10 | 11 | model::instance::Instance, |
11 | 12 | }; |
12 | 13 | use listener::Listeners; |
@@ -108,7 +109,7 @@ async fn main() -> Result<(), Error> { | ||
108 | 109 | info!("Websocket connection established with {}", address); |
109 | 110 | |
110 | 111 | let connection = RawConnection { |
111 | task: tokio::spawn(connection_worker( | |
112 | task: tokio::spawn(connection_wrapper( | |
112 | 113 | connection, |
113 | 114 | listeners.clone(), |
114 | 115 | connections.clone(), |
src/messages/authentication.rs
@@ -1,5 +1,7 @@ | ||
1 | 1 | use serde::{Deserialize, Serialize}; |
2 | 2 | |
3 | use crate::model::authenticated::UserAuthenticationToken; | |
4 | ||
3 | 5 | use super::InstanceAuthenticated; |
4 | 6 | |
5 | 7 | /// An authentication message. |
@@ -19,7 +21,7 @@ pub enum AuthenticationRequest { | ||
19 | 21 | /// # Authentication |
20 | 22 | /// - Instance Authentication |
21 | 23 | /// - **ONLY ACCEPTED WHEN SAME-INSTANCE** |
22 | RegisterAccount(InstanceAuthenticated<RegisterAccountRequest>), | |
24 | RegisterAccount(RegisterAccountRequest), | |
23 | 25 | |
24 | 26 | /// An authentication token request. |
25 | 27 | /// |
@@ -27,24 +29,24 @@ pub enum AuthenticationRequest { | ||
27 | 29 | /// |
28 | 30 | /// # Authentication |
29 | 31 | /// - Instance Authentication |
30 | /// - **ONLY ACCEPTED WHEN SAME-INSTANCE** | |
31 | 32 | /// - Identifies the Instance to issue the token for |
32 | 33 | /// # Authorization |
33 | 34 | /// - Credentials ([`crate::backend::AuthBackend`]-based) |
34 | 35 | /// - Identifies the User account to issue a token for |
35 | 36 | /// - Decrypts user private key to issue to |
36 | AuthenticationToken(InstanceAuthenticated<AuthenticationTokenRequest>), | |
37 | AuthenticationToken(AuthenticationTokenRequest), | |
37 | 38 | |
38 | 39 | /// An authentication token extension request. |
39 | 40 | /// |
40 | 41 | /// # Authentication |
41 | 42 | /// - Instance Authentication |
42 | /// - **ONLY ACCEPTED WHEN SAME-INSTANCE** | |
43 | 43 | /// - Identifies the Instance to issue the token for |
44 | /// - User Authentication | |
45 | /// - Authenticates the validity of the token | |
44 | 46 | /// # Authorization |
45 | 47 | /// - Token-based |
46 | 48 | /// - Validates authorization using token's authenticity |
47 | TokenExtension(InstanceAuthenticated<TokenExtensionRequest>), | |
49 | TokenExtension(TokenExtensionRequest), | |
48 | 50 | } |
49 | 51 | |
50 | 52 | #[derive(Clone, Serialize, Deserialize)] |
@@ -70,7 +72,6 @@ pub struct RegisterAccountResponse { | ||
70 | 72 | /// See [`AuthenticationRequest::AuthenticationToken`]'s documentation. |
71 | 73 | #[derive(Clone, Serialize, Deserialize)] |
72 | 74 | pub struct AuthenticationTokenRequest { |
73 | pub secret_key: String, | |
74 | 75 | pub username: String, |
75 | 76 | pub password: String, |
76 | 77 | } |
@@ -83,8 +84,7 @@ pub struct AuthenticationTokenResponse { | ||
83 | 84 | /// See [`AuthenticationRequest::TokenExtension`]'s documentation. |
84 | 85 | #[derive(Clone, Serialize, Deserialize)] |
85 | 86 | pub struct TokenExtensionRequest { |
86 | pub secret_key: String, | |
87 | pub token: String, | |
87 | pub token: UserAuthenticationToken, | |
88 | 88 | } |
89 | 89 | |
90 | 90 | #[derive(Clone, Serialize, Deserialize)] |
src/messages/mod.rs
@@ -34,6 +34,15 @@ pub enum MessageKind { | ||
34 | 34 | Authentication(AuthenticationMessage), |
35 | 35 | Discovery(DiscoveryMessage), |
36 | 36 | User(UserMessage), |
37 | Error(ErrorMessage), | |
38 | } | |
39 | ||
40 | #[derive(Clone, Debug, Serialize, Deserialize, thiserror::Error)] | |
41 | pub enum ErrorMessage { | |
42 | #[error("user {0} doesn't exist or isn't valid in this context")] | |
43 | InvalidUser(User), | |
44 | #[error("internal error: shutdown")] | |
45 | Shutdown, | |
37 | 46 | } |
38 | 47 | |
39 | 48 | /// An authenticated message, where the instance is authenticating |
src/model/authenticated.rs
@@ -0,0 +1,335 @@ | ||
1 | use std::{any::type_name, ops::Deref, pin::Pin, str::FromStr}; | |
2 | ||
3 | use anyhow::Error; | |
4 | use futures_util::{future::BoxFuture, Future, FutureExt}; | |
5 | use jsonwebtoken::{decode, Algorithm, DecodingKey, TokenData, Validation}; | |
6 | use rsa::{pkcs1::DecodeRsaPublicKey, RsaPublicKey}; | |
7 | use serde::{de::DeserializeOwned, Deserialize, Serialize}; | |
8 | ||
9 | use crate::{ | |
10 | authentication::UserTokenMetadata, connection::wrapper::ConnectionState, messages::MessageKind, | |
11 | }; | |
12 | ||
13 | use super::{instance::Instance, user::User}; | |
14 | ||
15 | #[derive(Debug, Hash, PartialEq, Eq, Serialize, Deserialize)] | |
16 | pub struct Authenticated<T: Serialize> { | |
17 | #[serde(flatten)] | |
18 | source: Vec<AuthenticationSource>, | |
19 | message_type: String, | |
20 | #[serde(flatten)] | |
21 | message: T, | |
22 | } | |
23 | ||
24 | pub trait AuthenticationSourceProvider: Sized { | |
25 | fn authenticate(self, payload: &Vec<u8>) -> AuthenticationSource; | |
26 | } | |
27 | ||
28 | pub trait AuthenticationSourceProviders: Sized { | |
29 | fn authenticate_all(self, payload: &Vec<u8>) -> Vec<AuthenticationSource>; | |
30 | } | |
31 | ||
32 | impl<A> AuthenticationSourceProviders for A | |
33 | where | |
34 | A: AuthenticationSourceProvider, | |
35 | { | |
36 | fn authenticate_all(self, payload: &Vec<u8>) -> Vec<AuthenticationSource> { | |
37 | vec![self.authenticate(payload)] | |
38 | } | |
39 | } | |
40 | ||
41 | impl<A, B> AuthenticationSourceProviders for (A, B) | |
42 | where | |
43 | A: AuthenticationSourceProvider, | |
44 | B: AuthenticationSourceProvider, | |
45 | { | |
46 | fn authenticate_all(self, payload: &Vec<u8>) -> Vec<AuthenticationSource> { | |
47 | let (first, second) = self; | |
48 | ||
49 | vec![first.authenticate(payload), second.authenticate(payload)] | |
50 | } | |
51 | } | |
52 | ||
53 | impl<T: Serialize> Authenticated<T> { | |
54 | pub fn new(message: T, auth_sources: impl AuthenticationSourceProvider) -> Self { | |
55 | let message_payload = serde_json::to_vec(&message).unwrap(); | |
56 | ||
57 | let authentication = auth_sources.authenticate_all(&message_payload); | |
58 | ||
59 | Self { | |
60 | source: authentication, | |
61 | message_type: type_name::<T>().to_string(), | |
62 | message, | |
63 | } | |
64 | } | |
65 | } | |
66 | ||
67 | mod verified {} | |
68 | ||
69 | pub struct UserAuthenticator { | |
70 | pub user: User, | |
71 | pub token: UserAuthenticationToken, | |
72 | } | |
73 | ||
74 | impl AuthenticationSourceProvider for UserAuthenticator { | |
75 | fn authenticate(self, payload: &Vec<u8>) -> AuthenticationSource { | |
76 | AuthenticationSource::User { | |
77 | user: self.user, | |
78 | token: self.token, | |
79 | } | |
80 | } | |
81 | } | |
82 | ||
83 | pub struct InstanceAuthenticator<'a> { | |
84 | pub instance: Instance, | |
85 | pub private_key: &'a str, | |
86 | } | |
87 | ||
88 | impl AuthenticationSourceProvider for InstanceAuthenticator<'_> { | |
89 | fn authenticate(self, payload: &Vec<u8>) -> AuthenticationSource { | |
90 | todo!() | |
91 | } | |
92 | } | |
93 | ||
94 | #[repr(transparent)] | |
95 | #[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)] | |
96 | pub struct UserAuthenticationToken(String); | |
97 | ||
98 | impl From<String> for UserAuthenticationToken { | |
99 | fn from(value: String) -> Self { | |
100 | Self(value) | |
101 | } | |
102 | } | |
103 | ||
104 | impl ToString for UserAuthenticationToken { | |
105 | fn to_string(&self) -> String { | |
106 | self.0.clone() | |
107 | } | |
108 | } | |
109 | ||
110 | impl AsRef<str> for UserAuthenticationToken { | |
111 | fn as_ref(&self) -> &str { | |
112 | &self.0 | |
113 | } | |
114 | } | |
115 | ||
116 | #[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)] | |
117 | pub struct InstanceSignature(Vec<u8>); | |
118 | ||
119 | #[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)] | |
120 | pub enum AuthenticationSource { | |
121 | User { | |
122 | user: User, | |
123 | token: UserAuthenticationToken, | |
124 | }, | |
125 | Instance { | |
126 | instance: Instance, | |
127 | signature: InstanceSignature, | |
128 | }, | |
129 | } | |
130 | ||
131 | pub struct NetworkMessage(pub Vec<u8>); | |
132 | ||
133 | impl Deref for NetworkMessage { | |
134 | type Target = [u8]; | |
135 | ||
136 | fn deref(&self) -> &Self::Target { | |
137 | &self.0 | |
138 | } | |
139 | } | |
140 | ||
141 | pub struct AuthenticatedUser(pub User); | |
142 | ||
143 | #[derive(Debug, thiserror::Error)] | |
144 | pub enum UserAuthenticationError { | |
145 | #[error("user authentication missing")] | |
146 | Missing, | |
147 | // #[error("{0}")] | |
148 | // InstanceAuthentication(#[from] Error), | |
149 | #[error("user token was invalid")] | |
150 | InvalidToken, | |
151 | #[error("an error has occured")] | |
152 | Other(#[from] Error), | |
153 | } | |
154 | ||
155 | pub struct AuthenticatedInstance(Instance); | |
156 | ||
157 | impl AuthenticatedInstance { | |
158 | pub fn inner(&self) -> &Instance { | |
159 | &self.0 | |
160 | } | |
161 | } | |
162 | ||
163 | #[async_trait::async_trait] | |
164 | pub trait FromMessage<S: Send + Sync>: Sized + Send + Sync { | |
165 | async fn from_message(message: &NetworkMessage, state: &S) -> Result<Self, Error>; | |
166 | } | |
167 | ||
168 | #[async_trait::async_trait] | |
169 | impl FromMessage<ConnectionState> for AuthenticatedUser { | |
170 | async fn from_message( | |
171 | network_message: &NetworkMessage, | |
172 | state: &ConnectionState, | |
173 | ) -> Result<Self, Error> { | |
174 | let message: Authenticated<MessageKind> = | |
175 | serde_json::from_slice(&network_message).map_err(|e| Error::from(e))?; | |
176 | ||
177 | let (auth_user, auth_token) = message | |
178 | .source | |
179 | .iter() | |
180 | .filter_map(|auth| { | |
181 | if let AuthenticationSource::User { user, token } = auth { | |
182 | Some((user, token)) | |
183 | } else { | |
184 | None | |
185 | } | |
186 | }) | |
187 | .next() | |
188 | .ok_or_else(|| UserAuthenticationError::Missing)?; | |
189 | ||
190 | let authenticated_instance = | |
191 | AuthenticatedInstance::from_message(network_message, state).await?; | |
192 | ||
193 | let public_key_raw = public_key(&auth_user.instance).await?; | |
194 | let verification_key = DecodingKey::from_rsa_pem(public_key_raw.as_bytes()).unwrap(); | |
195 | ||
196 | let data: TokenData<UserTokenMetadata> = decode( | |
197 | auth_token.as_ref(), | |
198 | &verification_key, | |
199 | &Validation::new(Algorithm::RS256), | |
200 | ) | |
201 | .unwrap(); | |
202 | ||
203 | if data.claims.user != *auth_user | |
204 | || data.claims.generated_for != *authenticated_instance.inner() | |
205 | { | |
206 | Err(Error::from(UserAuthenticationError::InvalidToken)) | |
207 | } else { | |
208 | Ok(AuthenticatedUser(data.claims.user)) | |
209 | } | |
210 | } | |
211 | } | |
212 | ||
213 | #[async_trait::async_trait] | |
214 | impl FromMessage<ConnectionState> for AuthenticatedInstance { | |
215 | async fn from_message( | |
216 | message: &NetworkMessage, | |
217 | state: &ConnectionState, | |
218 | ) -> Result<Self, Error> { | |
219 | todo!() | |
220 | } | |
221 | } | |
222 | ||
223 | #[async_trait::async_trait] | |
224 | impl FromMessage<ConnectionState> for MessageKind { | |
225 | async fn from_message( | |
226 | message: &NetworkMessage, | |
227 | state: &ConnectionState, | |
228 | ) -> Result<Self, Error> { | |
229 | todo!() | |
230 | } | |
231 | } | |
232 | ||
233 | #[async_trait::async_trait] | |
234 | impl<S, T> FromMessage<S> for Option<T> | |
235 | where | |
236 | T: FromMessage<S>, | |
237 | S: Send + Sync + 'static, | |
238 | { | |
239 | async fn from_message(message: &NetworkMessage, state: &S) -> Result<Self, Error> { | |
240 | Ok(T::from_message(message, state).await.ok()) | |
241 | } | |
242 | } | |
243 | ||
244 | #[async_trait::async_trait] | |
245 | pub trait MessageHandler<T, S, R> { | |
246 | async fn handle_message(self, message: &NetworkMessage, state: &S) -> Result<R, Error>; | |
247 | } | |
248 | #[async_trait::async_trait] | |
249 | impl<F, R, S, T1, T, E> MessageHandler<(T1,), S, R> for T | |
250 | where | |
251 | T: FnOnce(T1) -> F + Clone + Send + 'static, | |
252 | F: Future<Output = Result<R, E>> + Send, | |
253 | T1: FromMessage<S> + Send, | |
254 | S: Send + Sync, | |
255 | E: std::error::Error + Send + Sync + 'static, | |
256 | { | |
257 | async fn handle_message(self, message: &NetworkMessage, state: &S) -> Result<R, Error> { | |
258 | let value = T1::from_message(message, state).await?; | |
259 | self(value).await.map_err(|e| Error::from(e)) | |
260 | } | |
261 | } | |
262 | ||
263 | #[async_trait::async_trait] | |
264 | impl<F, R, S, T1, T2, T, E> MessageHandler<(T1, T2), S, R> for T | |
265 | where | |
266 | T: FnOnce(T1, T2) -> F + Clone + Send + 'static, | |
267 | F: Future<Output = Result<R, E>> + Send, | |
268 | T1: FromMessage<S> + Send, | |
269 | T2: FromMessage<S> + Send, | |
270 | S: Send + Sync, | |
271 | E: std::error::Error + Send + Sync + 'static, | |
272 | { | |
273 | async fn handle_message(self, message: &NetworkMessage, state: &S) -> Result<R, Error> { | |
274 | let value = T1::from_message(message, state).await?; | |
275 | let value_2 = T2::from_message(message, state).await?; | |
276 | self(value, value_2).await.map_err(|e| Error::from(e)) | |
277 | } | |
278 | } | |
279 | ||
280 | #[async_trait::async_trait] | |
281 | impl<F, R, S, T1, T2, T3, T, E> MessageHandler<(T1, T2, T3), S, R> for T | |
282 | where | |
283 | T: FnOnce(T1, T2, T3) -> F + Clone + Send + 'static, | |
284 | F: Future<Output = Result<R, E>> + Send, | |
285 | T1: FromMessage<S> + Send, | |
286 | T2: FromMessage<S> + Send, | |
287 | T3: FromMessage<S> + Send, | |
288 | S: Send + Sync, | |
289 | E: std::error::Error + Send + Sync + 'static, | |
290 | { | |
291 | async fn handle_message(self, message: &NetworkMessage, state: &S) -> Result<R, Error> { | |
292 | let value = T1::from_message(message, state).await?; | |
293 | let value_2 = T2::from_message(message, state).await?; | |
294 | let value_3 = T3::from_message(message, state).await?; | |
295 | ||
296 | self(value, value_2, value_3) | |
297 | .await | |
298 | .map_err(|e| Error::from(e)) | |
299 | } | |
300 | } | |
301 | ||
302 | pub struct State<T>(pub T); | |
303 | ||
304 | #[async_trait::async_trait] | |
305 | impl<T> FromMessage<T> for State<T> | |
306 | where | |
307 | T: Clone + Send + Sync, | |
308 | { | |
309 | async fn from_message(_: &NetworkMessage, state: &T) -> Result<Self, Error> { | |
310 | Ok(Self(state.clone())) | |
311 | } | |
312 | } | |
313 | ||
314 | // Temp | |
315 | #[async_trait::async_trait] | |
316 | impl<T, S> FromMessage<S> for Message<T> | |
317 | where | |
318 | T: DeserializeOwned + Send + Sync + Serialize, | |
319 | S: Clone + Send + Sync, | |
320 | { | |
321 | async fn from_message(message: &NetworkMessage, _: &S) -> Result<Self, Error> { | |
322 | Ok(Message(serde_json::from_slice(&message)?)) | |
323 | } | |
324 | } | |
325 | ||
326 | pub struct Message<T: Serialize + DeserializeOwned>(pub T); | |
327 | ||
328 | async fn public_key(instance: &Instance) -> Result<String, Error> { | |
329 | let key = reqwest::get(format!("https://{}/.giterated/pubkey.pem", instance.url)) | |
330 | .await? | |
331 | .text() | |
332 | .await?; | |
333 | ||
334 | Ok(key) | |
335 | } |