sanic_jwt_extended/decorators.py
from functools import wraps
from typing import Callable, List, Optional, Tuple
from sanic.request import Request
from sanic_jwt_extended.exceptions import (
AccessDeniedError,
ConfigurationConflictError,
CSRFError,
FreshTokenRequiredError,
InvalidHeaderError,
NoAuthorizationError,
RevokedTokenError,
WrongTokenError,
)
from sanic_jwt_extended.jwt_manager import JWT
from sanic_jwt_extended.tokens import Token
try:
from hmac import compare_digest
except ImportError: # pragma: no cover
def compare_digest(a, b):
if isinstance(a, str):
a = a.encode("utf-8")
if isinstance(b, str):
b = b.encode("utf-8")
if len(a) != len(b):
return False
r = 0
for x, y in zip(a, b):
r |= x ^ y
return not r
jwt_get_function = Callable[[Request, bool], Tuple[str, Optional[str]]]
def _get_request(args) -> Request:
if isinstance(args[0], Request):
request = args[0]
else:
request = args[1]
return request
def _get_raw_jwt_from_request(request, is_access=True):
functions: List[jwt_get_function] = []
for eligible_location in JWT.config.token_location:
if eligible_location == "header":
functions.append(_get_raw_jwt_from_headers)
if eligible_location == "query":
functions.append(_get_raw_jwt_from_query_params)
if eligible_location == "cookies":
functions.append(_get_raw_jwt_from_cookies)
raw_jwt = None
csrf_value = None
errors = []
for f in functions:
try:
raw_jwt, csrf_value = f(request, is_access)
break
except NoAuthorizationError as e:
errors.append(str(e))
if not raw_jwt:
raise NoAuthorizationError(', '.join(errors))
return raw_jwt, csrf_value
def _get_raw_jwt_from_headers(request, is_access):
header_key = (
JWT.config.jwt_header_key if is_access else JWT.config.refresh_jwt_header_key
)
header_prefix = JWT.config.jwt_header_prefix
token_header = request.headers.get(header_key)
if not token_header:
raise NoAuthorizationError(f'Missing header "{header_key}"')
parts: List[str] = token_header.split()
if parts[0] != header_prefix or len(parts) != 2:
raise InvalidHeaderError(
f"Bad {header_key} header. Expected value '{header_prefix} <JWT>'"
)
encoded_token: str = parts[1]
return encoded_token, None
def _get_raw_jwt_from_query_params(request, _):
encoded_token = request.args.get(JWT.config.jwt_query_param_name)
if not encoded_token:
raise NoAuthorizationError(
f'Missing query parameter "{JWT.config.jwt_query_param_name}"'
)
return encoded_token, None
def _get_raw_jwt_from_cookies(request, is_access):
cookie_key = JWT.config.jwt_cookie if is_access else JWT.config.refresh_jwt_cookie
csrf_header_key = (
JWT.config.jwt_csrf_header if is_access else JWT.config.refresh_jwt_csrf_header
)
encoded_token = request.cookies.get(cookie_key)
csrf_value = None
if not encoded_token:
raise NoAuthorizationError(f'Missing cookie "{cookie_key}"')
if JWT.config.csrf_protect and request.method in JWT.config.csrf_request_methods:
csrf_value = request.headers.get(csrf_header_key)
if not csrf_value:
raise CSRFError("Missing CSRF token")
return encoded_token, csrf_value
def _csrf_check(csrf_from_request, csrf_from_jwt):
if not csrf_from_jwt or not isinstance(csrf_from_jwt, str):
raise CSRFError('Can not find valid CSRF data from token')
if not compare_digest(csrf_from_request, csrf_from_jwt):
raise CSRFError('CSRF double submit tokens do not match')
def jwt_required(
function=None, *, allow=None, deny=None, fresh_required=False,
):
def real(fn):
@wraps(fn)
async def wrapper(*args, **kwargs):
request = _get_request(args)
raw_jwt, csrf_value = _get_raw_jwt_from_request(request)
token_obj = Token(raw_jwt)
if csrf_value:
_csrf_check(csrf_value, token_obj.csrf)
if token_obj.type != "access":
raise WrongTokenError("Only access tokens are allowed")
if fresh_required and not token_obj.fresh:
raise FreshTokenRequiredError("Only fresh access tokens are allowed")
if allow and token_obj.role not in allow:
raise AccessDeniedError("You are not allowed to access here")
if deny and token_obj.role in deny:
raise AccessDeniedError("You are not allowed to access here")
if JWT.config.use_blacklist and await JWT.blacklist.is_blacklisted(
token_obj
):
raise RevokedTokenError("Token has been revoked")
kwargs["token"] = token_obj
return await fn(*args, **kwargs)
return wrapper
if function:
return real(function)
else:
if allow and deny:
raise ConfigurationConflictError(
"Can not use 'deny' and 'allow' option together."
)
return real
def jwt_optional(function):
@wraps(function)
async def wrapper(*args, **kwargs):
request = _get_request(args)
token_obj: Optional[Token] = None
try:
raw_jwt, csrf_value = _get_raw_jwt_from_request(request)
token_obj = Token(raw_jwt)
if csrf_value:
_csrf_check(csrf_value, token_obj.csrf)
if token_obj.type != "access":
raise WrongTokenError("Only access tokens are allowed")
except (NoAuthorizationError, InvalidHeaderError):
pass
kwargs["token"] = token_obj
return await function(*args, **kwargs)
return wrapper
def refresh_jwt_required(function=None, *, allow=None, deny=None):
def real(fn):
@wraps(fn)
async def wrapper(*args, **kwargs):
request = _get_request(args)
raw_jwt, csrf_value = _get_raw_jwt_from_request(request, is_access=False)
token_obj = Token(raw_jwt)
if csrf_value:
_csrf_check(csrf_value, token_obj.csrf)
if token_obj.type != "refresh":
raise WrongTokenError("Only refresh tokens are allowed")
if allow and token_obj.role not in allow:
raise AccessDeniedError("You are not allowed to refresh in here")
if deny and token_obj.role in deny:
raise AccessDeniedError("You are not allowed to refresh in here")
if JWT.config.use_blacklist and await JWT.blacklist.is_blacklisted(
token_obj
):
raise RevokedTokenError("Token has been revoked")
kwargs["token"] = token_obj
return await fn(*args, **kwargs)
return wrapper
if function:
return real(function)
else:
if allow and deny:
raise ConfigurationConflictError(
"Can not use 'deny' and 'allow' option together."
)
return real