vj4/handler/base.py

Summary

Maintainability
C
1 day
Test Coverage
import accept
import asyncio
import calendar
import functools
import hmac
import logging
import markupsafe
import pytz
import sockjs
from aiohttp import web
from email import utils

from vj4 import app
from vj4 import error
from vj4 import template
from vj4.model import blacklist
from vj4.model import builtin
from vj4.model import domain
from vj4.model import fs
from vj4.model import opcount
from vj4.model import token
from vj4.model import user
from vj4.model.adaptor import setting
from vj4.service import mailer
from vj4.util import json
from vj4.util import locale
from vj4.util import options
from vj4.util import tools

_logger = logging.getLogger(__name__)


class HandlerBase(setting.SettingMixin):
  NAME = None
  TITLE = None

  domain = builtin.DOMAIN_SYSTEM
  domain_id = builtin.DOMAIN_ID_SYSTEM
  domain_user = builtin.DOMAIN_USER_GUEST
  locale = locale.get(options.default_locale)
  timezone = None
  user = builtin.USER_GUEST

  async def prepare(self):
    self.session = await self.update_session()
    if 'domain_id' in self.request.match_info:
      self.domain_id = self.request.match_info.pop('domain_id')
    if 'uid' in self.session:
      uid = self.session['uid']
      self.user, self.domain, self.domain_user, bdoc = await asyncio.gather(
          user.get_by_uid(uid),
          domain.get(self.domain_id),
          domain.get_user(self.domain_id, uid),
          blacklist.get(self.remote_ip))
      if not self.user:
        raise error.UserNotFoundError(uid)
      if not self.domain_user:
        self.domain_user = {}
    else:
      self.domain, bdoc = await asyncio.gather(
          domain.get(self.domain_id), blacklist.get(self.remote_ip))
    self.view_lang = self.get_setting('view_lang')
    try:
      self.timezone = pytz.timezone(self.get_setting('timezone'))
    except pytz.UnknownTimeZoneError:
      pass
    self.locale = locale.get(self.view_lang)
    self.datetime_stamp = _datetime_stamp
    if bdoc:
      raise error.BlacklistedError(self.remote_ip)
    if not self.GLOBAL and not self.has_priv(builtin.PRIV_VIEW_ALL_DOMAIN):
      self.check_perm(builtin.PERM_VIEW)

  def has_perm(self, perm):
    role = self.domain_user.get('role', builtin.ROLE_DEFAULT)
    mask = domain.get_all_roles(self.domain).get(role, builtin.PERM_NONE)
    return ((perm & mask) == perm
            or self.has_priv(builtin.PRIV_MANAGE_ALL_DOMAIN))

  def check_perm(self, perm):
    if not self.has_perm(perm):
      raise error.PermissionError(perm)

  def has_priv(self, priv):
    return (priv & self.user['priv']) == priv

  def check_priv(self, priv):
    if not self.has_priv(priv):
      raise error.PrivilegeError(priv)

  def dudoc_has_perm(self, udoc, dudoc, perm, ddoc=None):
    if not udoc or not dudoc:
      return False
    # TODO(iceboy): Fix caller when dudoc=None is passed in.
    role = dudoc.get('role', builtin.ROLE_DEFAULT)
    mask = domain.get_all_roles(ddoc if ddoc else self.domain).get(role, builtin.PERM_NONE)
    return ((perm & mask) == perm
            or self.udoc_has_priv(udoc, builtin.PRIV_MANAGE_ALL_DOMAIN))

  def udoc_has_priv(self, udoc, priv):
    if not udoc:
      return False
    # TODO(iceboy): Fix caller when udoc=None is passed in.
    return (priv & udoc['priv']) == priv

  def own(self, doc, perm=builtin.PERM_NONE, field='owner_uid', priv=builtin.PRIV_NONE):
    return (doc[field] == self.user['_id']) and self.has_perm(perm) and self.has_priv(priv)

  async def update_session(self, *, new_saved=False, **kwargs):
    """Update or create session if necessary.

    If 'sid' in cookie, the 'expire_at' field is updated.
    If 'sid' not in cookie, only create when there is extra data.

    Args:
      new_saved: use saved session on creation.
      kwargs: extra data.

    Returns:
      The session document.
    """
    (sid, save), session = map(self.request.cookies.get, ['sid', 'save']), None
    if not sid:
      save = new_saved
    if save:
      token_type = token.TYPE_SAVED_SESSION
      session_expire_seconds = options.saved_session_expire_seconds
    else:
      token_type = token.TYPE_UNSAVED_SESSION
      session_expire_seconds = options.unsaved_session_expire_seconds
    if sid:
      session = await token.update(sid, token_type, session_expire_seconds,
                                   **{**kwargs,
                                      'update_ip': self.remote_ip,
                                      'update_ua': self.request.headers.get('User-Agent')})
    if kwargs and not session:
      sid, session = await token.add(token_type, session_expire_seconds,
                                     **{**kwargs,
                                        'create_ip': self.remote_ip,
                                        'create_ua': self.request.headers.get('User-Agent')})
    if session:
      cookie_kwargs = {'domain': options.cookie_domain,
                       'secure': options.cookie_secure,
                       'httponly': True}
      if save:
        timestamp = calendar.timegm(session['expire_at'].utctimetuple())
        cookie_kwargs['expires'] = utils.formatdate(timestamp, usegmt=True)
        cookie_kwargs['max_age'] = session_expire_seconds
        self.response.set_cookie('save', '1', **cookie_kwargs)
      self.response.set_cookie('sid', sid, **cookie_kwargs)
    else:
      self.clear_cookies('sid', 'save')
    return session or {}

  async def delete_session(self):
    sid, save = map(self.request.cookies.get, ['sid', 'save'])
    if sid:
      if save:
        token_type = token.TYPE_SAVED_SESSION
      else:
        token_type = token.TYPE_UNSAVED_SESSION
      await token.delete(sid, token_type)
    self.clear_cookies('sid', 'save')

  def clear_cookies(self, *names):
    for name in names:
      if name in self.request.cookies:
        self.response.set_cookie(name, '',
                                 expires=utils.formatdate(0, usegmt=True),
                                 domain=options.cookie_domain,
                                 secure=options.cookie_secure,
                                 httponly=True)

  @property
  def remote_ip(self):
    return tools.get_remote_ip(self.request)

  @property
  def csrf_token(self):
    if self.session:
      return _get_csrf_token(self.session['_id'])
    else:
      return ''

  def build_path(self, *args):
    return [(self.domain['name'], self.reverse_url('domain_main')), *args]

  def reverse_url(self, name, **kwargs):
    kwargs.setdefault('domain_id', self.domain_id)
    return _reverse_url(name, **kwargs)

  def translate(self, text):
    return self.locale.get(text, text)

  def render_html(self, template_name, **kwargs):
    kwargs['handler'] = self
    if '_' not in kwargs:
      kwargs['_'] = self.translate
    kwargs['domain_id'] = self.domain_id
    kwargs['domain'] = self.domain
    if 'page_name' not in kwargs:
      kwargs['page_name'] = self.NAME
    if 'page_title' not in kwargs:
      kwargs['page_title'] = self.translate(self.TITLE)
    if 'path_components' not in kwargs:
      kwargs['path_components'] = self.build_path((self.translate(self.NAME), None))
    kwargs['reverse_url'] = self.reverse_url
    kwargs['datetime_span'] = functools.partial(_datetime_span, timezone=self.timezone)
    return template.Environment().get_template(template_name).render(kwargs)

  def render_title(self, page_title=None):
    if not page_title:
      page_title = self.translate(self.TITLE)
    if self.domain_id != builtin.DOMAIN_ID_SYSTEM:
      page_title += ' - {}'.format(self.domain['name'])
    page_title += ' - Vijos'
    return page_title

  async def send_mail(self, mail, title, template_name, **kwargs):
    content = self.render_html(template_name, url_prefix=options.url_prefix, **kwargs)
    translate = self.translate
    if '_' in kwargs:
      translate = kwargs['_']
    await mailer.send_mail(mail, '{0} - Vijos'.format(translate(title)), content)


