mas_handlers/oauth2/authorization/
mod.rs

1// Copyright 2024 New Vector Ltd.
2// Copyright 2021-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7use axum::{
8    extract::{Form, State},
9    response::{IntoResponse, Response},
10};
11use hyper::StatusCode;
12use mas_axum_utils::{SessionInfoExt, cookies::CookieJar, sentry::SentryEventID};
13use mas_data_model::{AuthorizationCode, Pkce};
14use mas_router::{PostAuthAction, UrlBuilder};
15use mas_storage::{
16    BoxClock, BoxRepository, BoxRng,
17    oauth2::{OAuth2AuthorizationGrantRepository, OAuth2ClientRepository},
18};
19use mas_templates::Templates;
20use oauth2_types::{
21    errors::{ClientError, ClientErrorCode},
22    pkce,
23    requests::{AuthorizationRequest, GrantType, Prompt, ResponseMode},
24    response_type::ResponseType,
25};
26use rand::{Rng, distributions::Alphanumeric};
27use serde::Deserialize;
28use thiserror::Error;
29
30use self::callback::CallbackDestination;
31use crate::{BoundActivityTracker, PreferredLanguage, impl_from_error_for_route};
32
33mod callback;
34pub(crate) mod consent;
35
36#[derive(Debug, Error)]
37pub enum RouteError {
38    #[error(transparent)]
39    Internal(Box<dyn std::error::Error + Send + Sync + 'static>),
40
41    #[error("could not find client")]
42    ClientNotFound,
43
44    #[error("invalid response mode")]
45    InvalidResponseMode,
46
47    #[error("invalid parameters")]
48    IntoCallbackDestination(#[from] self::callback::IntoCallbackDestinationError),
49
50    #[error("invalid redirect uri")]
51    UnknownRedirectUri(#[from] mas_data_model::InvalidRedirectUriError),
52}
53
54impl IntoResponse for RouteError {
55    fn into_response(self) -> axum::response::Response {
56        let event_id = sentry::capture_error(&self);
57        // TODO: better error pages
58        let response = match self {
59            RouteError::Internal(e) => {
60                (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response()
61            }
62            RouteError::ClientNotFound => {
63                (StatusCode::BAD_REQUEST, "could not find client").into_response()
64            }
65            RouteError::InvalidResponseMode => {
66                (StatusCode::BAD_REQUEST, "invalid response mode").into_response()
67            }
68            RouteError::IntoCallbackDestination(e) => {
69                (StatusCode::BAD_REQUEST, e.to_string()).into_response()
70            }
71            RouteError::UnknownRedirectUri(e) => (
72                StatusCode::BAD_REQUEST,
73                format!("Invalid redirect URI ({e})"),
74            )
75                .into_response(),
76        };
77
78        (SentryEventID::from(event_id), response).into_response()
79    }
80}
81
82impl_from_error_for_route!(mas_storage::RepositoryError);
83impl_from_error_for_route!(mas_templates::TemplateError);
84impl_from_error_for_route!(self::callback::CallbackDestinationError);
85impl_from_error_for_route!(mas_policy::LoadError);
86impl_from_error_for_route!(mas_policy::EvaluationError);
87
88#[derive(Deserialize)]
89pub(crate) struct Params {
90    #[serde(flatten)]
91    auth: AuthorizationRequest,
92
93    #[serde(flatten)]
94    pkce: Option<pkce::AuthorizationRequest>,
95}
96
97/// Given a list of response types and an optional user-defined response mode,
98/// figure out what response mode must be used, and emit an error if the
99/// suggested response mode isn't allowed for the given response types.
100fn resolve_response_mode(
101    response_type: &ResponseType,
102    suggested_response_mode: Option<ResponseMode>,
103) -> Result<ResponseMode, RouteError> {
104    use ResponseMode as M;
105
106    // If the response type includes either "token" or "id_token", the default
107    // response mode is "fragment" and the response mode "query" must not be
108    // used
109    if response_type.has_token() || response_type.has_id_token() {
110        match suggested_response_mode {
111            None => Ok(M::Fragment),
112            Some(M::Query) => Err(RouteError::InvalidResponseMode),
113            Some(mode) => Ok(mode),
114        }
115    } else {
116        // In other cases, all response modes are allowed, defaulting to "query"
117        Ok(suggested_response_mode.unwrap_or(M::Query))
118    }
119}
120
121#[tracing::instrument(
122    name = "handlers.oauth2.authorization.get",
123    fields(client.id = %params.auth.client_id),
124    skip_all,
125    err,
126)]
127#[allow(clippy::too_many_lines)]
128pub(crate) async fn get(
129    mut rng: BoxRng,
130    clock: BoxClock,
131    PreferredLanguage(locale): PreferredLanguage,
132    State(templates): State<Templates>,
133    State(url_builder): State<UrlBuilder>,
134    activity_tracker: BoundActivityTracker,
135    mut repo: BoxRepository,
136    cookie_jar: CookieJar,
137    Form(params): Form<Params>,
138) -> Result<Response, RouteError> {
139    // First, figure out what client it is
140    let client = repo
141        .oauth2_client()
142        .find_by_client_id(&params.auth.client_id)
143        .await?
144        .ok_or(RouteError::ClientNotFound)?;
145
146    // And resolve the redirect_uri and response_mode
147    let redirect_uri = client
148        .resolve_redirect_uri(&params.auth.redirect_uri)?
149        .clone();
150    let response_type = params.auth.response_type;
151    let response_mode = resolve_response_mode(&response_type, params.auth.response_mode)?;
152
153    // Now we have a proper callback destination to go to on error
154    let callback_destination = CallbackDestination::try_new(
155        &response_mode,
156        redirect_uri.clone(),
157        params.auth.state.clone(),
158    )?;
159
160    // Get the session info from the cookie
161    let (session_info, cookie_jar) = cookie_jar.session_info();
162
163    // One day, we will have try blocks
164    let res: Result<Response, RouteError> = ({
165        let templates = templates.clone();
166        let callback_destination = callback_destination.clone();
167        let locale = locale.clone();
168        async move {
169            let maybe_session = session_info.load_active_session(&mut repo).await?;
170            let prompt = params.auth.prompt.as_deref().unwrap_or_default();
171
172            // Check if the request/request_uri/registration params are used. If so, reply
173            // with the right error since we don't support them.
174            if params.auth.request.is_some() {
175                return Ok(callback_destination.go(
176                    &templates,
177                    &locale,
178                    ClientError::from(ClientErrorCode::RequestNotSupported),
179                )?);
180            }
181
182            if params.auth.request_uri.is_some() {
183                return Ok(callback_destination.go(
184                    &templates,
185                    &locale,
186                    ClientError::from(ClientErrorCode::RequestUriNotSupported),
187                )?);
188            }
189
190            // Check if the client asked for a `token` response type, and bail out if it's
191            // the case, since we don't support them
192            if response_type.has_token() {
193                return Ok(callback_destination.go(
194                    &templates,
195                    &locale,
196                    ClientError::from(ClientErrorCode::UnsupportedResponseType),
197                )?);
198            }
199
200            // If the client asked for a `id_token` response type, we must check if it can
201            // use the `implicit` grant type
202            if response_type.has_id_token() && !client.grant_types.contains(&GrantType::Implicit) {
203                return Ok(callback_destination.go(
204                    &templates,
205                    &locale,
206                    ClientError::from(ClientErrorCode::UnauthorizedClient),
207                )?);
208            }
209
210            if params.auth.registration.is_some() {
211                return Ok(callback_destination.go(
212                    &templates,
213                    &locale,
214                    ClientError::from(ClientErrorCode::RegistrationNotSupported),
215                )?);
216            }
217
218            // Fail early if prompt=none; we never let it go through
219            if prompt.contains(&Prompt::None) {
220                return Ok(callback_destination.go(
221                    &templates,
222                    &locale,
223                    ClientError::from(ClientErrorCode::LoginRequired),
224                )?);
225            }
226
227            let code: Option<AuthorizationCode> = if response_type.has_code() {
228                // Check if it is allowed to use this grant type
229                if !client.grant_types.contains(&GrantType::AuthorizationCode) {
230                    return Ok(callback_destination.go(
231                        &templates,
232                        &locale,
233                        ClientError::from(ClientErrorCode::UnauthorizedClient),
234                    )?);
235                }
236
237                // 32 random alphanumeric characters, about 190bit of entropy
238                let code: String = (&mut rng)
239                    .sample_iter(&Alphanumeric)
240                    .take(32)
241                    .map(char::from)
242                    .collect();
243
244                let pkce = params.pkce.map(|p| Pkce {
245                    challenge: p.code_challenge,
246                    challenge_method: p.code_challenge_method,
247                });
248
249                Some(AuthorizationCode { code, pkce })
250            } else {
251                // If the request had PKCE params but no code asked, it should get back with an
252                // error
253                if params.pkce.is_some() {
254                    return Ok(callback_destination.go(
255                        &templates,
256                        &locale,
257                        ClientError::from(ClientErrorCode::InvalidRequest),
258                    )?);
259                }
260
261                None
262            };
263
264            let grant = repo
265                .oauth2_authorization_grant()
266                .add(
267                    &mut rng,
268                    &clock,
269                    &client,
270                    redirect_uri.clone(),
271                    params.auth.scope,
272                    code,
273                    params.auth.state.clone(),
274                    params.auth.nonce,
275                    response_mode,
276                    response_type.has_id_token(),
277                    params.auth.login_hint,
278                )
279                .await?;
280            let continue_grant = PostAuthAction::continue_grant(grant.id);
281
282            let res = match maybe_session {
283                None if prompt.contains(&Prompt::Create) => {
284                    // Client asked for a registration, show the registration prompt
285                    repo.save().await?;
286
287                    url_builder
288                        .redirect(&mas_router::Register::and_then(continue_grant))
289                        .into_response()
290                }
291
292                None => {
293                    // Other cases where we don't have a session, ask for a login
294                    repo.save().await?;
295
296                    url_builder
297                        .redirect(&mas_router::Login::and_then(continue_grant))
298                        .into_response()
299                }
300
301                Some(user_session) => {
302                    // TODO: better support for prompt=create when we have a session
303                    repo.save().await?;
304
305                    activity_tracker
306                        .record_browser_session(&clock, &user_session)
307                        .await;
308                    url_builder
309                        .redirect(&mas_router::Consent(grant.id))
310                        .into_response()
311                }
312            };
313
314            Ok(res)
315        }
316    })
317    .await;
318
319    let response = match res {
320        Ok(r) => r,
321        Err(err) => {
322            tracing::error!(%err);
323            callback_destination.go(
324                &templates,
325                &locale,
326                ClientError::from(ClientErrorCode::ServerError),
327            )?
328        }
329    };
330
331    Ok((cookie_jar, response).into_response())
332}