feat: add sanitization check, refactor error returns

This commit is contained in:
Sandro Eiler 2024-02-10 21:33:01 +01:00
parent 13162f6470
commit d7d37341ba
6 changed files with 106 additions and 7 deletions

View file

@ -1,10 +1,16 @@
use crate::Result;
use axum::extract::State;
use axum::routing::post;
use axum::Form;
use axum::Router;
use axum::{
http::StatusCode,
response::{IntoResponse, Response},
};
use chrono::Utc;
use serde::Deserialize;
use sqlx::PgPool;
use unicode_segmentation::UnicodeSegmentation;
use uuid::Uuid;
#[derive(Debug, Deserialize)]
@ -13,6 +19,32 @@ struct FormData {
name: String,
}
/// Returns `true` if the input satisfies all our validation constraints
/// on subscriber names, `false` otherwise.
pub fn is_valid_name(s: &str) -> bool {
// `.trim()` returns a view over the input `s` without trailing
// whitespace-like characters.
// `.is_empty` checks if the view contains any character.
let is_empty_or_whitespace = s.trim().is_empty();
// A grapheme is defined by the Unicode standard as a "user-perceived"
// character: `å` is a single grapheme, but it is composed of two characters
// (`a` and `̊`).
//
// `graphemes` returns an iterator over the graphemes in the input `s`.
// `true` specifies that we want to use the extended grapheme definition set,
// the recommended one.
let is_too_long = s.graphemes(true).count() > 256;
// Iterate over all characters in the input `s` to check if any of them matches
// one of the characters in the forbidden array.
let forbidden_characters = ['/', '(', ')', '"', '<', '>', '\\', '{', '}'];
let contains_forbidden_characters = s.chars().any(|g| forbidden_characters.contains(&g));
// Return `false` if any of our conditions have been violated
!(is_empty_or_whitespace || is_too_long || contains_forbidden_characters)
}
#[tracing::instrument(
name = "Adding a new subscriber",
skip(form, pool),
@ -22,13 +54,19 @@ struct FormData {
subscriber_name = %form.name
)
)]
pub async fn subscribe(State(pool): State<PgPool>, Form(form): Form<FormData>) {
pub async fn subscribe(State(pool): State<PgPool>, Form(form): Form<FormData>) -> Response {
if !is_valid_name(&form.name) {
tracing::error!("Failed to add subscriber to the database");
return (StatusCode::BAD_REQUEST, "Invalid name").into_response();
}
match insert_subscriber(&pool, &form).await {
Ok(_) => {
tracing::info!("Subscriber added to the database");
return (StatusCode::OK,).into_response();
}
Err(_) => {
tracing::error!("Failed to add subscriber to the database");
return (StatusCode::INTERNAL_SERVER_ERROR, "Something went wrong").into_response();
}
}
}
@ -37,8 +75,8 @@ pub async fn subscribe(State(pool): State<PgPool>, Form(form): Form<FormData>) {
name = "Saving new subscriber details in the database",
skip(form, pool)
)]
pub async fn insert_subscriber(pool: &PgPool, form: &FormData) -> Result<(), sqlx::Error> {
sqlx::query!(
pub async fn insert_subscriber(pool: &PgPool, form: &FormData) -> Result<()> {
let _ = sqlx::query!(
r#"
INSERT INTO subscriptions (id, email, name, subscribed_at)
VALUES ($1, $2, $3, $4)
@ -55,7 +93,7 @@ pub async fn insert_subscriber(pool: &PgPool, form: &FormData) -> Result<(), sql
.map_err(|e| {
tracing::error!("Failed to execute query: {:?}", e);
e
})?;
});
Ok(())
}