mas_handlers/oauth2/
token.rs

1// Copyright 2024 New Vector Ltd.
2// Copyright 2021-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7use std::sync::{Arc, LazyLock};
8
9use axum::{Json, extract::State, response::IntoResponse};
10use axum_extra::typed_header::TypedHeader;
11use chrono::Duration;
12use headers::{CacheControl, HeaderMap, HeaderMapExt, Pragma};
13use hyper::StatusCode;
14use mas_axum_utils::{
15    client_authorization::{ClientAuthorization, CredentialsVerificationError},
16    sentry::SentryEventID,
17};
18use mas_data_model::{
19    AuthorizationGrantStage, Client, Device, DeviceCodeGrantState, SiteConfig, TokenType, UserAgent,
20};
21use mas_keystore::{Encrypter, Keystore};
22use mas_matrix::HomeserverConnection;
23use mas_oidc_client::types::scope::ScopeToken;
24use mas_policy::Policy;
25use mas_router::UrlBuilder;
26use mas_storage::{
27    BoxClock, BoxRepository, BoxRng, Clock, RepositoryAccess,
28    oauth2::{
29        OAuth2AccessTokenRepository, OAuth2AuthorizationGrantRepository,
30        OAuth2RefreshTokenRepository, OAuth2SessionRepository,
31    },
32    user::BrowserSessionRepository,
33};
34use oauth2_types::{
35    errors::{ClientError, ClientErrorCode},
36    pkce::CodeChallengeError,
37    requests::{
38        AccessTokenRequest, AccessTokenResponse, AuthorizationCodeGrant, ClientCredentialsGrant,
39        DeviceCodeGrant, GrantType, RefreshTokenGrant,
40    },
41    scope,
42};
43use opentelemetry::{Key, KeyValue, metrics::Counter};
44use thiserror::Error;
45use tracing::{debug, info};
46use ulid::Ulid;
47
48use super::{generate_id_token, generate_token_pair};
49use crate::{BoundActivityTracker, METER, impl_from_error_for_route};
50
51static TOKEN_REQUEST_COUNTER: LazyLock<Counter<u64>> = LazyLock::new(|| {
52    METER
53        .u64_counter("mas.oauth2.token_request")
54        .with_description("How many OAuth 2.0 token requests have gone through")
55        .with_unit("{request}")
56        .build()
57});
58const GRANT_TYPE: Key = Key::from_static_str("grant_type");
59const RESULT: Key = Key::from_static_str("successful");
60
61#[derive(Debug, Error)]
62pub(crate) enum RouteError {
63    #[error(transparent)]
64    Internal(Box<dyn std::error::Error + Send + Sync + 'static>),
65
66    #[error("bad request")]
67    BadRequest,
68
69    #[error("pkce verification failed")]
70    PkceVerification(#[from] CodeChallengeError),
71
72    #[error("client not found")]
73    ClientNotFound,
74
75    #[error("client not allowed")]
76    ClientNotAllowed,
77
78    #[error("could not verify client credentials")]
79    ClientCredentialsVerification(#[from] CredentialsVerificationError),
80
81    #[error("grant not found")]
82    GrantNotFound,
83
84    #[error("invalid grant")]
85    InvalidGrant,
86
87    #[error("refresh token not found")]
88    RefreshTokenNotFound,
89
90    #[error("refresh token {0} is invalid")]
91    RefreshTokenInvalid(Ulid),
92
93    #[error("session {0} is invalid")]
94    SessionInvalid(Ulid),
95
96    #[error("client id mismatch: expected {expected}, got {actual}")]
97    ClientIDMismatch { expected: Ulid, actual: Ulid },
98
99    #[error("policy denied the request")]
100    DeniedByPolicy(Vec<mas_policy::Violation>),
101
102    #[error("unsupported grant type")]
103    UnsupportedGrantType,
104
105    #[error("unauthorized client")]
106    UnauthorizedClient,
107
108    #[error("failed to load browser session")]
109    NoSuchBrowserSession,
110
111    #[error("failed to load oauth session")]
112    NoSuchOAuthSession,
113
114    #[error(
115        "failed to load the next refresh token ({next:?}) from the previous one ({previous:?})"
116    )]
117    NoSuchNextRefreshToken { next: Ulid, previous: Ulid },
118
119    #[error(
120        "failed to load the access token ({access_token:?}) associated with the next refresh token ({refresh_token:?})"
121    )]
122    NoSuchNextAccessToken {
123        access_token: Ulid,
124        refresh_token: Ulid,
125    },
126
127    #[error("no access token associated with the refresh token {refresh_token:?}")]
128    NoAccessTokenOnRefreshToken { refresh_token: Ulid },
129
130    #[error("device code grant expired")]
131    DeviceCodeExpired,
132
133    #[error("device code grant is still pending")]
134    DeviceCodePending,
135
136    #[error("device code grant was rejected")]
137    DeviceCodeRejected,
138
139    #[error("device code grant was already exchanged")]
140    DeviceCodeExchanged,
141
142    #[error("failed to provision device")]
143    ProvisionDeviceFailed(#[source] anyhow::Error),
144}
145
146impl IntoResponse for RouteError {
147    fn into_response(self) -> axum::response::Response {
148        let event_id = sentry::capture_error(&self);
149
150        TOKEN_REQUEST_COUNTER.add(1, &[KeyValue::new(RESULT, "error")]);
151
152        let response = match self {
153            Self::Internal(_)
154            | Self::NoSuchBrowserSession
155            | Self::NoSuchOAuthSession
156            | Self::ProvisionDeviceFailed(_)
157            | Self::NoSuchNextRefreshToken { .. }
158            | Self::NoSuchNextAccessToken { .. }
159            | Self::NoAccessTokenOnRefreshToken { .. } => (
160                StatusCode::INTERNAL_SERVER_ERROR,
161                Json(ClientError::from(ClientErrorCode::ServerError)),
162            ),
163            Self::BadRequest => (
164                StatusCode::BAD_REQUEST,
165                Json(ClientError::from(ClientErrorCode::InvalidRequest)),
166            ),
167            Self::PkceVerification(err) => (
168                StatusCode::BAD_REQUEST,
169                Json(
170                    ClientError::from(ClientErrorCode::InvalidGrant)
171                        .with_description(format!("PKCE verification failed: {err}")),
172                ),
173            ),
174            Self::ClientNotFound | Self::ClientCredentialsVerification(_) => (
175                StatusCode::UNAUTHORIZED,
176                Json(ClientError::from(ClientErrorCode::InvalidClient)),
177            ),
178            Self::ClientNotAllowed | Self::UnauthorizedClient => (
179                StatusCode::UNAUTHORIZED,
180                Json(ClientError::from(ClientErrorCode::UnauthorizedClient)),
181            ),
182            Self::DeniedByPolicy(violations) => (
183                StatusCode::FORBIDDEN,
184                Json(
185                    ClientError::from(ClientErrorCode::InvalidScope).with_description(
186                        violations
187                            .into_iter()
188                            .map(|violation| violation.msg)
189                            .collect::<Vec<_>>()
190                            .join(", "),
191                    ),
192                ),
193            ),
194            Self::DeviceCodeRejected => (
195                StatusCode::FORBIDDEN,
196                Json(ClientError::from(ClientErrorCode::AccessDenied)),
197            ),
198            Self::DeviceCodeExpired => (
199                StatusCode::FORBIDDEN,
200                Json(ClientError::from(ClientErrorCode::ExpiredToken)),
201            ),
202            Self::DeviceCodePending => (
203                StatusCode::FORBIDDEN,
204                Json(ClientError::from(ClientErrorCode::AuthorizationPending)),
205            ),
206            Self::InvalidGrant
207            | Self::DeviceCodeExchanged
208            | Self::RefreshTokenNotFound
209            | Self::RefreshTokenInvalid(_)
210            | Self::SessionInvalid(_)
211            | Self::ClientIDMismatch { .. }
212            | Self::GrantNotFound => (
213                StatusCode::BAD_REQUEST,
214                Json(ClientError::from(ClientErrorCode::InvalidGrant)),
215            ),
216            Self::UnsupportedGrantType => (
217                StatusCode::BAD_REQUEST,
218                Json(ClientError::from(ClientErrorCode::UnsupportedGrantType)),
219            ),
220        };
221
222        (SentryEventID::from(event_id), response).into_response()
223    }
224}
225
226impl_from_error_for_route!(mas_storage::RepositoryError);
227impl_from_error_for_route!(mas_policy::EvaluationError);
228impl_from_error_for_route!(super::IdTokenSignatureError);
229
230#[tracing::instrument(
231    name = "handlers.oauth2.token.post",
232    fields(client.id = client_authorization.client_id()),
233    skip_all,
234    err,
235)]
236pub(crate) async fn post(
237    mut rng: BoxRng,
238    clock: BoxClock,
239    State(http_client): State<reqwest::Client>,
240    State(key_store): State<Keystore>,
241    State(url_builder): State<UrlBuilder>,
242    activity_tracker: BoundActivityTracker,
243    mut repo: BoxRepository,
244    State(homeserver): State<Arc<dyn HomeserverConnection>>,
245    State(site_config): State<SiteConfig>,
246    State(encrypter): State<Encrypter>,
247    policy: Policy,
248    user_agent: Option<TypedHeader<headers::UserAgent>>,
249    client_authorization: ClientAuthorization<AccessTokenRequest>,
250) -> Result<impl IntoResponse, RouteError> {
251    let user_agent = user_agent.map(|ua| UserAgent::parse(ua.as_str().to_owned()));
252    let client = client_authorization
253        .credentials
254        .fetch(&mut repo)
255        .await?
256        .ok_or(RouteError::ClientNotFound)?;
257
258    let method = client
259        .token_endpoint_auth_method
260        .as_ref()
261        .ok_or(RouteError::ClientNotAllowed)?;
262
263    client_authorization
264        .credentials
265        .verify(&http_client, &encrypter, method, &client)
266        .await?;
267
268    let form = client_authorization.form.ok_or(RouteError::BadRequest)?;
269
270    let grant_type = form.grant_type();
271
272    let (reply, repo) = match form {
273        AccessTokenRequest::AuthorizationCode(grant) => {
274            authorization_code_grant(
275                &mut rng,
276                &clock,
277                &activity_tracker,
278                &grant,
279                &client,
280                &key_store,
281                &url_builder,
282                &site_config,
283                repo,
284                &homeserver,
285                user_agent,
286            )
287            .await?
288        }
289        AccessTokenRequest::RefreshToken(grant) => {
290            refresh_token_grant(
291                &mut rng,
292                &clock,
293                &activity_tracker,
294                &grant,
295                &client,
296                &site_config,
297                repo,
298                user_agent,
299            )
300            .await?
301        }
302        AccessTokenRequest::ClientCredentials(grant) => {
303            client_credentials_grant(
304                &mut rng,
305                &clock,
306                &activity_tracker,
307                &grant,
308                &client,
309                &site_config,
310                repo,
311                policy,
312                user_agent,
313            )
314            .await?
315        }
316        AccessTokenRequest::DeviceCode(grant) => {
317            device_code_grant(
318                &mut rng,
319                &clock,
320                &activity_tracker,
321                &grant,
322                &client,
323                &key_store,
324                &url_builder,
325                &site_config,
326                repo,
327                &homeserver,
328                user_agent,
329            )
330            .await?
331        }
332        _ => {
333            return Err(RouteError::UnsupportedGrantType);
334        }
335    };
336
337    repo.save().await?;
338
339    TOKEN_REQUEST_COUNTER.add(
340        1,
341        &[
342            KeyValue::new(GRANT_TYPE, grant_type),
343            KeyValue::new(RESULT, "success"),
344        ],
345    );
346
347    let mut headers = HeaderMap::new();
348    headers.typed_insert(CacheControl::new().with_no_store());
349    headers.typed_insert(Pragma::no_cache());
350
351    Ok((headers, Json(reply)))
352}
353
354#[allow(clippy::too_many_lines)] // TODO: refactor some parts out
355async fn authorization_code_grant(
356    mut rng: &mut BoxRng,
357    clock: &impl Clock,
358    activity_tracker: &BoundActivityTracker,
359    grant: &AuthorizationCodeGrant,
360    client: &Client,
361    key_store: &Keystore,
362    url_builder: &UrlBuilder,
363    site_config: &SiteConfig,
364    mut repo: BoxRepository,
365    homeserver: &Arc<dyn HomeserverConnection>,
366    user_agent: Option<UserAgent>,
367) -> Result<(AccessTokenResponse, BoxRepository), RouteError> {
368    // Check that the client is allowed to use this grant type
369    if !client.grant_types.contains(&GrantType::AuthorizationCode) {
370        return Err(RouteError::UnauthorizedClient);
371    }
372
373    let authz_grant = repo
374        .oauth2_authorization_grant()
375        .find_by_code(&grant.code)
376        .await?
377        .ok_or(RouteError::GrantNotFound)?;
378
379    let now = clock.now();
380
381    let session_id = match authz_grant.stage {
382        AuthorizationGrantStage::Cancelled { cancelled_at } => {
383            debug!(%cancelled_at, "Authorization grant was cancelled");
384            return Err(RouteError::InvalidGrant);
385        }
386        AuthorizationGrantStage::Exchanged {
387            exchanged_at,
388            fulfilled_at,
389            session_id,
390        } => {
391            debug!(%exchanged_at, %fulfilled_at, "Authorization code was already exchanged");
392
393            // Ending the session if the token was already exchanged more than 20s ago
394            if now - exchanged_at > Duration::microseconds(20 * 1000 * 1000) {
395                debug!("Ending potentially compromised session");
396                let session = repo
397                    .oauth2_session()
398                    .lookup(session_id)
399                    .await?
400                    .ok_or(RouteError::NoSuchOAuthSession)?;
401                repo.oauth2_session().finish(clock, session).await?;
402                repo.save().await?;
403            }
404
405            return Err(RouteError::InvalidGrant);
406        }
407        AuthorizationGrantStage::Pending => {
408            debug!("Authorization grant has not been fulfilled yet");
409            return Err(RouteError::InvalidGrant);
410        }
411        AuthorizationGrantStage::Fulfilled {
412            session_id,
413            fulfilled_at,
414        } => {
415            if now - fulfilled_at > Duration::microseconds(10 * 60 * 1000 * 1000) {
416                debug!("Code exchange took more than 10 minutes");
417                return Err(RouteError::InvalidGrant);
418            }
419
420            session_id
421        }
422    };
423
424    let mut session = repo
425        .oauth2_session()
426        .lookup(session_id)
427        .await?
428        .ok_or(RouteError::NoSuchOAuthSession)?;
429
430    if let Some(user_agent) = user_agent {
431        session = repo
432            .oauth2_session()
433            .record_user_agent(session, user_agent)
434            .await?;
435    }
436
437    // This should never happen, since we looked up in the database using the code
438    let code = authz_grant.code.as_ref().ok_or(RouteError::InvalidGrant)?;
439
440    if client.id != session.client_id {
441        return Err(RouteError::UnauthorizedClient);
442    }
443
444    match (code.pkce.as_ref(), grant.code_verifier.as_ref()) {
445        (None, None) => {}
446        // We have a challenge but no verifier (or vice-versa)? Bad request.
447        (Some(_), None) | (None, Some(_)) => return Err(RouteError::BadRequest),
448        // If we have both, we need to check the code validity
449        (Some(pkce), Some(verifier)) => {
450            pkce.verify(verifier)?;
451        }
452    }
453
454    let Some(user_session_id) = session.user_session_id else {
455        tracing::warn!("No user session associated with this OAuth2 session");
456        return Err(RouteError::InvalidGrant);
457    };
458
459    let browser_session = repo
460        .browser_session()
461        .lookup(user_session_id)
462        .await?
463        .ok_or(RouteError::NoSuchBrowserSession)?;
464
465    let last_authentication = repo
466        .browser_session()
467        .get_last_authentication(&browser_session)
468        .await?;
469
470    let ttl = site_config.access_token_ttl;
471    let (access_token, refresh_token) =
472        generate_token_pair(&mut rng, clock, &mut repo, &session, ttl).await?;
473
474    let id_token = if session.scope.contains(&scope::OPENID) {
475        Some(generate_id_token(
476            &mut rng,
477            clock,
478            url_builder,
479            key_store,
480            client,
481            Some(&authz_grant),
482            &browser_session,
483            Some(&access_token),
484            last_authentication.as_ref(),
485        )?)
486    } else {
487        None
488    };
489
490    let mut params = AccessTokenResponse::new(access_token.access_token)
491        .with_expires_in(ttl)
492        .with_refresh_token(refresh_token.refresh_token)
493        .with_scope(session.scope.clone());
494
495    if let Some(id_token) = id_token {
496        params = params.with_id_token(id_token);
497    }
498
499    // Lock the user sync to make sure we don't get into a race condition
500    repo.user()
501        .acquire_lock_for_sync(&browser_session.user)
502        .await?;
503
504    // Look for device to provision
505    let mxid = homeserver.mxid(&browser_session.user.username);
506    for scope in &*session.scope {
507        if let Some(device) = Device::from_scope_token(scope) {
508            homeserver
509                .create_device(&mxid, device.as_str())
510                .await
511                .map_err(RouteError::ProvisionDeviceFailed)?;
512        }
513    }
514
515    repo.oauth2_authorization_grant()
516        .exchange(clock, authz_grant)
517        .await?;
518
519    // XXX: there is a potential (but unlikely) race here, where the activity for
520    // the session is recorded before the transaction is committed. We would have to
521    // save the repository here to fix that.
522    activity_tracker
523        .record_oauth2_session(clock, &session)
524        .await;
525
526    Ok((params, repo))
527}
528
529#[allow(clippy::too_many_lines)]
530async fn refresh_token_grant(
531    rng: &mut BoxRng,
532    clock: &impl Clock,
533    activity_tracker: &BoundActivityTracker,
534    grant: &RefreshTokenGrant,
535    client: &Client,
536    site_config: &SiteConfig,
537    mut repo: BoxRepository,
538    user_agent: Option<UserAgent>,
539) -> Result<(AccessTokenResponse, BoxRepository), RouteError> {
540    // Check that the client is allowed to use this grant type
541    if !client.grant_types.contains(&GrantType::RefreshToken) {
542        return Err(RouteError::UnauthorizedClient);
543    }
544
545    let refresh_token = repo
546        .oauth2_refresh_token()
547        .find_by_token(&grant.refresh_token)
548        .await?
549        .ok_or(RouteError::RefreshTokenNotFound)?;
550
551    let mut session = repo
552        .oauth2_session()
553        .lookup(refresh_token.session_id)
554        .await?
555        .ok_or(RouteError::NoSuchOAuthSession)?;
556
557    // Let's for now record the user agent on each refresh, that should be
558    // responsive enough and not too much of a burden on the database.
559    if let Some(user_agent) = user_agent {
560        session = repo
561            .oauth2_session()
562            .record_user_agent(session, user_agent)
563            .await?;
564    }
565
566    if !session.is_valid() {
567        return Err(RouteError::SessionInvalid(session.id));
568    }
569
570    if client.id != session.client_id {
571        // As per https://datatracker.ietf.org/doc/html/rfc6749#section-5.2
572        return Err(RouteError::ClientIDMismatch {
573            expected: session.client_id,
574            actual: client.id,
575        });
576    }
577
578    if !refresh_token.is_valid() {
579        // We're seing a refresh token that already has been consumed, this might be a
580        // double-refresh or a replay attack
581
582        // First, get the next refresh token
583        let Some(next_refresh_token_id) = refresh_token.next_refresh_token_id() else {
584            // If we don't have a 'next' refresh token, it may just be because this was
585            // before we were recording those. Let's just treat it as a replay.
586            return Err(RouteError::RefreshTokenInvalid(refresh_token.id));
587        };
588
589        let Some(next_refresh_token) = repo
590            .oauth2_refresh_token()
591            .lookup(next_refresh_token_id)
592            .await?
593        else {
594            return Err(RouteError::NoSuchNextRefreshToken {
595                next: next_refresh_token_id,
596                previous: refresh_token.id,
597            });
598        };
599
600        // Check if the next refresh token was already consumed or not
601        if !next_refresh_token.is_valid() {
602            // XXX: This is a replay, we *may* want to invalidate the session
603            return Err(RouteError::RefreshTokenInvalid(next_refresh_token.id));
604        }
605
606        // Check if the associated access token was already used
607        let Some(access_token_id) = next_refresh_token.access_token_id else {
608            // This should in theory not happen: this means an access token got cleaned up,
609            // but the refresh token was still valid.
610            return Err(RouteError::NoAccessTokenOnRefreshToken {
611                refresh_token: next_refresh_token.id,
612            });
613        };
614
615        // Load it
616        let next_access_token = repo
617            .oauth2_access_token()
618            .lookup(access_token_id)
619            .await?
620            .ok_or(RouteError::NoSuchNextAccessToken {
621                access_token: access_token_id,
622                refresh_token: next_refresh_token_id,
623            })?;
624
625        if next_access_token.is_used() {
626            // XXX: This is a replay, we *may* want to invalidate the session
627            return Err(RouteError::RefreshTokenInvalid(next_refresh_token.id));
628        }
629
630        // Looks like it's a double-refresh, client lost their refresh token on
631        // the way back. Let's revoke the unused access and refresh tokens, and
632        // issue new ones
633        info!(
634            oauth2_session.id = %session.id,
635            oauth2_client.id = %client.id,
636            %refresh_token.id,
637            "Refresh token already used, but issued refresh and access tokens are unused. Assuming those were lost; revoking those and reissuing new ones."
638        );
639
640        repo.oauth2_access_token()
641            .revoke(clock, next_access_token)
642            .await?;
643
644        repo.oauth2_refresh_token()
645            .revoke(clock, next_refresh_token)
646            .await?;
647    }
648
649    activity_tracker
650        .record_oauth2_session(clock, &session)
651        .await;
652
653    let ttl = site_config.access_token_ttl;
654    let (new_access_token, new_refresh_token) =
655        generate_token_pair(rng, clock, &mut repo, &session, ttl).await?;
656
657    let refresh_token = repo
658        .oauth2_refresh_token()
659        .consume(clock, refresh_token, &new_refresh_token)
660        .await?;
661
662    if let Some(access_token_id) = refresh_token.access_token_id {
663        let access_token = repo.oauth2_access_token().lookup(access_token_id).await?;
664        if let Some(access_token) = access_token {
665            // If it is a double-refresh, it might already be revoked
666            if !access_token.state.is_revoked() {
667                repo.oauth2_access_token()
668                    .revoke(clock, access_token)
669                    .await?;
670            }
671        }
672    }
673
674    let params = AccessTokenResponse::new(new_access_token.access_token)
675        .with_expires_in(ttl)
676        .with_refresh_token(new_refresh_token.refresh_token)
677        .with_scope(session.scope);
678
679    Ok((params, repo))
680}
681
682async fn client_credentials_grant(
683    rng: &mut BoxRng,
684    clock: &impl Clock,
685    activity_tracker: &BoundActivityTracker,
686    grant: &ClientCredentialsGrant,
687    client: &Client,
688    site_config: &SiteConfig,
689    mut repo: BoxRepository,
690    mut policy: Policy,
691    user_agent: Option<UserAgent>,
692) -> Result<(AccessTokenResponse, BoxRepository), RouteError> {
693    // Check that the client is allowed to use this grant type
694    if !client.grant_types.contains(&GrantType::ClientCredentials) {
695        return Err(RouteError::UnauthorizedClient);
696    }
697
698    // Default to an empty scope if none is provided
699    let scope = grant
700        .scope
701        .clone()
702        .unwrap_or_else(|| std::iter::empty::<ScopeToken>().collect());
703
704    // Make the request go through the policy engine
705    let res = policy
706        .evaluate_authorization_grant(mas_policy::AuthorizationGrantInput {
707            user: None,
708            client,
709            scope: &scope,
710            grant_type: mas_policy::GrantType::ClientCredentials,
711            requester: mas_policy::Requester {
712                ip_address: activity_tracker.ip(),
713                user_agent: user_agent.clone().map(|ua| ua.raw),
714            },
715        })
716        .await?;
717    if !res.valid() {
718        return Err(RouteError::DeniedByPolicy(res.violations));
719    }
720
721    // Start the session
722    let mut session = repo
723        .oauth2_session()
724        .add_from_client_credentials(rng, clock, client, scope)
725        .await?;
726
727    if let Some(user_agent) = user_agent {
728        session = repo
729            .oauth2_session()
730            .record_user_agent(session, user_agent)
731            .await?;
732    }
733
734    let ttl = site_config.access_token_ttl;
735    let access_token_str = TokenType::AccessToken.generate(rng);
736
737    let access_token = repo
738        .oauth2_access_token()
739        .add(rng, clock, &session, access_token_str, Some(ttl))
740        .await?;
741
742    let mut params = AccessTokenResponse::new(access_token.access_token).with_expires_in(ttl);
743
744    // XXX: there is a potential (but unlikely) race here, where the activity for
745    // the session is recorded before the transaction is committed. We would have to
746    // save the repository here to fix that.
747    activity_tracker
748        .record_oauth2_session(clock, &session)
749        .await;
750
751    if !session.scope.is_empty() {
752        // We only return the scope if it's not empty
753        params = params.with_scope(session.scope);
754    }
755
756    Ok((params, repo))
757}
758
759async fn device_code_grant(
760    rng: &mut BoxRng,
761    clock: &impl Clock,
762    activity_tracker: &BoundActivityTracker,
763    grant: &DeviceCodeGrant,
764    client: &Client,
765    key_store: &Keystore,
766    url_builder: &UrlBuilder,
767    site_config: &SiteConfig,
768    mut repo: BoxRepository,
769    homeserver: &Arc<dyn HomeserverConnection>,
770    user_agent: Option<UserAgent>,
771) -> Result<(AccessTokenResponse, BoxRepository), RouteError> {
772    // Check that the client is allowed to use this grant type
773    if !client.grant_types.contains(&GrantType::DeviceCode) {
774        return Err(RouteError::UnauthorizedClient);
775    }
776
777    let grant = repo
778        .oauth2_device_code_grant()
779        .find_by_device_code(&grant.device_code)
780        .await?
781        .ok_or(RouteError::GrantNotFound)?;
782
783    // Check that the client match
784    if client.id != grant.client_id {
785        return Err(RouteError::ClientIDMismatch {
786            expected: grant.client_id,
787            actual: client.id,
788        });
789    }
790
791    if grant.expires_at < clock.now() {
792        return Err(RouteError::DeviceCodeExpired);
793    }
794
795    let browser_session_id = match &grant.state {
796        DeviceCodeGrantState::Pending => {
797            return Err(RouteError::DeviceCodePending);
798        }
799        DeviceCodeGrantState::Rejected { .. } => {
800            return Err(RouteError::DeviceCodeRejected);
801        }
802        DeviceCodeGrantState::Exchanged { .. } => {
803            return Err(RouteError::DeviceCodeExchanged);
804        }
805        DeviceCodeGrantState::Fulfilled {
806            browser_session_id, ..
807        } => browser_session_id,
808    };
809
810    let browser_session = repo
811        .browser_session()
812        .lookup(*browser_session_id)
813        .await?
814        .ok_or(RouteError::NoSuchBrowserSession)?;
815
816    // Start the session
817    let mut session = repo
818        .oauth2_session()
819        .add_from_browser_session(rng, clock, client, &browser_session, grant.scope.clone())
820        .await?;
821
822    repo.oauth2_device_code_grant()
823        .exchange(clock, grant, &session)
824        .await?;
825
826    // XXX: should we get the user agent from the device code grant instead?
827    if let Some(user_agent) = user_agent {
828        session = repo
829            .oauth2_session()
830            .record_user_agent(session, user_agent)
831            .await?;
832    }
833
834    let ttl = site_config.access_token_ttl;
835    let access_token_str = TokenType::AccessToken.generate(rng);
836
837    let access_token = repo
838        .oauth2_access_token()
839        .add(rng, clock, &session, access_token_str, Some(ttl))
840        .await?;
841
842    let mut params =
843        AccessTokenResponse::new(access_token.access_token.clone()).with_expires_in(ttl);
844
845    // If the client uses the refresh token grant type, we also generate a refresh
846    // token
847    if client.grant_types.contains(&GrantType::RefreshToken) {
848        let refresh_token_str = TokenType::RefreshToken.generate(rng);
849
850        let refresh_token = repo
851            .oauth2_refresh_token()
852            .add(rng, clock, &session, &access_token, refresh_token_str)
853            .await?;
854
855        params = params.with_refresh_token(refresh_token.refresh_token);
856    }
857
858    // If the client asked for an ID token, we generate one
859    if session.scope.contains(&scope::OPENID) {
860        let id_token = generate_id_token(
861            rng,
862            clock,
863            url_builder,
864            key_store,
865            client,
866            None,
867            &browser_session,
868            Some(&access_token),
869            None,
870        )?;
871
872        params = params.with_id_token(id_token);
873    }
874
875    // Lock the user sync to make sure we don't get into a race condition
876    repo.user()
877        .acquire_lock_for_sync(&browser_session.user)
878        .await?;
879
880    // Look for device to provision
881    let mxid = homeserver.mxid(&browser_session.user.username);
882    for scope in &*session.scope {
883        if let Some(device) = Device::from_scope_token(scope) {
884            homeserver
885                .create_device(&mxid, device.as_str())
886                .await
887                .map_err(RouteError::ProvisionDeviceFailed)?;
888        }
889    }
890
891    // XXX: there is a potential (but unlikely) race here, where the activity for
892    // the session is recorded before the transaction is committed. We would have to
893    // save the repository here to fix that.
894    activity_tracker
895        .record_oauth2_session(clock, &session)
896        .await;
897
898    if !session.scope.is_empty() {
899        // We only return the scope if it's not empty
900        params = params.with_scope(session.scope);
901    }
902
903    Ok((params, repo))
904}
905
906#[cfg(test)]
907mod tests {
908    use hyper::Request;
909    use mas_data_model::{AccessToken, AuthorizationCode, RefreshToken};
910    use mas_router::SimpleRoute;
911    use oauth2_types::{
912        registration::ClientRegistrationResponse,
913        requests::{DeviceAuthorizationResponse, ResponseMode},
914        scope::{OPENID, Scope},
915    };
916    use sqlx::PgPool;
917
918    use super::*;
919    use crate::test_utils::{RequestBuilderExt, ResponseExt, TestState, setup};
920
921    #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
922    async fn test_auth_code_grant(pool: PgPool) {
923        setup();
924        let state = TestState::from_pool(pool).await.unwrap();
925
926        // Provision a client
927        let request =
928            Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
929                "client_uri": "https://example.com/",
930                "redirect_uris": ["https://example.com/callback"],
931                "token_endpoint_auth_method": "none",
932                "response_types": ["code"],
933                "grant_types": ["authorization_code"],
934            }));
935
936        let response = state.request(request).await;
937        response.assert_status(StatusCode::CREATED);
938
939        let ClientRegistrationResponse { client_id, .. } = response.json();
940
941        // Let's provision a user and create a session for them. This part is hard to
942        // test with just HTTP requests, so we'll use the repository directly.
943        let mut repo = state.repository().await.unwrap();
944
945        let user = repo
946            .user()
947            .add(&mut state.rng(), &state.clock, "alice".to_owned())
948            .await
949            .unwrap();
950
951        let browser_session = repo
952            .browser_session()
953            .add(&mut state.rng(), &state.clock, &user, None)
954            .await
955            .unwrap();
956
957        // Lookup the client in the database.
958        let client = repo
959            .oauth2_client()
960            .find_by_client_id(&client_id)
961            .await
962            .unwrap()
963            .unwrap();
964
965        // Start a grant
966        let code = "thisisaverysecurecode";
967        let grant = repo
968            .oauth2_authorization_grant()
969            .add(
970                &mut state.rng(),
971                &state.clock,
972                &client,
973                "https://example.com/redirect".parse().unwrap(),
974                Scope::from_iter([OPENID]),
975                Some(AuthorizationCode {
976                    code: code.to_owned(),
977                    pkce: None,
978                }),
979                Some("state".to_owned()),
980                Some("nonce".to_owned()),
981                ResponseMode::Query,
982                false,
983                None,
984            )
985            .await
986            .unwrap();
987
988        let session = repo
989            .oauth2_session()
990            .add_from_browser_session(
991                &mut state.rng(),
992                &state.clock,
993                &client,
994                &browser_session,
995                grant.scope.clone(),
996            )
997            .await
998            .unwrap();
999
1000        // And fulfill it
1001        let grant = repo
1002            .oauth2_authorization_grant()
1003            .fulfill(&state.clock, &session, grant)
1004            .await
1005            .unwrap();
1006
1007        repo.save().await.unwrap();
1008
1009        // Now call the token endpoint to get an access token.
1010        let request =
1011            Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1012                "grant_type": "authorization_code",
1013                "code": code,
1014                "redirect_uri": grant.redirect_uri,
1015                "client_id": client.client_id,
1016            }));
1017
1018        let response = state.request(request).await;
1019        response.assert_status(StatusCode::OK);
1020
1021        let AccessTokenResponse { access_token, .. } = response.json();
1022
1023        // Check that the token is valid
1024        assert!(state.is_access_token_valid(&access_token).await);
1025
1026        // Exchange it again, this it should fail
1027        let request =
1028            Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1029                "grant_type": "authorization_code",
1030                "code": code,
1031                "redirect_uri": grant.redirect_uri,
1032                "client_id": client.client_id,
1033            }));
1034
1035        let response = state.request(request).await;
1036        response.assert_status(StatusCode::BAD_REQUEST);
1037        let error: ClientError = response.json();
1038        assert_eq!(error.error, ClientErrorCode::InvalidGrant);
1039
1040        // The token should still be valid
1041        assert!(state.is_access_token_valid(&access_token).await);
1042
1043        // Now wait a bit
1044        state.clock.advance(Duration::try_minutes(1).unwrap());
1045
1046        // Exchange it again, this it should fail
1047        let request =
1048            Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1049                "grant_type": "authorization_code",
1050                "code": code,
1051                "redirect_uri": grant.redirect_uri,
1052                "client_id": client.client_id,
1053            }));
1054
1055        let response = state.request(request).await;
1056        response.assert_status(StatusCode::BAD_REQUEST);
1057        let error: ClientError = response.json();
1058        assert_eq!(error.error, ClientErrorCode::InvalidGrant);
1059
1060        // And it should have revoked the token we got
1061        assert!(!state.is_access_token_valid(&access_token).await);
1062
1063        // Try another one and wait for too long before exchanging it
1064        let mut repo = state.repository().await.unwrap();
1065        let code = "thisisanothercode";
1066        let grant = repo
1067            .oauth2_authorization_grant()
1068            .add(
1069                &mut state.rng(),
1070                &state.clock,
1071                &client,
1072                "https://example.com/redirect".parse().unwrap(),
1073                Scope::from_iter([OPENID]),
1074                Some(AuthorizationCode {
1075                    code: code.to_owned(),
1076                    pkce: None,
1077                }),
1078                Some("state".to_owned()),
1079                Some("nonce".to_owned()),
1080                ResponseMode::Query,
1081                false,
1082                None,
1083            )
1084            .await
1085            .unwrap();
1086
1087        let session = repo
1088            .oauth2_session()
1089            .add_from_browser_session(
1090                &mut state.rng(),
1091                &state.clock,
1092                &client,
1093                &browser_session,
1094                grant.scope.clone(),
1095            )
1096            .await
1097            .unwrap();
1098
1099        // And fulfill it
1100        let grant = repo
1101            .oauth2_authorization_grant()
1102            .fulfill(&state.clock, &session, grant)
1103            .await
1104            .unwrap();
1105
1106        repo.save().await.unwrap();
1107
1108        // Now wait a bit
1109        state
1110            .clock
1111            .advance(Duration::microseconds(15 * 60 * 1000 * 1000));
1112
1113        // Exchange it, it should fail
1114        let request =
1115            Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1116                "grant_type": "authorization_code",
1117                "code": code,
1118                "redirect_uri": grant.redirect_uri,
1119                "client_id": client.client_id,
1120            }));
1121
1122        let response = state.request(request).await;
1123        response.assert_status(StatusCode::BAD_REQUEST);
1124        let ClientError { error, .. } = response.json();
1125        assert_eq!(error, ClientErrorCode::InvalidGrant);
1126    }
1127
1128    #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
1129    async fn test_refresh_token_grant(pool: PgPool) {
1130        setup();
1131        let state = TestState::from_pool(pool).await.unwrap();
1132
1133        // Provision a client
1134        let request =
1135            Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
1136                "client_uri": "https://example.com/",
1137                "redirect_uris": ["https://example.com/callback"],
1138                "token_endpoint_auth_method": "none",
1139                "response_types": ["code"],
1140                "grant_types": ["authorization_code", "refresh_token"],
1141            }));
1142
1143        let response = state.request(request).await;
1144        response.assert_status(StatusCode::CREATED);
1145
1146        let ClientRegistrationResponse { client_id, .. } = response.json();
1147
1148        // Let's provision a user and create a session for them. This part is hard to
1149        // test with just HTTP requests, so we'll use the repository directly.
1150        let mut repo = state.repository().await.unwrap();
1151
1152        let user = repo
1153            .user()
1154            .add(&mut state.rng(), &state.clock, "alice".to_owned())
1155            .await
1156            .unwrap();
1157
1158        let browser_session = repo
1159            .browser_session()
1160            .add(&mut state.rng(), &state.clock, &user, None)
1161            .await
1162            .unwrap();
1163
1164        // Lookup the client in the database.
1165        let client = repo
1166            .oauth2_client()
1167            .find_by_client_id(&client_id)
1168            .await
1169            .unwrap()
1170            .unwrap();
1171
1172        // Get a token pair
1173        let session = repo
1174            .oauth2_session()
1175            .add_from_browser_session(
1176                &mut state.rng(),
1177                &state.clock,
1178                &client,
1179                &browser_session,
1180                Scope::from_iter([OPENID]),
1181            )
1182            .await
1183            .unwrap();
1184
1185        let (AccessToken { access_token, .. }, RefreshToken { refresh_token, .. }) =
1186            generate_token_pair(
1187                &mut state.rng(),
1188                &state.clock,
1189                &mut repo,
1190                &session,
1191                Duration::microseconds(5 * 60 * 1000 * 1000),
1192            )
1193            .await
1194            .unwrap();
1195
1196        repo.save().await.unwrap();
1197
1198        // First check that the token is valid
1199        assert!(state.is_access_token_valid(&access_token).await);
1200
1201        // Now call the token endpoint to get an access token.
1202        let request =
1203            Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1204                "grant_type": "refresh_token",
1205                "refresh_token": refresh_token,
1206                "client_id": client.client_id,
1207            }));
1208
1209        let response = state.request(request).await;
1210        response.assert_status(StatusCode::OK);
1211
1212        let old_access_token = access_token;
1213        let old_refresh_token = refresh_token;
1214        let response: AccessTokenResponse = response.json();
1215        let access_token = response.access_token;
1216        let refresh_token = response.refresh_token.expect("to have a refresh token");
1217
1218        // Check that the new token is valid
1219        assert!(state.is_access_token_valid(&access_token).await);
1220
1221        // Check that the old token is no longer valid
1222        assert!(!state.is_access_token_valid(&old_access_token).await);
1223
1224        // Call it again with the old token, it should fail
1225        let request =
1226            Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1227                "grant_type": "refresh_token",
1228                "refresh_token": old_refresh_token,
1229                "client_id": client.client_id,
1230            }));
1231
1232        let response = state.request(request).await;
1233        response.assert_status(StatusCode::BAD_REQUEST);
1234        let ClientError { error, .. } = response.json();
1235        assert_eq!(error, ClientErrorCode::InvalidGrant);
1236
1237        // Call it again with the new token, it should work
1238        let request =
1239            Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1240                "grant_type": "refresh_token",
1241                "refresh_token": refresh_token,
1242                "client_id": client.client_id,
1243            }));
1244
1245        let response = state.request(request).await;
1246        response.assert_status(StatusCode::OK);
1247        let _: AccessTokenResponse = response.json();
1248    }
1249
1250    #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
1251    async fn test_double_refresh(pool: PgPool) {
1252        setup();
1253        let state = TestState::from_pool(pool).await.unwrap();
1254
1255        // Provision a client
1256        let request =
1257            Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
1258                "client_uri": "https://example.com/",
1259                "redirect_uris": ["https://example.com/callback"],
1260                "token_endpoint_auth_method": "none",
1261                "response_types": ["code"],
1262                "grant_types": ["authorization_code", "refresh_token"],
1263            }));
1264
1265        let response = state.request(request).await;
1266        response.assert_status(StatusCode::CREATED);
1267
1268        let ClientRegistrationResponse { client_id, .. } = response.json();
1269
1270        // Let's provision a user and create a session for them. This part is hard to
1271        // test with just HTTP requests, so we'll use the repository directly.
1272        let mut repo = state.repository().await.unwrap();
1273
1274        let user = repo
1275            .user()
1276            .add(&mut state.rng(), &state.clock, "alice".to_owned())
1277            .await
1278            .unwrap();
1279
1280        let browser_session = repo
1281            .browser_session()
1282            .add(&mut state.rng(), &state.clock, &user, None)
1283            .await
1284            .unwrap();
1285
1286        // Lookup the client in the database.
1287        let client = repo
1288            .oauth2_client()
1289            .find_by_client_id(&client_id)
1290            .await
1291            .unwrap()
1292            .unwrap();
1293
1294        // Get a token pair
1295        let session = repo
1296            .oauth2_session()
1297            .add_from_browser_session(
1298                &mut state.rng(),
1299                &state.clock,
1300                &client,
1301                &browser_session,
1302                Scope::from_iter([OPENID]),
1303            )
1304            .await
1305            .unwrap();
1306
1307        let (AccessToken { access_token, .. }, RefreshToken { refresh_token, .. }) =
1308            generate_token_pair(
1309                &mut state.rng(),
1310                &state.clock,
1311                &mut repo,
1312                &session,
1313                Duration::microseconds(5 * 60 * 1000 * 1000),
1314            )
1315            .await
1316            .unwrap();
1317
1318        repo.save().await.unwrap();
1319
1320        // First check that the token is valid
1321        assert!(state.is_access_token_valid(&access_token).await);
1322
1323        // Now call the token endpoint to get an access token.
1324        let request =
1325            Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1326                "grant_type": "refresh_token",
1327                "refresh_token": refresh_token,
1328                "client_id": client.client_id,
1329            }));
1330
1331        let first_response = state.request(request).await;
1332        first_response.assert_status(StatusCode::OK);
1333        let first_response: AccessTokenResponse = first_response.json();
1334
1335        // Call a second time, it should work, as we haven't done anything yet with the
1336        // token
1337        let request =
1338            Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1339                "grant_type": "refresh_token",
1340                "refresh_token": refresh_token,
1341                "client_id": client.client_id,
1342            }));
1343
1344        let second_response = state.request(request).await;
1345        second_response.assert_status(StatusCode::OK);
1346        let second_response: AccessTokenResponse = second_response.json();
1347
1348        // Check that we got new tokens
1349        assert_ne!(first_response.access_token, second_response.access_token);
1350        assert_ne!(first_response.refresh_token, second_response.refresh_token);
1351
1352        // Check that the old-new token is invalid
1353        assert!(
1354            !state
1355                .is_access_token_valid(&first_response.access_token)
1356                .await
1357        );
1358
1359        // Check that the new-new token is valid
1360        assert!(
1361            state
1362                .is_access_token_valid(&second_response.access_token)
1363                .await
1364        );
1365
1366        // Do a third refresh, this one should not work, as we've used the new
1367        // access token
1368        let request =
1369            Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1370                "grant_type": "refresh_token",
1371                "refresh_token": refresh_token,
1372                "client_id": client.client_id,
1373            }));
1374
1375        let third_response = state.request(request).await;
1376        third_response.assert_status(StatusCode::BAD_REQUEST);
1377
1378        // The other reason we consider a new refresh token to be 'used' is if
1379        // it was already used in a refresh
1380        // So, if we do a refresh with the second_response.refresh_token, then
1381        // another refresh with the result, redoing one with
1382        // second_response.refresh_token again should fail
1383        let request =
1384            Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1385                "grant_type": "refresh_token",
1386                "refresh_token": second_response.refresh_token,
1387                "client_id": client.client_id,
1388            }));
1389
1390        // This one is fine
1391        let fourth_response = state.request(request).await;
1392        fourth_response.assert_status(StatusCode::OK);
1393        let fourth_response: AccessTokenResponse = fourth_response.json();
1394
1395        // Do another one, it should be fine as well
1396        let request =
1397            Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1398                "grant_type": "refresh_token",
1399                "refresh_token": fourth_response.refresh_token,
1400                "client_id": client.client_id,
1401            }));
1402
1403        let fifth_response = state.request(request).await;
1404        fifth_response.assert_status(StatusCode::OK);
1405
1406        // But now, if we re-do with the second_response.refresh_token, it should
1407        // fail
1408        let request =
1409            Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1410                "grant_type": "refresh_token",
1411                "refresh_token": second_response.refresh_token,
1412                "client_id": client.client_id,
1413            }));
1414
1415        let sixth_response = state.request(request).await;
1416        sixth_response.assert_status(StatusCode::BAD_REQUEST);
1417    }
1418
1419    #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
1420    async fn test_client_credentials(pool: PgPool) {
1421        setup();
1422        let state = TestState::from_pool(pool).await.unwrap();
1423
1424        // Provision a client
1425        let request =
1426            Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
1427                "client_uri": "https://example.com/",
1428                "token_endpoint_auth_method": "client_secret_post",
1429                "grant_types": ["client_credentials"],
1430            }));
1431
1432        let response = state.request(request).await;
1433        response.assert_status(StatusCode::CREATED);
1434
1435        let response: ClientRegistrationResponse = response.json();
1436        let client_id = response.client_id;
1437        let client_secret = response.client_secret.expect("to have a client secret");
1438
1439        // Call the token endpoint with an empty scope
1440        let request =
1441            Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1442                "grant_type": "client_credentials",
1443                "client_id": client_id,
1444                "client_secret": client_secret,
1445            }));
1446
1447        let response = state.request(request).await;
1448        response.assert_status(StatusCode::OK);
1449
1450        let response: AccessTokenResponse = response.json();
1451        assert!(response.refresh_token.is_none());
1452        assert!(response.expires_in.is_some());
1453        assert!(response.scope.is_none());
1454
1455        // Revoke the token
1456        let request = Request::post(mas_router::OAuth2Revocation::PATH).form(serde_json::json!({
1457            "token": response.access_token,
1458            "client_id": client_id,
1459            "client_secret": client_secret,
1460        }));
1461
1462        let response = state.request(request).await;
1463        response.assert_status(StatusCode::OK);
1464
1465        // We should be allowed to ask for the GraphQL API scope
1466        let request =
1467            Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1468                "grant_type": "client_credentials",
1469                "client_id": client_id,
1470                "client_secret": client_secret,
1471                "scope": "urn:mas:graphql:*"
1472            }));
1473
1474        let response = state.request(request).await;
1475        response.assert_status(StatusCode::OK);
1476
1477        let response: AccessTokenResponse = response.json();
1478        assert!(response.refresh_token.is_none());
1479        assert!(response.expires_in.is_some());
1480        assert_eq!(response.scope, Some("urn:mas:graphql:*".parse().unwrap()));
1481
1482        // Revoke the token
1483        let request = Request::post(mas_router::OAuth2Revocation::PATH).form(serde_json::json!({
1484            "token": response.access_token,
1485            "client_id": client_id,
1486            "client_secret": client_secret,
1487        }));
1488
1489        let response = state.request(request).await;
1490        response.assert_status(StatusCode::OK);
1491
1492        // We should be NOT allowed to ask for the MAS admin scope
1493        let request =
1494            Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1495                "grant_type": "client_credentials",
1496                "client_id": client_id,
1497                "client_secret": client_secret,
1498                "scope": "urn:mas:admin"
1499            }));
1500
1501        let response = state.request(request).await;
1502        response.assert_status(StatusCode::FORBIDDEN);
1503
1504        let ClientError { error, .. } = response.json();
1505        assert_eq!(error, ClientErrorCode::InvalidScope);
1506
1507        // Now, if we add the client to the admin list in the policy, it should work
1508        let state = {
1509            let mut state = state;
1510            state.policy_factory = crate::test_utils::policy_factory(
1511                "example.com",
1512                serde_json::json!({
1513                    "admin_clients": [client_id]
1514                }),
1515            )
1516            .await
1517            .unwrap();
1518            state
1519        };
1520
1521        let request =
1522            Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1523                "grant_type": "client_credentials",
1524                "client_id": client_id,
1525                "client_secret": client_secret,
1526                "scope": "urn:mas:admin"
1527            }));
1528
1529        let response = state.request(request).await;
1530        response.assert_status(StatusCode::OK);
1531
1532        let response: AccessTokenResponse = response.json();
1533        assert!(response.refresh_token.is_none());
1534        assert!(response.expires_in.is_some());
1535        assert_eq!(response.scope, Some("urn:mas:admin".parse().unwrap()));
1536
1537        // Revoke the token
1538        let request = Request::post(mas_router::OAuth2Revocation::PATH).form(serde_json::json!({
1539            "token": response.access_token,
1540            "client_id": client_id,
1541            "client_secret": client_secret,
1542        }));
1543
1544        let response = state.request(request).await;
1545        response.assert_status(StatusCode::OK);
1546    }
1547
1548    #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
1549    async fn test_device_code_grant(pool: PgPool) {
1550        setup();
1551        let state = TestState::from_pool(pool).await.unwrap();
1552
1553        // Provision a client
1554        let request =
1555            Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
1556                "client_uri": "https://example.com/",
1557                "token_endpoint_auth_method": "none",
1558                "grant_types": ["urn:ietf:params:oauth:grant-type:device_code", "refresh_token"],
1559                "response_types": [],
1560            }));
1561
1562        let response = state.request(request).await;
1563        response.assert_status(StatusCode::CREATED);
1564
1565        let response: ClientRegistrationResponse = response.json();
1566        let client_id = response.client_id;
1567
1568        // Start a device code grant
1569        let request = Request::post(mas_router::OAuth2DeviceAuthorizationEndpoint::PATH).form(
1570            serde_json::json!({
1571                "client_id": client_id,
1572                "scope": "openid",
1573            }),
1574        );
1575        let response = state.request(request).await;
1576        response.assert_status(StatusCode::OK);
1577
1578        let device_grant: DeviceAuthorizationResponse = response.json();
1579
1580        // Poll the token endpoint, it should be pending
1581        let request =
1582            Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1583                "grant_type": "urn:ietf:params:oauth:grant-type:device_code",
1584                "device_code": device_grant.device_code,
1585                "client_id": client_id,
1586            }));
1587        let response = state.request(request).await;
1588        response.assert_status(StatusCode::FORBIDDEN);
1589
1590        let ClientError { error, .. } = response.json();
1591        assert_eq!(error, ClientErrorCode::AuthorizationPending);
1592
1593        // Let's provision a user and create a browser session for them. This part is
1594        // hard to test with just HTTP requests, so we'll use the repository
1595        // directly.
1596        let mut repo = state.repository().await.unwrap();
1597
1598        let user = repo
1599            .user()
1600            .add(&mut state.rng(), &state.clock, "alice".to_owned())
1601            .await
1602            .unwrap();
1603
1604        let browser_session = repo
1605            .browser_session()
1606            .add(&mut state.rng(), &state.clock, &user, None)
1607            .await
1608            .unwrap();
1609
1610        // Find the grant
1611        let grant = repo
1612            .oauth2_device_code_grant()
1613            .find_by_user_code(&device_grant.user_code)
1614            .await
1615            .unwrap()
1616            .unwrap();
1617
1618        // And fulfill it
1619        let grant = repo
1620            .oauth2_device_code_grant()
1621            .fulfill(&state.clock, grant, &browser_session)
1622            .await
1623            .unwrap();
1624
1625        repo.save().await.unwrap();
1626
1627        // Now call the token endpoint to get an access token.
1628        let request =
1629            Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1630                "grant_type": "urn:ietf:params:oauth:grant-type:device_code",
1631                "device_code": grant.device_code,
1632                "client_id": client_id,
1633            }));
1634
1635        let response = state.request(request).await;
1636        response.assert_status(StatusCode::OK);
1637
1638        let response: AccessTokenResponse = response.json();
1639
1640        // Check that the token is valid
1641        assert!(state.is_access_token_valid(&response.access_token).await);
1642        // We advertised the refresh token grant type, so we should have a refresh token
1643        assert!(response.refresh_token.is_some());
1644        // We asked for the openid scope, so we should have an ID token
1645        assert!(response.id_token.is_some());
1646
1647        // Calling it again should fail
1648        let request =
1649            Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1650                "grant_type": "urn:ietf:params:oauth:grant-type:device_code",
1651                "device_code": grant.device_code,
1652                "client_id": client_id,
1653            }));
1654        let response = state.request(request).await;
1655        response.assert_status(StatusCode::BAD_REQUEST);
1656
1657        let ClientError { error, .. } = response.json();
1658        assert_eq!(error, ClientErrorCode::InvalidGrant);
1659
1660        // Do another grant and make it expire
1661        let request = Request::post(mas_router::OAuth2DeviceAuthorizationEndpoint::PATH).form(
1662            serde_json::json!({
1663                "client_id": client_id,
1664                "scope": "openid",
1665            }),
1666        );
1667        let response = state.request(request).await;
1668        response.assert_status(StatusCode::OK);
1669
1670        let device_grant: DeviceAuthorizationResponse = response.json();
1671
1672        // Poll the token endpoint, it should be pending
1673        let request =
1674            Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1675                "grant_type": "urn:ietf:params:oauth:grant-type:device_code",
1676                "device_code": device_grant.device_code,
1677                "client_id": client_id,
1678            }));
1679        let response = state.request(request).await;
1680        response.assert_status(StatusCode::FORBIDDEN);
1681
1682        let ClientError { error, .. } = response.json();
1683        assert_eq!(error, ClientErrorCode::AuthorizationPending);
1684
1685        state.clock.advance(Duration::try_hours(1).unwrap());
1686
1687        // Poll again, it should be expired
1688        let request =
1689            Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1690                "grant_type": "urn:ietf:params:oauth:grant-type:device_code",
1691                "device_code": device_grant.device_code,
1692                "client_id": client_id,
1693            }));
1694        let response = state.request(request).await;
1695        response.assert_status(StatusCode::FORBIDDEN);
1696
1697        let ClientError { error, .. } = response.json();
1698        assert_eq!(error, ClientErrorCode::ExpiredToken);
1699
1700        // Do another grant and reject it
1701        let request = Request::post(mas_router::OAuth2DeviceAuthorizationEndpoint::PATH).form(
1702            serde_json::json!({
1703                "client_id": client_id,
1704                "scope": "openid",
1705            }),
1706        );
1707        let response = state.request(request).await;
1708        response.assert_status(StatusCode::OK);
1709
1710        let device_grant: DeviceAuthorizationResponse = response.json();
1711
1712        // Find the grant and reject it
1713        let mut repo = state.repository().await.unwrap();
1714
1715        // Find the grant
1716        let grant = repo
1717            .oauth2_device_code_grant()
1718            .find_by_user_code(&device_grant.user_code)
1719            .await
1720            .unwrap()
1721            .unwrap();
1722
1723        // And reject it
1724        let grant = repo
1725            .oauth2_device_code_grant()
1726            .reject(&state.clock, grant, &browser_session)
1727            .await
1728            .unwrap();
1729
1730        repo.save().await.unwrap();
1731
1732        // Poll the token endpoint, it should be rejected
1733        let request =
1734            Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1735                "grant_type": "urn:ietf:params:oauth:grant-type:device_code",
1736                "device_code": grant.device_code,
1737                "client_id": client_id,
1738            }));
1739        let response = state.request(request).await;
1740        response.assert_status(StatusCode::FORBIDDEN);
1741
1742        let ClientError { error, .. } = response.json();
1743        assert_eq!(error, ClientErrorCode::AccessDenied);
1744    }
1745
1746    #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
1747    async fn test_unsupported_grant(pool: PgPool) {
1748        setup();
1749        let state = TestState::from_pool(pool).await.unwrap();
1750
1751        // Provision a client
1752        let request =
1753            Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
1754                "client_uri": "https://example.com/",
1755                "redirect_uris": ["https://example.com/callback"],
1756                "token_endpoint_auth_method": "client_secret_post",
1757                "grant_types": ["password"],
1758                "response_types": [],
1759            }));
1760
1761        let response = state.request(request).await;
1762        response.assert_status(StatusCode::CREATED);
1763
1764        let response: ClientRegistrationResponse = response.json();
1765        let client_id = response.client_id;
1766        let client_secret = response.client_secret.expect("to have a client secret");
1767
1768        // Call the token endpoint with an unsupported grant type
1769        let request =
1770            Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1771                "grant_type": "password",
1772                "client_id": client_id,
1773                "client_secret": client_secret,
1774                "username": "john",
1775                "password": "hunter2",
1776            }));
1777
1778        let response = state.request(request).await;
1779        response.assert_status(StatusCode::BAD_REQUEST);
1780        let ClientError { error, .. } = response.json();
1781        assert_eq!(error, ClientErrorCode::UnsupportedGrantType);
1782    }
1783}