wglass/zoonado

View on GitHub
zoonado/session.py

Summary

Maintainability
B
4 hrs
Test Coverage
from __future__ import unicode_literals

import collections
import logging
import random

from tornado import gen, ioloop, iostream

from zoonado import protocol, exc
from .connection import Connection
from .states import States, SessionStateMachine
from .retry import RetryPolicy


DEFAULT_ZOOKEEPER_PORT = 2181

MAX_FIND_WAIT = 60  # in seconds

HEARTBEAT_FREQUENCY = 3  # heartbeats per timeout interval


log = logging.getLogger(__name__)


class Session(object):

    def __init__(self, servers, timeout, retry_policy, allow_read_only):
        self.hosts = []
        for server in servers.split(","):
            if ":" in server:
                host, port = server.split(":")
            else:
                host = server
                port = DEFAULT_ZOOKEEPER_PORT

            self.hosts.append((host, port))

        self.conn = None
        self.state = SessionStateMachine()

        self.retry_policy = retry_policy or RetryPolicy.forever()
        self.allow_read_only = allow_read_only

        self.xid = 0
        self.last_zxid = None

        self.session_id = None
        self.timeout = timeout
        self.password = b'\x00'

        self.heartbeat_handle = None

        self.watch_callbacks = collections.defaultdict(set)

        self.closing = False

    @gen.coroutine
    def ensure_safe_state(self, writing=False):
        safe_states = [States.CONNECTED]
        if self.allow_read_only and not writing:
            safe_states.append(States.READ_ONLY)

        if self.state in safe_states:
            return

        yield self.state.wait_for(*safe_states)

    @gen.coroutine
    def start(self):
        io_loop = ioloop.IOLoop.current()
        io_loop.add_callback(self.set_heartbeat)
        io_loop.add_callback(self.repair_loop)

        yield self.ensure_safe_state()

    @gen.coroutine
    def find_server(self, allow_read_only):
        conn = None

        retry_policy = RetryPolicy.exponential_backoff(maximum=MAX_FIND_WAIT)

        while not conn:
            yield retry_policy.enforce()

            servers = random.sample(self.hosts, len(self.hosts))
            for host, port in servers:
                log.info("Connecting to %s:%s", host, port)
                conn = yield self.make_connection(host, port)
                if not conn or (conn.start_read_only and not allow_read_only):
                    continue

            if not conn:
                log.warn("No servers available, will keep trying.")

        old_conn = self.conn
        self.conn = conn

        io_loop = ioloop.IOLoop.current()

        if old_conn:
            io_loop.add_callback(old_conn.close, self.timeout)

        if conn.start_read_only:
            io_loop.add_callback(self.find_server, allow_read_only=False)

    @gen.coroutine
    def make_connection(self, host, port):
        conn = Connection(host, port, watch_handler=self.event_dispatch)
        try:
            yield conn.connect()
        except Exception:
            log.exception("Couldn't connect to %s:%s", host, port)
            return

        raise gen.Return(conn)

    @gen.coroutine
    def establish_session(self):
        log.info("Establising session.")
        zxid, response = yield self.conn.send_connect(
            protocol.ConnectRequest(
                protocol_version=0,
                last_seen_zxid=self.last_zxid or 0,
                timeout=int((self.timeout or 0) * 1000),
                session_id=self.session_id or 0,
                password=self.password,
                read_only=self.allow_read_only,
            )
        )
        self.last_zxid = zxid

        if response.session_id == 0:  # invalid session, probably expired
            self.state.transition_to(States.LOST)
            raise exc.SessionLost()

        log.info("Got session id %s", hex(response.session_id))
        log.info("Negotiated timeout: %s seconds", response.timeout / 1000)

        self.session_id = response.session_id
        self.password = response.password
        self.timeout = response.timeout / 1000

        self.last_zxid = zxid

    @gen.coroutine
    def repair_loop(self):
        while not self.closing:
            yield self.state.wait_for(States.SUSPENDED, States.LOST)
            if self.closing:
                break

            yield self.find_server(allow_read_only=self.allow_read_only)

            session_was_lost = self.state == States.LOST

            try:
                yield self.establish_session()
            except exc.SessionLost:
                self.conn.abort(exc.SessionLost)
                yield self.conn.close(self.timeout)
                self.session_id = None
                self.password = b'\x00'
                continue

            if self.conn.start_read_only:
                self.state.transition_to(States.READ_ONLY)
            else:
                self.state.transition_to(States.CONNECTED)

            self.conn.start_read_loop()

            if session_was_lost:
                yield self.set_existing_watches()

    @gen.coroutine
    def send(self, request):
        response = None
        while not response:
            yield self.retry_policy.enforce(request)
            yield self.ensure_safe_state(writing=request.writes_data)

            try:
                self.xid += 1
                zxid, response = yield self.conn.send(request, xid=self.xid)
                self.last_zxid = zxid
                self.set_heartbeat()
                self.retry_policy.clear(request)
            except exc.ConnectError:
                self.state.transition_to(States.SUSPENDED)

        raise gen.Return(response)

    def set_heartbeat(self):
        timeout = self.timeout / HEARTBEAT_FREQUENCY

        io_loop = ioloop.IOLoop.current()

        if self.heartbeat_handle:
            io_loop.remove_timeout(self.heartbeat_handle)

        self.heartbeat_handle = io_loop.call_later(timeout, self.heartbeat)

    @gen.coroutine
    def heartbeat(self):
        if self.closing:
            return
        yield self.ensure_safe_state()

        try:
            zxid, _ = yield self.conn.send(protocol.PingRequest())
            self.last_zxid = zxid
        except (exc.ConnectError, iostream.StreamClosedError):
            self.state.transition_to(States.SUSPENDED)
        finally:
            self.set_heartbeat()

    def add_watch_callback(self, event_type, path, callback):
        self.watch_callbacks[(event_type, path)].add(callback)

    def remove_watch_callback(self, event_type, path, callback):
        self.watch_callbacks[(event_type, path)].discard(callback)

    def event_dispatch(self, event):
        log.debug("Got watch event: %s", event)

        if event.type:
            key = (event.type, event.path)
            for callback in self.watch_callbacks[key]:
                ioloop.IOLoop.current().add_callback(callback, event.path)
            return

        if event.state == protocol.WatchEvent.DISCONNECTED:
            log.error("Got 'disconnected' watch event.")
            self.state.transition_to(States.LOST)
        elif event.state == protocol.WatchEvent.SESSION_EXPIRED:
            log.error("Got 'session expired' watch event.")
            self.state.transition_to(States.LOST)
        elif event.state == protocol.WatchEvent.AUTH_FAILED:
            log.error("Got 'auth failed' watch event.")
            self.state.transition_to(States.LOST)
        elif event.state == protocol.WatchEvent.CONNECTED_READ_ONLY:
            log.warn("Got 'connected read only' watch event.")
            self.state.transition_to(States.READ_ONLY)
        elif event.state == protocol.WatchEvent.SASL_AUTHENTICATED:
            log.info("Authentication successful.")
        elif event.state == protocol.WatchEvent.CONNECTED:
            log.info("Got 'connected' watch event.")
            self.state.transition_to(States.CONNECTED)

    @gen.coroutine
    def set_existing_watches(self):
        if not self.watch_callbacks:
            return

        request = protocol.SetWatchesRequest(
            relative_zxid=self.last_zxid or 0,
            data_watches=[],
            exist_watches=[],
            child_watches=[],
        )

        for event_type, path in self.watch_callbacks.keys():
            if event_type == protocol.WatchEvent.CREATED:
                request.exist_watches.append(path)
            if event_type == protocol.WatchEvent.DATA_CHANGED:
                request.data_watches.append(path)
            elif event_type == protocol.WatchEvent.CHILDREN_CHANGED:
                request.child_watches.append(path)

        yield self.send(request)

    @gen.coroutine
    def close(self):
        self.closing = True

        yield self.send(protocol.CloseRequest())
        self.state.transition_to(States.LOST)

        yield self.conn.close(self.timeout)