1 |
use std::{any::type_name, ops::Deref, pin::Pin, str::FromStr};
|
2 |
|
3 |
use anyhow::Error;
|
4 |
use futures_util::{future::BoxFuture, Future, FutureExt};
|
5 |
use jsonwebtoken::{decode, Algorithm, DecodingKey, TokenData, Validation};
|
6 |
use rsa::{pkcs1::DecodeRsaPublicKey, RsaPublicKey};
|
7 |
use serde::{de::DeserializeOwned, Deserialize, Serialize};
|
8 |
|
9 |
use crate::{
|
10 |
authentication::UserTokenMetadata, connection::wrapper::ConnectionState, messages::MessageKind,
|
11 |
};
|
12 |
|
13 |
use super::{instance::Instance, user::User};
|
14 |
|
15 |
#[derive(Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
|
16 |
pub struct Authenticated<T: Serialize> {
|
17 |
#[serde(flatten)]
|
18 |
source: Vec<AuthenticationSource>,
|
19 |
message_type: String,
|
20 |
#[serde(flatten)]
|
21 |
message: T,
|
22 |
}
|
23 |
|
24 |
pub trait AuthenticationSourceProvider: Sized {
|
25 |
fn authenticate(self, payload: &Vec<u8>) -> AuthenticationSource;
|
26 |
}
|
27 |
|
28 |
pub trait AuthenticationSourceProviders: Sized {
|
29 |
fn authenticate_all(self, payload: &Vec<u8>) -> Vec<AuthenticationSource>;
|
30 |
}
|
31 |
|
32 |
impl<A> AuthenticationSourceProviders for A
|
33 |
where
|
34 |
A: AuthenticationSourceProvider,
|
35 |
{
|
36 |
fn authenticate_all(self, payload: &Vec<u8>) -> Vec<AuthenticationSource> {
|
37 |
vec![self.authenticate(payload)]
|
38 |
}
|
39 |
}
|
40 |
|
41 |
impl<A, B> AuthenticationSourceProviders for (A, B)
|
42 |
where
|
43 |
A: AuthenticationSourceProvider,
|
44 |
B: AuthenticationSourceProvider,
|
45 |
{
|
46 |
fn authenticate_all(self, payload: &Vec<u8>) -> Vec<AuthenticationSource> {
|
47 |
let (first, second) = self;
|
48 |
|
49 |
vec![first.authenticate(payload), second.authenticate(payload)]
|
50 |
}
|
51 |
}
|
52 |
|
53 |
impl<T: Serialize> Authenticated<T> {
|
54 |
pub fn new(message: T, auth_sources: impl AuthenticationSourceProvider) -> Self {
|
55 |
let message_payload = serde_json::to_vec(&message).unwrap();
|
56 |
|
57 |
let authentication = auth_sources.authenticate_all(&message_payload);
|
58 |
|
59 |
Self {
|
60 |
source: authentication,
|
61 |
message_type: type_name::<T>().to_string(),
|
62 |
message,
|
63 |
}
|
64 |
}
|
65 |
}
|
66 |
|
67 |
mod verified {}
|
68 |
|
69 |
pub struct UserAuthenticator {
|
70 |
pub user: User,
|
71 |
pub token: UserAuthenticationToken,
|
72 |
}
|
73 |
|
74 |
impl AuthenticationSourceProvider for UserAuthenticator {
|
75 |
fn authenticate(self, payload: &Vec<u8>) -> AuthenticationSource {
|
76 |
AuthenticationSource::User {
|
77 |
user: self.user,
|
78 |
token: self.token,
|
79 |
}
|
80 |
}
|
81 |
}
|
82 |
|
83 |
pub struct InstanceAuthenticator<'a> {
|
84 |
pub instance: Instance,
|
85 |
pub private_key: &'a str,
|
86 |
}
|
87 |
|
88 |
impl AuthenticationSourceProvider for InstanceAuthenticator<'_> {
|
89 |
fn authenticate(self, payload: &Vec<u8>) -> AuthenticationSource {
|
90 |
todo!()
|
91 |
}
|
92 |
}
|
93 |
|
94 |
#[repr(transparent)]
|
95 |
#[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
|
96 |
pub struct UserAuthenticationToken(String);
|
97 |
|
98 |
impl From<String> for UserAuthenticationToken {
|
99 |
fn from(value: String) -> Self {
|
100 |
Self(value)
|
101 |
}
|
102 |
}
|
103 |
|
104 |
impl ToString for UserAuthenticationToken {
|
105 |
fn to_string(&self) -> String {
|
106 |
self.0.clone()
|
107 |
}
|
108 |
}
|
109 |
|
110 |
impl AsRef<str> for UserAuthenticationToken {
|
111 |
fn as_ref(&self) -> &str {
|
112 |
&self.0
|
113 |
}
|
114 |
}
|
115 |
|
116 |
#[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
|
117 |
pub struct InstanceSignature(Vec<u8>);
|
118 |
|
119 |
#[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
|
120 |
pub enum AuthenticationSource {
|
121 |
User {
|
122 |
user: User,
|
123 |
token: UserAuthenticationToken,
|
124 |
},
|
125 |
Instance {
|
126 |
instance: Instance,
|
127 |
signature: InstanceSignature,
|
128 |
},
|
129 |
}
|
130 |
|
131 |
pub struct NetworkMessage(pub Vec<u8>);
|
132 |
|
133 |
impl Deref for NetworkMessage {
|
134 |
type Target = [u8];
|
135 |
|
136 |
fn deref(&self) -> &Self::Target {
|
137 |
&self.0
|
138 |
}
|
139 |
}
|
140 |
|
141 |
pub struct AuthenticatedUser(pub User);
|
142 |
|
143 |
#[derive(Debug, thiserror::Error)]
|
144 |
pub enum UserAuthenticationError {
|
145 |
#[error("user authentication missing")]
|
146 |
Missing,
|
147 |
|
148 |
|
149 |
#[error("user token was invalid")]
|
150 |
InvalidToken,
|
151 |
#[error("an error has occured")]
|
152 |
Other(#[from] Error),
|
153 |
}
|
154 |
|
155 |
pub struct AuthenticatedInstance(Instance);
|
156 |
|
157 |
impl AuthenticatedInstance {
|
158 |
pub fn inner(&self) -> &Instance {
|
159 |
&self.0
|
160 |
}
|
161 |
}
|
162 |
|
163 |
#[async_trait::async_trait]
|
164 |
pub trait FromMessage<S: Send + Sync>: Sized + Send + Sync {
|
165 |
async fn from_message(message: &NetworkMessage, state: &S) -> Result<Self, Error>;
|
166 |
}
|
167 |
|
168 |
#[async_trait::async_trait]
|
169 |
impl FromMessage<ConnectionState> for AuthenticatedUser {
|
170 |
async fn from_message(
|
171 |
network_message: &NetworkMessage,
|
172 |
state: &ConnectionState,
|
173 |
) -> Result<Self, Error> {
|
174 |
let message: Authenticated<MessageKind> =
|
175 |
serde_json::from_slice(&network_message).map_err(|e| Error::from(e))?;
|
176 |
|
177 |
let (auth_user, auth_token) = message
|
178 |
.source
|
179 |
.iter()
|
180 |
.filter_map(|auth| {
|
181 |
if let AuthenticationSource::User { user, token } = auth {
|
182 |
Some((user, token))
|
183 |
} else {
|
184 |
None
|
185 |
}
|
186 |
})
|
187 |
.next()
|
188 |
.ok_or_else(|| UserAuthenticationError::Missing)?;
|
189 |
|
190 |
let authenticated_instance =
|
191 |
AuthenticatedInstance::from_message(network_message, state).await?;
|
192 |
|
193 |
let public_key_raw = public_key(&auth_user.instance).await?;
|
194 |
let verification_key = DecodingKey::from_rsa_pem(public_key_raw.as_bytes()).unwrap();
|
195 |
|
196 |
let data: TokenData<UserTokenMetadata> = decode(
|
197 |
auth_token.as_ref(),
|
198 |
&verification_key,
|
199 |
&Validation::new(Algorithm::RS256),
|
200 |
)
|
201 |
.unwrap();
|
202 |
|
203 |
if data.claims.user != *auth_user
|
204 |
|| data.claims.generated_for != *authenticated_instance.inner()
|
205 |
{
|
206 |
Err(Error::from(UserAuthenticationError::InvalidToken))
|
207 |
} else {
|
208 |
Ok(AuthenticatedUser(data.claims.user))
|
209 |
}
|
210 |
}
|
211 |
}
|
212 |
|
213 |
#[async_trait::async_trait]
|
214 |
impl FromMessage<ConnectionState> for AuthenticatedInstance {
|
215 |
async fn from_message(
|
216 |
message: &NetworkMessage,
|
217 |
state: &ConnectionState,
|
218 |
) -> Result<Self, Error> {
|
219 |
todo!()
|
220 |
}
|
221 |
}
|
222 |
|
223 |
#[async_trait::async_trait]
|
224 |
impl FromMessage<ConnectionState> for MessageKind {
|
225 |
async fn from_message(
|
226 |
message: &NetworkMessage,
|
227 |
state: &ConnectionState,
|
228 |
) -> Result<Self, Error> {
|
229 |
todo!()
|
230 |
}
|
231 |
}
|
232 |
|
233 |
#[async_trait::async_trait]
|
234 |
impl<S, T> FromMessage<S> for Option<T>
|
235 |
where
|
236 |
T: FromMessage<S>,
|
237 |
S: Send + Sync + 'static,
|
238 |
{
|
239 |
async fn from_message(message: &NetworkMessage, state: &S) -> Result<Self, Error> {
|
240 |
Ok(T::from_message(message, state).await.ok())
|
241 |
}
|
242 |
}
|
243 |
|
244 |
#[async_trait::async_trait]
|
245 |
pub trait MessageHandler<T, S, R> {
|
246 |
async fn handle_message(self, message: &NetworkMessage, state: &S) -> Result<R, Error>;
|
247 |
}
|
248 |
#[async_trait::async_trait]
|
249 |
impl<F, R, S, T1, T, E> MessageHandler<(T1,), S, R> for T
|
250 |
where
|
251 |
T: FnOnce(T1) -> F + Clone + Send + 'static,
|
252 |
F: Future<Output = Result<R, E>> + Send,
|
253 |
T1: FromMessage<S> + Send,
|
254 |
S: Send + Sync,
|
255 |
E: std::error::Error + Send + Sync + 'static,
|
256 |
{
|
257 |
async fn handle_message(self, message: &NetworkMessage, state: &S) -> Result<R, Error> {
|
258 |
let value = T1::from_message(message, state).await?;
|
259 |
self(value).await.map_err(|e| Error::from(e))
|
260 |
}
|
261 |
}
|
262 |
|
263 |
#[async_trait::async_trait]
|
264 |
impl<F, R, S, T1, T2, T, E> MessageHandler<(T1, T2), S, R> for T
|
265 |
where
|
266 |
T: FnOnce(T1, T2) -> F + Clone + Send + 'static,
|
267 |
F: Future<Output = Result<R, E>> + Send,
|
268 |
T1: FromMessage<S> + Send,
|
269 |
T2: FromMessage<S> + Send,
|
270 |
S: Send + Sync,
|
271 |
E: std::error::Error + Send + Sync + 'static,
|
272 |
{
|
273 |
async fn handle_message(self, message: &NetworkMessage, state: &S) -> Result<R, Error> {
|
274 |
let value = T1::from_message(message, state).await?;
|
275 |
let value_2 = T2::from_message(message, state).await?;
|
276 |
self(value, value_2).await.map_err(|e| Error::from(e))
|
277 |
}
|
278 |
}
|
279 |
|
280 |
#[async_trait::async_trait]
|
281 |
impl<F, R, S, T1, T2, T3, T, E> MessageHandler<(T1, T2, T3), S, R> for T
|
282 |
where
|
283 |
T: FnOnce(T1, T2, T3) -> F + Clone + Send + 'static,
|
284 |
F: Future<Output = Result<R, E>> + Send,
|
285 |
T1: FromMessage<S> + Send,
|
286 |
T2: FromMessage<S> + Send,
|
287 |
T3: FromMessage<S> + Send,
|
288 |
S: Send + Sync,
|
289 |
E: std::error::Error + Send + Sync + 'static,
|
290 |
{
|
291 |
async fn handle_message(self, message: &NetworkMessage, state: &S) -> Result<R, Error> {
|
292 |
let value = T1::from_message(message, state).await?;
|
293 |
let value_2 = T2::from_message(message, state).await?;
|
294 |
let value_3 = T3::from_message(message, state).await?;
|
295 |
|
296 |
self(value, value_2, value_3)
|
297 |
.await
|
298 |
.map_err(|e| Error::from(e))
|
299 |
}
|
300 |
}
|
301 |
|
302 |
pub struct State<T>(pub T);
|
303 |
|
304 |
#[async_trait::async_trait]
|
305 |
impl<T> FromMessage<T> for State<T>
|
306 |
where
|
307 |
T: Clone + Send + Sync,
|
308 |
{
|
309 |
async fn from_message(_: &NetworkMessage, state: &T) -> Result<Self, Error> {
|
310 |
Ok(Self(state.clone()))
|
311 |
}
|
312 |
}
|
313 |
|
314 |
|
315 |
#[async_trait::async_trait]
|
316 |
impl<T, S> FromMessage<S> for Message<T>
|
317 |
where
|
318 |
T: DeserializeOwned + Send + Sync + Serialize,
|
319 |
S: Clone + Send + Sync,
|
320 |
{
|
321 |
async fn from_message(message: &NetworkMessage, _: &S) -> Result<Self, Error> {
|
322 |
Ok(Message(serde_json::from_slice(&message)?))
|
323 |
}
|
324 |
}
|
325 |
|
326 |
pub struct Message<T: Serialize + DeserializeOwned>(pub T);
|
327 |
|
328 |
async fn public_key(instance: &Instance) -> Result<String, Error> {
|
329 |
let key = reqwest::get(format!("https://{}/.giterated/pubkey.pem", instance.url))
|
330 |
.await?
|
331 |
.text()
|
332 |
.await?;
|
333 |
|
334 |
Ok(key)
|
335 |
}
|
336 |
|