class Handler(web.View, HandlerBase):
  def __await__(self):
    try:
      self.response = web.Response()
      yield from HandlerBase.prepare(self).__await__()
      yield from super(Handler, self).__await__()
    except asyncio.CancelledError:
      raise
    except error.UserFacingError as e:
      self.response.set_status(e.http_status, None)
      if isinstance(e, error.PermissionError):
        e.args = (self.translate(e.args[0]), *e.args[1:])
      if self.prefer_json:
        self.response.content_type = 'application/json'
        message = self.translate(e.message).format(*e.args)
        self.response.text = json.encode({'error': {**e.to_dict(), 'message': message}})
      else:
        self.render(e.template_name, error=e,
                    page_name='error', page_title=self.translate('error'),
                    path_components=self.build_path((self.translate('error'), None)))
      uid = self.user['_id'] if hasattr(self, 'user') else None
      _logger.warning('User facing error by %s %s %s: %s', self.url, self.remote_ip, uid, repr(e))
    except Exception as e:
      uid = self.user['_id'] if hasattr(self, 'user') else None
      _logger.error('System error by %s %s %s: %s', self.url, self.remote_ip, uid, repr(e))
      raise
    return self.response

  def render(self, template_name, **kwargs):
    self.response.content_type = 'text/html'
    self.response.headers.add('Cache-Control', 'no-store, no-cache, must-revalidate')
    self.response.headers.add('Pragma', 'no-cache')
    self.response.text = self.render_html(template_name, **kwargs)

  def json(self, obj):
    self.response.content_type = 'application/json'
    self.response.headers.add('Cache-Control', 'no-store, no-cache, must-revalidate')
    self.response.headers.add('Pragma', 'no-cache')
    self.response.text = json.encode(obj)

  async def binary(self, data, content_type='application/octet-stream', file_name=None):
    self.response = web.StreamResponse()
    self.response.content_length = len(data)
    self.response.content_type = content_type
    if file_name:
      for char in '/<>:\"\'\\|?* ':
        file_name = file_name.replace(char, '')
      self.response.headers.add('Content-Disposition',
                                'attachment; filename="{}"'.format(file_name))
    await self.response.prepare(self.request)
    await self.response.write(data)

  @property
  def prefer_json(self):
    accept_header = self.request.headers.get('Accept')
    try:
      parse_result = accept.parse(accept_header)
    except (ValueError, IndexError):
      # the accept library throws bogus exceptions
      _logger.warning('Unparsable accept header: %s', accept_header)
      return False
    for d in parse_result:
      if d.media_type == 'application/json':
        return True
      elif d.media_type == 'text/html' or d.all_types:
        return False
    return False

  @property
  def url(self):
    return self.request.path

  @property
  def referer_or_main(self):
    return self.request.headers.get('referer') or self.reverse_url('domain_main')

  def redirect(self, redirect_url):
    self.response.set_status(web.HTTPFound.status_code, None)
    self.response.headers['Location'] = redirect_url

  def json_or_redirect(self, redirect_url, **kwargs):
    if self.prefer_json:
      self.json(kwargs)
    else:
      self.redirect(redirect_url)

  def json_or_render(self, template_name, **kwargs):
    if self.prefer_json:
      self.json(kwargs)
    else:
      self.render(template_name, **kwargs)

  @property
  def ui_context(self):
    return {'csrf_token': self.csrf_token,
            'cdn_prefix': options.cdn_prefix,
            'url_prefix': options.url_prefix}

  @property
  def user_context(self):
    return {'uid': self.user['_id'],
            'domain': self.domain_id}


