1 |
use std::{collections::HashMap, ops::Deref};
|
2 |
|
3 |
use anyhow::Error;
|
4 |
use futures_util::Future;
|
5 |
use giterated_models::model::{
|
6 |
authenticated::{Authenticated, AuthenticatedPayload, AuthenticationSource, UserTokenMetadata},
|
7 |
instance::Instance,
|
8 |
user::User,
|
9 |
};
|
10 |
use jsonwebtoken::{decode, Algorithm, DecodingKey, TokenData, Validation};
|
11 |
use rsa::{
|
12 |
pkcs1::{DecodeRsaPrivateKey, DecodeRsaPublicKey},
|
13 |
pss::{Signature, VerifyingKey},
|
14 |
sha2::Sha256,
|
15 |
signature::Verifier,
|
16 |
RsaPublicKey,
|
17 |
};
|
18 |
use serde::{de::DeserializeOwned, Serialize};
|
19 |
use serde_json::Value;
|
20 |
|
21 |
use crate::connection::wrapper::ConnectionState;
|
22 |
|
23 |
pub struct NetworkMessage(pub Vec<u8>);
|
24 |
|
25 |
impl Deref for NetworkMessage {
|
26 |
type Target = [u8];
|
27 |
|
28 |
fn deref(&self) -> &Self::Target {
|
29 |
&self.0
|
30 |
}
|
31 |
}
|
32 |
|
33 |
pub struct AuthenticatedUser(pub User);
|
34 |
|
35 |
#[derive(Debug, thiserror::Error)]
|
36 |
pub enum UserAuthenticationError {
|
37 |
#[error("user authentication missing")]
|
38 |
Missing,
|
39 |
|
40 |
|
41 |
#[error("user token was invalid")]
|
42 |
InvalidToken,
|
43 |
#[error("an error has occured")]
|
44 |
Other(#[from] Error),
|
45 |
}
|
46 |
|
47 |
pub struct AuthenticatedInstance(Instance);
|
48 |
|
49 |
impl AuthenticatedInstance {
|
50 |
pub fn inner(&self) -> &Instance {
|
51 |
&self.0
|
52 |
}
|
53 |
}
|
54 |
|
55 |
#[async_trait::async_trait]
|
56 |
pub trait FromMessage<S: Send + Sync>: Sized + Send + Sync {
|
57 |
async fn from_message(message: &NetworkMessage, state: &S) -> Result<Self, Error>;
|
58 |
}
|
59 |
|
60 |
#[async_trait::async_trait]
|
61 |
impl FromMessage<ConnectionState> for AuthenticatedUser {
|
62 |
async fn from_message(
|
63 |
network_message: &NetworkMessage,
|
64 |
state: &ConnectionState,
|
65 |
) -> Result<Self, Error> {
|
66 |
let message: AuthenticatedPayload =
|
67 |
serde_json::from_slice(&network_message).map_err(|e| Error::from(e))?;
|
68 |
|
69 |
let (auth_user, auth_token) = message
|
70 |
.source
|
71 |
.iter()
|
72 |
.filter_map(|auth| {
|
73 |
if let AuthenticationSource::User { user, token } = auth {
|
74 |
Some((user, token))
|
75 |
} else {
|
76 |
None
|
77 |
}
|
78 |
})
|
79 |
.next()
|
80 |
.ok_or_else(|| UserAuthenticationError::Missing)?;
|
81 |
|
82 |
let authenticated_instance =
|
83 |
AuthenticatedInstance::from_message(network_message, state).await?;
|
84 |
|
85 |
let public_key_raw = public_key(&auth_user.instance).await?;
|
86 |
let verification_key = DecodingKey::from_rsa_pem(public_key_raw.as_bytes()).unwrap();
|
87 |
|
88 |
let data: TokenData<UserTokenMetadata> = decode(
|
89 |
auth_token.as_ref(),
|
90 |
&verification_key,
|
91 |
&Validation::new(Algorithm::RS256),
|
92 |
)
|
93 |
.unwrap();
|
94 |
|
95 |
if data.claims.user != *auth_user
|
96 |
|| data.claims.generated_for != *authenticated_instance.inner()
|
97 |
{
|
98 |
Err(Error::from(UserAuthenticationError::InvalidToken))
|
99 |
} else {
|
100 |
Ok(AuthenticatedUser(data.claims.user))
|
101 |
}
|
102 |
}
|
103 |
}
|
104 |
|
105 |
#[async_trait::async_trait]
|
106 |
impl FromMessage<ConnectionState> for AuthenticatedInstance {
|
107 |
async fn from_message(
|
108 |
network_message: &NetworkMessage,
|
109 |
state: &ConnectionState,
|
110 |
) -> Result<Self, Error> {
|
111 |
let message: AuthenticatedPayload =
|
112 |
serde_json::from_slice(&network_message).map_err(|e| Error::from(e))?;
|
113 |
|
114 |
let (instance, signature) = message
|
115 |
.source
|
116 |
.iter()
|
117 |
.filter_map(|auth: &AuthenticationSource| {
|
118 |
if let AuthenticationSource::Instance {
|
119 |
instance,
|
120 |
signature,
|
121 |
} = auth
|
122 |
{
|
123 |
Some((instance, signature))
|
124 |
} else {
|
125 |
None
|
126 |
}
|
127 |
})
|
128 |
.next()
|
129 |
|
130 |
.ok_or_else(|| UserAuthenticationError::Missing)?;
|
131 |
|
132 |
let public_key = {
|
133 |
let cached_keys = state.cached_keys.read().await;
|
134 |
|
135 |
if let Some(key) = cached_keys.get(&instance) {
|
136 |
key.clone()
|
137 |
} else {
|
138 |
drop(cached_keys);
|
139 |
let mut cached_keys = state.cached_keys.write().await;
|
140 |
let key = public_key(&instance).await?;
|
141 |
let public_key = RsaPublicKey::from_pkcs1_pem(&key).unwrap();
|
142 |
cached_keys.insert(instance.clone(), public_key.clone());
|
143 |
public_key
|
144 |
}
|
145 |
};
|
146 |
|
147 |
let verifying_key: VerifyingKey<Sha256> = VerifyingKey::new(public_key);
|
148 |
|
149 |
verifying_key.verify(
|
150 |
&message.payload,
|
151 |
&Signature::try_from(signature.as_ref()).unwrap(),
|
152 |
)?;
|
153 |
|
154 |
Ok(AuthenticatedInstance(instance.clone()))
|
155 |
}
|
156 |
}
|
157 |
|
158 |
#[async_trait::async_trait]
|
159 |
impl<S, T> FromMessage<S> for Option<T>
|
160 |
where
|
161 |
T: FromMessage<S>,
|
162 |
S: Send + Sync + 'static,
|
163 |
{
|
164 |
async fn from_message(message: &NetworkMessage, state: &S) -> Result<Self, Error> {
|
165 |
Ok(T::from_message(message, state).await.ok())
|
166 |
}
|
167 |
}
|
168 |
|
169 |
#[async_trait::async_trait]
|
170 |
pub trait MessageHandler<T, S, R> {
|
171 |
async fn handle_message(self, message: &NetworkMessage, state: &S) -> Result<R, Error>;
|
172 |
}
|
173 |
#[async_trait::async_trait]
|
174 |
impl<F, R, S, T1, T, E> MessageHandler<(T1,), S, R> for T
|
175 |
where
|
176 |
T: FnOnce(T1) -> F + Clone + Send + 'static,
|
177 |
F: Future<Output = Result<R, E>> + Send,
|
178 |
T1: FromMessage<S> + Send,
|
179 |
S: Send + Sync,
|
180 |
E: std::error::Error + Send + Sync + 'static,
|
181 |
{
|
182 |
async fn handle_message(self, message: &NetworkMessage, state: &S) -> Result<R, Error> {
|
183 |
let value = T1::from_message(message, state).await?;
|
184 |
self(value).await.map_err(|e| Error::from(e))
|
185 |
}
|
186 |
}
|
187 |
|
188 |
#[async_trait::async_trait]
|
189 |
impl<F, R, S, T1, T2, T, E> MessageHandler<(T1, T2), S, R> for T
|
190 |
where
|
191 |
T: FnOnce(T1, T2) -> F + Clone + Send + 'static,
|
192 |
F: Future<Output = Result<R, E>> + Send,
|
193 |
T1: FromMessage<S> + Send,
|
194 |
T2: FromMessage<S> + Send,
|
195 |
S: Send + Sync,
|
196 |
E: std::error::Error + Send + Sync + 'static,
|
197 |
{
|
198 |
async fn handle_message(self, message: &NetworkMessage, state: &S) -> Result<R, Error> {
|
199 |
let value = T1::from_message(message, state).await?;
|
200 |
let value_2 = T2::from_message(message, state).await?;
|
201 |
self(value, value_2).await.map_err(|e| Error::from(e))
|
202 |
}
|
203 |
}
|
204 |
|
205 |
#[async_trait::async_trait]
|
206 |
impl<F, R, S, T1, T2, T3, T, E> MessageHandler<(T1, T2, T3), S, R> for T
|
207 |
where
|
208 |
T: FnOnce(T1, T2, T3) -> F + Clone + Send + 'static,
|
209 |
F: Future<Output = Result<R, E>> + Send,
|
210 |
T1: FromMessage<S> + Send,
|
211 |
T2: FromMessage<S> + Send,
|
212 |
T3: FromMessage<S> + Send,
|
213 |
S: Send + Sync,
|
214 |
E: std::error::Error + Send + Sync + 'static,
|
215 |
{
|
216 |
async fn handle_message(self, message: &NetworkMessage, state: &S) -> Result<R, Error> {
|
217 |
let value = T1::from_message(message, state).await?;
|
218 |
let value_2 = T2::from_message(message, state).await?;
|
219 |
let value_3 = T3::from_message(message, state).await?;
|
220 |
|
221 |
self(value, value_2, value_3)
|
222 |
.await
|
223 |
.map_err(|e| Error::from(e))
|
224 |
}
|
225 |
}
|
226 |
|
227 |
pub struct State<T>(pub T);
|
228 |
|
229 |
#[async_trait::async_trait]
|
230 |
impl<T> FromMessage<T> for State<T>
|
231 |
where
|
232 |
T: Clone + Send + Sync,
|
233 |
{
|
234 |
async fn from_message(_: &NetworkMessage, state: &T) -> Result<Self, Error> {
|
235 |
Ok(Self(state.clone()))
|
236 |
}
|
237 |
}
|
238 |
|
239 |
|
240 |
#[async_trait::async_trait]
|
241 |
impl<T, S> FromMessage<S> for Message<T>
|
242 |
where
|
243 |
T: DeserializeOwned + Send + Sync + Serialize,
|
244 |
S: Clone + Send + Sync,
|
245 |
{
|
246 |
async fn from_message(message: &NetworkMessage, _: &S) -> Result<Self, Error> {
|
247 |
let payload: AuthenticatedPayload = serde_json::from_slice(&message)?;
|
248 |
Ok(Message(bincode::deserialize(&payload.payload)?))
|
249 |
}
|
250 |
}
|
251 |
|
252 |
pub struct Message<T: Serialize + DeserializeOwned>(pub T);
|
253 |
|
254 |
async fn public_key(instance: &Instance) -> Result<String, Error> {
|
255 |
let key = reqwest::get(format!("https://{}/.giterated/pubkey.pem", instance.url))
|
256 |
.await?
|
257 |
.text()
|
258 |
.await?;
|
259 |
|
260 |
Ok(key)
|
261 |
}
|
262 |
|
263 |
|
264 |
|
265 |
|
266 |
|
267 |
pub struct HandshakeMessage<T: Serialize + DeserializeOwned>(pub T);
|
268 |
|
269 |
#[async_trait::async_trait]
|
270 |
impl<T, S> FromMessage<S> for HandshakeMessage<T>
|
271 |
where
|
272 |
T: DeserializeOwned + Send + Sync + Serialize,
|
273 |
S: Clone + Send + Sync,
|
274 |
{
|
275 |
async fn from_message(message: &NetworkMessage, _: &S) -> Result<Self, Error> {
|
276 |
Ok(HandshakeMessage(serde_json::from_slice(&message.0)?))
|
277 |
}
|
278 |
}
|
279 |
|