Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rate limit websocket joins. #2165

Merged
merged 4 commits into from
Mar 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,4 @@ doku = "0.10.2"
opentelemetry = { version = "0.16", features = ["rt-tokio"] }
opentelemetry-otlp = "0.9"
tracing-opentelemetry = "0.16"
parking_lot = "0.12"
1 change: 1 addition & 0 deletions crates/apub/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ background-jobs = "0.11.0"
reqwest = { version = "0.11.7", features = ["json"] }
html2md = "0.2.13"
once_cell = "1.8.0"
parking_lot = "0.12"

[dev-dependencies]
serial_test = "0.5.1"
Expand Down
2 changes: 1 addition & 1 deletion crates/apub/src/objects/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,10 @@ pub(crate) mod tests {
LemmyError,
};
use lemmy_websocket::{chat_server::ChatServer, LemmyContext};
use parking_lot::Mutex;
use reqwest::Client;
use reqwest_middleware::ClientBuilder;
use std::sync::Arc;
use tokio::sync::Mutex;

// TODO: would be nice if we didnt have to use a full context for tests.
// or at least write a helper function so this code is shared with main.rs
Expand Down
1 change: 1 addition & 0 deletions crates/utils/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ uuid = { version = "0.8.2", features = ["serde", "v4"] }
encoding = "0.2.33"
html2text = "0.2.1"
rosetta-i18n = "0.1"
parking_lot = "0.12"