class OperationHandler(Handler):
  DEFAULT_OPERATION = 'default'

  async def post(self):
    arguments = {**await self.request.post()}
    operation = arguments.pop('operation', self.DEFAULT_OPERATION)
    try:
      method = getattr(self, 'post_' + operation)
    except AttributeError:
      raise error.InvalidOperationError(operation) from None
    await method(**arguments)


class Connection(sockjs.Session, HandlerBase):
  def __init__(self, *args, **kwargs):
    super(Connection, self).__init__(*args, **kwargs)
    self.response = web.Response()  # dummy response

  async def on_open(self):
    pass

  async def on_message(self, **kwargs):
    pass

  async def on_close(self):
    pass

  def send(self, **kwargs):
    super(Connection, self).send(json.encode(kwargs))


@functools.lru_cache()
def _get_csrf_token(session_id_binary):
  return hmac.new(b'csrf_token', session_id_binary, 'sha256').hexdigest()


@functools.lru_cache()
def _reverse_url(name, *, domain_id, **kwargs):
  """DEPRECATED: This function is deprecated. But we don't have a replacement yet."""
  kwargs = {key: str(value) for key, value in kwargs.items()}
  if domain_id != builtin.DOMAIN_ID_SYSTEM:
    name += '_with_domain_id'
    kwargs['domain_id'] = domain_id
  if kwargs:
    return str(app.Application().router[name].url_for(**kwargs))
  else:
    return str(app.Application().router[name].url_for())


