feat: add input validation

This commit is contained in:
Sandro Eiler 2024-02-12 10:55:23 +01:00
parent d7d37341ba
commit 419be581b3
10 changed files with 271 additions and 122 deletions

78
Cargo.lock generated
View file

@ -245,6 +245,15 @@ dependencies = [
"windows-targets 0.48.1",
]
[[package]]
name = "claims"
version = "0.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b6995bbe186456c36307f8ea36be3eefe42f49d106896414e18efc4fb2f846b5"
dependencies = [
"autocfg",
]
[[package]]
name = "config"
version = "0.14.0"
@ -395,6 +404,12 @@ dependencies = [
"powerfmt",
]
[[package]]
name = "deunicode"
version = "1.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3ae2a35373c5c74340b79ae6780b498b2b183915ec5dacf263aac5a099bf485a"
[[package]]
name = "digest"
version = "0.10.7"
@ -440,6 +455,16 @@ dependencies = [
"cfg-if",
]
[[package]]
name = "env_logger"
version = "0.8.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a19187fea3ac7e84da7dacf48de0c45d63c6a76f9490dae389aead16c243fce3"
dependencies = [
"log",
"regex",
]
[[package]]
name = "equivalent"
version = "1.0.1"
@ -484,6 +509,16 @@ version = "2.5.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0206175f82b8d6bf6652ff7d71a1e27fd2e4efde587fd368662814d6ec1d9ce0"
[[package]]
name = "fake"
version = "2.9.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1c25829bde82205da46e1823b2259db6273379f626fc211f126f65654a2669be"
dependencies = [
"deunicode",
"rand",
]
[[package]]
name = "fastrand"
version = "1.9.0"
@ -1040,9 +1075,14 @@ version = "0.2.0"
dependencies = [
"axum",
"chrono",
"claims",
"config",
"fake",
"hyper 1.1.0",
"once_cell",
"quickcheck",
"quickcheck_macros",
"rand",
"reqwest",
"secrecy",
"serde",
@ -1059,6 +1099,7 @@ dependencies = [
"tracing-subscriber",
"unicode-segmentation",
"uuid",
"validator",
]
[[package]]
@ -1538,6 +1579,28 @@ dependencies = [
"unicode-ident",
]
[[package]]
name = "quickcheck"
version = "1.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "588f6378e4dd99458b60ec275b4477add41ce4fa9f64dcba6f15adccb19b50d6"
dependencies = [
"env_logger",
"log",
"rand",
]
[[package]]
name = "quickcheck_macros"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b22a693222d716a9587786f37ac3f6b4faedb5b80c23914e7303ff5a1d8016e9"
dependencies = [
"proc-macro2",
"quote",
"syn 1.0.109",
]
[[package]]
name = "quote"
version = "1.0.29"
@ -2723,6 +2786,21 @@ dependencies = [
"rand",
]
[[package]]
name = "validator"
version = "0.16.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b92f40481c04ff1f4f61f304d61793c7b56ff76ac1469f1beb199b1445b253bd"
dependencies = [
"idna",
"lazy_static",
"regex",
"serde",
"serde_derive",
"serde_json",
"url",
]
[[package]]
name = "valuable"
version = "0.1.0"

View file

@ -36,6 +36,8 @@ tracing-log = "0.2"
secrecy = { version = "0.8", features = ["serde"] }
unicode-segmentation = "1"
strum_macros = "0.26"
validator = "0.16"
# async-trait = "0.1"
# strum_macros = "0.25"
@ -54,3 +56,8 @@ features = [
[dev-dependencies]
reqwest = "0.11"
once_cell = "1"
claims = "0.7"
fake = "2.9.2"
quickcheck = "1.0.3"
quickcheck_macros = "1.0.0"
rand = "0.8.5"

7
src/domain/mod.rs Normal file
View file

@ -0,0 +1,7 @@
mod new_subscriber;
mod subscriber_email;
mod subscriber_name;
pub use new_subscriber::NewSubscriber;
pub use subscriber_email::SubscriberEmail;
pub use subscriber_name::SubscriberName;

View file

@ -0,0 +1,7 @@
use crate::domain::subscriber_email::SubscriberEmail;
use crate::domain::subscriber_name::SubscriberName;
pub struct NewSubscriber {
pub email: SubscriberEmail,
pub name: SubscriberName,
}

View file

@ -0,0 +1,76 @@
use validator::validate_email;
#[derive(Debug)]
pub struct SubscriberEmail(String);
impl SubscriberEmail {
pub fn parse(s: String) -> Result<SubscriberEmail, String> {
if validate_email(&s) {
Ok(Self(s))
} else {
Err(format!("{} is not a valid subscriber email.", s))
}
}
}
impl AsRef<str> for SubscriberEmail {
fn as_ref(&self) -> &str {
&self.0
}
}
#[cfg(test)]
mod tests {
use super::SubscriberEmail;
use claims::assert_err;
use fake::faker::internet::en::SafeEmail;
use fake::Fake;
use rand::{rngs::StdRng, SeedableRng};
#[derive(Debug, Clone)]
struct ValidEmailFixture(pub String);
impl quickcheck::Arbitrary for ValidEmailFixture {
fn arbitrary(g: &mut quickcheck::Gen) -> Self {
let mut rng = StdRng::seed_from_u64(u64::arbitrary(g));
let email = SafeEmail().fake_with_rng(&mut rng);
Self(email)
}
}
#[quickcheck_macros::quickcheck]
fn valid_emails_are_parsed_successfully(valid_email: ValidEmailFixture) -> bool {
dbg!(&valid_email.0);
SubscriberEmail::parse(valid_email.0).is_ok()
}
#[test]
fn empty_string_is_rejected() {
let email = "".to_string();
assert_err!(SubscriberEmail::parse(email));
}
#[test]
fn email_missing_at_symbol_is_rejected() {
let email = "ursuladomain.com".to_string();
assert_err!(SubscriberEmail::parse(email));
}
#[test]
fn email_missing_subject_is_rejected() {
let email = "@domain.com".to_string();
assert_err!(SubscriberEmail::parse(email));
}
#[test]
fn email_sharp_s_prefix_is_rejected() {
let email = "joe.ßchmidt@example.com".to_string();
assert_err!(SubscriberEmail::parse(email));
}
#[test]
fn email_with_invalid_characters_is_rejected() {
let email = "joe.smith@ex ample.com".to_string();
assert_err!(SubscriberEmail::parse(email));
}
}

View file

@ -0,0 +1,70 @@
use unicode_segmentation::UnicodeSegmentation;
#[derive(Debug)]
pub struct SubscriberName(String);
impl SubscriberName {
/// Returns an instance of `SubscriberName` if the input satisfies all
/// our validation constraints on subscriber names.
pub fn parse(s: String) -> Result<SubscriberName, String> {
let is_empty_or_whitespace = s.trim().is_empty();
let is_too_long = s.graphemes(true).count() > 256;
let forbidden_characters = ['/', '(', ')', '"', '<', '>', '\\', '{', '}'];
let contains_forbidden_characters = s.chars().any(|g| forbidden_characters.contains(&g));
if is_empty_or_whitespace || is_too_long || contains_forbidden_characters {
Err(format!("{} is not a valid subscriber name.", s))
} else {
Ok(Self(s))
}
}
}
impl AsRef<str> for SubscriberName {
fn as_ref(&self) -> &str {
&self.0
}
}
#[cfg(test)]
mod tests {
use crate::domain::SubscriberName;
use claims::{assert_err, assert_ok};
#[test]
fn a_256_grapheme_long_name_is_valid() {
let name = "ё".repeat(256);
assert_ok!(SubscriberName::parse(name));
}
#[test]
fn a_name_longer_than_256_graphemes_is_rejected() {
let name = "a".repeat(257);
assert_err!(SubscriberName::parse(name));
}
#[test]
fn whitespace_only_names_are_rejected() {
let name = " ".to_string();
assert_err!(SubscriberName::parse(name));
}
#[test]
fn empty_string_is_rejected() {
let name = "".to_string();
assert_err!(SubscriberName::parse(name));
}
#[test]
fn names_containing_an_invalid_character_are_rejected() {
for name in &['/', '(', ')', '"', '<', '>', '\\', '{', '}'] {
let name = name.to_string();
assert_err!(SubscriberName::parse(name));
}
}
#[test]
fn a_valid_name_is_parsed_successfully() {
let name = "Ursula Le Guin".to_string();
assert_ok!(SubscriberName::parse(name));
}
}

View file

@ -1,78 +0,0 @@
use axum::http::StatusCode;
use axum::response::{IntoResponse, Response};
use serde::Serialize;
pub type Result<T> = core::result::Result<T, Error>;
#[derive(Clone, Debug, Serialize, strum_macros::AsRefStr)]
#[serde(tag = "type", content = "data")]
pub enum Error {
LoginFail,
// -- Auth errors.
AuthFailNoAuthTokenCookie,
AuthFailTokenWrongFormat,
AuthFailCtxNotInRequestExt,
// -- Model errors.
// FIXME: Delete this:
PropertyDeleteFailIdNotFound { id: u64 },
// -- Service errors.
SubscriptionFailInvalidName,
}
impl Error {
pub fn client_status_and_error(&self) -> (StatusCode, ClientError) {
match self {
// -- Login.
Self::LoginFail => (StatusCode::UNAUTHORIZED, ClientError::LOGIN_FAIL),
// -- Auth.
Self::AuthFailNoAuthTokenCookie
| Self::AuthFailTokenWrongFormat
| Self::AuthFailCtxNotInRequestExt => (StatusCode::FORBIDDEN, ClientError::NO_AUTH),
// -- Model.
// FIXME: Delete this:
Self::PropertyDeleteFailIdNotFound { .. } => {
(StatusCode::BAD_REQUEST, ClientError::INVALID_PARAMS)
}
// -- Subscription.
Self::SubscriptionFailInvalidName => {
(StatusCode::BAD_REQUEST, ClientError::INVALID_PARAMS)
}
// -- Fallback.
#[allow(unreachable_patterns)]
_ => (
StatusCode::INTERNAL_SERVER_ERROR,
ClientError::SERVICE_ERROR,
),
}
}
}
impl IntoResponse for Error {
fn into_response(self) -> Response {
// TODO: trace something here maybe.
// Create a placeholder Axum response.
let mut response = StatusCode::INTERNAL_SERVER_ERROR.into_response();
// Insert the Error into the response.
response.extensions_mut().insert(self);
response
}
}
#[derive(Debug, strum_macros::AsRefStr)]
#[allow(non_camel_case_types)]
pub enum ClientError {
LOGIN_FAIL,
NO_AUTH,
INVALID_PARAMS,
SERVICE_ERROR,
}

View file

@ -2,10 +2,8 @@
//!
//! This is an example documentation string for the root of the crate.
pub use self::error::{Error, Result};
pub mod configuration;
pub mod error;
pub mod domain;
pub mod routes;
pub mod startup;
pub mod telemetry;

View file

@ -1,4 +1,5 @@
use crate::Result;
use crate::domain::SubscriberEmail;
use crate::domain::{NewSubscriber, SubscriberName};
use axum::extract::State;
use axum::routing::post;
use axum::Form;
@ -10,7 +11,6 @@ use axum::{
use chrono::Utc;
use serde::Deserialize;
use sqlx::PgPool;
use unicode_segmentation::UnicodeSegmentation;
use uuid::Uuid;
#[derive(Debug, Deserialize)]
@ -19,32 +19,6 @@ struct FormData {
name: String,
}
/// Returns `true` if the input satisfies all our validation constraints
/// on subscriber names, `false` otherwise.
pub fn is_valid_name(s: &str) -> bool {
// `.trim()` returns a view over the input `s` without trailing
// whitespace-like characters.
// `.is_empty` checks if the view contains any character.
let is_empty_or_whitespace = s.trim().is_empty();
// A grapheme is defined by the Unicode standard as a "user-perceived"
// character: `å` is a single grapheme, but it is composed of two characters
// (`a` and `̊`).
//
// `graphemes` returns an iterator over the graphemes in the input `s`.
// `true` specifies that we want to use the extended grapheme definition set,
// the recommended one.
let is_too_long = s.graphemes(true).count() > 256;
// Iterate over all characters in the input `s` to check if any of them matches
// one of the characters in the forbidden array.
let forbidden_characters = ['/', '(', ')', '"', '<', '>', '\\', '{', '}'];
let contains_forbidden_characters = s.chars().any(|g| forbidden_characters.contains(&g));
// Return `false` if any of our conditions have been violated
!(is_empty_or_whitespace || is_too_long || contains_forbidden_characters)
}
#[tracing::instrument(
name = "Adding a new subscriber",
skip(form, pool),
@ -55,17 +29,24 @@ pub fn is_valid_name(s: &str) -> bool {
)
)]
pub async fn subscribe(State(pool): State<PgPool>, Form(form): Form<FormData>) -> Response {
if !is_valid_name(&form.name) {
tracing::error!("Failed to add subscriber to the database");
return (StatusCode::BAD_REQUEST, "Invalid name").into_response();
}
match insert_subscriber(&pool, &form).await {
let name = match SubscriberName::parse(form.name) {
Ok(name) => name,
Err(_) => {
return (StatusCode::BAD_REQUEST, "Invalid name").into_response();
}
};
let email = match SubscriberEmail::parse(form.email) {
Ok(email) => email,
Err(_) => {
return (StatusCode::BAD_REQUEST, "Invalid email address").into_response();
}
};
let new_subscriber = NewSubscriber { email, name };
match insert_subscriber(&pool, &new_subscriber).await {
Ok(_) => {
tracing::info!("Subscriber added to the database");
return (StatusCode::OK,).into_response();
}
Err(_) => {
tracing::error!("Failed to add subscriber to the database");
return (StatusCode::INTERNAL_SERVER_ERROR, "Something went wrong").into_response();
}
}
@ -73,17 +54,20 @@ pub async fn subscribe(State(pool): State<PgPool>, Form(form): Form<FormData>) -
#[tracing::instrument(
name = "Saving new subscriber details in the database",
skip(form, pool)
skip(new_subscriber, pool)
)]
pub async fn insert_subscriber(pool: &PgPool, form: &FormData) -> Result<()> {
pub async fn insert_subscriber(
pool: &PgPool,
new_subscriber: &NewSubscriber,
) -> Result<(), sqlx::Error> {
let _ = sqlx::query!(
r#"
INSERT INTO subscriptions (id, email, name, subscribed_at)
VALUES ($1, $2, $3, $4)
"#,
Uuid::new_v4(),
form.email,
form.name,
new_subscriber.email.as_ref(),
new_subscriber.name.as_ref(),
Utc::now()
)
// We use `get_ref` to get an immutable reference to the `PgConnection`

View file

@ -105,7 +105,7 @@ async fn subscribe_returns_a_200_for_valid_form_data() {
}
#[tokio::test]
async fn subscribe_returns_a_200_when_fields_are_present_but_empty() {
async fn subscribe_returns_a_400_when_fields_are_present_but_invalid() {
// Arrange
let app = spawn_app().await;
let client = reqwest::Client::new();
@ -127,7 +127,7 @@ async fn subscribe_returns_a_200_when_fields_are_present_but_empty() {
// Assert
assert_eq!(
200,
400,
response.status().as_u16(),
"The API did not return a 200 OK when the payload was {}.",
description