[build-dependencies]
rosetta-build = "0.1"
10 changes: 5 additions & 5 deletions crates/utils/src/rate_limit/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use actix_web::{
HttpResponse,
};
use futures::future::{ok, Ready};
use parking_lot::Mutex;
use rate_limiter::{RateLimitType, RateLimiter};
use std::{
future::Future,
Expand All @@ -12,7 +13,6 @@ use std::{
sync::Arc,
task::{Context, Poll},
};
use tokio::sync::Mutex;

pub mod rate_limiter;

Expand Down Expand Up @@ -68,20 +68,20 @@ impl RateLimit {

impl RateLimited {
/// Returns true if the request passed the rate limit, false if it failed and should be rejected.
pub async fn check(self, ip_addr: IpAddr) -> bool {
pub fn check(self, ip_addr: IpAddr) -> bool {
// Does not need to be blocking because the RwLock in settings never held across await points,
// and the operation here locks only long enough to clone
let rate_limit = self.rate_limit_config;

let mut limiter = self.rate_limiter.lock().await;

let (kind, interval) = match self.type_ {
RateLimitType::Message => (rate_limit.message, rate_limit.message_per_second),
RateLimitType::Post => (rate_limit.post, rate_limit.post_per_second),
RateLimitType::Register => (rate_limit.register, rate_limit.register_per_second),
RateLimitType::Image => (rate_limit.image, rate_limit.image_per_second),
RateLimitType::Comment => (rate_limit.comment, rate_limit.comment_per_second),
};
let mut limiter = self.rate_limiter.lock();
asonix marked this conversation as resolved.
Show resolved Hide resolved

limiter.check_rate_limit_full(self.type_, &ip_addr, kind, interval)
}
}
Expand Down Expand Up @@ -127,7 +127,7 @@ where
let service = self.service.clone();

Box::pin(async move {
if rate_limited.check(ip_addr).await {
if rate_limited.check(ip_addr) {
service.call(req).await
} else {
let (http_req, _) = req.into_parts();
Expand Down
1 change: 1 addition & 0 deletions crates/websocket/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,4 @@ actix-web = { version = "4.0.0", default-features = false, features = ["rustls"]
actix-web-actors = { version = "4.1.0", default-features = false }
opentelemetry = "0.16"
tracing-opentelemetry = "0.16"
parking_lot = "0.12"
14 changes: 7 additions & 7 deletions crates/websocket/src/chat_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -481,19 +481,19 @@ impl ChatServer {
// check if api call passes the rate limit, and generate future for later execution
let (passed, fut) = if let Ok(user_operation_crud) = UserOperationCrud::from_str(op) {
let passed = match user_operation_crud {
UserOperationCrud::Register => rate_limiter.register().check(ip).await,
UserOperationCrud::CreatePost => rate_limiter.post().check(ip).await,
UserOperationCrud::CreateCommunity => rate_limiter.register().check(ip).await,
UserOperationCrud::CreateComment => rate_limiter.comment().check(ip).await,
_ => rate_limiter.message().check(ip).await,
UserOperationCrud::Register => rate_limiter.register().check(ip),
UserOperationCrud::CreatePost => rate_limiter.post().check(ip),
UserOperationCrud::CreateCommunity => rate_limiter.register().check(ip),
UserOperationCrud::CreateComment => rate_limiter.comment().check(ip),
_ => rate_limiter.message().check(ip),
};
let fut = (message_handler_crud)(context, msg.id, user_operation_crud, data);
(passed, fut)
} else {
let user_operation = UserOperation::from_str(op)?;
let passed = match user_operation {
UserOperation::GetCaptcha => rate_limiter.post().check(ip).await,
_ => rate_limiter.message().check(ip).await,
UserOperation::GetCaptcha => rate_limiter.post().check(ip),
_ => rate_limiter.message().check(ip),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This means that we are checking the message rate limit twice in a row. Should be removed (also in line 488), so that only specific rate limits (register, post etc) are checked here, in addition to message rate limit.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made a new PR to address this.

};
let fut = (message_handler)(context, msg.id, user_operation, data);
(passed, fut)
Expand Down
25 changes: 24 additions & 1 deletion crates/websocket/src/routes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use crate::{
use actix::prelude::*;
use actix_web::{web, Error, HttpRequest, HttpResponse};
use actix_web_actors::ws;
use lemmy_utils::{utils::get_ip, ConnectionId, IpAddr};
use lemmy_utils::{rate_limit::RateLimit, utils::get_ip, ConnectionId, IpAddr};
use std::time::{Duration, Instant};
use tracing::{debug, error, info};

Expand All @@ -20,13 +20,15 @@ pub async fn chat_route(
req: HttpRequest,
stream: web::Payload,
context: web::Data<LemmyContext>,
rate_limiter: web::Data<RateLimit>,
) -> Result<HttpResponse, Error> {
ws::start(
WsSession {
cs_addr: context.chat_server().to_owned(),
id: 0,
hb: Instant::now(),
ip: get_ip(&req.connection_info()),
rate_limiter: rate_limiter.as_ref().to_owned(),
},
&req,
stream,
Expand All @@ -41,6 +43,8 @@ struct WsSession {
/// Client must send ping at least once per 10 seconds (CLIENT_TIMEOUT),
/// otherwise we drop connection.
hb: Instant,
/// A rate limiter for websocket joins
rate_limiter: RateLimit,
}

impl Actor for WsSession {
Expand All @@ -57,6 +61,11 @@ impl Actor for WsSession {
// before processing any other events.
// across all routes within application
let addr = ctx.address();

if !self.rate_limit_check(ctx) {
return;
}

self
.cs_addr
.send(Connect {
Expand Down Expand Up @@ -98,6 +107,10 @@ impl Handler<WsMessage> for WsSession {
/// WebSocket message handler
impl StreamHandler<Result<ws::Message, ws::ProtocolError>> for WsSession {
fn handle(&mut self, result: Result<ws::Message, ws::ProtocolError>, ctx: &mut Self::Context) {
if !self.rate_limit_check(ctx) {
return;
}

let message = match result {
Ok(m) => m,
Err(e) => {
Expand Down Expand Up @@ -169,4 +182,14 @@ impl WsSession {
ctx.ping(b"");
});
}

/// Check the rate limit, and stop the ctx if it fails
fn rate_limit_check(&mut self, ctx: &mut ws::WebsocketContext<Self>) -> bool {
let check = self.rate_limiter.message().check(self.ip.to_owned());
if !check {
debug!("Websocket join with IP: {} has been rate limited.", self.ip);
ctx.stop()
}
check
}
}
3 changes: 2 additions & 1 deletion src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@ use lemmy_utils::{
REQWEST_TIMEOUT,
};
use lemmy_websocket::{chat_server::ChatServer, LemmyContext};
use parking_lot::Mutex;
use reqwest::Client;
use reqwest_middleware::ClientBuilder;
use reqwest_tracing::TracingMiddleware;
use std::{env, sync::Arc, thread};
use tokio::sync::Mutex;
use tracing_actix_web::TracingLogger;

embed_migrations!();
Expand Down Expand Up @@ -136,6 +136,7 @@ async fn main() -> Result<(), LemmyError> {
.wrap(actix_web::middleware::Logger::default())
.wrap(TracingLogger::<QuieterRootSpanBuilder>::new())
.app_data(Data::new(context))
.app_data(Data::new(rate_limiter.clone()))
// The routes
.configure(|cfg| api_routes::config(cfg, &rate_limiter))
.configure(|cfg| lemmy_apub::http::routes::config(cfg, &settings))
Expand Down