diff --git a/Cargo.toml b/Cargo.toml index 3e0be60..080cb26 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,4 +19,6 @@ rand = "*" jsonwebtoken = { version = "*", features = ["use_pem"]} chrono = { version = "0.4", features = [ "serde", "std" ] } reqwest = { version = "0.11" } -anyhow = "*" \ No newline at end of file +anyhow = "*" +deadpool = "*" +async-trait = "*" \ No newline at end of file diff --git a/src/lib.rs b/src/lib.rs index 1b52437..ed73847 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,13 +1,15 @@ use std::convert::Infallible; +use std::net::SocketAddr; use std::str::FromStr; -use std::{error::Error, net::SocketAddr}; +use std::sync::Arc; +use anyhow::Error; +use deadpool::managed::{BuildError, Manager, Pool, RecycleResult}; use futures_util::{SinkExt, StreamExt}; use giterated_daemon::messages::authentication::{RegisterAccountRequest, RegisterAccountResponse}; use giterated_daemon::messages::UnvalidatedUserAuthenticated; use giterated_daemon::model::repository::RepositoryVisibility; use giterated_daemon::model::user::User; -use giterated_daemon::{version, validate_version}; use giterated_daemon::{ handshake::{HandshakeFinalize, HandshakeMessage, InitiateHandshake}, messages::{ @@ -26,8 +28,10 @@ use giterated_daemon::{ repository::{Repository, RepositoryView}, }, }; +use giterated_daemon::{validate_version, version}; use serde::Serialize; use tokio::net::TcpStream; +use tokio::sync::broadcast::{Receiver, Sender}; use tokio_tungstenite::{connect_async, tungstenite::Message, MaybeTlsStream, WebSocketStream}; type Socket = WebSocketStream>; @@ -43,7 +47,7 @@ pub struct GiteratedApiBuilder { } pub trait AsInstance { - type Error: Error + Send + Sync + 'static; + type Error: std::error::Error + Send + Sync + 'static; fn into_instance(self) -> Result; } @@ -99,74 +103,78 @@ impl GiteratedApiBuilder { } pub async fn build(&mut self) -> Result { - Ok(GiteratedApi::new( - self.our_instance.clone(), - self.our_private_key.clone().unwrap(), - self.our_public_key.clone().unwrap(), - self.target_instance.clone(), - ) - .await?) + Ok(GiteratedApi { + configuration: Arc::new(GiteratedApiConfiguration { + our_private_key: self.our_private_key.take().unwrap(), + our_public_key: self.our_public_key.take().unwrap(), + target_instance: self.target_instance.take(), + // todo + target_public_key: None, + }), + }) } } -pub struct GiteratedApi { - pub connection: Socket, - pub our_instance: Instance, +struct GiteratedConnectionPool { + target_instance: Instance, +} + +#[async_trait::async_trait] +impl Manager for GiteratedConnectionPool { + type Type = Socket; + type Error = anyhow::Error; + + async fn create(&self) -> Result { + info!("Creating new Daemon connection"); + let mut connection = GiteratedApi::connect_to(self.target_instance.clone()).await?; + + // Handshake first! + GiteratedApi::handle_handshake(&mut connection, &self.target_instance).await?; + + Ok(connection) + } + + async fn recycle(&self, _: &mut Socket) -> RecycleResult { + Ok(()) + } +} + +pub struct GiteratedApiConfiguration { pub our_private_key: String, pub our_public_key: String, pub target_instance: Option, pub target_public_key: Option, } -impl GiteratedApi { - pub async fn new( - local_instance: Instance, - private_key: String, - public_key: String, - target_instance: Option, - ) -> Result { - let connection = Self::connect_to( - target_instance - .clone() - .unwrap_or_else(|| local_instance.clone()), - ) - .await?; - - let mut api = GiteratedApi { - connection, - our_instance: local_instance, - our_private_key: private_key, - our_public_key: public_key, - target_instance, - target_public_key: None, - }; - - // Handle handshake - api.handle_handshake().await?; - - Ok(api) +#[derive(Clone)] +pub struct DaemonConnectionPool(Pool); + +impl DaemonConnectionPool { + pub fn connect( + instance: impl ToOwned, + ) -> Result> { + Ok(Self( + Pool::builder(GiteratedConnectionPool { + target_instance: instance.to_owned(), + }) + .build()?, + )) } +} +#[derive(Clone)] +pub struct GiteratedApi { + configuration: Arc, +} + +impl GiteratedApi { pub async fn public_key(&mut self) -> String { - if let Some(public_key) = &self.target_public_key { + if let Some(public_key) = &self.configuration.target_public_key { public_key.clone() } else { - let key = reqwest::get(format!( - "https://{}/.giterated/pubkey.pem", - self.target_instance - .as_ref() - .unwrap_or_else(|| &self.our_instance) - .url - )) - .await - .unwrap() - .text() - .await - .unwrap(); - - self.target_public_key = Some(key.clone()); - - key + assert!(self.configuration.target_instance.is_none()); + + self.configuration.our_public_key.clone() } } @@ -175,28 +183,34 @@ impl GiteratedApi { /// # Authorization /// - Must be made by the same instance its being sent to pub async fn register( - &mut self, + &self, username: String, email: Option, password: String, - ) -> Result> { + pool: &DaemonConnectionPool, + ) -> Result { + let mut connection = pool.0.get().await.unwrap(); + let message = InstanceAuthenticated::new( RegisterAccountRequest { username, email, password, }, - self.our_instance.clone(), - self.our_private_key.clone(), + pool.0.manager().target_instance.clone(), + self.configuration.our_private_key.clone(), ) .unwrap(); - self.send_message(&MessageKind::Authentication( - AuthenticationMessage::Request(AuthenticationRequest::RegisterAccount(message)), - )) + Self::send_message( + &MessageKind::Authentication(AuthenticationMessage::Request( + AuthenticationRequest::RegisterAccount(message), + )), + &mut connection, + ) .await?; - while let Ok(payload) = self.next_payload().await { + while let Ok(payload) = self.next_payload(&mut connection).await { if let Ok(MessageKind::Authentication(AuthenticationMessage::Response( AuthenticationResponse::RegisterAccount(response), ))) = serde_json::from_slice(&payload) @@ -210,21 +224,25 @@ impl GiteratedApi { /// Create repository on the target instance. pub async fn create_repository( - &mut self, + &self, user_token: String, name: String, description: Option, visibility: RepositoryVisibility, default_branch: String, owner: User, - ) -> Result> { + pool: &DaemonConnectionPool, + ) -> Result { + let mut connection = pool.0.get().await.unwrap(); + let target_respository = Repository { owner: owner.clone(), name: name.clone(), instance: self + .configuration .target_instance .as_ref() - .unwrap_or(&self.our_instance) + .unwrap_or(&pool.0.manager().target_instance) .clone(), }; @@ -236,17 +254,25 @@ impl GiteratedApi { owner, }; - let message = - UnvalidatedUserAuthenticated::new(request, user_token, self.our_private_key.clone()) - .unwrap(); + let message = UnvalidatedUserAuthenticated::new( + request, + user_token, + self.configuration.our_private_key.clone(), + ) + .unwrap(); - self.send_message(&MessageKind::Repository(RepositoryMessage { - target: target_respository, - command: RepositoryMessageKind::Request(RepositoryRequest::CreateRepository(message)), - })) + Self::send_message( + &MessageKind::Repository(RepositoryMessage { + target: target_respository, + command: RepositoryMessageKind::Request(RepositoryRequest::CreateRepository( + message, + )), + }), + &mut connection, + ) .await?; - while let Ok(payload) = self.next_payload().await { + while let Ok(payload) = self.next_payload(&mut connection).await { if let Ok(MessageKind::Repository(RepositoryMessage { command: RepositoryMessageKind::Response(RepositoryResponse::CreateRepository(_response)), @@ -264,7 +290,10 @@ impl GiteratedApi { &mut self, token: &str, repository: Repository, - ) -> Result> { + pool: &DaemonConnectionPool, + ) -> Result { + let mut connection = pool.0.get().await.unwrap(); + let message = UnvalidatedUserAuthenticated::new( RepositoryInfoRequest { repository: repository.clone(), @@ -273,19 +302,22 @@ impl GiteratedApi { path: None, }, token.to_string(), - self.our_private_key.clone(), + self.configuration.our_private_key.clone(), ) .unwrap(); - self.send_message(&MessageKind::Repository(RepositoryMessage { - target: repository.clone(), - command: RepositoryMessageKind::Request(RepositoryRequest::RepositoryInfo(message)), - })) + Self::send_message( + &MessageKind::Repository(RepositoryMessage { + target: repository.clone(), + command: RepositoryMessageKind::Request(RepositoryRequest::RepositoryInfo(message)), + }), + &mut connection, + ) .await?; loop { // while let Ok(payload) = Self::next_payload(&mut socket).await { - let payload = match self.next_payload().await { + let payload = match self.next_payload(&mut connection).await { Ok(payload) => payload, Err(err) => { error!("Error while fetching next payload: {:?}", err); @@ -314,26 +346,32 @@ impl GiteratedApi { secret_key: String, username: String, password: String, - ) -> Result> { + pool: &DaemonConnectionPool, + ) -> Result { + let mut connection = pool.0.get().await.unwrap(); + let request = InstanceAuthenticated::new( AuthenticationTokenRequest { secret_key, username, password, }, - self.our_instance.clone(), + pool.0.manager().target_instance.clone(), include_str!("example_keys/giterated.key").to_string(), ) .unwrap(); - self.send_message(&MessageKind::Authentication( - AuthenticationMessage::Request(AuthenticationRequest::AuthenticationToken(request)), - )) + Self::send_message( + &MessageKind::Authentication(AuthenticationMessage::Request( + AuthenticationRequest::AuthenticationToken(request), + )), + &mut connection, + ) .await?; loop { // while let Ok(payload) = Self::next_payload(&mut socket).await { - let payload = match self.next_payload().await { + let payload = match self.next_payload(&mut connection).await { Ok(payload) => payload, Err(err) => { error!("Error while fetching next payload: {:?}", err); @@ -359,20 +397,26 @@ impl GiteratedApi { &mut self, secret_key: String, token: String, - ) -> Result, Box> { + pool: &DaemonConnectionPool, + ) -> Result, Error> { + let mut connection = pool.0.get().await.unwrap(); + let request = InstanceAuthenticated::new( TokenExtensionRequest { secret_key, token }, - self.our_instance.clone(), - self.our_private_key.clone(), + pool.0.manager().target_instance.clone(), + self.configuration.our_private_key.clone(), ) .unwrap(); - self.send_message(&MessageKind::Authentication( - AuthenticationMessage::Request(AuthenticationRequest::TokenExtension(request)), - )) + Self::send_message( + &MessageKind::Authentication(AuthenticationMessage::Request( + AuthenticationRequest::TokenExtension(request), + )), + &mut connection, + ) .await?; - while let Ok(payload) = self.next_payload().await { + while let Ok(payload) = self.next_payload(&mut connection).await { if let Ok(MessageKind::Authentication(AuthenticationMessage::Response( AuthenticationResponse::TokenExtension(response), ))) = serde_json::from_slice(&payload) @@ -397,18 +441,22 @@ impl GiteratedApi { Ok(websocket) } - async fn handle_handshake(&mut self) -> Result<(), anyhow::Error> { + async fn handle_handshake( + socket: &mut Socket, + instance: &Instance, + ) -> Result<(), anyhow::Error> { // Send handshake initiation - self.send_message(&MessageKind::Handshake(HandshakeMessage::Initiate( - InitiateHandshake { - identity: self.our_instance.clone(), + Self::send_message( + &MessageKind::Handshake(HandshakeMessage::Initiate(InitiateHandshake { + identity: instance.clone(), version: version(), - }, - ))) + })), + socket, + ) .await?; - while let Some(message) = self.connection.next().await { + while let Some(message) = socket.next().await { let message = match message { Ok(message) => message, Err(err) => { @@ -442,7 +490,6 @@ impl GiteratedApi { match handshake { HandshakeMessage::Initiate(_) => unimplemented!(), HandshakeMessage::Response(response) => { - let message = if !validate_version(&response.version) { error!( "Version compatibility failure! Our Version: {}, Their Version: {}", @@ -457,9 +504,10 @@ impl GiteratedApi { HandshakeFinalize { success: true } }; // Send HandshakeMessage::Finalize - self.send_message(&MessageKind::Handshake(HandshakeMessage::Finalize( - message - ))) + Self::send_message( + &MessageKind::Handshake(HandshakeMessage::Finalize(message)), + socket, + ) .await?; } HandshakeMessage::Finalize(finalize) => { @@ -476,15 +524,19 @@ impl GiteratedApi { Ok(()) } - async fn send_message(&mut self, message: &T) -> Result<(), anyhow::Error> { - self.connection + async fn send_message( + message: &T, + socket: &mut Socket, + ) -> Result<(), anyhow::Error> { + socket .send(Message::Binary(serde_json::to_vec(&message).unwrap())) .await?; + Ok(()) } - async fn next_payload(&mut self) -> Result, Box> { - while let Some(message) = self.connection.next().await { + async fn next_payload(&self, socket: &mut Socket) -> Result, Error> { + while let Some(message) = socket.next().await { let message = message?; match message { diff --git a/src/main.rs b/src/main.rs index a1565b3..ece87dc 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,16 +1,12 @@ -use std::{collections::BTreeMap, str::FromStr, time::SystemTime}; - -use chrono::{DateTime, NaiveDateTime, Utc}; -use giterated_api::{GiteratedApi, GiteratedApiBuilder}; -use giterated_daemon::{ - messages::repository::CreateRepositoryRequest, - model::{ - instance::Instance, - repository::{Repository, RepositoryVisibility}, - user::User, - }, +use std::str::FromStr; + +use giterated_api::{DaemonConnectionPool, GiteratedApiBuilder}; +use giterated_daemon::model::{ + instance::Instance, + repository::{Repository, RepositoryVisibility}, + user::User, }; -use jsonwebtoken::{decode, encode, Algorithm, DecodingKey, EncodingKey, TokenData, Validation}; +use jsonwebtoken::{decode, Algorithm, DecodingKey, TokenData, Validation}; use serde::{Deserialize, Serialize}; // use jwt::SignWithKey; @@ -18,46 +14,15 @@ use serde::{Deserialize, Serialize}; extern crate tracing; #[tokio::main] -async fn main() { +async fn main() -> Result<(), anyhow::Error> { tracing_subscriber::fmt::init(); - // info!( - // "Response from Daemon: {:?}", - // GiteratedApi::repository_info(Repository { - // name: String::from("foo"), - // instance: Instance { - // url: String::from("127.0.0.1:8080") - // } - // }) - // .await - // ); - - // let encoding_key = - // EncodingKey::from_rsa_pem(include_bytes!("example_keys/giterated.key")).unwrap(); - - // let claims = UserTokenMetadata { - // user: User { - // username: String::from("ambee"), - // instance: Instance { - // url: String::from("giterated.dev"), - // }, - // }, - // generated_for: Instance { - // url: String::from("giterated.dev"), - // }, - // exp: SystemTime::UNIX_EPOCH.elapsed().unwrap().as_secs(), - // }; - - // let token = encode( - // &jsonwebtoken::Header::new(Algorithm::RS256), - // &claims, - // &encoding_key, - // ) - // .unwrap(); + + let pool = DaemonConnectionPool::connect(Instance::from_str("giterated.dev")?).unwrap(); let mut api = GiteratedApiBuilder::from_local("giterated.dev") .unwrap() .private_key(include_str!("example_keys/giterated.key")) - .public_key(include_str!("example_keys/giterated.key")) + .public_key(include_str!("example_keys/giterated.key.pub")) .build() .await .unwrap(); @@ -69,6 +34,7 @@ async fn main() { String::from("ambee"), None, String::from("lolthisisinthecommithistory"), + &pool, ) .await; @@ -79,6 +45,7 @@ async fn main() { String::from("foobar"), String::from("ambee"), String::from("password"), + &pool, ) .await .unwrap(); @@ -101,7 +68,7 @@ async fn main() { info!("Lets extend that token!"); let new_token = api - .extend_token(String::from("foobar"), token.clone()) + .extend_token(String::from("foobar"), token.clone(), &pool) .await .unwrap(); info!("New Token Returned:\n{:?}", new_token); @@ -116,6 +83,7 @@ async fn main() { RepositoryVisibility::Public, String::from("master"), User::from_str("ambee:giterated.dev").unwrap(), + &pool, ) .await .unwrap(); @@ -127,12 +95,15 @@ async fn main() { let view = api .repository_info( &token, - Repository::from_str("ambee:giterated.dev/hello@giterated.dev").unwrap(), + Repository::from_str("ambee:giterated.dev/super-repository@giterated.dev").unwrap(), + &pool, ) .await .unwrap(); info!("Repository Info:\n{:#?}", view); + + Ok(()) } #[derive(Debug, Serialize, Deserialize)]