Skip to content

Commit

Permalink
Add authorization via central OPA instance (#1)
Browse files Browse the repository at this point in the history
Authorisation is still optional when starting the server but can be
configured by providing an address for the OPA instance to use as well
as a query for each of visit access (used to a limit access to
incrementing scan numbers) and admin access (used to limit access to
configurations).

When enabled, authorisation requires users to provide a bearer token via
an `Authorization=Bearer <token>` header in their requests to the
service.
  • Loading branch information
tpoliaw authored Dec 16, 2024
1 parent b1ce675 commit cd9df8c
Show file tree
Hide file tree
Showing 5 changed files with 530 additions and 6 deletions.
5 changes: 5 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ unwrap_used = "deny"
async-graphql = { version = "7.0.13", features = ["tracing"] }
async-graphql-axum = "7.0.13"
axum = "0.7.9"
axum-extra = { version = "0.9.3", features = ["typed-header"] }
chrono = "0.4.39"
clap = { version = "4.5.23", features = ["cargo", "derive", "env"] }
futures = "0.3.31"
Expand All @@ -19,6 +20,8 @@ opentelemetry-otlp = "0.27.0"
opentelemetry-semantic-conventions = "0.27.0"
opentelemetry-stdout = "0.27.0"
opentelemetry_sdk = { version = "0.27.1", features = ["rt-tokio"] }
reqwest = { version = "0.12.7", features = ["json", "rustls-tls-native-roots"], default-features = false }
serde = { version = "1.0.210", features = ["derive"] }
sqlx = { version = "0.8.2", features = ["runtime-tokio", "sqlite"] }
tokio = { version = "1.42.0", features = ["full"] }
tracing = "0.1.41"
Expand All @@ -27,6 +30,8 @@ tracing-subscriber = { version = "0.3.19", features = ["env-filter"] }
url = "2.5.4"

[dev-dependencies]
assert_matches = "1.5.0"
async-std = { version = "1.13.0", features = ["attributes"], default-features = false }
httpmock = { version = "0.7.0", default-features = false }
rstest = "0.23.0"
tempfile = "3.14.0"
4 changes: 2 additions & 2 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
FROM rust:1.82.0-slim AS build

RUN rustup target add x86_64-unknown-linux-musl && \
apt update && \
apt install -y musl-tools musl-dev && \
apt-get update && \
apt-get install -y musl-tools musl-dev && \
update-ca-certificates

COPY ./Cargo.toml ./Cargo.toml
Expand Down
89 changes: 89 additions & 0 deletions src/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,28 @@ pub struct ServeOptions {
/// The root directory for external number tracking
#[clap(long, env = "NUMTRACKER_ROOT_DIRECTORY")]
root_directory: Option<PathBuf>,
#[clap(flatten, next_help_heading = "Authorization")]
pub policy: Option<PolicyOptions>,
}

#[derive(Debug, Default, Parser)]
#[group(requires = "policy_host")]
pub struct PolicyOptions {
/// Beamline Policy Endpoint
///
/// eg, https://authz.diamond.ac.uk
#[clap(long = "policy", required = false)]
pub policy_host: String,
/// The Rego rule used to generate visit access data
///
/// eg. v1/data/diamond/policy/session/write_to_beamline_visit
#[clap(long, required = false)]
pub access_query: String,
/// The Rego rule used to generate admin access data
///
/// eg. v1/data/diamond/policy/admin/configure_beamline
#[clap(long, required = false)]
pub admin_query: String,
}

#[derive(Debug, Args)]
Expand Down Expand Up @@ -132,6 +154,7 @@ impl TracingOptions {
mod tests {
use std::path::PathBuf;

use assert_matches::assert_matches;
use clap::error::ErrorKind;
use clap::Parser;
use tracing::Level;
Expand All @@ -154,6 +177,8 @@ mod tests {
};
assert_eq!(cmd.addr(), ("0.0.0.0".parse().unwrap(), 8000));
assert_eq!(cmd.root_directory(), None);

assert_matches!(cmd.policy, None);
}

#[test]
Expand All @@ -174,6 +199,70 @@ mod tests {
};
assert_eq!(cmd.addr(), ("127.0.0.1".parse().unwrap(), 8765));
assert_eq!(cmd.root_directory, Some("/tmp/trackers".into()));
assert_matches!(cmd.policy, None);
}

#[test]
fn policy_arguments() {
let cli = Cli::try_parse_from([
APP,
"serve",
"--policy",
"opa.example.com",
"--admin-query",
"demo/admin_check",
"--access-query",
"demo/access_check",
])
.unwrap();
let cmd = assert_matches!(cli.command, Command::Serve(cmd) => cmd);
let policy = assert_matches!(cmd.policy, Some(plc) => plc);

assert_eq!(policy.policy_host, "opa.example.com");
assert_eq!(policy.admin_query, "demo/admin_check");
assert_eq!(policy.access_query, "demo/access_check");
}

#[test]
fn missing_admin_query() {
let err = Cli::try_parse_from([
APP,
"serve",
"--policy",
"opa.example.com",
"--access-query",
"demo/access-query",
])
.unwrap_err();
assert_eq!(err.kind(), ErrorKind::MissingRequiredArgument);
}

#[test]
fn missing_access_query() {
let err = Cli::try_parse_from([
APP,
"serve",
"--policy",
"opa.example.com",
"--admin-query",
"demo/admin-query",
])
.unwrap_err();
assert_eq!(err.kind(), ErrorKind::MissingRequiredArgument);
}

#[test]
fn policy_queries_without_host() {
let err = Cli::try_parse_from([
APP,
"serve",
"--access-query",
"demo/access-query",
"--admin-query",
"demo/admin-query",
])
.unwrap_err();
assert_eq!(err.kind(), ErrorKind::MissingRequiredArgument);
}

#[test]
Expand Down
43 changes: 39 additions & 4 deletions src/graphql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use std::any;
use std::borrow::Cow;
use std::error::Error;
use std::fmt::Display;
use std::future::Future;
use std::path::{Component, Path, PathBuf};
use std::sync::Arc;

Expand All @@ -27,9 +28,13 @@ use async_graphql::{
Scalar, ScalarType, Schema, SimpleObject, Value,
};
use async_graphql_axum::{GraphQLRequest, GraphQLResponse};
use auth::{AuthError, PolicyCheck};
use axum::response::{Html, IntoResponse};
use axum::routing::{get, post};
use axum::{Extension, Router};
use axum_extra::headers::authorization::Bearer;
use axum_extra::headers::Authorization;
use axum_extra::TypedHeader;
use chrono::{Datelike, Local};
use tokio::net::TcpListener;
use tracing::{info, instrument, trace, warn};
Expand All @@ -45,26 +50,30 @@ use crate::paths::{
};
use crate::template::{FieldSource, PathTemplate};

mod auth;

pub async fn serve_graphql(db: &Path, opts: ServeOptions) {
let db = SqliteScanPathService::connect(db)
.await
.expect("Unable to open DB");
let directory_numtracker = NumTracker::for_root_directory(opts.root_directory())
.expect("Could not read external directories");
info!("Serving graphql endpoints on {:?}", opts.addr());
let addr = opts.addr();
let schema = Schema::build(Query, Mutation, EmptySubscription)
.extension(Tracing)
.limit_directives(32)
.data(db)
.data(directory_numtracker)
.data(opts.policy.map(PolicyCheck::new))
.finish();
let app = Router::new()
.route("/graphql", post(graphql_handler))
.route("/graphiql", get(graphiql))
.layer(Extension(schema));
let listener = TcpListener::bind(opts.addr())
let listener = TcpListener::bind(addr)
.await
.unwrap_or_else(|_| panic!("Port {:?} in use", opts.addr()));
.unwrap_or_else(|_| panic!("Port {:?} in use", addr));
axum::serve(listener, app)
.await
.expect("Can't serve graphql endpoint");
Expand All @@ -82,10 +91,13 @@ async fn graphiql() -> impl IntoResponse {
#[instrument(skip_all)]
async fn graphql_handler(
schema: Extension<Schema<Query, Mutation, EmptySubscription>>,
auth_token: Option<TypedHeader<Authorization<Bearer>>>,
req: GraphQLRequest,
) -> GraphQLResponse {
let inner = req.into_inner();
schema.execute(inner).await.into()
schema
.execute(req.into_inner().data(auth_token.map(|header| header.0)))
.await
.into()
}

/// Read-only API for GraphQL
Expand Down Expand Up @@ -263,6 +275,7 @@ impl Query {
ctx: &Context<'_>,
beamline: String,
) -> async_graphql::Result<BeamlineConfiguration> {
check_auth(ctx, |policy, token| policy.check_admin(token, &beamline)).await?;
let db = ctx.data::<SqliteScanPathService>()?;
trace!("Getting config for {beamline:?}");
Ok(db.current_configuration(&beamline).await?)
Expand All @@ -280,6 +293,10 @@ impl Mutation {
visit: String,
sub: Option<Subdirectory>,
) -> async_graphql::Result<ScanPaths> {
check_auth(ctx, |policy, token| {
policy.check_access(token, &beamline, &visit)
})
.await?;
let db = ctx.data::<SqliteScanPathService>()?;
let nt = ctx.data::<NumTracker>()?;
// There is a race condition here if a process increments the file
Expand Down Expand Up @@ -312,6 +329,7 @@ impl Mutation {
beamline: String,
config: ConfigurationUpdates,
) -> async_graphql::Result<BeamlineConfiguration> {
check_auth(ctx, |pc, token| pc.check_admin(token, &beamline)).await?;
let db = ctx.data::<SqliteScanPathService>()?;
trace!("Configuring: {beamline}: {config:?}");
let upd = config.into_update(beamline);
Expand All @@ -322,6 +340,23 @@ impl Mutation {
}
}

async fn check_auth<'ctx, Check, R>(ctx: &Context<'ctx>, check: Check) -> async_graphql::Result<()>
where
Check: Fn(&'ctx PolicyCheck, Option<&'ctx Authorization<Bearer>>) -> R,
R: Future<Output = Result<(), AuthError>>,
{
if let Some(policy) = ctx.data::<Option<PolicyCheck>>()? {
trace!("Auth enabled: checking token");
let token = ctx.data::<Authorization<Bearer>>().ok();
check(policy, token)
.await
.inspect_err(|e| info!("Authorization failed: {e:?}"))
.map_err(async_graphql::Error::from)
} else {
Ok(())
}
}

#[derive(Debug, InputObject)]
struct ConfigurationUpdates {
visit: Option<InputTemplate<VisitTemplate>>,
Expand Down
Loading

0 comments on commit cd9df8c

Please sign in to comment.