1 |
use std::{any::type_name, collections::HashMap, ops::Deref};
|
2 |
|
3 |
use anyhow::Error;
|
4 |
use futures_util::Future;
|
5 |
use jsonwebtoken::{decode, Algorithm, DecodingKey, TokenData, Validation};
|
6 |
use serde::{de::DeserializeOwned, Deserialize, Serialize};
|
7 |
use serde_json::Value;
|
8 |
|
9 |
use crate::{authentication::UserTokenMetadata, connection::wrapper::ConnectionState};
|
10 |
|
11 |
use super::{instance::Instance, user::User};
|
12 |
|
13 |
#[derive(Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
|
14 |
pub struct Authenticated<T: Serialize> {
|
15 |
|
16 |
|
17 |
source: Vec<AuthenticationSource>,
|
18 |
message_type: String,
|
19 |
#[serde(flatten)]
|
20 |
message: T,
|
21 |
}
|
22 |
|
23 |
pub trait AuthenticationSourceProvider: Sized {
|
24 |
fn authenticate(self, payload: &Vec<u8>) -> AuthenticationSource;
|
25 |
}
|
26 |
|
27 |
pub trait AuthenticationSourceProviders: Sized {
|
28 |
fn authenticate_all(self, payload: &Vec<u8>) -> Vec<AuthenticationSource>;
|
29 |
}
|
30 |
|
31 |
impl<A> AuthenticationSourceProviders for A
|
32 |
where
|
33 |
A: AuthenticationSourceProvider,
|
34 |
{
|
35 |
fn authenticate_all(self, payload: &Vec<u8>) -> Vec<AuthenticationSource> {
|
36 |
vec![self.authenticate(payload)]
|
37 |
}
|
38 |
}
|
39 |
|
40 |
impl<A, B> AuthenticationSourceProviders for (A, B)
|
41 |
where
|
42 |
A: AuthenticationSourceProvider,
|
43 |
B: AuthenticationSourceProvider,
|
44 |
{
|
45 |
fn authenticate_all(self, payload: &Vec<u8>) -> Vec<AuthenticationSource> {
|
46 |
let (first, second) = self;
|
47 |
|
48 |
vec![first.authenticate(payload), second.authenticate(payload)]
|
49 |
}
|
50 |
}
|
51 |
|
52 |
impl<T: Serialize> Authenticated<T> {
|
53 |
pub fn new(message: T, auth_sources: impl AuthenticationSourceProvider) -> Self {
|
54 |
let message_payload = serde_json::to_vec(&message).unwrap();
|
55 |
|
56 |
let authentication = auth_sources.authenticate_all(&message_payload);
|
57 |
|
58 |
Self {
|
59 |
source: authentication,
|
60 |
message_type: type_name::<T>().to_string(),
|
61 |
message,
|
62 |
}
|
63 |
}
|
64 |
|
65 |
pub fn new_empty(message: T) -> Self {
|
66 |
Self {
|
67 |
source: vec![],
|
68 |
message_type: type_name::<T>().to_string(),
|
69 |
message,
|
70 |
}
|
71 |
}
|
72 |
|
73 |
pub fn append_authentication(&mut self, authentication: impl AuthenticationSourceProvider) {
|
74 |
let message_payload = serde_json::to_vec(&self.message).unwrap();
|
75 |
|
76 |
self.source
|
77 |
.push(authentication.authenticate(&message_payload));
|
78 |
}
|
79 |
}
|
80 |
|
81 |
mod verified {}
|
82 |
|
83 |
#[derive(Clone, Debug)]
|
84 |
pub struct UserAuthenticator {
|
85 |
pub user: User,
|
86 |
pub token: UserAuthenticationToken,
|
87 |
}
|
88 |
|
89 |
impl AuthenticationSourceProvider for UserAuthenticator {
|
90 |
fn authenticate(self, _payload: &Vec<u8>) -> AuthenticationSource {
|
91 |
AuthenticationSource::User {
|
92 |
user: self.user,
|
93 |
token: self.token,
|
94 |
}
|
95 |
}
|
96 |
}
|
97 |
|
98 |
#[derive(Clone)]
|
99 |
pub struct InstanceAuthenticator<'a> {
|
100 |
pub instance: Instance,
|
101 |
pub private_key: &'a str,
|
102 |
}
|
103 |
|
104 |
impl AuthenticationSourceProvider for InstanceAuthenticator<'_> {
|
105 |
fn authenticate(self, _payload: &Vec<u8>) -> AuthenticationSource {
|
106 |
AuthenticationSource::Instance {
|
107 |
instance: self.instance,
|
108 |
|
109 |
signature: InstanceSignature(self.private_key.as_bytes().to_vec()),
|
110 |
}
|
111 |
}
|
112 |
}
|
113 |
|
114 |
#[repr(transparent)]
|
115 |
#[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
|
116 |
pub struct UserAuthenticationToken(String);
|
117 |
|
118 |
impl From<String> for UserAuthenticationToken {
|
119 |
fn from(value: String) -> Self {
|
120 |
Self(value)
|
121 |
}
|
122 |
}
|
123 |
|
124 |
impl ToString for UserAuthenticationToken {
|
125 |
fn to_string(&self) -> String {
|
126 |
self.0.clone()
|
127 |
}
|
128 |
}
|
129 |
|
130 |
impl AsRef<str> for UserAuthenticationToken {
|
131 |
fn as_ref(&self) -> &str {
|
132 |
&self.0
|
133 |
}
|
134 |
}
|
135 |
|
136 |
#[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
|
137 |
pub struct InstanceSignature(Vec<u8>);
|
138 |
|
139 |
#[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
|
140 |
pub enum AuthenticationSource {
|
141 |
User {
|
142 |
user: User,
|
143 |
token: UserAuthenticationToken,
|
144 |
},
|
145 |
Instance {
|
146 |
instance: Instance,
|
147 |
signature: InstanceSignature,
|
148 |
},
|
149 |
}
|
150 |
|
151 |
pub struct NetworkMessage(pub Vec<u8>);
|
152 |
|
153 |
impl Deref for NetworkMessage {
|
154 |
type Target = [u8];
|
155 |
|
156 |
fn deref(&self) -> &Self::Target {
|
157 |
&self.0
|
158 |
}
|
159 |
}
|
160 |
|
161 |
pub struct AuthenticatedUser(pub User);
|
162 |
|
163 |
#[derive(Debug, thiserror::Error)]
|
164 |
pub enum UserAuthenticationError {
|
165 |
#[error("user authentication missing")]
|
166 |
Missing,
|
167 |
|
168 |
|
169 |
#[error("user token was invalid")]
|
170 |
InvalidToken,
|
171 |
#[error("an error has occured")]
|
172 |
Other(#[from] Error),
|
173 |
}
|
174 |
|
175 |
pub struct AuthenticatedInstance(Instance);
|
176 |
|
177 |
impl AuthenticatedInstance {
|
178 |
pub fn inner(&self) -> &Instance {
|
179 |
&self.0
|
180 |
}
|
181 |
}
|
182 |
|
183 |
#[async_trait::async_trait]
|
184 |
pub trait FromMessage<S: Send + Sync>: Sized + Send + Sync {
|
185 |
async fn from_message(message: &NetworkMessage, state: &S) -> Result<Self, Error>;
|
186 |
}
|
187 |
|
188 |
#[async_trait::async_trait]
|
189 |
impl FromMessage<ConnectionState> for AuthenticatedUser {
|
190 |
async fn from_message(
|
191 |
network_message: &NetworkMessage,
|
192 |
state: &ConnectionState,
|
193 |
) -> Result<Self, Error> {
|
194 |
let message: Authenticated<HashMap<String, Value>> =
|
195 |
serde_json::from_slice(&network_message).map_err(|e| Error::from(e))?;
|
196 |
|
197 |
let (auth_user, auth_token) = message
|
198 |
.source
|
199 |
.iter()
|
200 |
.filter_map(|auth| {
|
201 |
if let AuthenticationSource::User { user, token } = auth {
|
202 |
Some((user, token))
|
203 |
} else {
|
204 |
None
|
205 |
}
|
206 |
})
|
207 |
.next()
|
208 |
.ok_or_else(|| UserAuthenticationError::Missing)?;
|
209 |
|
210 |
let authenticated_instance =
|
211 |
AuthenticatedInstance::from_message(network_message, state).await?;
|
212 |
|
213 |
let public_key_raw = public_key(&auth_user.instance).await?;
|
214 |
let verification_key = DecodingKey::from_rsa_pem(public_key_raw.as_bytes()).unwrap();
|
215 |
|
216 |
let data: TokenData<UserTokenMetadata> = decode(
|
217 |
auth_token.as_ref(),
|
218 |
&verification_key,
|
219 |
&Validation::new(Algorithm::RS256),
|
220 |
)
|
221 |
.unwrap();
|
222 |
|
223 |
if data.claims.user != *auth_user
|
224 |
|| data.claims.generated_for != *authenticated_instance.inner()
|
225 |
{
|
226 |
Err(Error::from(UserAuthenticationError::InvalidToken))
|
227 |
} else {
|
228 |
Ok(AuthenticatedUser(data.claims.user))
|
229 |
}
|
230 |
}
|
231 |
}
|
232 |
|
233 |
#[async_trait::async_trait]
|
234 |
impl FromMessage<ConnectionState> for AuthenticatedInstance {
|
235 |
async fn from_message(
|
236 |
network_message: &NetworkMessage,
|
237 |
_state: &ConnectionState,
|
238 |
) -> Result<Self, Error> {
|
239 |
let message: Authenticated<HashMap<String, Value>> =
|
240 |
serde_json::from_slice(&network_message).map_err(|e| Error::from(e))?;
|
241 |
|
242 |
let (instance, signature) = message
|
243 |
.source
|
244 |
.iter()
|
245 |
.filter_map(|auth| {
|
246 |
if let AuthenticationSource::Instance {
|
247 |
instance,
|
248 |
signature,
|
249 |
} = auth
|
250 |
{
|
251 |
Some((instance, signature))
|
252 |
} else {
|
253 |
None
|
254 |
}
|
255 |
})
|
256 |
.next()
|
257 |
|
258 |
.ok_or_else(|| UserAuthenticationError::Missing)?;
|
259 |
|
260 |
|
261 |
|
262 |
Ok(Self(instance.clone()))
|
263 |
}
|
264 |
}
|
265 |
|
266 |
#[async_trait::async_trait]
|
267 |
impl<S, T> FromMessage<S> for Option<T>
|
268 |
where
|
269 |
T: FromMessage<S>,
|
270 |
S: Send + Sync + 'static,
|
271 |
{
|
272 |
async fn from_message(message: &NetworkMessage, state: &S) -> Result<Self, Error> {
|
273 |
Ok(T::from_message(message, state).await.ok())
|
274 |
}
|
275 |
}
|
276 |
|
277 |
#[async_trait::async_trait]
|
278 |
pub trait MessageHandler<T, S, R> {
|
279 |
async fn handle_message(self, message: &NetworkMessage, state: &S) -> Result<R, Error>;
|
280 |
}
|
281 |
#[async_trait::async_trait]
|
282 |
impl<F, R, S, T1, T, E> MessageHandler<(T1,), S, R> for T
|
283 |
where
|
284 |
T: FnOnce(T1) -> F + Clone + Send + 'static,
|
285 |
F: Future<Output = Result<R, E>> + Send,
|
286 |
T1: FromMessage<S> + Send,
|
287 |
S: Send + Sync,
|
288 |
E: std::error::Error + Send + Sync + 'static,
|
289 |
{
|
290 |
async fn handle_message(self, message: &NetworkMessage, state: &S) -> Result<R, Error> {
|
291 |
let value = T1::from_message(message, state).await?;
|
292 |
self(value).await.map_err(|e| Error::from(e))
|
293 |
}
|
294 |
}
|
295 |
|
296 |
#[async_trait::async_trait]
|
297 |
impl<F, R, S, T1, T2, T, E> MessageHandler<(T1, T2), S, R> for T
|
298 |
where
|
299 |
T: FnOnce(T1, T2) -> F + Clone + Send + 'static,
|
300 |
F: Future<Output = Result<R, E>> + Send,
|
301 |
T1: FromMessage<S> + Send,
|
302 |
T2: FromMessage<S> + Send,
|
303 |
S: Send + Sync,
|
304 |
E: std::error::Error + Send + Sync + 'static,
|
305 |
{
|
306 |
async fn handle_message(self, message: &NetworkMessage, state: &S) -> Result<R, Error> {
|
307 |
let value = T1::from_message(message, state).await?;
|
308 |
let value_2 = T2::from_message(message, state).await?;
|
309 |
self(value, value_2).await.map_err(|e| Error::from(e))
|
310 |
}
|
311 |
}
|
312 |
|
313 |
#[async_trait::async_trait]
|
314 |
impl<F, R, S, T1, T2, T3, T, E> MessageHandler<(T1, T2, T3), S, R> for T
|
315 |
where
|
316 |
T: FnOnce(T1, T2, T3) -> F + Clone + Send + 'static,
|
317 |
F: Future<Output = Result<R, E>> + Send,
|
318 |
T1: FromMessage<S> + Send,
|
319 |
T2: FromMessage<S> + Send,
|
320 |
T3: FromMessage<S> + Send,
|
321 |
S: Send + Sync,
|
322 |
E: std::error::Error + Send + Sync + 'static,
|
323 |
{
|
324 |
async fn handle_message(self, message: &NetworkMessage, state: &S) -> Result<R, Error> {
|
325 |
let value = T1::from_message(message, state).await?;
|
326 |
let value_2 = T2::from_message(message, state).await?;
|
327 |
let value_3 = T3::from_message(message, state).await?;
|
328 |
|
329 |
self(value, value_2, value_3)
|
330 |
.await
|
331 |
.map_err(|e| Error::from(e))
|
332 |
}
|
333 |
}
|
334 |
|
335 |
pub struct State<T>(pub T);
|
336 |
|
337 |
#[async_trait::async_trait]
|
338 |
impl<T> FromMessage<T> for State<T>
|
339 |
where
|
340 |
T: Clone + Send + Sync,
|
341 |
{
|
342 |
async fn from_message(_: &NetworkMessage, state: &T) -> Result<Self, Error> {
|
343 |
Ok(Self(state.clone()))
|
344 |
}
|
345 |
}
|
346 |
|
347 |
|
348 |
#[async_trait::async_trait]
|
349 |
impl<T, S> FromMessage<S> for Message<T>
|
350 |
where
|
351 |
T: DeserializeOwned + Send + Sync + Serialize,
|
352 |
S: Clone + Send + Sync,
|
353 |
{
|
354 |
async fn from_message(message: &NetworkMessage, _: &S) -> Result<Self, Error> {
|
355 |
Ok(Message(serde_json::from_slice(&message)?))
|
356 |
}
|
357 |
}
|
358 |
|
359 |
pub struct Message<T: Serialize + DeserializeOwned>(pub T);
|
360 |
|
361 |
async fn public_key(instance: &Instance) -> Result<String, Error> {
|
362 |
let key = reqwest::get(format!("https://{}/.giterated/pubkey.pem", instance.url))
|
363 |
.await?
|
364 |
.text()
|
365 |
.await?;
|
366 |
|
367 |
Ok(key)
|
368 |
}
|
369 |
|