refactor: remove code duplications for main and testing

This commit is contained in:
Sandro Eiler 2024-03-03 21:48:44 +01:00
parent da1a508616
commit ebd7755731
6 changed files with 143 additions and 114 deletions

23
Cargo.lock generated
View file

@ -119,7 +119,7 @@ dependencies = [
"http 1.0.0",
"http-body 1.0.0",
"http-body-util",
"hyper 1.1.0",
"hyper 1.2.0",
"hyper-util",
"itoa",
"matchit",
@ -736,9 +736,9 @@ dependencies = [
[[package]]
name = "h2"
version = "0.4.0"
version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e1d308f63daf4181410c242d34c11f928dcb3aa105852019e043c9d1f4e4368a"
checksum = "31d030e59af851932b72ceebadf4a2b5986dba4c3b99dd2493f8273a0f151943"
dependencies = [
"bytes",
"fnv",
@ -926,20 +926,21 @@ dependencies = [
[[package]]
name = "hyper"
version = "1.1.0"
version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fb5aa53871fc917b1a9ed87b683a5d86db645e23acb32c2e0785a353e522fb75"
checksum = "186548d73ac615b32a73aafe38fb4f56c0d340e110e5a200bcadbaf2e199263a"
dependencies = [
"bytes",
"futures-channel",
"futures-util",
"h2 0.4.0",
"h2 0.4.2",
"http 1.0.0",
"http-body 1.0.0",
"httparse",
"httpdate",
"itoa",
"pin-project-lite",
"smallvec",
"tokio",
"want",
]
@ -969,7 +970,7 @@ dependencies = [
"futures-util",
"http 1.0.0",
"http-body 1.0.0",
"hyper 1.1.0",
"hyper 1.2.0",
"pin-project-lite",
"socket2 0.5.5",
"tokio",
@ -1110,7 +1111,7 @@ dependencies = [
"claims",
"config",
"fake",
"hyper 1.1.0",
"hyper 1.2.0",
"once_cell",
"quickcheck",
"quickcheck_macros",
@ -1981,9 +1982,9 @@ dependencies = [
[[package]]
name = "smallvec"
version = "1.10.0"
version = "1.13.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a507befe795404456341dfab10cef66ead4c041f62b8b11bbb92bffe5d0953e0"
checksum = "e6ecd384b10a64542d77071bd64bd7b231f4ed5940fba55e98c3de13824cf3d7"
[[package]]
name = "socket2"
@ -3077,7 +3078,7 @@ dependencies = [
"futures",
"http 1.0.0",
"http-body-util",
"hyper 1.1.0",
"hyper 1.2.0",
"hyper-util",
"log",
"once_cell",

View file

@ -14,7 +14,7 @@ name = "learn_axum"
[dependencies]
tokio = { version = "1.36.0", features = ["full"] }
hyper = { version = "1.1.0", features = ["full"] }
hyper = { version = "1.2.0", features = ["full"] }
# Serde / json
serde = { version = "1.0", features = ["derive"] }
serde_json = "1"

View file

@ -5,8 +5,9 @@ use sqlx::postgres::PgSslMode;
use sqlx::ConnectOptions;
use crate::domain::SubscriberEmail;
use crate::email_client::EmailClient;
#[derive(serde::Deserialize)]
#[derive(serde::Deserialize, Clone)]
/// The setting collection.
///
/// * `database`: database settings
@ -17,7 +18,7 @@ pub struct Settings {
pub email_client: EmailClientSettings,
}
#[derive(serde::Deserialize)]
#[derive(serde::Deserialize, Clone)]
pub struct EmailClientSettings {
pub base_url: String,
pub sender_email: String,
@ -25,16 +26,7 @@ pub struct EmailClientSettings {
pub timeout_milliseconds: u64,
}
impl EmailClientSettings {
pub fn sender(&self) -> Result<SubscriberEmail, String> {
SubscriberEmail::parse(self.sender_email.clone())
}
pub fn timeout(&self) -> std::time::Duration {
std::time::Duration::from_millis(self.timeout_milliseconds)
}
}
#[derive(serde::Deserialize)]
#[derive(serde::Deserialize, Clone)]
/// The application settings.
///
/// * `port`: The port to listen on
@ -45,7 +37,7 @@ pub struct ApplicationSettings {
pub host: String,
}
#[derive(serde::Deserialize)]
#[derive(serde::Deserialize, Clone)]
/// The database settings.
///
/// * `username`: the DB username
@ -70,6 +62,26 @@ pub enum Environment {
Production,
}
impl EmailClientSettings {
pub fn client(self) -> EmailClient {
let sender_email = self.sender().expect("Invalid sender email address.");
let timeout = self.timeout();
EmailClient::new(
self.base_url,
sender_email,
self.authorization_token,
timeout,
)
}
pub fn sender(&self) -> Result<SubscriberEmail, String> {
SubscriberEmail::parse(self.sender_email.clone())
}
pub fn timeout(&self) -> std::time::Duration {
std::time::Duration::from_millis(self.timeout_milliseconds)
}
}
impl Environment {
pub fn as_str(&self) -> &'static str {
match self {

View file

@ -1,38 +1,14 @@
use learn_axum::configuration::get_configuration;
use learn_axum::email_client::EmailClient;
use learn_axum::startup;
use learn_axum::startup::Application;
use learn_axum::telemetry::{get_subscriber, init_subscriber};
use sqlx::postgres::PgPoolOptions;
use tokio::net::TcpListener;
#[tokio::main]
/// Entry point for the application.
/// Log level default can be overridden with the RUST_LOG environment variable.
async fn main() {
async fn main() -> std::io::Result<()> {
let subscriber = get_subscriber("learn_axum".into(), "info".into(), std::io::stdout);
init_subscriber(subscriber);
let configuration = get_configuration().expect("Failed to read configuration.");
let addr = format!(
"{}:{}",
configuration.application.host, configuration.application.port
);
let listener = TcpListener::bind(addr).await.unwrap(); //.expect("Unable to bind to port");
let connection_pool = PgPoolOptions::new()
.acquire_timeout(std::time::Duration::from_secs(2))
.connect_lazy_with(configuration.database.with_db());
let sender_email = configuration
.email_client
.sender()
.expect("Invalid sender email address.");
let timeout = configuration.email_client.timeout();
let email_client = EmailClient::new(
configuration.email_client.base_url,
sender_email,
configuration.email_client.authorization_token,
timeout,
);
startup::run(listener, connection_pool, email_client)
.await
.unwrap();
let application = Application::build(configuration).await?;
application.run().await.unwrap();
Ok(())
}

View file

@ -1,8 +1,10 @@
use crate::configuration::{DatabaseSettings, Settings};
use crate::email_client::EmailClient;
use axum::http::Request;
use axum::routing::IntoMakeService;
use axum::serve::Serve;
use axum::Router;
use sqlx::postgres::PgPoolOptions;
use sqlx::PgPool;
use tokio::net::TcpListener;
use tower::ServiceBuilder;
@ -14,6 +16,11 @@ use tower_http::{
use tracing::Level;
use uuid::Uuid;
pub struct Application {
app: Router,
listener: TcpListener,
}
#[derive(Clone)]
pub struct AppState {
pub db_pool: PgPool,
@ -32,43 +39,69 @@ impl MakeRequestId for MakeRequestUuid {
}
}
/// API routing
///
/// * `connection`: The postgres connection pool
pub fn app(db_connection: PgPool, email_client: EmailClient) -> Router {
let state = AppState {
db_pool: db_connection.clone(),
email_client: email_client.clone(),
};
Router::new()
.merge(crate::routes::routes_health_check())
.merge(crate::routes::routes_subscriptions(state))
.layer(
// from https://docs.rs/tower-http/0.2.5/tower_http/request_id/index.html#using-trace
ServiceBuilder::new()
.set_x_request_id(MakeRequestUuid)
.layer(
TraceLayer::new_for_http()
.make_span_with(
DefaultMakeSpan::new()
.include_headers(true)
.level(Level::INFO),
)
.on_response(DefaultOnResponse::new().include_headers(true)),
)
.propagate_x_request_id(),
)
impl Application {
pub async fn build(configuration: Settings) -> Result<Self, std::io::Error> {
let connection_pool = get_connection_pool(&configuration.database);
let sender_email = configuration
.email_client
.sender()
.expect("Invalid sender email address.");
let timeout = configuration.email_client.timeout();
let email_client = EmailClient::new(
configuration.email_client.base_url,
sender_email,
configuration.email_client.authorization_token,
timeout,
);
let address = format!(
"{}:{}",
configuration.application.host, configuration.application.port
);
let listener = TcpListener::bind(&address).await?;
let state = AppState {
db_pool: connection_pool.clone(),
email_client: email_client.clone(),
};
let app = Router::new()
.merge(crate::routes::routes_health_check())
.merge(crate::routes::routes_subscriptions(state))
.layer(
// from https://docs.rs/tower-http/0.2.5/tower_http/request_id/index.html#using-trace
ServiceBuilder::new()
.set_x_request_id(MakeRequestUuid)
.layer(
TraceLayer::new_for_http()
.make_span_with(
DefaultMakeSpan::new()
.include_headers(true)
.level(Level::INFO),
)
.on_response(DefaultOnResponse::new().include_headers(true)),
)
.propagate_x_request_id(),
);
Ok(Self { app, listener })
}
/// Start the server
pub fn run(self) -> Serve<IntoMakeService<Router>, Router> {
axum::serve(self.listener, self.app.into_make_service())
}
/// Get the address of the server
pub fn address(&self) -> String {
format!("{}", self.listener.local_addr().unwrap())
}
/// Get the port of the server
pub fn port(&self) -> u16 {
self.listener.local_addr().unwrap().port()
}
}
/// Start the server
///
/// * `listener`: The TCP listener
/// * `connection`: The postgres connection pool
/// * `email_client`: The email client
pub fn run(
listener: TcpListener,
connection: PgPool,
email_client: EmailClient,
) -> Serve<IntoMakeService<Router>, Router> {
axum::serve(listener, app(connection, email_client).into_make_service())
pub fn get_connection_pool(configuration: &DatabaseSettings) -> PgPool {
PgPoolOptions::new()
.acquire_timeout(std::time::Duration::from_secs(2))
.connect_lazy_with(configuration.with_db())
}

View file

@ -1,9 +1,8 @@
use learn_axum::configuration::{get_configuration, DatabaseSettings};
use learn_axum::email_client::EmailClient;
use learn_axum::startup::{get_connection_pool, Application};
use learn_axum::telemetry::{get_subscriber, init_subscriber};
use once_cell::sync::Lazy;
use sqlx::{Connection, Executor, PgConnection, PgPool};
use tokio::net::TcpListener;
use uuid::Uuid;
/// Ensure that the `tracing` stack is only initialised once using `once_cell`
@ -33,33 +32,41 @@ pub async fn spawn_app() -> TestApp {
// All other invocations will instead skip execution.
Lazy::force(&TRACING);
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let address = format!("http://{}", listener.local_addr().unwrap());
// TODO:
// // Launch a mock server to stand in for Postmark's API
// let email_server = MockServer::start().await;
let mut configuration = get_configuration().expect("Failed to read configuration.");
configuration.database.name = Uuid::new_v4().to_string();
let connection_pool = configure_database(&configuration.database).await;
// Randomise configuration to ensure test isolation
let configuration = {
let mut c = get_configuration().expect("Failed to read configuration.");
// Use a different database for each test case
c.database.name = Uuid::new_v4().to_string();
// Use a random OS port
c.application.port = 0;
c
};
// TODO: remove code duplication
let sender_email = configuration
.email_client
.sender()
.expect("Invalid sender email address.");
let timeout = configuration.email_client.timeout();
let email_client = EmailClient::new(
configuration.email_client.base_url,
sender_email,
configuration.email_client.authorization_token,
timeout,
);
// Create and migrate the database
configure_database(&configuration.database).await;
let connection_pool = get_connection_pool(&configuration.database);
let application = Application::build(configuration.clone())
.await
.expect("Failed to build application.");
// Get the port before spawning the application
let address = format!("http://127.0.0.1:{}", application.port());
// Launch the application as a background task
tokio::spawn(async move { application.run().await.expect("Failed to run the server") });
let service = learn_axum::startup::app(connection_pool.clone(), email_client);
tokio::spawn(async move {
axum::serve(listener, service).await.unwrap();
});
TestApp {
address,
// address: format!("http://localhost:{}", application_port),
// port: application_port,
db_pool: connection_pool,
// email_server,
// test_user: TestUser::generate(),
// api_client: client,
// email_client: configuration.email_client.client(),
}
}