feat: add input validation

This commit is contained in:
Sandro Eiler 2024-02-12 10:55:23 +01:00
parent d7d37341ba
commit 419be581b3
10 changed files with 271 additions and 122 deletions

7
src/domain/mod.rs Normal file
View file

@ -0,0 +1,7 @@
mod new_subscriber;
mod subscriber_email;
mod subscriber_name;
pub use new_subscriber::NewSubscriber;
pub use subscriber_email::SubscriberEmail;
pub use subscriber_name::SubscriberName;

View file

@ -0,0 +1,7 @@
use crate::domain::subscriber_email::SubscriberEmail;
use crate::domain::subscriber_name::SubscriberName;
pub struct NewSubscriber {
pub email: SubscriberEmail,
pub name: SubscriberName,
}

View file

@ -0,0 +1,76 @@
use validator::validate_email;
#[derive(Debug)]
pub struct SubscriberEmail(String);
impl SubscriberEmail {
pub fn parse(s: String) -> Result<SubscriberEmail, String> {
if validate_email(&s) {
Ok(Self(s))
} else {
Err(format!("{} is not a valid subscriber email.", s))
}
}
}
impl AsRef<str> for SubscriberEmail {
fn as_ref(&self) -> &str {
&self.0
}
}
#[cfg(test)]
mod tests {
use super::SubscriberEmail;
use claims::assert_err;
use fake::faker::internet::en::SafeEmail;
use fake::Fake;
use rand::{rngs::StdRng, SeedableRng};
#[derive(Debug, Clone)]
struct ValidEmailFixture(pub String);
impl quickcheck::Arbitrary for ValidEmailFixture {
fn arbitrary(g: &mut quickcheck::Gen) -> Self {
let mut rng = StdRng::seed_from_u64(u64::arbitrary(g));
let email = SafeEmail().fake_with_rng(&mut rng);
Self(email)
}
}
#[quickcheck_macros::quickcheck]
fn valid_emails_are_parsed_successfully(valid_email: ValidEmailFixture) -> bool {
dbg!(&valid_email.0);
SubscriberEmail::parse(valid_email.0).is_ok()
}
#[test]
fn empty_string_is_rejected() {
let email = "".to_string();
assert_err!(SubscriberEmail::parse(email));
}
#[test]
fn email_missing_at_symbol_is_rejected() {
let email = "ursuladomain.com".to_string();
assert_err!(SubscriberEmail::parse(email));
}
#[test]
fn email_missing_subject_is_rejected() {
let email = "@domain.com".to_string();
assert_err!(SubscriberEmail::parse(email));
}
#[test]
fn email_sharp_s_prefix_is_rejected() {
let email = "joe.ßchmidt@example.com".to_string();
assert_err!(SubscriberEmail::parse(email));
}
#[test]
fn email_with_invalid_characters_is_rejected() {
let email = "joe.smith@ex ample.com".to_string();
assert_err!(SubscriberEmail::parse(email));
}
}

View file

@ -0,0 +1,70 @@
use unicode_segmentation::UnicodeSegmentation;
#[derive(Debug)]
pub struct SubscriberName(String);
impl SubscriberName {
/// Returns an instance of `SubscriberName` if the input satisfies all
/// our validation constraints on subscriber names.
pub fn parse(s: String) -> Result<SubscriberName, String> {
let is_empty_or_whitespace = s.trim().is_empty();
let is_too_long = s.graphemes(true).count() > 256;
let forbidden_characters = ['/', '(', ')', '"', '<', '>', '\\', '{', '}'];
let contains_forbidden_characters = s.chars().any(|g| forbidden_characters.contains(&g));
if is_empty_or_whitespace || is_too_long || contains_forbidden_characters {
Err(format!("{} is not a valid subscriber name.", s))
} else {
Ok(Self(s))
}
}
}
impl AsRef<str> for SubscriberName {
fn as_ref(&self) -> &str {
&self.0
}
}
#[cfg(test)]
mod tests {
use crate::domain::SubscriberName;
use claims::{assert_err, assert_ok};
#[test]
fn a_256_grapheme_long_name_is_valid() {
let name = "ё".repeat(256);
assert_ok!(SubscriberName::parse(name));
}
#[test]
fn a_name_longer_than_256_graphemes_is_rejected() {
let name = "a".repeat(257);
assert_err!(SubscriberName::parse(name));
}
#[test]
fn whitespace_only_names_are_rejected() {
let name = " ".to_string();
assert_err!(SubscriberName::parse(name));
}
#[test]
fn empty_string_is_rejected() {
let name = "".to_string();
assert_err!(SubscriberName::parse(name));
}
#[test]
fn names_containing_an_invalid_character_are_rejected() {
for name in &['/', '(', ')', '"', '<', '>', '\\', '{', '}'] {
let name = name.to_string();
assert_err!(SubscriberName::parse(name));
}
}
#[test]
fn a_valid_name_is_parsed_successfully() {
let name = "Ursula Le Guin".to_string();
assert_ok!(SubscriberName::parse(name));
}
}

View file

