feat: add sanitization check, refactor error returns

This commit is contained in:
Sandro Eiler 2024-02-10 21:33:01 +01:00
parent 13162f6470
commit d7d37341ba
6 changed files with 106 additions and 7 deletions

15
Cargo.lock generated
View file

@ -1049,6 +1049,7 @@ dependencies = [
"serde-aux",
"serde_json",
"sqlx",
"strum_macros",
"tokio",
"tower",
"tower-http",
@ -1056,6 +1057,7 @@ dependencies = [
"tracing-bunyan-formatter",
"tracing-log 0.2.0",
"tracing-subscriber",
"unicode-segmentation",
"uuid",
]
@ -2253,6 +2255,19 @@ dependencies = [
"unicode-normalization",
]
[[package]]
name = "strum_macros"
version = "0.26.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7a3417fc93d76740d974a01654a09777cb500428cc874ca9f45edfe0c4d4cd18"
dependencies = [
"heck",
"proc-macro2",
"quote",
"rustversion",
"syn 2.0.32",
]
[[package]]
name = "subtle"
version = "2.5.0"

View file

@ -34,7 +34,8 @@ tracing-subscriber = { version = "0.3", features = ["registry", "env-filter"] }
tracing-bunyan-formatter = "0.3"
tracing-log = "0.2"
secrecy = { version = "0.8", features = ["serde"] }
# lazy-regex = "3"
unicode-segmentation = "1"
strum_macros = "0.26"
# async-trait = "0.1"
# strum_macros = "0.25"

View file

@ -15,7 +15,11 @@ pub enum Error {
AuthFailCtxNotInRequestExt,
// -- Model errors.
// FIXME: Delete this:
PropertyDeleteFailIdNotFound { id: u64 },
// -- Service errors.
SubscriptionFailInvalidName,
}
impl Error {
@ -30,11 +34,18 @@ impl Error {
| Self::AuthFailCtxNotInRequestExt => (StatusCode::FORBIDDEN, ClientError::NO_AUTH),
// -- Model.
// FIXME: Delete this:
Self::PropertyDeleteFailIdNotFound { .. } => {
(StatusCode::BAD_REQUEST, ClientError::INVALID_PARAMS)
}
// -- Subscription.
Self::SubscriptionFailInvalidName => {
(StatusCode::BAD_REQUEST, ClientError::INVALID_PARAMS)
}
// -- Fallback.
#[allow(unreachable_patterns)]
_ => (
StatusCode::INTERNAL_SERVER_ERROR,
ClientError::SERVICE_ERROR,
@ -45,7 +56,7 @@ impl Error {
impl IntoResponse for Error {
fn into_response(self) -> Response {
println!("->> {:<12} - {self:?}", "INTO_RESPONSE");
// TODO: trace something here maybe.
// Create a placeholder Axum response.
let mut response = StatusCode::INTERNAL_SERVER_ERROR.into_response();

View file

@ -2,7 +2,10 @@
//!
//! This is an example documentation string for the root of the crate.
pub use self::error::{Error, Result};
pub mod configuration;
pub mod error;
pub mod routes;
pub mod startup;
pub mod telemetry;

View file

@ -1,10 +1,16 @@
use crate::Result;
use axum::extract::State;
use axum::routing::post;
use axum::Form;
use axum::Router;
use axum::{
http::StatusCode,
response::{IntoResponse, Response},
};
use chrono::Utc;
use serde::Deserialize;
use sqlx::PgPool;
use unicode_segmentation::UnicodeSegmentation;
use uuid::Uuid;
#[derive(Debug, Deserialize)]
@ -13,6 +19,32 @@ struct FormData {
name: String,
}
/// Returns `true` if the input satisfies all our validation constraints
/// on subscriber names, `false` otherwise.
pub fn is_valid_name(s: &str) -> bool {
// `.trim()` returns a view over the input `s` without trailing
// whitespace-like characters.
// `.is_empty` checks if the view contains any character.
let is_empty_or_whitespace = s.trim().is_empty();
// A grapheme is defined by the Unicode standard as a "user-perceived"
// character: `å` is a single grapheme, but it is composed of two characters
// (`a` and `̊`).
//
// `graphemes` returns an iterator over the graphemes in the input `s`.
// `true` specifies that we want to use the extended grapheme definition set,
// the recommended one.
let is_too_long = s.graphemes(true).count() > 256;
// Iterate over all characters in the input `s` to check if any of them matches
// one of the characters in the forbidden array.
let forbidden_characters = ['/', '(', ')', '"', '<', '>', '\\', '{', '}'];
let contains_forbidden_characters = s.chars().any(|g| forbidden_characters.contains(&g));
// Return `false` if any of our conditions have been violated
!(is_empty_or_whitespace || is_too_long || contains_forbidden_characters)
}
#[tracing::instrument(
name = "Adding a new subscriber",
skip(form, pool),
@ -22,13 +54,19 @@ struct FormData {
subscriber_name = %form.name
)
)]
pub async fn subscribe(State(pool): State<PgPool>, Form(form): Form<FormData>) {
pub async fn subscribe(State(pool): State<PgPool>, Form(form): Form<FormData>) -> Response {
if !is_valid_name(&form.name) {
tracing::error!("Failed to add subscriber to the database");
return (StatusCode::BAD_REQUEST, "Invalid name").into_response();
}
match insert_subscriber(&pool, &form).await {
Ok(_) => {
tracing::info!("Subscriber added to the database");
return (StatusCode::OK,).into_response();
}
Err(_) => {
tracing::error!("Failed to add subscriber to the database");
return (StatusCode::INTERNAL_SERVER_ERROR, "Something went wrong").into_response();
}
}
}
@ -37,8 +75,8 @@ pub async fn subscribe(State(pool): State<PgPool>, Form(form): Form<FormData>) {
name = "Saving new subscriber details in the database",
skip(form, pool)
)]
pub async fn insert_subscriber(pool: &PgPool, form: &FormData) -> Result<(), sqlx::Error> {
sqlx::query!(
pub async fn insert_subscriber(pool: &PgPool, form: &FormData) -> Result<()> {
let _ = sqlx::query!(
r#"
INSERT INTO subscriptions (id, email, name, subscribed_at)
VALUES ($1, $2, $3, $4)
@ -55,7 +93,7 @@ pub async fn insert_subscriber(pool: &PgPool, form: &FormData) -> Result<(), sql
.map_err(|e| {
tracing::error!("Failed to execute query: {:?}", e);
e
})?;
});
Ok(())
}

View file

@ -70,7 +70,7 @@ async fn subscribe_returns_a_422_when_data_is_missing() {
assert_eq!(
422,
response.status().as_u16(),
"The API did not fail with 400 Bad Request when the payload was {}.",
"The API did not fail with 422 when the payload was {}.",
error_message
);
}
@ -104,6 +104,37 @@ async fn subscribe_returns_a_200_for_valid_form_data() {
assert_eq!(saved.name, "le guin");
}
#[tokio::test]
async fn subscribe_returns_a_200_when_fields_are_present_but_empty() {
// Arrange
let app = spawn_app().await;
let client = reqwest::Client::new();
let test_cases = vec![
("name=&email=ursula_le_guin%40gmail.com", "empty name"),
("name=Ursula&email=", "empty email"),
("name=Ursula&email=definitely-not-an-email", "invalid email"),
];
for (body, description) in test_cases {
// Act
let response = client
.post(&format!("{}/subscriptions", &app.address))
.header("Content-Type", "application/x-www-form-urlencoded")
.body(body)
.send()
.await
.expect("Failed to execute request.");
// Assert
assert_eq!(
200,
response.status().as_u16(),
"The API did not return a 200 OK when the payload was {}.",
description
);
}
}
async fn spawn_app() -> TestApp {
// The first time `initialize` is invoked the code in `TRACING` is executed.
// All other invocations will instead skip execution.