use std::{str::FromStr, sync::atomic::Ordering}; use anyhow::Error; use giterated_models::messages::handshake::{ HandshakeFinalize, HandshakeResponse, InitiateHandshake, }; use semver::Version; use crate::{ connection::ConnectionError, message::{Message, MessageHandler, NetworkMessage, State}, validate_version, version, }; use super::{wrapper::ConnectionState, HandlerUnhandled}; pub async fn handshake_handle( message: &NetworkMessage, state: &ConnectionState, ) -> Result<(), Error> { if initiate_handshake .handle_message(&message, state) .await .is_ok() { Ok(()) } else if handshake_response .handle_message(&message, state) .await .is_ok() { Ok(()) } else if handshake_finalize .handle_message(&message, state) .await .is_ok() { Ok(()) } else { Err(Error::from(HandlerUnhandled)) } } async fn initiate_handshake( Message(initiation): Message, State(connection_state): State, ) -> Result<(), HandshakeError> { if !validate_version(&initiation.version) { error!( "Version compatibility failure! Our Version: {}, Their Version: {}", Version::from_str(&std::env::var("CARGO_PKG_VERSION").unwrap()).unwrap(), initiation.version ); connection_state .send(HandshakeFinalize { success: false }) .await .map_err(|e| HandshakeError::SendError(e))?; Ok(()) } else { connection_state .send(HandshakeResponse { identity: connection_state.instance.clone(), version: version(), }) .await .map_err(|e| HandshakeError::SendError(e))?; Ok(()) } } async fn handshake_response( Message(response): Message, State(connection_state): State, ) -> Result<(), HandshakeError> { if !validate_version(&response.version) { error!( "Version compatibility failure! Our Version: {}, Their Version: {}", Version::from_str(&std::env::var("CARGO_PKG_VERSION").unwrap()).unwrap(), response.version ); connection_state .send(HandshakeFinalize { success: false }) .await .map_err(|e| HandshakeError::SendError(e))?; Ok(()) } else { connection_state .send(HandshakeFinalize { success: true }) .await .map_err(|e| HandshakeError::SendError(e))?; Ok(()) } } async fn handshake_finalize( Message(finalize): Message, State(connection_state): State, ) -> Result<(), HandshakeError> { if !finalize.success { error!("Error during handshake, aborting connection"); return Err(Error::from(ConnectionError::Shutdown).into()); } else { connection_state.handshaked.store(true, Ordering::SeqCst); connection_state .send(HandshakeFinalize { success: true }) .await .map_err(|e| HandshakeError::SendError(e))?; Ok(()) } } #[derive(Debug, thiserror::Error)] pub enum HandshakeError { #[error("version mismatch during handshake, ours: {0}, theirs: {1}")] VersionMismatch(Version, Version), #[error("while sending message: {0}")] SendError(Error), #[error("{0}")] Other(#[from] Error), }