@ -1,78 +0,0 @@
use axum::http::StatusCode;
use axum::response::{IntoResponse, Response};
use serde::Serialize;
pub type Result<T> = core::result::Result<T, Error>;
#[derive(Clone, Debug, Serialize, strum_macros::AsRefStr)]
#[serde(tag = "type", content = "data")]
pub enum Error {
LoginFail,
// -- Auth errors.
AuthFailNoAuthTokenCookie,
AuthFailTokenWrongFormat,
AuthFailCtxNotInRequestExt,
// -- Model errors.
// FIXME: Delete this:
PropertyDeleteFailIdNotFound { id: u64 },
// -- Service errors.
SubscriptionFailInvalidName,
}
impl Error {
pub fn client_status_and_error(&self) -> (StatusCode, ClientError) {
match self {
// -- Login.
Self::LoginFail => (StatusCode::UNAUTHORIZED, ClientError::LOGIN_FAIL),
// -- Auth.
Self::AuthFailNoAuthTokenCookie
| Self::AuthFailTokenWrongFormat
| Self::AuthFailCtxNotInRequestExt => (StatusCode::FORBIDDEN, ClientError::NO_AUTH),
// -- Model.
// FIXME: Delete this:
Self::PropertyDeleteFailIdNotFound { .. } => {
(StatusCode::BAD_REQUEST, ClientError::INVALID_PARAMS)
}
// -- Subscription.
Self::SubscriptionFailInvalidName => {
(StatusCode::BAD_REQUEST, ClientError::INVALID_PARAMS)
}
// -- Fallback.
#[allow(unreachable_patterns)]
_ => (
StatusCode::INTERNAL_SERVER_ERROR,
ClientError::SERVICE_ERROR,
),
}
}
}
impl IntoResponse for Error {
fn into_response(self) -> Response {
// TODO: trace something here maybe.
// Create a placeholder Axum response.
let mut response = StatusCode::INTERNAL_SERVER_ERROR.into_response();
// Insert the Error into the response.
response.extensions_mut().insert(self);
response
}
}
#[derive(Debug, strum_macros::AsRefStr)]
#[allow(non_camel_case_types)]
pub enum ClientError {
LOGIN_FAIL,
NO_AUTH,
INVALID_PARAMS,
SERVICE_ERROR,
}

View file

@ -2,10 +2,8 @@
//!
//! This is an example documentation string for the root of the crate.
pub use self::error::{Error, Result};
pub mod configuration;
pub mod error;
pub mod domain;
pub mod routes;
pub mod startup;
pub mod telemetry;

View file

@ -1,4 +1,5 @@
use crate::Result;
use crate::domain::SubscriberEmail;
use crate::domain::{NewSubscriber, SubscriberName};
use axum::extract::State;
use axum::routing::post;
use axum::Form;
@ -10,7 +11,6 @@ use axum::{
use chrono::Utc;
use serde::Deserialize;
use sqlx::PgPool;
use unicode_segmentation::UnicodeSegmentation;
use uuid::Uuid;
#[derive(Debug, Deserialize)]
@ -19,32 +19,6 @@ struct FormData {
name: String,
}
/// Returns `true` if the input satisfies all our validation constraints
/// on subscriber names, `false` otherwise.
pub fn is_valid_name(s: &str) -> bool {
// `.trim()` returns a view over the input `s` without trailing
// whitespace-like characters.
// `.is_empty` checks if the view contains any character.
let is_empty_or_whitespace = s.trim().is_empty();
// A grapheme is defined by the Unicode standard as a "user-perceived"
// character: `å` is a single grapheme, but it is composed of two characters
// (`a` and `̊`).
//
// `graphemes` returns an iterator over the graphemes in the input `s`.
// `true` specifies that we want to use the extended grapheme definition set,
// the recommended one.
let is_too_long = s.graphemes(true).count() > 256;
// Iterate over all characters in the input `s` to check if any of them matches
// one of the characters in the forbidden array.
let forbidden_characters = ['/', '(', ')', '"', '<', '>', '\\', '{', '}'];
let contains_forbidden_characters = s.chars().any(|g| forbidden_characters.contains(&g));
// Return `false` if any of our conditions have been violated
!(is_empty_or_whitespace || is_too_long || contains_forbidden_characters)
}
#[tracing::instrument(
name = "Adding a new subscriber",
skip(form, pool),
@ -55,17 +29,24 @@ pub fn is_valid_name(s: &str) -> bool {
)
)]
pub async fn subscribe(State(pool): State<PgPool>, Form(form): Form<FormData>) -> Response {
if !is_valid_name(&form.name) {
tracing::error!("Failed to add subscriber to the database");
return (StatusCode::BAD_REQUEST, "Invalid name").into_response();
}
match insert_subscriber(&pool, &form).await {
let name = match SubscriberName::parse(form.name) {
Ok(name) => name,
Err(_) => {
return (StatusCode::BAD_REQUEST, "Invalid name").into_response();
}
};
let email = match SubscriberEmail::parse(form.email) {
Ok(email) => email,
Err(_) => {
return (StatusCode::BAD_REQUEST, "Invalid email address").into_response();
}
};
let new_subscriber = NewSubscriber { email, name };
match insert_subscriber(&pool, &new_subscriber).await {
Ok(_) => {
tracing::info!("Subscriber added to the database");
return (StatusCode::OK,).into_response();
}
Err(_) => {
tracing::error!("Failed to add subscriber to the database");
return (StatusCode::INTERNAL_SERVER_ERROR, "Something went wrong").into_response();
}
}
@ -73,17 +54,20 @@ pub async fn subscribe(State(pool): State<PgPool>, Form(form): Form<FormData>) -
#[tracing::instrument(
name = "Saving new subscriber details in the database",
skip(form, pool)
skip(new_subscriber, pool)
)]
pub async fn insert_subscriber(pool: &PgPool, form: &FormData) -> Result<()> {
pub async fn insert_subscriber(
pool: &PgPool,
new_subscriber: &NewSubscriber,
) -> Result<(), sqlx::Error> {
let _ = sqlx::query!(
r#"
INSERT INTO subscriptions (id, email, name, subscribed_at)
VALUES ($1, $2, $3, $4)
"#,
Uuid::new_v4(),
form.email,
form.name,
new_subscriber.email.as_ref(),
new_subscriber.name.as_ref(),
Utc::now()
)
// We use `get_ref` to get an immutable reference to the `PgConnection`