From 87c4e424bc3638ec1f3a3da409d1ba06a3be5db8 Mon Sep 17 00:00:00 2001 From: Johanna Dorothea Reichmann Date: Sun, 15 Sep 2024 13:14:25 +0200 Subject: [PATCH] feat: add AccessToken and ValidAccessToken axum extractors and use them in /api/tsigkey/ --- src/error/openid.rs | 38 ++++++++++++++++---- src/handlers/api.rs | 24 ++++++++++--- src/handlers/user.rs | 39 ++++++++++---------- src/util/openid.rs | 77 ++++++++++++++++++++++++++++++++++++++++ templates/user_home.html | 2 +- 5 files changed, 150 insertions(+), 30 deletions(-) diff --git a/src/error/openid.rs b/src/error/openid.rs index d6b63b5..e380015 100644 --- a/src/error/openid.rs +++ b/src/error/openid.rs @@ -1,17 +1,25 @@ use std::fmt::Display; -use axum::response::IntoResponse; +use openidconnect::AccessToken; +use axum::response::{ + IntoResponse, + Redirect, +}; use reqwest::StatusCode; #[derive(Debug)] pub enum AuthError { OpenIdConfig(openidconnect::ConfigurationError), + OpenIdRequest(String), + TokenNotActive(AccessToken), } impl Display for AuthError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { AuthError::OpenIdConfig(e) => f.write_fmt(format_args!("OpenID Connect configuration error: {:?}", e)), + AuthError::OpenIdRequest(s) => f.write_fmt(format_args!("{}", s)), + AuthError::TokenNotActive(t) => f.write_fmt(format_args!("Token {} is not active", t.secret())), } } } @@ -24,11 +32,27 @@ impl From for AuthError { } } -impl IntoResponse for AuthError { - fn into_response(self) -> axum::response::Response { - ( - StatusCode::INTERNAL_SERVER_ERROR, - format!("Authentication error: {}", self), - ).into_response() +impl From for AuthError { + fn from(s: String) -> Self { + Self::OpenIdRequest(s) + } +} + +impl From for AuthError { + fn from(t: AccessToken) -> Self { + Self::TokenNotActive(t) + } +} + + +impl IntoResponse for AuthError { + fn into_response(self) -> axum::response::Response { + match self { + AuthError::TokenNotActive(_) => Redirect::to("/openid/login").into_response(), + _ => ( + StatusCode::INTERNAL_SERVER_ERROR, + format!("Authentication error: {}", self), + ).into_response(), + } } } diff --git a/src/handlers/api.rs b/src/handlers/api.rs index c1b9b75..2b33962 100644 --- a/src/handlers/api.rs +++ b/src/handlers/api.rs @@ -8,29 +8,45 @@ use axum::response::{ }; use crate::AppState; +use crate::util::openid::ValidAccessToken; use crate::model::tsigkey::{ TsigKey, }; -pub async fn list(State(app_state): State) -> impl IntoResponse { +pub async fn list( + State(app_state): State, + ValidAccessToken(token): ValidAccessToken, +) -> impl IntoResponse { app_state.pdns_client.list_tsig_keys() .await .and_then(|keys| Ok(Json(keys))) } -pub async fn get(State(app_state): State, Path(tsig_key_name): Path) -> impl IntoResponse { +pub async fn get( + State(app_state): State, + ValidAccessToken(token): ValidAccessToken, + Path(tsig_key_name): Path +) -> impl IntoResponse { app_state.pdns_client.get_tsig_key(tsig_key_name) .await .and_then(|key| Ok(Json(key))) } -pub async fn add(State(app_state): State, Json(tsig_key): Json) -> impl IntoResponse { +pub async fn add( + State(app_state): State, + ValidAccessToken(token): ValidAccessToken, + Json(tsig_key): Json, +) -> impl IntoResponse { app_state.pdns_client.add_tsig_key(tsig_key) .await .and_then(|key| Ok(Json(key))) } -pub async fn delete(State(app_state): State, Path(tsig_key_name): Path) -> impl IntoResponse { +pub async fn delete( + State(app_state): State, + ValidAccessToken(token): ValidAccessToken, + Path(tsig_key_name): Path, +) -> impl IntoResponse { app_state.pdns_client.delete_tsig_key(tsig_key_name) .await .and_then(|_| Ok(Json(()))) diff --git a/src/handlers/user.rs b/src/handlers/user.rs index 14afb29..5d03b3a 100644 --- a/src/handlers/user.rs +++ b/src/handlers/user.rs @@ -4,12 +4,12 @@ use std::time::{ }; use askama::Template; -use axum::response::IntoResponse; +use axum::response::{ + IntoResponse, +}; use axum::extract::State; -use axum_extra::extract::cookie::CookieJar; use openidconnect::{ - AccessToken, reqwest::async_http_client, TokenIntrospectionResponse, }; @@ -17,32 +17,35 @@ use openidconnect::{ use crate::util::askama::HtmlTemplate; use crate::AppState; use crate::error::openid::AuthError; +use crate::util::openid::AccessToken; #[derive(Template)] #[template(path = "user_home.html")] struct UserHomeTemplate { - is_active: bool, + active: bool, username: String, duration: Duration, } pub async fn home( State(app_state): State, - cookies: CookieJar + AccessToken(token): AccessToken, ) -> Result { let now = Instant::now(); - let token_serialized: Option = cookies.get("access_token") - .map(|cookie| cookie.value().to_owned()); - let (is_active, username); - (is_active, username) = match token_serialized { - Some(token) => { - let introspection_response = app_state.oidc_client - .introspect(&AccessToken::new(token))? - .request_async(async_http_client).await.unwrap(); - println!("Token introspected, answer is {:?}", introspection_response); - (introspection_response.active(), introspection_response.username().unwrap_or("").to_string()) + let response = app_state.oidc_client + .introspect(&token)? + .request_async(async_http_client) + .await + .map_err(|e| e.to_string())?; + match response.active() { + true => { + let username = response.username().unwrap().to_string(); + Ok(HtmlTemplate(UserHomeTemplate { + active: true, + username, + duration: now.elapsed(), + })) }, - None => (false, "".to_string()) - }; - Ok(HtmlTemplate(UserHomeTemplate { is_active, username, duration: now.elapsed() })) + false => Err(AuthError::TokenNotActive(token)), + } } diff --git a/src/util/openid.rs b/src/util/openid.rs index b143bb9..65f9d8d 100644 --- a/src/util/openid.rs +++ b/src/util/openid.rs @@ -9,7 +9,84 @@ use openidconnect::{ IssuerUrl, RedirectUrl, }; +use openidconnect::TokenIntrospectionResponse; use url::Url; +use axum::{ + async_trait, + extract::{ + FromRef, + FromRequestParts, + State, + }, + http::{ + StatusCode, + request::Parts, + }, + response::Redirect, +}; +use axum_extra::extract::cookie::CookieJar; + +use crate::AppState; + +pub(crate) struct AccessToken(pub(crate) openidconnect::AccessToken); +pub(crate) struct ValidAccessToken(pub(crate) openidconnect::AccessToken); +pub(crate) type AccessTokenRejection = (StatusCode, String); + +#[async_trait] +impl

