From ebd7755731ad91624503d737bbdee8bb23b106c4 Mon Sep 17 00:00:00 2001 From: Sandro Eiler Date: Sun, 3 Mar 2024 21:48:44 +0100 Subject: [PATCH] refactor: remove code duplications for main and testing --- Cargo.lock | 23 +++++----- Cargo.toml | 2 +- src/configuration.rs | 38 +++++++++------ src/main.rs | 34 ++------------ src/startup.rs | 107 ++++++++++++++++++++++++++++--------------- tests/api/helpers.rs | 53 +++++++++++---------- 6 files changed, 143 insertions(+), 114 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index b8afafd..14bb877 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -119,7 +119,7 @@ dependencies = [ "http 1.0.0", "http-body 1.0.0", "http-body-util", - "hyper 1.1.0", + "hyper 1.2.0", "hyper-util", "itoa", "matchit", @@ -736,9 +736,9 @@ dependencies = [ [[package]] name = "h2" -version = "0.4.0" +version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e1d308f63daf4181410c242d34c11f928dcb3aa105852019e043c9d1f4e4368a" +checksum = "31d030e59af851932b72ceebadf4a2b5986dba4c3b99dd2493f8273a0f151943" dependencies = [ "bytes", "fnv", @@ -926,20 +926,21 @@ dependencies = [ [[package]] name = "hyper" -version = "1.1.0" +version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fb5aa53871fc917b1a9ed87b683a5d86db645e23acb32c2e0785a353e522fb75" +checksum = "186548d73ac615b32a73aafe38fb4f56c0d340e110e5a200bcadbaf2e199263a" dependencies = [ "bytes", "futures-channel", "futures-util", - "h2 0.4.0", + "h2 0.4.2", "http 1.0.0", "http-body 1.0.0", "httparse", "httpdate", "itoa", "pin-project-lite", + "smallvec", "tokio", "want", ] @@ -969,7 +970,7 @@ dependencies = [ "futures-util", "http 1.0.0", "http-body 1.0.0", - "hyper 1.1.0", + "hyper 1.2.0", "pin-project-lite", "socket2 0.5.5", "tokio", @@ -1110,7 +1111,7 @@ dependencies = [ "claims", "config", "fake", - "hyper 1.1.0", + "hyper 1.2.0", "once_cell", "quickcheck", "quickcheck_macros", @@ -1981,9 +1982,9 @@ dependencies = [ [[package]] name = "smallvec" -version = "1.10.0" +version = "1.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a507befe795404456341dfab10cef66ead4c041f62b8b11bbb92bffe5d0953e0" +checksum = "e6ecd384b10a64542d77071bd64bd7b231f4ed5940fba55e98c3de13824cf3d7" [[package]] name = "socket2" @@ -3077,7 +3078,7 @@ dependencies = [ "futures", "http 1.0.0", "http-body-util", - "hyper 1.1.0", + "hyper 1.2.0", "hyper-util", "log", "once_cell", diff --git a/Cargo.toml b/Cargo.toml index 96143d8..5c621aa 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,7 +14,7 @@ name = "learn_axum" [dependencies] tokio = { version = "1.36.0", features = ["full"] } -hyper = { version = "1.1.0", features = ["full"] } +hyper = { version = "1.2.0", features = ["full"] } # Serde / json serde = { version = "1.0", features = ["derive"] } serde_json = "1" diff --git a/src/configuration.rs b/src/configuration.rs index bdfe58f..6c8fab1 100644 --- a/src/configuration.rs +++ b/src/configuration.rs @@ -5,8 +5,9 @@ use sqlx::postgres::PgSslMode; use sqlx::ConnectOptions; use crate::domain::SubscriberEmail; +use crate::email_client::EmailClient; -#[derive(serde::Deserialize)] +#[derive(serde::Deserialize, Clone)] /// The setting collection. /// /// * `database`: database settings @@ -17,7 +18,7 @@ pub struct Settings { pub email_client: EmailClientSettings, } -#[derive(serde::Deserialize)] +#[derive(serde::Deserialize, Clone)] pub struct EmailClientSettings { pub base_url: String, pub sender_email: String, @@ -25,16 +26,7 @@ pub struct EmailClientSettings { pub timeout_milliseconds: u64, } -impl EmailClientSettings { - pub fn sender(&self) -> Result { - SubscriberEmail::parse(self.sender_email.clone()) - } - pub fn timeout(&self) -> std::time::Duration { - std::time::Duration::from_millis(self.timeout_milliseconds) - } -} - -#[derive(serde::Deserialize)] +#[derive(serde::Deserialize, Clone)] /// The application settings. /// /// * `port`: The port to listen on @@ -45,7 +37,7 @@ pub struct ApplicationSettings { pub host: String, } -#[derive(serde::Deserialize)] +#[derive(serde::Deserialize, Clone)] /// The database settings. /// /// * `username`: the DB username @@ -70,6 +62,26 @@ pub enum Environment { Production, } +impl EmailClientSettings { + pub fn client(self) -> EmailClient { + let sender_email = self.sender().expect("Invalid sender email address."); + let timeout = self.timeout(); + EmailClient::new( + self.base_url, + sender_email, + self.authorization_token, + timeout, + ) + } + + pub fn sender(&self) -> Result { + SubscriberEmail::parse(self.sender_email.clone()) + } + pub fn timeout(&self) -> std::time::Duration { + std::time::Duration::from_millis(self.timeout_milliseconds) + } +} + impl Environment { pub fn as_str(&self) -> &'static str { match self { diff --git a/src/main.rs b/src/main.rs index 388e19b..4e2fc64 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,38 +1,14 @@ use learn_axum::configuration::get_configuration; -use learn_axum::email_client::EmailClient; -use learn_axum::startup; +use learn_axum::startup::Application; use learn_axum::telemetry::{get_subscriber, init_subscriber}; -use sqlx::postgres::PgPoolOptions; -use tokio::net::TcpListener; #[tokio::main] -/// Entry point for the application. -/// Log level default can be overridden with the RUST_LOG environment variable. -async fn main() { +async fn main() -> std::io::Result<()> { let subscriber = get_subscriber("learn_axum".into(), "info".into(), std::io::stdout); init_subscriber(subscriber); let configuration = get_configuration().expect("Failed to read configuration."); - let addr = format!( - "{}:{}", - configuration.application.host, configuration.application.port - ); - let listener = TcpListener::bind(addr).await.unwrap(); //.expect("Unable to bind to port"); - let connection_pool = PgPoolOptions::new() - .acquire_timeout(std::time::Duration::from_secs(2)) - .connect_lazy_with(configuration.database.with_db()); - let sender_email = configuration - .email_client - .sender() - .expect("Invalid sender email address."); - let timeout = configuration.email_client.timeout(); - let email_client = EmailClient::new( - configuration.email_client.base_url, - sender_email, - configuration.email_client.authorization_token, - timeout, - ); - startup::run(listener, connection_pool, email_client) - .await - .unwrap(); + let application = Application::build(configuration).await?; + application.run().await.unwrap(); + Ok(()) } diff --git a/src/startup.rs b/src/startup.rs index 0b495e7..c491757 100644 --- a/src/startup.rs +++ b/src/startup.rs @@ -1,8 +1,10 @@ +use crate::configuration::{DatabaseSettings, Settings}; use crate::email_client::EmailClient; use axum::http::Request; use axum::routing::IntoMakeService; use axum::serve::Serve; use axum::Router; +use sqlx::postgres::PgPoolOptions; use sqlx::PgPool; use tokio::net::TcpListener; use tower::ServiceBuilder; @@ -14,6 +16,11 @@ use tower_http::{ use tracing::Level; use uuid::Uuid; +pub struct Application { + app: Router, + listener: TcpListener, +} + #[derive(Clone)] pub struct AppState { pub db_pool: PgPool, @@ -32,43 +39,69 @@ impl MakeRequestId for MakeRequestUuid { } } -/// API routing -/// -/// * `connection`: The postgres connection pool -pub fn app(db_connection: PgPool, email_client: EmailClient) -> Router { - let state = AppState { - db_pool: db_connection.clone(), - email_client: email_client.clone(), - }; - Router::new() - .merge(crate::routes::routes_health_check()) - .merge(crate::routes::routes_subscriptions(state)) - .layer( - // from https://docs.rs/tower-http/0.2.5/tower_http/request_id/index.html#using-trace - ServiceBuilder::new() - .set_x_request_id(MakeRequestUuid) - .layer( - TraceLayer::new_for_http() - .make_span_with( - DefaultMakeSpan::new() - .include_headers(true) - .level(Level::INFO), - ) - .on_response(DefaultOnResponse::new().include_headers(true)), - ) - .propagate_x_request_id(), - ) +impl Application { + pub async fn build(configuration: Settings) -> Result { + let connection_pool = get_connection_pool(&configuration.database); + let sender_email = configuration + .email_client + .sender() + .expect("Invalid sender email address."); + let timeout = configuration.email_client.timeout(); + let email_client = EmailClient::new( + configuration.email_client.base_url, + sender_email, + configuration.email_client.authorization_token, + timeout, + ); + let address = format!( + "{}:{}", + configuration.application.host, configuration.application.port + ); + let listener = TcpListener::bind(&address).await?; + + let state = AppState { + db_pool: connection_pool.clone(), + email_client: email_client.clone(), + }; + let app = Router::new() + .merge(crate::routes::routes_health_check()) + .merge(crate::routes::routes_subscriptions(state)) + .layer( + // from https://docs.rs/tower-http/0.2.5/tower_http/request_id/index.html#using-trace + ServiceBuilder::new() + .set_x_request_id(MakeRequestUuid) + .layer( + TraceLayer::new_for_http() + .make_span_with( + DefaultMakeSpan::new() + .include_headers(true) + .level(Level::INFO), + ) + .on_response(DefaultOnResponse::new().include_headers(true)), + ) + .propagate_x_request_id(), + ); + Ok(Self { app, listener }) + } + + /// Start the server + pub fn run(self) -> Serve, Router> { + axum::serve(self.listener, self.app.into_make_service()) + } + + /// Get the address of the server + pub fn address(&self) -> String { + format!("{}", self.listener.local_addr().unwrap()) + } + + /// Get the port of the server + pub fn port(&self) -> u16 { + self.listener.local_addr().unwrap().port() + } } -/// Start the server -/// -/// * `listener`: The TCP listener -/// * `connection`: The postgres connection pool -/// * `email_client`: The email client -pub fn run( - listener: TcpListener, - connection: PgPool, - email_client: EmailClient, -) -> Serve, Router> { - axum::serve(listener, app(connection, email_client).into_make_service()) +pub fn get_connection_pool(configuration: &DatabaseSettings) -> PgPool { + PgPoolOptions::new() + .acquire_timeout(std::time::Duration::from_secs(2)) + .connect_lazy_with(configuration.with_db()) } diff --git a/tests/api/helpers.rs b/tests/api/helpers.rs index 5b6f493..c80a827 100644 --- a/tests/api/helpers.rs +++ b/tests/api/helpers.rs @@ -1,9 +1,8 @@ use learn_axum::configuration::{get_configuration, DatabaseSettings}; -use learn_axum::email_client::EmailClient; +use learn_axum::startup::{get_connection_pool, Application}; use learn_axum::telemetry::{get_subscriber, init_subscriber}; use once_cell::sync::Lazy; use sqlx::{Connection, Executor, PgConnection, PgPool}; -use tokio::net::TcpListener; use uuid::Uuid; /// Ensure that the `tracing` stack is only initialised once using `once_cell` @@ -33,33 +32,41 @@ pub async fn spawn_app() -> TestApp { // All other invocations will instead skip execution. Lazy::force(&TRACING); - let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); - let address = format!("http://{}", listener.local_addr().unwrap()); + // TODO: + // // Launch a mock server to stand in for Postmark's API + // let email_server = MockServer::start().await; - let mut configuration = get_configuration().expect("Failed to read configuration."); - configuration.database.name = Uuid::new_v4().to_string(); - let connection_pool = configure_database(&configuration.database).await; + // Randomise configuration to ensure test isolation + let configuration = { + let mut c = get_configuration().expect("Failed to read configuration."); + // Use a different database for each test case + c.database.name = Uuid::new_v4().to_string(); + // Use a random OS port + c.application.port = 0; + c + }; - // TODO: remove code duplication - let sender_email = configuration - .email_client - .sender() - .expect("Invalid sender email address."); - let timeout = configuration.email_client.timeout(); - let email_client = EmailClient::new( - configuration.email_client.base_url, - sender_email, - configuration.email_client.authorization_token, - timeout, - ); + // Create and migrate the database + configure_database(&configuration.database).await; + let connection_pool = get_connection_pool(&configuration.database); + let application = Application::build(configuration.clone()) + .await + .expect("Failed to build application."); + // Get the port before spawning the application + let address = format!("http://127.0.0.1:{}", application.port()); + + // Launch the application as a background task + tokio::spawn(async move { application.run().await.expect("Failed to run the server") }); - let service = learn_axum::startup::app(connection_pool.clone(), email_client); - tokio::spawn(async move { - axum::serve(listener, service).await.unwrap(); - }); TestApp { address, + // address: format!("http://localhost:{}", application_port), + // port: application_port, db_pool: connection_pool, + // email_server, + // test_user: TestUser::generate(), + // api_client: client, + // email_client: configuration.email_client.client(), } }