From cd9df8c209984b5c010ad51e03a750769a0090c9 Mon Sep 17 00:00:00 2001 From: Peter Holloway Date: Mon, 16 Dec 2024 10:58:32 +0000 Subject: [PATCH] Add authorization via central OPA instance (#1) Authorisation is still optional when starting the server but can be configured by providing an address for the OPA instance to use as well as a query for each of visit access (used to a limit access to incrementing scan numbers) and admin access (used to limit access to configurations). When enabled, authorisation requires users to provide a bearer token via an `Authorization=Bearer ` header in their requests to the service. --- Cargo.toml | 5 + Dockerfile | 4 +- src/cli.rs | 89 ++++++++++ src/graphql.rs | 43 ++++- src/graphql/auth.rs | 395 ++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 530 insertions(+), 6 deletions(-) create mode 100644 src/graphql/auth.rs diff --git a/Cargo.toml b/Cargo.toml index 8dde146..5eac625 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,6 +11,7 @@ unwrap_used = "deny" async-graphql = { version = "7.0.13", features = ["tracing"] } async-graphql-axum = "7.0.13" axum = "0.7.9" +axum-extra = { version = "0.9.3", features = ["typed-header"] } chrono = "0.4.39" clap = { version = "4.5.23", features = ["cargo", "derive", "env"] } futures = "0.3.31" @@ -19,6 +20,8 @@ opentelemetry-otlp = "0.27.0" opentelemetry-semantic-conventions = "0.27.0" opentelemetry-stdout = "0.27.0" opentelemetry_sdk = { version = "0.27.1", features = ["rt-tokio"] } +reqwest = { version = "0.12.7", features = ["json", "rustls-tls-native-roots"], default-features = false } +serde = { version = "1.0.210", features = ["derive"] } sqlx = { version = "0.8.2", features = ["runtime-tokio", "sqlite"] } tokio = { version = "1.42.0", features = ["full"] } tracing = "0.1.41" @@ -27,6 +30,8 @@ tracing-subscriber = { version = "0.3.19", features = ["env-filter"] } url = "2.5.4" [dev-dependencies] +assert_matches = "1.5.0" async-std = { version = "1.13.0", features = ["attributes"], default-features = false } +httpmock = { version = "0.7.0", default-features = false } rstest = "0.23.0" tempfile = "3.14.0" diff --git a/Dockerfile b/Dockerfile index 7111b28..c940537 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,8 +1,8 @@ FROM rust:1.82.0-slim AS build RUN rustup target add x86_64-unknown-linux-musl && \ - apt update && \ - apt install -y musl-tools musl-dev && \ + apt-get update && \ + apt-get install -y musl-tools musl-dev && \ update-ca-certificates COPY ./Cargo.toml ./Cargo.toml diff --git a/src/cli.rs b/src/cli.rs index c862194..8809a20 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -61,6 +61,28 @@ pub struct ServeOptions { /// The root directory for external number tracking #[clap(long, env = "NUMTRACKER_ROOT_DIRECTORY")] root_directory: Option, + #[clap(flatten, next_help_heading = "Authorization")] + pub policy: Option, +} + +#[derive(Debug, Default, Parser)] +#[group(requires = "policy_host")] +pub struct PolicyOptions { + /// Beamline Policy Endpoint + /// + /// eg, https://authz.diamond.ac.uk + #[clap(long = "policy", required = false)] + pub policy_host: String, + /// The Rego rule used to generate visit access data + /// + /// eg. v1/data/diamond/policy/session/write_to_beamline_visit + #[clap(long, required = false)] + pub access_query: String, + /// The Rego rule used to generate admin access data + /// + /// eg. v1/data/diamond/policy/admin/configure_beamline + #[clap(long, required = false)] + pub admin_query: String, } #[derive(Debug, Args)] @@ -132,6 +154,7 @@ impl TracingOptions { mod tests { use std::path::PathBuf; + use assert_matches::assert_matches; use clap::error::ErrorKind; use clap::Parser; use tracing::Level; @@ -154,6 +177,8 @@ mod tests { }; assert_eq!(cmd.addr(), ("0.0.0.0".parse().unwrap(), 8000)); assert_eq!(cmd.root_directory(), None); + + assert_matches!(cmd.policy, None); } #[test] @@ -174,6 +199,70 @@ mod tests { }; assert_eq!(cmd.addr(), ("127.0.0.1".parse().unwrap(), 8765)); assert_eq!(cmd.root_directory, Some("/tmp/trackers".into())); + assert_matches!(cmd.policy, None); + } + + #[test] + fn policy_arguments() { + let cli = Cli::try_parse_from([ + APP, + "serve", + "--policy", + "opa.example.com", + "--admin-query", + "demo/admin_check", + "--access-query", + "demo/access_check", + ]) + .unwrap(); + let cmd = assert_matches!(cli.command, Command::Serve(cmd) => cmd); + let policy = assert_matches!(cmd.policy, Some(plc) => plc); + + assert_eq!(policy.policy_host, "opa.example.com"); + assert_eq!(policy.admin_query, "demo/admin_check"); + assert_eq!(policy.access_query, "demo/access_check"); + } + + #[test] + fn missing_admin_query() { + let err = Cli::try_parse_from([ + APP, + "serve", + "--policy", + "opa.example.com", + "--access-query", + "demo/access-query", + ]) + .unwrap_err(); + assert_eq!(err.kind(), ErrorKind::MissingRequiredArgument); + } + + #[test] + fn missing_access_query() { + let err = Cli::try_parse_from([ + APP, + "serve", + "--policy", + "opa.example.com", + "--admin-query", + "demo/admin-query", + ]) + .unwrap_err(); + assert_eq!(err.kind(), ErrorKind::MissingRequiredArgument); + } + + #[test] + fn policy_queries_without_host() { + let err = Cli::try_parse_from([ + APP, + "serve", + "--access-query", + "demo/access-query", + "--admin-query", + "demo/admin-query", + ]) + .unwrap_err(); + assert_eq!(err.kind(), ErrorKind::MissingRequiredArgument); } #[test] diff --git a/src/graphql.rs b/src/graphql.rs index 01bdda1..0a8e5f0 100644 --- a/src/graphql.rs +++ b/src/graphql.rs @@ -16,6 +16,7 @@ use std::any; use std::borrow::Cow; use std::error::Error; use std::fmt::Display; +use std::future::Future; use std::path::{Component, Path, PathBuf}; use std::sync::Arc; @@ -27,9 +28,13 @@ use async_graphql::{ Scalar, ScalarType, Schema, SimpleObject, Value, }; use async_graphql_axum::{GraphQLRequest, GraphQLResponse}; +use auth::{AuthError, PolicyCheck}; use axum::response::{Html, IntoResponse}; use axum::routing::{get, post}; use axum::{Extension, Router}; +use axum_extra::headers::authorization::Bearer; +use axum_extra::headers::Authorization; +use axum_extra::TypedHeader; use chrono::{Datelike, Local}; use tokio::net::TcpListener; use tracing::{info, instrument, trace, warn}; @@ -45,6 +50,8 @@ use crate::paths::{ }; use crate::template::{FieldSource, PathTemplate}; +mod auth; + pub async fn serve_graphql(db: &Path, opts: ServeOptions) { let db = SqliteScanPathService::connect(db) .await @@ -52,19 +59,21 @@ pub async fn serve_graphql(db: &Path, opts: ServeOptions) { let directory_numtracker = NumTracker::for_root_directory(opts.root_directory()) .expect("Could not read external directories"); info!("Serving graphql endpoints on {:?}", opts.addr()); + let addr = opts.addr(); let schema = Schema::build(Query, Mutation, EmptySubscription) .extension(Tracing) .limit_directives(32) .data(db) .data(directory_numtracker) + .data(opts.policy.map(PolicyCheck::new)) .finish(); let app = Router::new() .route("/graphql", post(graphql_handler)) .route("/graphiql", get(graphiql)) .layer(Extension(schema)); - let listener = TcpListener::bind(opts.addr()) + let listener = TcpListener::bind(addr) .await - .unwrap_or_else(|_| panic!("Port {:?} in use", opts.addr())); + .unwrap_or_else(|_| panic!("Port {:?} in use", addr)); axum::serve(listener, app) .await .expect("Can't serve graphql endpoint"); @@ -82,10 +91,13 @@ async fn graphiql() -> impl IntoResponse { #[instrument(skip_all)] async fn graphql_handler( schema: Extension>, + auth_token: Option>>, req: GraphQLRequest, ) -> GraphQLResponse { - let inner = req.into_inner(); - schema.execute(inner).await.into() + schema + .execute(req.into_inner().data(auth_token.map(|header| header.0))) + .await + .into() } /// Read-only API for GraphQL @@ -263,6 +275,7 @@ impl Query { ctx: &Context<'_>, beamline: String, ) -> async_graphql::Result { + check_auth(ctx, |policy, token| policy.check_admin(token, &beamline)).await?; let db = ctx.data::()?; trace!("Getting config for {beamline:?}"); Ok(db.current_configuration(&beamline).await?) @@ -280,6 +293,10 @@ impl Mutation { visit: String, sub: Option, ) -> async_graphql::Result { + check_auth(ctx, |policy, token| { + policy.check_access(token, &beamline, &visit) + }) + .await?; let db = ctx.data::()?; let nt = ctx.data::()?; // There is a race condition here if a process increments the file @@ -312,6 +329,7 @@ impl Mutation { beamline: String, config: ConfigurationUpdates, ) -> async_graphql::Result { + check_auth(ctx, |pc, token| pc.check_admin(token, &beamline)).await?; let db = ctx.data::()?; trace!("Configuring: {beamline}: {config:?}"); let upd = config.into_update(beamline); @@ -322,6 +340,23 @@ impl Mutation { } } +async fn check_auth<'ctx, Check, R>(ctx: &Context<'ctx>, check: Check) -> async_graphql::Result<()> +where + Check: Fn(&'ctx PolicyCheck, Option<&'ctx Authorization>) -> R, + R: Future>, +{ + if let Some(policy) = ctx.data::>()? { + trace!("Auth enabled: checking token"); + let token = ctx.data::>().ok(); + check(policy, token) + .await + .inspect_err(|e| info!("Authorization failed: {e:?}")) + .map_err(async_graphql::Error::from) + } else { + Ok(()) + } +} + #[derive(Debug, InputObject)] struct ConfigurationUpdates { visit: Option>, diff --git a/src/graphql/auth.rs b/src/graphql/auth.rs new file mode 100644 index 0000000..cc837e7 --- /dev/null +++ b/src/graphql/auth.rs @@ -0,0 +1,395 @@ +// Copyright 2024 Diamond Light Source +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::fmt::Display; +use std::str::FromStr; + +use axum_extra::headers::authorization::Bearer; +use axum_extra::headers::Authorization; +use serde::{Deserialize, Serialize}; +use tracing::info; + +use crate::cli::PolicyOptions; + +const AUDIENCE: &str = "account"; + +type Token = Authorization; + +#[derive(Debug, Deserialize)] +#[cfg_attr(test, derive(Serialize))] +struct Response { + result: bool, +} + +#[derive(Debug, Serialize)] +#[cfg_attr(test, derive(Deserialize))] +pub struct AccessRequest<'a> { + token: &'a str, + audience: &'a str, + proposal: u32, + visit: u16, + beamline: &'a str, +} + +impl<'a> AccessRequest<'a> { + fn new(token: Option<&'a Token>, visit: Visit, beamline: &'a str) -> Result { + Ok(Self { + token: token.ok_or(AuthError::Missing)?.token(), + audience: AUDIENCE, + proposal: visit.proposal, + visit: visit.session, + beamline, + }) + } +} + +#[derive(Debug, Serialize)] +#[cfg_attr(test, derive(Deserialize))] +pub struct AdminRequest<'a> { + token: &'a str, + audience: &'a str, + beamline: &'a str, +} + +impl<'r> AdminRequest<'r> { + fn new(token: Option<&'r Token>, beamline: &'r str) -> Result { + Ok(Self { + token: token.ok_or(AuthError::Missing)?.token(), + audience: AUDIENCE, + beamline, + }) + } +} + +#[derive(Debug)] +struct InvalidVisit; + +#[cfg_attr(test, derive(Debug))] +struct Visit { + proposal: u32, + session: u16, +} +impl FromStr for Visit { + type Err = InvalidVisit; + + fn from_str(s: &str) -> Result { + let (code_prop, vis) = s.split_once('-').ok_or(InvalidVisit)?; + let prop = code_prop + .chars() + .skip_while(|p| !p.is_ascii_digit()) + .collect::(); + let proposal = prop.parse().map_err(|_| InvalidVisit)?; + let session = vis.parse().map_err(|_| InvalidVisit)?; + Ok(Self { proposal, session }) + } +} + +pub(crate) struct PolicyCheck { + client: reqwest::Client, + /// Rego query for getting admin rights + admin: String, + /// Rego query for getting access rights + access: String, +} + +impl PolicyCheck { + pub fn new(endpoint: PolicyOptions) -> Self { + info!( + "Checking authorization against {:?} using {:?} for admin and {:?} for access", + endpoint.policy_host, endpoint.admin_query, endpoint.access_query + ); + Self { + client: reqwest::Client::new(), + admin: format!("{}/{}", endpoint.policy_host, endpoint.admin_query), + access: format!("{}/{}", endpoint.policy_host, &endpoint.access_query), + } + } + pub async fn check_access( + &self, + token: Option<&Authorization>, + beamline: &str, + visit: &str, + ) -> Result<(), AuthError> { + let visit: Visit = visit.parse().map_err(|_| AuthError::Failed)?; + self.authorise(&self.access, AccessRequest::new(token, visit, beamline)?) + .await + } + + pub async fn check_admin( + &self, + token: Option<&Authorization>, + beamline: &str, + ) -> Result<(), AuthError> { + self.authorise(&self.admin, AdminRequest::new(token, beamline)?) + .await + } + + async fn authorise(&self, query: &str, input: impl Serialize) -> Result<(), AuthError> { + let response = self.client.post(query).json(&input).send().await?; + if response.json::().await?.result { + Ok(()) + } else { + Err(AuthError::Failed) + } + } +} + +#[derive(Debug)] +pub enum AuthError { + ServerError(reqwest::Error), + Failed, + Missing, +} + +impl Display for AuthError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + AuthError::ServerError(_) => write!(f, "Invalid authorization configuration"), + AuthError::Failed => write!(f, "Authentication failed"), + AuthError::Missing => f.write_str("No authentication token was provided"), + } + } +} + +impl std::error::Error for AuthError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + AuthError::ServerError(e) => Some(e), + _ => None, + } + } +} + +impl From for AuthError { + fn from(value: reqwest::Error) -> Self { + Self::ServerError(value) + } +} + +#[cfg(test)] +mod tests { + use std::str::FromStr as _; + + use assert_matches::assert_matches; + use axum::http::HeaderValue; + use axum_extra::headers::authorization::{Bearer, Credentials}; + use axum_extra::headers::Authorization; + use httpmock::MockServer; + use rstest::rstest; + + use super::{ + AccessRequest, AdminRequest, AuthError, InvalidVisit, PolicyCheck, Response, Visit, + AUDIENCE, + }; + use crate::cli::PolicyOptions; + + fn token(name: &'static str) -> Option> { + Some(Authorization( + Bearer::decode(&HeaderValue::from_str(&format!("Bearer {name}")).unwrap()).unwrap(), + )) + } + + #[test] + fn valid_visit() { + let visit = Visit::from_str("cm12345-1").unwrap(); + assert_eq!(visit.session, 1); + assert_eq!(visit.proposal, 12345); + } + + #[rstest] + #[case::no_proposal("cm-3")] + #[case::no_session("cm12345")] + #[case::invalid_session("cm12345-abc")] + #[case::invalid_proposal("cm123abc-12")] + #[case::negative_session("cm1234--12")] + fn invalid_visit(#[case] visit: &str) { + assert_matches!(Visit::from_str(visit), Err(InvalidVisit)) + } + + #[tokio::test] + async fn successful_access_check() { + let server = MockServer::start(); + let mock = server + .mock_async(|when, then| { + when.method("POST") + .path("/demo/access") + .json_body_obj(&AccessRequest { + token: "token", + beamline: "i22", + visit: 4, + proposal: 1234, + audience: AUDIENCE, + }); + then.status(200).json_body_obj(&Response { result: true }); + }) + .await; + let check = PolicyCheck::new(PolicyOptions { + policy_host: server.url(""), + access_query: "demo/access".into(), + admin_query: "demo/admin".into(), + }); + check + .check_access(token("token").as_ref(), "i22", "cm1234-4") + .await + .unwrap(); + mock.assert(); + } + + #[tokio::test] + async fn successful_admin_check() { + let server = MockServer::start(); + let mock = server + .mock_async(|when, then| { + when.method("POST") + .path("/demo/admin") + .json_body_obj(&AdminRequest { + token: "token", + beamline: "i22", + audience: AUDIENCE, + }); + then.status(200).json_body_obj(&Response { result: true }); + }) + .await; + let check = PolicyCheck::new(PolicyOptions { + policy_host: server.url(""), + access_query: "demo/access".into(), + admin_query: "demo/admin".into(), + }); + check + .check_admin(token("token").as_ref(), "i22") + .await + .unwrap(); + mock.assert(); + } + + #[tokio::test] + async fn denied_access_check() { + let server = MockServer::start(); + let mock = server + .mock_async(|when, then| { + when.method("POST") + .path("/demo/access") + .json_body_obj(&AccessRequest { + token: "token", + beamline: "i22", + proposal: 1234, + visit: 4, + audience: AUDIENCE, + }); + then.status(200).json_body_obj(&Response { result: false }); + }) + .await; + let check = PolicyCheck::new(PolicyOptions { + policy_host: server.url(""), + access_query: "demo/access".into(), + admin_query: "demo/admin".into(), + }); + + let result = check + .check_access(token("token").as_ref(), "i22", "cm1234-4") + .await; + let Err(AuthError::Failed) = result else { + panic!("Unexpected result from unauthorised check: {result:?}"); + }; + mock.assert(); + } + + #[tokio::test] + async fn denied_admin_check() { + let server = MockServer::start(); + let mock = server + .mock_async(|when, then| { + when.method("POST") + .path("/demo/admin") + .json_body_obj(&AdminRequest { + token: "token", + beamline: "i22", + audience: AUDIENCE, + }); + then.status(200).json_body_obj(&Response { result: false }); + }) + .await; + let check = PolicyCheck::new(PolicyOptions { + policy_host: server.url(""), + access_query: "demo/access".into(), + admin_query: "demo/admin".into(), + }); + let result = check.check_admin(token("token").as_ref(), "i22").await; + let Err(AuthError::Failed) = result else { + panic!("Unexpected result from unauthorised check: {result:?}"); + }; + mock.assert(); + } + + #[tokio::test] + async fn unauthorised_access_check() { + let server = MockServer::start(); + let mock = server + .mock_async(|_, _| { + // mock that rejects every request + }) + .await; + let check = PolicyCheck::new(PolicyOptions { + policy_host: server.url(""), + access_query: "demo/access".into(), + admin_query: "demo/admin".into(), + }); + let result = check.check_access(None, "i22", "cm1234-4").await; + let Err(AuthError::Missing) = result else { + panic!("Unexpected result from unauthorised check: {result:?}"); + }; + mock.assert_hits(0); + } + + #[tokio::test] + async fn unauthorised_admin_check() { + let server = MockServer::start(); + let mock = server + .mock_async(|_, _| { + // mock that rejects every request + }) + .await; + let check = PolicyCheck::new(PolicyOptions { + policy_host: server.url(""), + access_query: "demo/access".into(), + admin_query: "demo/admin".into(), + }); + let result = check.check_admin(None, "i22").await; + let Err(AuthError::Missing) = result else { + panic!("Unexpected result from unauthorised check: {result:?}"); + }; + mock.assert_hits(0); + } + + #[tokio::test] + async fn server_error() { + let server = MockServer::start(); + let mock = server + .mock_async(|when, then| { + when.method("POST"); + then.status(503); + }) + .await; + let check = PolicyCheck::new(PolicyOptions { + policy_host: server.url(""), + access_query: "demo/access".into(), + admin_query: "demo/admin".into(), + }); + let result = check.check_admin(token("token").as_ref(), "i22").await; + let Err(AuthError::ServerError(_)) = result else { + panic!("Unexpected result from unauthorised check: {result:?}"); + }; + mock.assert(); + } +}