Fix authenticated endpoints
parent: tbd commit: 1400b06
Showing 8 changed files with 124 insertions and 68 deletions
Cargo.lock
@@ -163,6 +163,15 @@ source = "registry+https://github.com/rust-lang/crates.io-index" | ||
163 | 163 | checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b" |
164 | 164 | |
165 | 165 | [[package]] |
166 | name = "bincode" | |
167 | version = "1.3.3" | |
168 | source = "registry+https://github.com/rust-lang/crates.io-index" | |
169 | checksum = "b1f45e9417d87227c7a56d22e471c6206462cba514c7590c09aff4cf6d1ddcad" | |
170 | dependencies = [ | |
171 | "serde", | |
172 | ] | |
173 | ||
174 | [[package]] | |
166 | 175 | name = "bitflags" |
167 | 176 | version = "1.3.2" |
168 | 177 | source = "registry+https://github.com/rust-lang/crates.io-index" |
@@ -678,6 +687,7 @@ dependencies = [ | ||
678 | 687 | "argon2", |
679 | 688 | "async-trait", |
680 | 689 | "base64 0.21.3", |
690 | "bincode", | |
681 | 691 | "chrono", |
682 | 692 | "deadpool", |
683 | 693 | "futures-util", |
@@ -711,6 +721,7 @@ dependencies = [ | ||
711 | 721 | "argon2", |
712 | 722 | "async-trait", |
713 | 723 | "base64 0.21.3", |
724 | "bincode", | |
714 | 725 | "chrono", |
715 | 726 | "futures-util", |
716 | 727 | "git2", |
giterated-daemon/Cargo.toml
@@ -26,6 +26,7 @@ tower = "*" | ||
26 | 26 | giterated-models = { path = "../giterated-models" } |
27 | 27 | giterated-api = { path = "../../giterated-api" } |
28 | 28 | deadpool = "*" |
29 | bincode = "*" | |
29 | 30 | |
30 | 31 | toml = { version = "0.7" } |
31 | 32 |
giterated-daemon/src/connection/forwarded.rs
@@ -1,12 +1,12 @@ | ||
1 | 1 | use futures_util::{SinkExt, StreamExt}; |
2 | 2 | use giterated_api::DaemonConnectionPool; |
3 | use giterated_models::{messages::error::ConnectionError, model::authenticated::Authenticated}; | |
3 | use giterated_models::{messages::error::ConnectionError, model::authenticated::{Authenticated, AuthenticatedPayload}}; | |
4 | 4 | use serde::Serialize; |
5 | 5 | use tokio_tungstenite::tungstenite::Message; |
6 | 6 | |
7 | pub async fn wrap_forwarded<T: Serialize>( | |
7 | pub async fn wrap_forwarded( | |
8 | 8 | pool: &DaemonConnectionPool, |
9 | message: Authenticated<T>, | |
9 | message: AuthenticatedPayload, | |
10 | 10 | ) -> Message { |
11 | 11 | let connection = pool.get().await; |
12 | 12 |
giterated-daemon/src/connection/handshake.rs
@@ -8,7 +8,7 @@ use semver::Version; | ||
8 | 8 | |
9 | 9 | use crate::{ |
10 | 10 | connection::ConnectionError, |
11 | message::{Message, MessageHandler, NetworkMessage, State}, | |
11 | message::{Message, MessageHandler, NetworkMessage, State, HandshakeMessage}, | |
12 | 12 | validate_version, version, |
13 | 13 | }; |
14 | 14 | |
@@ -42,9 +42,10 @@ pub async fn handshake_handle( | ||
42 | 42 | } |
43 | 43 | |
44 | 44 | async fn initiate_handshake( |
45 | Message(initiation): Message<InitiateHandshake>, | |
45 | HandshakeMessage(initiation): HandshakeMessage<InitiateHandshake>, | |
46 | 46 | State(connection_state): State<ConnectionState>, |
47 | 47 | ) -> Result<(), HandshakeError> { |
48 | info!("meow!"); | |
48 | 49 | connection_state |
49 | 50 | .send(HandshakeResponse { |
50 | 51 | identity: connection_state.instance.clone(), |
@@ -81,7 +82,7 @@ async fn initiate_handshake( | ||
81 | 82 | } |
82 | 83 | |
83 | 84 | async fn handshake_response( |
84 | Message(response): Message<HandshakeResponse>, | |
85 | HandshakeMessage(initiation): HandshakeMessage<HandshakeResponse>, | |
85 | 86 | State(connection_state): State<ConnectionState>, |
86 | 87 | ) -> Result<(), HandshakeError> { |
87 | 88 | connection_state |
@@ -114,7 +115,7 @@ async fn handshake_response( | ||
114 | 115 | } |
115 | 116 | |
116 | 117 | async fn handshake_finalize( |
117 | Message(finalize): Message<HandshakeFinalize>, | |
118 | HandshakeMessage(finalize): HandshakeMessage<HandshakeFinalize>, | |
118 | 119 | State(connection_state): State<ConnectionState>, |
119 | 120 | ) -> Result<(), HandshakeError> { |
120 | 121 | connection_state.handshaked.store(true, Ordering::SeqCst); |
giterated-daemon/src/connection/wrapper.rs
@@ -11,7 +11,7 @@ use anyhow::Error; | ||
11 | 11 | use futures_util::{SinkExt, StreamExt}; |
12 | 12 | use giterated_models::{ |
13 | 13 | messages::error::ConnectionError, |
14 | model::{authenticated::Authenticated, instance::Instance}, | |
14 | model::{authenticated::{Authenticated, AuthenticatedPayload}, instance::Instance}, | |
15 | 15 | }; |
16 | 16 | use rsa::RsaPublicKey; |
17 | 17 | use serde::Serialize; |
@@ -86,13 +86,14 @@ pub async fn connection_wrapper( | ||
86 | 86 | let message = NetworkMessage(payload.clone()); |
87 | 87 | |
88 | 88 | if !handshaked { |
89 | info!("im foo baring"); | |
89 | 90 | if handshake_handle(&message, &connection_state).await.is_ok() { |
90 | 91 | if connection_state.handshaked.load(Ordering::SeqCst) { |
91 | 92 | handshaked = true; |
92 | 93 | } |
93 | 94 | } |
94 | 95 | } else { |
95 | let raw = serde_json::from_slice::<Authenticated<Value>>(&payload).unwrap(); | |
96 | let raw = serde_json::from_slice::<AuthenticatedPayload>(&payload).unwrap(); | |
96 | 97 | |
97 | 98 | if let Some(target_instance) = &raw.target_instance { |
98 | 99 | // Forward request |
@@ -116,7 +117,9 @@ pub async fn connection_wrapper( | ||
116 | 117 | |
117 | 118 | match authentication_handle(message_type, &message, &connection_state).await { |
118 | 119 | Err(e) => { |
119 | let _ = connection_state.send(ConnectionError(e.to_string())).await; | |
120 | let _ = connection_state | |
121 | .send_raw(ConnectionError(e.to_string())) | |
122 | .await; | |
120 | 123 | } |
121 | 124 | Ok(true) => continue, |
122 | 125 | Ok(false) => {} |
@@ -124,7 +127,9 @@ pub async fn connection_wrapper( | ||
124 | 127 | |
125 | 128 | match repository_handle(message_type, &message, &connection_state).await { |
126 | 129 | Err(e) => { |
127 | let _ = connection_state.send(ConnectionError(e.to_string())).await; | |
130 | let _ = connection_state | |
131 | .send_raw(ConnectionError(e.to_string())) | |
132 | .await; | |
128 | 133 | } |
129 | 134 | Ok(true) => continue, |
130 | 135 | Ok(false) => {} |
@@ -132,7 +137,9 @@ pub async fn connection_wrapper( | ||
132 | 137 | |
133 | 138 | match user_handle(message_type, &message, &connection_state).await { |
134 | 139 | Err(e) => { |
135 | let _ = connection_state.send(ConnectionError(e.to_string())).await; | |
140 | let _ = connection_state | |
141 | .send_raw(ConnectionError(e.to_string())) | |
142 | .await; | |
136 | 143 | } |
137 | 144 | Ok(true) => continue, |
138 | 145 | Ok(false) => {} |
@@ -140,7 +147,9 @@ pub async fn connection_wrapper( | ||
140 | 147 | |
141 | 148 | match authentication_handle(message_type, &message, &connection_state).await { |
142 | 149 | Err(e) => { |
143 | let _ = connection_state.send(ConnectionError(e.to_string())).await; | |
150 | let _ = connection_state | |
151 | .send_raw(ConnectionError(e.to_string())) | |
152 | .await; | |
144 | 153 | } |
145 | 154 | Ok(true) => continue, |
146 | 155 | Ok(false) => {} |
@@ -189,4 +198,16 @@ impl ConnectionState { | ||
189 | 198 | |
190 | 199 | Ok(()) |
191 | 200 | } |
201 | ||
202 | pub async fn send_raw<T: Serialize>(&self, message: T) -> Result<(), Error> { | |
203 | let payload = serde_json::to_string(&message)?; | |
204 | info!("Sending payload: {}", &payload); | |
205 | self.socket | |
206 | .lock() | |
207 | .await | |
208 | .send(Message::Binary(payload.into_bytes())) | |
209 | .await?; | |
210 | ||
211 | Ok(()) | |
212 | } | |
192 | 213 | } |
giterated-daemon/src/message.rs
@@ -3,7 +3,7 @@ use std::{collections::HashMap, ops::Deref}; | ||
3 | 3 | use anyhow::Error; |
4 | 4 | use futures_util::Future; |
5 | 5 | use giterated_models::model::{ |
6 | authenticated::{Authenticated, AuthenticationSource, UserTokenMetadata}, | |
6 | authenticated::{Authenticated, AuthenticationSource, UserTokenMetadata, AuthenticatedPayload}, | |
7 | 7 | instance::Instance, |
8 | 8 | user::User, |
9 | 9 | }; |
@@ -63,7 +63,7 @@ impl FromMessage<ConnectionState> for AuthenticatedUser { | ||
63 | 63 | network_message: &NetworkMessage, |
64 | 64 | state: &ConnectionState, |
65 | 65 | ) -> Result<Self, Error> { |
66 | let message: Authenticated<HashMap<String, Value>> = | |
66 | let message: AuthenticatedPayload = | |
67 | 67 | serde_json::from_slice(&network_message).map_err(|e| Error::from(e))?; |
68 | 68 | |
69 | 69 | let (auth_user, auth_token) = message |
@@ -108,9 +108,9 @@ impl FromMessage<ConnectionState> for AuthenticatedInstance { | ||
108 | 108 | network_message: &NetworkMessage, |
109 | 109 | state: &ConnectionState, |
110 | 110 | ) -> Result<Self, Error> { |
111 | let message: Authenticated<HashMap<String, Value>> = | |
111 | let message: AuthenticatedPayload = | |
112 | 112 | serde_json::from_slice(&network_message).map_err(|e| Error::from(e))?; |
113 | ||
113 | ||
114 | 114 | let (instance, signature) = message |
115 | 115 | .source |
116 | 116 | .iter() |
@@ -137,7 +137,7 @@ impl FromMessage<ConnectionState> for AuthenticatedInstance { | ||
137 | 137 | } else { |
138 | 138 | drop(cached_keys); |
139 | 139 | let mut cached_keys = state.cached_keys.write().await; |
140 | let key = public_key(instance).await?; | |
140 | let key = public_key(&instance).await?; | |
141 | 141 | let public_key = RsaPublicKey::from_pkcs1_pem(&key).unwrap(); |
142 | 142 | cached_keys.insert(instance.clone(), public_key.clone()); |
143 | 143 | public_key |
@@ -146,15 +146,8 @@ impl FromMessage<ConnectionState> for AuthenticatedInstance { | ||
146 | 146 | |
147 | 147 | let verifying_key: VerifyingKey<Sha256> = VerifyingKey::new(public_key); |
148 | 148 | |
149 | let message_json = serde_json::to_vec(&message.message).unwrap(); | |
150 | ||
151 | info!( | |
152 | "Verification against: {}", | |
153 | std::str::from_utf8(&message_json).unwrap() | |
154 | ); | |
155 | ||
156 | 149 | verifying_key.verify( |
157 | &message_json, | |
150 | &message.payload, | |
158 | 151 | &Signature::try_from(signature.as_ref()).unwrap(), |
159 | 152 | )?; |
160 | 153 | |
@@ -251,7 +244,8 @@ where | ||
251 | 244 | S: Clone + Send + Sync, |
252 | 245 | { |
253 | 246 | async fn from_message(message: &NetworkMessage, _: &S) -> Result<Self, Error> { |
254 | Ok(Message(serde_json::from_slice(&message)?)) | |
247 | let payload: AuthenticatedPayload = serde_json::from_slice(&message)?; | |
248 | Ok(Message(bincode::deserialize(&payload.payload)?)) | |
255 | 249 | } |
256 | 250 | } |
257 | 251 | |
@@ -265,3 +259,20 @@ async fn public_key(instance: &Instance) -> Result<String, Error> { | ||
265 | 259 | |
266 | 260 | Ok(key) |
267 | 261 | } |
262 | ||
263 | /// Handshake-specific message type. | |
264 | /// | |
265 | /// Uses basic serde_json-based deserialization to maintain the highest | |
266 | /// level of compatibility across versions. | |
267 | pub struct HandshakeMessage<T: Serialize + DeserializeOwned>(pub T); | |
268 | ||
269 | #[async_trait::async_trait] | |
270 | impl<T, S> FromMessage<S> for HandshakeMessage<T> | |
271 | where | |
272 | T: DeserializeOwned + Send + Sync + Serialize, | |
273 | S: Clone + Send + Sync, | |
274 | { | |
275 | async fn from_message(message: &NetworkMessage, _: &S) -> Result<Self, Error> { | |
276 | Ok(HandshakeMessage(serde_json::from_slice(&message.0)?)) | |
277 | } | |
278 | } | |
278 | \ No newline at end of file |
giterated-models/Cargo.toml
@@ -23,6 +23,7 @@ argon2 = "*" | ||
23 | 23 | aes-gcm = "0.10.2" |
24 | 24 | semver = {version = "*", features = ["serde"]} |
25 | 25 | tower = "*" |
26 | bincode = "*" | |
26 | 27 | |
27 | 28 | toml = { version = "0.7" } |
28 | 29 |
giterated-models/src/model/authenticated.rs
@@ -1,4 +1,4 @@ | ||
1 | use std::any::type_name; | |
1 | use std::{any::type_name, fmt::Debug}; | |
2 | 2 | |
3 | 3 | use rsa::{ |
4 | 4 | pkcs1::DecodeRsaPrivateKey, |
@@ -19,28 +19,52 @@ pub struct UserTokenMetadata { | ||
19 | 19 | pub exp: u64, |
20 | 20 | } |
21 | 21 | |
22 | #[derive(Debug)] | |
23 | pub struct Authenticated<'a, T: Serialize> { | |
24 | pub target_instance: Option<Instance>, | |
25 | pub source: Vec<&'a dyn AuthenticationSourceProvider>, | |
26 | pub message_type: String, | |
27 | pub message: T, | |
28 | } | |
29 | ||
22 | 30 | #[derive(Debug, Hash, PartialEq, Eq, Serialize, Deserialize)] |
23 | pub struct Authenticated<T: Serialize> { | |
31 | pub struct AuthenticatedPayload { | |
24 | 32 | pub target_instance: Option<Instance>, |
25 | 33 | pub source: Vec<AuthenticationSource>, |
26 | 34 | pub message_type: String, |
27 | #[serde(flatten)] | |
28 | pub message: T, | |
35 | pub payload: Vec<u8>, | |
36 | } | |
37 | ||
38 | impl<'a, T: Serialize> From<Authenticated<'a, T>> for AuthenticatedPayload { | |
39 | fn from(mut value: Authenticated<'a, T>) -> Self { | |
40 | let payload = bincode::serialize(&value.message).unwrap(); | |
41 | ||
42 | AuthenticatedPayload { | |
43 | target_instance: value.target_instance, | |
44 | source: value | |
45 | .source | |
46 | .drain(..) | |
47 | .map(|provider| provider.authenticate(&payload)) | |
48 | .collect::<Vec<_>>(), | |
49 | message_type: value.message_type, | |
50 | payload, | |
51 | } | |
52 | } | |
29 | 53 | } |
30 | 54 | |
31 | pub trait AuthenticationSourceProvider: Sized { | |
32 | fn authenticate(self, payload: &Vec<u8>) -> AuthenticationSource; | |
55 | pub trait AuthenticationSourceProvider: Debug { | |
56 | fn authenticate(&self, payload: &Vec<u8>) -> AuthenticationSource; | |
33 | 57 | } |
34 | 58 | |
35 | pub trait AuthenticationSourceProviders: Sized { | |
36 | fn authenticate_all(self, payload: &Vec<u8>) -> Vec<AuthenticationSource>; | |
59 | pub trait AuthenticationSourceProviders: Debug { | |
60 | fn authenticate_all(&self, payload: &Vec<u8>) -> Vec<AuthenticationSource>; | |
37 | 61 | } |
38 | 62 | |
39 | 63 | impl<A> AuthenticationSourceProviders for A |
40 | 64 | where |
41 | 65 | A: AuthenticationSourceProvider, |
42 | 66 | { |
43 | fn authenticate_all(self, payload: &Vec<u8>) -> Vec<AuthenticationSource> { | |
67 | fn authenticate_all(&self, payload: &Vec<u8>) -> Vec<AuthenticationSource> { | |
44 | 68 | vec![self.authenticate(payload)] |
45 | 69 | } |
46 | 70 | } |
@@ -50,43 +74,26 @@ where | ||
50 | 74 | A: AuthenticationSourceProvider, |
51 | 75 | B: AuthenticationSourceProvider, |
52 | 76 | { |
53 | fn authenticate_all(self, payload: &Vec<u8>) -> Vec<AuthenticationSource> { | |
77 | fn authenticate_all(&self, payload: &Vec<u8>) -> Vec<AuthenticationSource> { | |
54 | 78 | let (first, second) = self; |
55 | 79 | |
56 | 80 | vec![first.authenticate(payload), second.authenticate(payload)] |
57 | 81 | } |
58 | 82 | } |
59 | 83 | |
60 | impl<T: Serialize> Authenticated<T> { | |
61 | pub fn new(message: T, auth_sources: impl AuthenticationSourceProvider) -> Self { | |
62 | let message_payload = serde_json::to_vec(&message).unwrap(); | |
63 | ||
64 | let authentication = auth_sources.authenticate_all(&message_payload); | |
65 | ||
84 | impl<'a, T: Serialize + Debug> Authenticated<'a, T> { | |
85 | pub fn new(message: T) -> Self { | |
66 | 86 | Self { |
67 | source: authentication, | |
87 | source: vec![], | |
68 | 88 | message_type: type_name::<T>().to_string(), |
69 | 89 | message, |
70 | 90 | target_instance: None, |
71 | 91 | } |
72 | 92 | } |
73 | 93 | |
74 | pub fn new_for( | |
75 | instance: impl ToOwned<Owned = Instance>, | |
76 | message: T, | |
77 | auth_sources: impl AuthenticationSourceProvider, | |
78 | ) -> Self { | |
79 | let message_payload = serde_json::to_vec(&message).unwrap(); | |
80 | ||
81 | info!( | |
82 | "Verifying payload: {}", | |
83 | std::str::from_utf8(&message_payload).unwrap() | |
84 | ); | |
85 | ||
86 | let authentication = auth_sources.authenticate_all(&message_payload); | |
87 | ||
94 | pub fn new_for(instance: impl ToOwned<Owned = Instance>, message: T) -> Self { | |
88 | 95 | Self { |
89 | source: authentication, | |
96 | source: vec![], | |
90 | 97 | message_type: type_name::<T>().to_string(), |
91 | 98 | message, |
92 | 99 | target_instance: Some(instance.to_owned()), |
@@ -102,7 +109,7 @@ impl<T: Serialize> Authenticated<T> { | ||
102 | 109 | } |
103 | 110 | } |
104 | 111 | |
105 | pub fn append_authentication(&mut self, authentication: impl AuthenticationSourceProvider) { | |
112 | pub fn append_authentication(&mut self, authentication: &'a dyn AuthenticationSourceProvider) { | |
106 | 113 | let message_payload = serde_json::to_vec(&self.message).unwrap(); |
107 | 114 | |
108 | 115 | info!( |
@@ -110,8 +117,11 @@ impl<T: Serialize> Authenticated<T> { | ||
110 | 117 | std::str::from_utf8(&message_payload).unwrap() |
111 | 118 | ); |
112 | 119 | |
113 | self.source | |
114 | .push(authentication.authenticate(&message_payload)); | |
120 | self.source.push(authentication); | |
121 | } | |
122 | ||
123 | pub fn into_payload(self) -> AuthenticatedPayload { | |
124 | self.into() | |
115 | 125 | } |
116 | 126 | } |
117 | 127 | |
@@ -124,22 +134,22 @@ pub struct UserAuthenticator { | ||
124 | 134 | } |
125 | 135 | |
126 | 136 | impl AuthenticationSourceProvider for UserAuthenticator { |
127 | fn authenticate(self, _payload: &Vec<u8>) -> AuthenticationSource { | |
137 | fn authenticate(&self, _payload: &Vec<u8>) -> AuthenticationSource { | |
128 | 138 | AuthenticationSource::User { |
129 | user: self.user, | |
130 | token: self.token, | |
139 | user: self.user.clone(), | |
140 | token: self.token.clone(), | |
131 | 141 | } |
132 | 142 | } |
133 | 143 | } |
134 | 144 | |
135 | #[derive(Clone)] | |
145 | #[derive(Debug, Clone)] | |
136 | 146 | pub struct InstanceAuthenticator<'a> { |
137 | 147 | pub instance: Instance, |
138 | 148 | pub private_key: &'a str, |
139 | 149 | } |
140 | 150 | |
141 | 151 | impl AuthenticationSourceProvider for InstanceAuthenticator<'_> { |
142 | fn authenticate(self, payload: &Vec<u8>) -> AuthenticationSource { | |
152 | fn authenticate(&self, payload: &Vec<u8>) -> AuthenticationSource { | |
143 | 153 | let mut rng = rand::thread_rng(); |
144 | 154 | |
145 | 155 | let private_key = RsaPrivateKey::from_pkcs1_pem(self.private_key).unwrap(); |
@@ -147,7 +157,7 @@ impl AuthenticationSourceProvider for InstanceAuthenticator<'_> { | ||
147 | 157 | let signature = signing_key.sign_with_rng(&mut rng, &payload); |
148 | 158 | |
149 | 159 | AuthenticationSource::Instance { |
150 | instance: self.instance, | |
160 | instance: self.instance.clone(), | |
151 | 161 | // TODO: Actually parse signature from private key |
152 | 162 | signature: InstanceSignature(signature.to_bytes().into_vec()), |
153 | 163 | } |