@functools.lru_cache()
def _datetime_span(dt, relative=True, format='%Y-%m-%d %H:%M:%S', timezone=pytz.utc):
  if not dt.tzinfo:
    dt = dt.replace(tzinfo=pytz.utc)
  return markupsafe.Markup(
      '<span class="time{0}" data-timestamp="{1}">{2}</span>'.format(
          ' relative' if relative else '',
          calendar.timegm(dt.utctimetuple()),
          dt.astimezone(timezone).strftime(format)))


@functools.lru_cache()
def _datetime_stamp(dt):
  if not dt.tzinfo:
    dt = dt.replace(tzinfo=pytz.utc)
  return calendar.timegm(dt.utctimetuple())


# Decorators
def require_perm(perm):
  def decorate(func):
    @functools.wraps(func)
    def wrapper(self, *args, **kwargs):
      self.check_perm(perm)
      return func(self, *args, **kwargs)

    return wrapper

  return decorate


def require_priv(priv):
  def decorate(func):
    @functools.wraps(func)
    def wrapper(self, *args, **kwargs):
      self.check_priv(priv)
      return func(self, *args, **kwargs)

    return wrapper

  return decorate


def require_csrf_token(func):
  @functools.wraps(func)
  def wrapper(self, *args, **kwargs):
    if self.csrf_token and self.csrf_token != kwargs.pop('csrf_token', ''):
      raise error.CsrfTokenError()
    return func(self, *args, **kwargs)

  return wrapper


def route_argument(func):
  @functools.wraps(func)
  def wrapped(self, **kwargs):
    return func(self, **{**kwargs, **self.request.match_info})

  return wrapped


def get_argument(func):
  @functools.wraps(func)
  def wrapped(self, **kwargs):
    return func(self, **{**kwargs, **self.request.query})

  return wrapped


def post_argument(coro):
  @functools.wraps(coro)
  async def wrapped(self, **kwargs):
    return await coro(self, **{**kwargs, **await self.request.post()})

  return wrapped


def multipart_argument(coro):
  @functools.wraps(coro)
  async def wrapped(self, **kwargs):
    file_ids = list()
    try:
      async for part in await self.request.multipart():
        if not part.filename:
          kwargs[part.name] = (await part.read()).decode()
        else:
          grid_in = await fs.add(self.get_content_type(part.filename))
          try:
            chunk = await part.read_chunk()
            while chunk:
              _, chunk = await asyncio.gather(grid_in.write(chunk), part.read_chunk())
            await grid_in.close()
          except:
            await grid_in.abort()
            raise
          file_id = await fs.link_by_md5(grid_in.md5, grid_in._id)
          if file_id:
            await fs.unlink(grid_in._id)
          else:
            file_id = grid_in._id
          file_ids.append(file_id)
          kwargs[part.name] = file_id
      return await coro(self, **kwargs)
    except:
      await asyncio.gather(*[fs.unlink(file_id) for file_id in file_ids])
      # TODO(iceboy): call self.response.force_close() after aiohttp supports it.
      raise

  return wrapped

def limit_rate(op, period_secs, max_operations):
  def decorate(coro):
    @functools.wraps(coro)
    async def wrapped(self, **kwargs):
      await opcount.inc(op, self.remote_ip, period_secs, max_operations)
      return await coro(self, **kwargs)

    return wrapped
  return decorate


def sanitize(func):
  @functools.wraps(func)
  def wrapped(self, **kwargs):
    new_kwargs = {}
    for key, value in kwargs.items():
      try:
        new_kwargs[key] = func.__annotations__[key](value)
      except KeyError:
        pass
      except Exception:
        raise error.InvalidArgumentError(key)
    return func(self, **new_kwargs)

  return wrapped