diff --git a/todored/src/config.rs b/todored/src/config.rs index e4615f2..fc9e2b3 100644 --- a/todored/src/config.rs +++ b/todored/src/config.rs @@ -4,3 +4,4 @@ use crate::proxy::ReverseProxyInfoList; required!(REDIS_CONN, String); required!(AXUM_HOST, String); // FIXME: Use SocketAddr when possible optional!(AXUM_XFORWARDED, ReverseProxyInfoList); +required!(TODORED_RATE_LIMIT_CONNECTIONS, usize); diff --git a/todored/src/proxy.rs b/todored/src/proxy.rs index cec4eec..eb8ca30 100644 --- a/todored/src/proxy.rs +++ b/todored/src/proxy.rs @@ -44,12 +44,13 @@ pub struct ExtractReverseProxy { } #[derive(Clone, Debug, PartialEq, Eq)] -pub struct ExtractReverseProxyOption ( Option ); +pub struct ExtractReverseProxyOption ( pub Option ); #[async_trait] impl FromRequestParts for ExtractReverseProxyOption where S: Send + Sync { type Rejection = ResponseError; + // TODO: Pending a security audit, as in second thought this doesn't seem so secure... async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { let proxy_list = config::AXUM_XFORWARDED.clone(); diff --git a/todored/src/routes/board/axum.rs b/todored/src/routes/board/axum.rs index f155eb7..5d3f1cd 100644 --- a/todored/src/routes/board/axum.rs +++ b/todored/src/routes/board/axum.rs @@ -1,17 +1,53 @@ -use axum::Extension; +use axum::{Extension}; use axum::extract::{Path, WebSocketUpgrade}; +use axum::http::StatusCode; +use axum::response::IntoResponse; use crate::kebab::Skewer; +use crate::proxy::ExtractReverseProxyOption; use super::ws; pub(crate) async fn handler( Path(board): Path, Extension(rclient): Extension, + ExtractReverseProxyOption(proxy_opt): ExtractReverseProxyOption, upgrade_request: WebSocketUpgrade, ) -> axum::response::Response { log::trace!("Kebabifying board name..."); let board = board.to_kebab_lowercase(); log::trace!("Kebabified board name to: {board:?}"); + log::trace!("Creating Redis connection for the handler..."); + let handle_redis = rclient.get_async_connection().await; + if handle_redis.is_err() { + log::error!("Could not open Redis connection for the handler."); + return Err::<(), StatusCode>(StatusCode::INTERNAL_SERVER_ERROR).into_response() + } + let mut handle_redis = handle_redis.unwrap(); + log::trace!("Created Redis connection for the main thread!"); + + let count = *crate::config::TODORED_RATE_LIMIT_CONNECTIONS; + log::trace!("TODORED_RATE_LIMIT_CONNECTIONS is {count}."); + if count > 0 { + if proxy_opt.is_none() { + log::error!("TODORED_RATE_LIMIT_CONNECTIONS is {count}, but a request has been received without the proxy headers!"); + return Err::<(), StatusCode>(StatusCode::BAD_GATEWAY).into_response(); + } + let proxy = proxy_opt.unwrap(); + log::trace!("Checking X-Forwarded-For header..."); + let ip = proxy.r#for.ip(); + log::trace!("User's IP is: {ip}"); + let key = format!("limit:{{{ip}}}:connections"); + log::trace!("Rate-limiting key is: {key:?}"); + + log::trace!("Running rate-limiting function..."); + let result = super::limit::rate_limit_by_key(&mut handle_redis, key, 1, count, 1).await; + + if result.is_err() { + log::warn!("User with IP {ip} hit connection rate limit!"); + return Err::<(), StatusCode>(StatusCode::BAD_REQUEST).into_response() + } + } + log::trace!("Received websocket request, upgrading..."); upgrade_request.on_upgrade(|websocket| ws::handler(board, rclient, websocket)) } diff --git a/todored/src/routes/board/limit.rs b/todored/src/routes/board/limit.rs index a71caa5..f0ce7a3 100644 --- a/todored/src/routes/board/limit.rs +++ b/todored/src/routes/board/limit.rs @@ -1,27 +1,28 @@ //! Rate limiting for board websocket. use axum::extract::ws::CloseCode; -use redis::Connection; use crate::outcome::LoggableOutcome; -pub fn rate_limit_by_key( - mut rconn: Connection, +pub async fn rate_limit_by_key( + rconn: &mut redis::aio::Connection, key: String, increment: usize, limit: usize, expiration_s: usize ) -> Result<(), CloseCode> { log::trace!("Incrementing rate limit counter for {key:?}..."); - let response: usize = rconn.cmd("INCRBY") + let response: usize = redis::cmd("INCRBY") .arg(&key) .arg(increment) + .query_async::(rconn).await .log_err_to_error("Could not increase rate limit counter") .map_err(|_| 1011u16)?; log::trace!("Refreshing rate limit counter expiration for {key:?}..."); - rconn.cmd("EXPIRE") + let _ = redis::cmd("EXPIRE") .arg(&key) .arg(expiration_s) + .query_async::(rconn).await .log_err_to_warn("Could not set expiration for rate limit counter"); log::trace!("Checking rate limit of {limit} / {expiration_s} s for {key:?}...");