FromRequestParts

for AccessToken +where + P: Send + Sync +{ + type Rejection = AccessTokenRejection; + + async fn from_request_parts(parts: &mut Parts, state: &P) -> Result { + let cookie_name = "access_token"; + let cookies = CookieJar::from_request_parts(parts, state).await.unwrap(); + let token: Option = cookies + .get(cookie_name) + .map(|cookie| cookie.value().to_owned()); + match token { + Some(token) => Ok(Self(openidconnect::AccessToken::new(token))), + None => Err(( + StatusCode::BAD_REQUEST, + format!( + "Request is missing the '{}' cookie", + cookie_name.to_string(), + )), + ) + } + } +} + +#[async_trait] +impl

FromRequestParts

for ValidAccessToken +where + AppState: FromRef

, + P: Send + Sync, +{ + type Rejection = Redirect; + + async fn from_request_parts(parts: &mut Parts, state: &P) -> Result { + match AccessToken::from_request_parts(parts, state).await { + Err(_) => Err(Redirect::to("/openid/login")), + Ok(AccessToken(token)) => { + match AppState::from_ref(state) + .oidc_client + .introspect(&token) + .map_err(|_| Redirect::to("/openid/login"))? + .request_async(async_http_client) + .await + .map_err(|_| Redirect::to("/openid/login")) { + Ok(t) => match t.active() { + true => Ok(ValidAccessToken(token)), + false => Err(Redirect::to("/openid/login")), + }, + Err(e) => Err(e), + } + } + } + } +} pub async fn create_client(issuer: Url, id: String, secret: String, redirect_url: Url) -> CoreClient { let issuer_url = IssuerUrl::from_url(issuer); diff --git a/templates/user_home.html b/templates/user_home.html index c3cfdcd..bebff62 100644 --- a/templates/user_home.html +++ b/templates/user_home.html @@ -7,7 +7,7 @@

User Home

{% endif %} -

Your user session is {% if is_active %}active{% else %}inactive{% endif %}

+

Your user session is {% if active %}active{% else %}inactive{% endif %}

Request took {{ duration.as_millis() }}ms

{% endblock content %}