use std::{str::FromStr, sync::atomic::Ordering}; use anyhow::Error; use giterated_models::messages::handshake::{ HandshakeFinalize, HandshakeResponse, InitiateHandshake, }; use semver::Version; use crate::{ connection::ConnectionError, message::{Message, MessageHandler, NetworkMessage, State}, validate_version, version, }; use super::{wrapper::ConnectionState, HandlerUnhandled}; pub async fn handshake_handle( message: &NetworkMessage, state: &ConnectionState, ) -> Result<(), Error> { if initiate_handshake .handle_message(&message, state) .await .is_ok() { Ok(()) } else if handshake_response .handle_message(&message, state) .await .is_ok() { Ok(()) } else if handshake_finalize .handle_message(&message, state) .await .is_ok() { Ok(()) } else { Err(Error::from(HandlerUnhandled)) } } async fn initiate_handshake( Message(initiation): Message, State(connection_state): State, ) -> Result<(), HandshakeError> { connection_state .send(HandshakeResponse { identity: connection_state.instance.clone(), version: version(), }) .await .map_err(|e| HandshakeError::SendError(e))?; Ok(()) // if !validate_version(&initiation.version) { // error!( // "Version compatibility failure! Our Version: {}, Their Version: {}", // Version::from_str(&std::env::var("CARGO_PKG_VERSION").unwrap()).unwrap(), // initiation.version // ); // connection_state // .send(HandshakeFinalize { success: false }) // .await // .map_err(|e| HandshakeError::SendError(e))?; // Ok(()) // } else { // connection_state // .send(HandshakeResponse { // identity: connection_state.instance.clone(), // version: version(), // }) // .await // .map_err(|e| HandshakeError::SendError(e))?; // Ok(()) // } } async fn handshake_response( Message(response): Message, State(connection_state): State, ) -> Result<(), HandshakeError> { connection_state .send(HandshakeFinalize { success: true }) .await .map_err(|e| HandshakeError::SendError(e))?; Ok(()) // if !validate_version(&response.version) { // error!( // "Version compatibility failure! Our Version: {}, Their Version: {}", // Version::from_str(&std::env::var("CARGO_PKG_VERSION").unwrap()).unwrap(), // response.version // ); // connection_state // .send(HandshakeFinalize { success: false }) // .await // .map_err(|e| HandshakeError::SendError(e))?; // Ok(()) // } else { // connection_state // .send(HandshakeFinalize { success: true }) // .await // .map_err(|e| HandshakeError::SendError(e))?; // Ok(()) // } } async fn handshake_finalize( Message(finalize): Message, State(connection_state): State, ) -> Result<(), HandshakeError> { connection_state.handshaked.store(true, Ordering::SeqCst); connection_state .send(HandshakeFinalize { success: true }) .await .map_err(|e| HandshakeError::SendError(e))?; Ok(()) // if !finalize.success { // error!("Error during handshake, aborting connection"); // return Err(Error::from(ConnectionError::Shutdown).into()); // } else { // connection_state.handshaked.store(true, Ordering::SeqCst); // connection_state // .send(HandshakeFinalize { success: true }) // .await // .map_err(|e| HandshakeError::SendError(e))?; // Ok(()) // } } #[derive(Debug, thiserror::Error)] pub enum HandshakeError { #[error("version mismatch during handshake, ours: {0}, theirs: {1}")] VersionMismatch(Version, Version), #[error("while sending message: {0}")] SendError(Error), #[error("{0}")] Other(#[from] Error), }