Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve JWT RSA key initialization and avoid saving public key #4085

Merged
merged 1 commit into from
Mar 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 43 additions & 16 deletions src/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
//
use chrono::{Duration, Utc};
use num_traits::FromPrimitive;
use once_cell::sync::Lazy;
use once_cell::sync::{Lazy, OnceCell};

use jsonwebtoken::{self, errors::ErrorKind, Algorithm, DecodingKey, EncodingKey, Header};
use openssl::rsa::Rsa;
use serde::de::DeserializeOwned;
use serde::ser::Serialize;

Expand All @@ -26,23 +27,45 @@ static JWT_SEND_ISSUER: Lazy<String> = Lazy::new(|| format!("{}|send", CONFIG.do
static JWT_ORG_API_KEY_ISSUER: Lazy<String> = Lazy::new(|| format!("{}|api.organization", CONFIG.domain_origin()));
static JWT_FILE_DOWNLOAD_ISSUER: Lazy<String> = Lazy::new(|| format!("{}|file_download", CONFIG.domain_origin()));

static PRIVATE_RSA_KEY: Lazy<EncodingKey> = Lazy::new(|| {
let key =
std::fs::read(CONFIG.private_rsa_key()).unwrap_or_else(|e| panic!("Error loading private RSA Key. \n{e}"));
EncodingKey::from_rsa_pem(&key).unwrap_or_else(|e| panic!("Error decoding private RSA Key.\n{e}"))
});
static PUBLIC_RSA_KEY: Lazy<DecodingKey> = Lazy::new(|| {
let key = std::fs::read(CONFIG.public_rsa_key()).unwrap_or_else(|e| panic!("Error loading public RSA Key. \n{e}"));
DecodingKey::from_rsa_pem(&key).unwrap_or_else(|e| panic!("Error decoding public RSA Key.\n{e}"))
});
static PRIVATE_RSA_KEY: OnceCell<EncodingKey> = OnceCell::new();
static PUBLIC_RSA_KEY: OnceCell<DecodingKey> = OnceCell::new();

pub fn load_keys() {
Lazy::force(&PRIVATE_RSA_KEY);
Lazy::force(&PUBLIC_RSA_KEY);
pub fn initialize_keys() -> Result<(), crate::error::Error> {
let mut priv_key_buffer = Vec::with_capacity(2048);

let priv_key = {
let mut priv_key_file = File::options().create(true).read(true).write(true).open(CONFIG.private_rsa_key())?;

#[allow(clippy::verbose_file_reads)]
let bytes_read = priv_key_file.read_to_end(&mut priv_key_buffer)?;

if bytes_read > 0 {
Rsa::private_key_from_pem(&priv_key_buffer[..bytes_read])?
} else {
// Only create the key if the file doesn't exist or is empty
let rsa_key = openssl::rsa::Rsa::generate(2048)?;
priv_key_buffer = rsa_key.private_key_to_pem()?;
priv_key_file.write_all(&priv_key_buffer)?;
info!("Private key created correctly.");
rsa_key
}
};

let pub_key_buffer = priv_key.public_key_to_pem()?;

let enc = EncodingKey::from_rsa_pem(&priv_key_buffer)?;
let dec: DecodingKey = DecodingKey::from_rsa_pem(&pub_key_buffer)?;
if PRIVATE_RSA_KEY.set(enc).is_err() {
err!("PRIVATE_RSA_KEY must only be initialized once")
}
if PUBLIC_RSA_KEY.set(dec).is_err() {
err!("PUBLIC_RSA_KEY must only be initialized once")
}
Ok(())
}

pub fn encode_jwt<T: Serialize>(claims: &T) -> String {
match jsonwebtoken::encode(&JWT_HEADER, claims, &PRIVATE_RSA_KEY) {
match jsonwebtoken::encode(&JWT_HEADER, claims, PRIVATE_RSA_KEY.wait()) {
Ok(token) => token,
Err(e) => panic!("Error encoding jwt {e}"),
}
Expand All @@ -56,7 +79,7 @@ fn decode_jwt<T: DeserializeOwned>(token: &str, issuer: String) -> Result<T, Err
validation.set_issuer(&[issuer]);

let token = token.replace(char::is_whitespace, "");
match jsonwebtoken::decode(&token, &PUBLIC_RSA_KEY, &validation) {
match jsonwebtoken::decode(&token, PUBLIC_RSA_KEY.wait(), &validation) {
Ok(d) => Ok(d.claims),
Err(err) => match *err.kind() {
ErrorKind::InvalidToken => err!("Token is invalid"),
Expand Down Expand Up @@ -799,7 +822,11 @@ impl<'r> FromRequest<'r> for OwnerHeaders {
//
// Client IP address detection
//
use std::net::IpAddr;
use std::{
fs::File,
io::{Read, Write},
net::IpAddr,
};

pub struct ClientIp {
pub ip: IpAddr,
Expand Down
5 changes: 1 addition & 4 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1164,7 +1164,7 @@ impl Config {
}

pub fn delete_user_config(&self) -> Result<(), Error> {
crate::util::delete_file(&CONFIG_FILE)?;
std::fs::remove_file(&*CONFIG_FILE)?;

// Empty user config
let usr = ConfigBuilder::default();
Expand All @@ -1189,9 +1189,6 @@ impl Config {
pub fn private_rsa_key(&self) -> String {
format!("{}.pem", CONFIG.rsa_key_filename())
}
pub fn public_rsa_key(&self) -> String {
format!("{}.pub.pem", CONFIG.rsa_key_filename())
}
pub fn mail_enabled(&self) -> bool {
let inner = &self.inner.read().unwrap().config;
inner._enable_smtp && (inner.smtp_host.is_some() || inner.use_sendmail)
Expand Down
2 changes: 1 addition & 1 deletion src/db/models/attachment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ impl Attachment {

let file_path = &self.get_file_path();

match crate::util::delete_file(file_path) {
match std::fs::remove_file(file_path) {
// Ignore "file not found" errors. This can happen when the
// upstream caller has already cleaned up the file as part of
// its own error handling.
Expand Down
27 changes: 1 addition & 26 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ async fn main() -> Result<(), Error> {
let extra_debug = matches!(level, LF::Trace | LF::Debug);

check_data_folder().await;
check_rsa_keys().unwrap_or_else(|_| {
auth::initialize_keys().unwrap_or_else(|_| {
error!("Error creating keys, exiting...");
exit(1);
});
Expand Down Expand Up @@ -444,31 +444,6 @@ async fn container_data_folder_is_persistent(data_folder: &str) -> bool {
true
}

fn check_rsa_keys() -> Result<(), crate::error::Error> {
// If the RSA keys don't exist, try to create them
let priv_path = CONFIG.private_rsa_key();
let pub_path = CONFIG.public_rsa_key();

if !util::file_exists(&priv_path) {
let rsa_key = openssl::rsa::Rsa::generate(2048)?;

let priv_key = rsa_key.private_key_to_pem()?;
crate::util::write_file(&priv_path, &priv_key)?;
info!("Private key created correctly.");
}

if !util::file_exists(&pub_path) {
let rsa_key = openssl::rsa::Rsa::private_key_from_pem(&std::fs::read(&priv_path)?)?;

let pub_key = rsa_key.public_key_to_pem()?;
crate::util::write_file(&pub_path, &pub_key)?;
info!("Public key created correctly.");
}

auth::load_keys();
Ok(())
}

fn check_web_vault() {
if !CONFIG.web_vault_enabled() {
return;
Expand Down
42 changes: 2 additions & 40 deletions src/util.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
//
// Web Headers and caching
//
use std::{
collections::HashMap,
io::{Cursor, ErrorKind},
ops::Deref,
};
use std::{collections::HashMap, io::Cursor, ops::Deref, path::Path};

use num_traits::ToPrimitive;
use rocket::{
Expand Down Expand Up @@ -334,40 +330,6 @@ impl Fairing for BetterLogging {
}
}

//
// File handling
//
use std::{
fs::{self, File},
io::Result as IOResult,
path::Path,
};

pub fn file_exists(path: &str) -> bool {
Path::new(path).exists()
}

pub fn write_file(path: &str, content: &[u8]) -> Result<(), crate::error::Error> {
use std::io::Write;
let mut f = match File::create(path) {
Ok(file) => file,
Err(e) => {
if e.kind() == ErrorKind::PermissionDenied {
error!("Can't create '{}': Permission denied", path);
}
return Err(From::from(e));
}
};

f.write_all(content)?;
f.flush()?;
Ok(())
}

pub fn delete_file(path: &str) -> IOResult<()> {
fs::remove_file(path)
}

pub fn get_display_size(size: i64) -> String {
const UNITS: [&str; 6] = ["bytes", "KB", "MB", "GB", "TB", "PB"];

Expand Down Expand Up @@ -444,7 +406,7 @@ pub fn get_env_str_value(key: &str) -> Option<String> {
match (value_from_env, value_file) {
(Ok(_), Ok(_)) => panic!("You should not define both {key} and {key_file}!"),
(Ok(v_env), Err(_)) => Some(v_env),
(Err(_), Ok(v_file)) => match fs::read_to_string(v_file) {
(Err(_), Ok(v_file)) => match std::fs::read_to_string(v_file) {
Ok(content) => Some(content.trim().to_string()),
Err(e) => panic!("Failed to load {key}: {e:?}"),
},
Expand Down