airbnb/caravel

View on GitHub
superset/utils/oauth2.py

Summary

Maintainability
A
35 mins
Test Coverage
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

from datetime import datetime, timedelta, timezone
from typing import Any, TYPE_CHECKING

import backoff
import jwt
from flask import current_app, url_for
from marshmallow import EXCLUDE, fields, post_load, Schema

from superset import db
from superset.distributed_lock import KeyValueDistributedLock
from superset.exceptions import CreateKeyValueDistributedLockFailedException
from superset.superset_typing import OAuth2ClientConfig, OAuth2State

if TYPE_CHECKING:
    from superset.db_engine_specs.base import BaseEngineSpec
    from superset.models.core import DatabaseUserOAuth2Tokens

JWT_EXPIRATION = timedelta(minutes=5)


@backoff.on_exception(
    backoff.expo,
    CreateKeyValueDistributedLockFailedException,
    factor=10,
    base=2,
    max_tries=5,
)
def get_oauth2_access_token(
    config: OAuth2ClientConfig,
    database_id: int,
    user_id: int,
    db_engine_spec: type[BaseEngineSpec],
) -> str | None:
    """
    Return a valid OAuth2 access token.

    If the token exists but is expired and a refresh token is available the function will
    return a fresh token and store it in the database for further requests. The function
    has a retry decorator, in case a dashboard with multiple charts triggers
    simultaneous requests for refreshing a stale token; in that case only the first
    process to acquire the lock will perform the refresh, and othe process should find a
    a valid token when they retry.
    """
    # pylint: disable=import-outside-toplevel
    from superset.models.core import DatabaseUserOAuth2Tokens

    token = (
        db.session.query(DatabaseUserOAuth2Tokens)
        .filter_by(user_id=user_id, database_id=database_id)
        .one_or_none()
    )
    if token is None:
        return None

    if token.access_token and datetime.now() < token.access_token_expiration:
        return token.access_token

    if token.refresh_token:
        return refresh_oauth2_token(config, database_id, user_id, db_engine_spec, token)

    # since the access token is expired and there's no refresh token, delete the entry
    db.session.delete(token)

    return None


def refresh_oauth2_token(
    config: OAuth2ClientConfig,
    database_id: int,
    user_id: int,
    db_engine_spec: type[BaseEngineSpec],
    token: DatabaseUserOAuth2Tokens,
) -> str | None:
    with KeyValueDistributedLock(
        namespace="refresh_oauth2_token",
        user_id=user_id,
        database_id=database_id,
    ):
        token_response = db_engine_spec.get_oauth2_fresh_token(
            config,
            token.refresh_token,
        )

        # store new access token; note that the refresh token might be revoked, in which
        # case there would be no access token in the response
        if "access_token" not in token_response:
            return None

        token.access_token = token_response["access_token"]
        token.access_token_expiration = datetime.now() + timedelta(
            seconds=token_response["expires_in"]
        )
        db.session.add(token)

    return token.access_token


def encode_oauth2_state(state: OAuth2State) -> str:
    """
    Encode the OAuth2 state.
    """
    payload = {
        "exp": datetime.now(tz=timezone.utc) + JWT_EXPIRATION,
        "database_id": state["database_id"],
        "user_id": state["user_id"],
        "default_redirect_uri": state["default_redirect_uri"],
        "tab_id": state["tab_id"],
    }
    encoded_state = jwt.encode(
        payload=payload,
        key=current_app.config["SECRET_KEY"],
        algorithm=current_app.config["DATABASE_OAUTH2_JWT_ALGORITHM"],
    )

    # Google OAuth2 needs periods to be escaped.
    encoded_state = encoded_state.replace(".", "%2E")

    return encoded_state


class OAuth2StateSchema(Schema):
    database_id = fields.Int(required=True)
    user_id = fields.Int(required=True)
    default_redirect_uri = fields.Str(required=True)
    tab_id = fields.Str(required=True)

    # pylint: disable=unused-argument
    @post_load
    def make_oauth2_state(
        self,
        data: dict[str, Any],
        **kwargs: Any,
    ) -> OAuth2State:
        return OAuth2State(
            database_id=data["database_id"],
            user_id=data["user_id"],
            default_redirect_uri=data["default_redirect_uri"],
            tab_id=data["tab_id"],
        )

    class Meta:  # pylint: disable=too-few-public-methods
        # ignore `exp`
        unknown = EXCLUDE


oauth2_state_schema = OAuth2StateSchema()


def decode_oauth2_state(encoded_state: str) -> OAuth2State:
    """
    Decode the OAuth2 state.
    """
    # Google OAuth2 needs periods to be escaped.
    encoded_state = encoded_state.replace("%2E", ".")

    payload = jwt.decode(
        jwt=encoded_state,
        key=current_app.config["SECRET_KEY"],
        algorithms=[current_app.config["DATABASE_OAUTH2_JWT_ALGORITHM"]],
    )
    state = oauth2_state_schema.load(payload)

    return state


class OAuth2ClientConfigSchema(Schema):
    id = fields.String(required=True)
    secret = fields.String(required=True)
    scope = fields.String(required=True)
    redirect_uri = fields.String(
        required=False,
        load_default=lambda: url_for("DatabaseRestApi.oauth2", _external=True),
    )
    authorization_request_uri = fields.String(required=True)
    token_request_uri = fields.String(required=True)