feat: add sanitization check, refactor error returns
This commit is contained in:
parent
13162f6470
commit
d7d37341ba
6 changed files with 106 additions and 7 deletions
15
Cargo.lock
generated
15
Cargo.lock
generated
|
|
@ -1049,6 +1049,7 @@ dependencies = [
|
||||||
"serde-aux",
|
"serde-aux",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
"sqlx",
|
"sqlx",
|
||||||
|
"strum_macros",
|
||||||
"tokio",
|
"tokio",
|
||||||
"tower",
|
"tower",
|
||||||
"tower-http",
|
"tower-http",
|
||||||
|
|
@ -1056,6 +1057,7 @@ dependencies = [
|
||||||
"tracing-bunyan-formatter",
|
"tracing-bunyan-formatter",
|
||||||
"tracing-log 0.2.0",
|
"tracing-log 0.2.0",
|
||||||
"tracing-subscriber",
|
"tracing-subscriber",
|
||||||
|
"unicode-segmentation",
|
||||||
"uuid",
|
"uuid",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
@ -2253,6 +2255,19 @@ dependencies = [
|
||||||
"unicode-normalization",
|
"unicode-normalization",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "strum_macros"
|
||||||
|
version = "0.26.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "7a3417fc93d76740d974a01654a09777cb500428cc874ca9f45edfe0c4d4cd18"
|
||||||
|
dependencies = [
|
||||||
|
"heck",
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"rustversion",
|
||||||
|
"syn 2.0.32",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "subtle"
|
name = "subtle"
|
||||||
version = "2.5.0"
|
version = "2.5.0"
|
||||||
|
|
|
||||||
|
|
@ -34,7 +34,8 @@ tracing-subscriber = { version = "0.3", features = ["registry", "env-filter"] }
|
||||||
tracing-bunyan-formatter = "0.3"
|
tracing-bunyan-formatter = "0.3"
|
||||||
tracing-log = "0.2"
|
tracing-log = "0.2"
|
||||||
secrecy = { version = "0.8", features = ["serde"] }
|
secrecy = { version = "0.8", features = ["serde"] }
|
||||||
# lazy-regex = "3"
|
unicode-segmentation = "1"
|
||||||
|
strum_macros = "0.26"
|
||||||
# async-trait = "0.1"
|
# async-trait = "0.1"
|
||||||
# strum_macros = "0.25"
|
# strum_macros = "0.25"
|
||||||
|
|
||||||
|
|
|
||||||
13
src/error.rs
13
src/error.rs
|
|
@ -15,7 +15,11 @@ pub enum Error {
|
||||||
AuthFailCtxNotInRequestExt,
|
AuthFailCtxNotInRequestExt,
|
||||||
|
|
||||||
// -- Model errors.
|
// -- Model errors.
|
||||||
|
// FIXME: Delete this:
|
||||||
PropertyDeleteFailIdNotFound { id: u64 },
|
PropertyDeleteFailIdNotFound { id: u64 },
|
||||||
|
|
||||||
|
// -- Service errors.
|
||||||
|
SubscriptionFailInvalidName,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Error {
|
impl Error {
|
||||||
|
|
@ -30,11 +34,18 @@ impl Error {
|
||||||
| Self::AuthFailCtxNotInRequestExt => (StatusCode::FORBIDDEN, ClientError::NO_AUTH),
|
| Self::AuthFailCtxNotInRequestExt => (StatusCode::FORBIDDEN, ClientError::NO_AUTH),
|
||||||
|
|
||||||
// -- Model.
|
// -- Model.
|
||||||
|
// FIXME: Delete this:
|
||||||
Self::PropertyDeleteFailIdNotFound { .. } => {
|
Self::PropertyDeleteFailIdNotFound { .. } => {
|
||||||
(StatusCode::BAD_REQUEST, ClientError::INVALID_PARAMS)
|
(StatusCode::BAD_REQUEST, ClientError::INVALID_PARAMS)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -- Subscription.
|
||||||
|
Self::SubscriptionFailInvalidName => {
|
||||||
|
(StatusCode::BAD_REQUEST, ClientError::INVALID_PARAMS)
|
||||||
|
}
|
||||||
|
|
||||||
// -- Fallback.
|
// -- Fallback.
|
||||||
|
#[allow(unreachable_patterns)]
|
||||||
_ => (
|
_ => (
|
||||||
StatusCode::INTERNAL_SERVER_ERROR,
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
ClientError::SERVICE_ERROR,
|
ClientError::SERVICE_ERROR,
|
||||||
|
|
@ -45,7 +56,7 @@ impl Error {
|
||||||
|
|
||||||
impl IntoResponse for Error {
|
impl IntoResponse for Error {
|
||||||
fn into_response(self) -> Response {
|
fn into_response(self) -> Response {
|
||||||
println!("->> {:<12} - {self:?}", "INTO_RESPONSE");
|
// TODO: trace something here maybe.
|
||||||
|
|
||||||
// Create a placeholder Axum response.
|
// Create a placeholder Axum response.
|
||||||
let mut response = StatusCode::INTERNAL_SERVER_ERROR.into_response();
|
let mut response = StatusCode::INTERNAL_SERVER_ERROR.into_response();
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,10 @@
|
||||||
//!
|
//!
|
||||||
//! This is an example documentation string for the root of the crate.
|
//! This is an example documentation string for the root of the crate.
|
||||||
|
|
||||||
|
pub use self::error::{Error, Result};
|
||||||
|
|
||||||
pub mod configuration;
|
pub mod configuration;
|
||||||
|
pub mod error;
|
||||||
pub mod routes;
|
pub mod routes;
|
||||||
pub mod startup;
|
pub mod startup;
|
||||||
pub mod telemetry;
|
pub mod telemetry;
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,16 @@
|
||||||
|
use crate::Result;
|
||||||
use axum::extract::State;
|
use axum::extract::State;
|
||||||
use axum::routing::post;
|
use axum::routing::post;
|
||||||
use axum::Form;
|
use axum::Form;
|
||||||
use axum::Router;
|
use axum::Router;
|
||||||
|
use axum::{
|
||||||
|
http::StatusCode,
|
||||||
|
response::{IntoResponse, Response},
|
||||||
|
};
|
||||||
use chrono::Utc;
|
use chrono::Utc;
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
use sqlx::PgPool;
|
use sqlx::PgPool;
|
||||||
|
use unicode_segmentation::UnicodeSegmentation;
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
|
|
@ -13,6 +19,32 @@ struct FormData {
|
||||||
name: String,
|
name: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Returns `true` if the input satisfies all our validation constraints
|
||||||
|
/// on subscriber names, `false` otherwise.
|
||||||
|
pub fn is_valid_name(s: &str) -> bool {
|
||||||
|
// `.trim()` returns a view over the input `s` without trailing
|
||||||
|
// whitespace-like characters.
|
||||||
|
// `.is_empty` checks if the view contains any character.
|
||||||
|
let is_empty_or_whitespace = s.trim().is_empty();
|
||||||
|
|
||||||
|
// A grapheme is defined by the Unicode standard as a "user-perceived"
|
||||||
|
// character: `å` is a single grapheme, but it is composed of two characters
|
||||||
|
// (`a` and `̊`).
|
||||||
|
//
|
||||||
|
// `graphemes` returns an iterator over the graphemes in the input `s`.
|
||||||
|
// `true` specifies that we want to use the extended grapheme definition set,
|
||||||
|
// the recommended one.
|
||||||
|
let is_too_long = s.graphemes(true).count() > 256;
|
||||||
|
|
||||||
|
// Iterate over all characters in the input `s` to check if any of them matches
|
||||||
|
// one of the characters in the forbidden array.
|
||||||
|
let forbidden_characters = ['/', '(', ')', '"', '<', '>', '\\', '{', '}'];
|
||||||
|
let contains_forbidden_characters = s.chars().any(|g| forbidden_characters.contains(&g));
|
||||||
|
|
||||||
|
// Return `false` if any of our conditions have been violated
|
||||||
|
!(is_empty_or_whitespace || is_too_long || contains_forbidden_characters)
|
||||||
|
}
|
||||||
|
|
||||||
#[tracing::instrument(
|
#[tracing::instrument(
|
||||||
name = "Adding a new subscriber",
|
name = "Adding a new subscriber",
|
||||||
skip(form, pool),
|
skip(form, pool),
|
||||||
|
|
@ -22,13 +54,19 @@ struct FormData {
|
||||||
subscriber_name = %form.name
|
subscriber_name = %form.name
|
||||||
)
|
)
|
||||||
)]
|
)]
|
||||||
pub async fn subscribe(State(pool): State<PgPool>, Form(form): Form<FormData>) {
|
pub async fn subscribe(State(pool): State<PgPool>, Form(form): Form<FormData>) -> Response {
|
||||||
|
if !is_valid_name(&form.name) {
|
||||||
|
tracing::error!("Failed to add subscriber to the database");
|
||||||
|
return (StatusCode::BAD_REQUEST, "Invalid name").into_response();
|
||||||
|
}
|
||||||
match insert_subscriber(&pool, &form).await {
|
match insert_subscriber(&pool, &form).await {
|
||||||
Ok(_) => {
|
Ok(_) => {
|
||||||
tracing::info!("Subscriber added to the database");
|
tracing::info!("Subscriber added to the database");
|
||||||
|
return (StatusCode::OK,).into_response();
|
||||||
}
|
}
|
||||||
Err(_) => {
|
Err(_) => {
|
||||||
tracing::error!("Failed to add subscriber to the database");
|
tracing::error!("Failed to add subscriber to the database");
|
||||||
|
return (StatusCode::INTERNAL_SERVER_ERROR, "Something went wrong").into_response();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -37,8 +75,8 @@ pub async fn subscribe(State(pool): State<PgPool>, Form(form): Form<FormData>) {
|
||||||
name = "Saving new subscriber details in the database",
|
name = "Saving new subscriber details in the database",
|
||||||
skip(form, pool)
|
skip(form, pool)
|
||||||
)]
|
)]
|
||||||
pub async fn insert_subscriber(pool: &PgPool, form: &FormData) -> Result<(), sqlx::Error> {
|
pub async fn insert_subscriber(pool: &PgPool, form: &FormData) -> Result<()> {
|
||||||
sqlx::query!(
|
let _ = sqlx::query!(
|
||||||
r#"
|
r#"
|
||||||
INSERT INTO subscriptions (id, email, name, subscribed_at)
|
INSERT INTO subscriptions (id, email, name, subscribed_at)
|
||||||
VALUES ($1, $2, $3, $4)
|
VALUES ($1, $2, $3, $4)
|
||||||
|
|
@ -55,7 +93,7 @@ pub async fn insert_subscriber(pool: &PgPool, form: &FormData) -> Result<(), sql
|
||||||
.map_err(|e| {
|
.map_err(|e| {
|
||||||
tracing::error!("Failed to execute query: {:?}", e);
|
tracing::error!("Failed to execute query: {:?}", e);
|
||||||
e
|
e
|
||||||
})?;
|
});
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -70,7 +70,7 @@ async fn subscribe_returns_a_422_when_data_is_missing() {
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
422,
|
422,
|
||||||
response.status().as_u16(),
|
response.status().as_u16(),
|
||||||
"The API did not fail with 400 Bad Request when the payload was {}.",
|
"The API did not fail with 422 when the payload was {}.",
|
||||||
error_message
|
error_message
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
@ -104,6 +104,37 @@ async fn subscribe_returns_a_200_for_valid_form_data() {
|
||||||
assert_eq!(saved.name, "le guin");
|
assert_eq!(saved.name, "le guin");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn subscribe_returns_a_200_when_fields_are_present_but_empty() {
|
||||||
|
// Arrange
|
||||||
|
let app = spawn_app().await;
|
||||||
|
let client = reqwest::Client::new();
|
||||||
|
let test_cases = vec![
|
||||||
|
("name=&email=ursula_le_guin%40gmail.com", "empty name"),
|
||||||
|
("name=Ursula&email=", "empty email"),
|
||||||
|
("name=Ursula&email=definitely-not-an-email", "invalid email"),
|
||||||
|
];
|
||||||
|
|
||||||
|
for (body, description) in test_cases {
|
||||||
|
// Act
|
||||||
|
let response = client
|
||||||
|
.post(&format!("{}/subscriptions", &app.address))
|
||||||
|
.header("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
.body(body)
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.expect("Failed to execute request.");
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
assert_eq!(
|
||||||
|
200,
|
||||||
|
response.status().as_u16(),
|
||||||
|
"The API did not return a 200 OK when the payload was {}.",
|
||||||
|
description
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
async fn spawn_app() -> TestApp {
|
async fn spawn_app() -> TestApp {
|
||||||
// The first time `initialize` is invoked the code in `TRACING` is executed.
|
// The first time `initialize` is invoked the code in `TRACING` is executed.
|
||||||
// All other invocations will instead skip execution.
|
// All other invocations will instead skip execution.
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue