timdaman/check_docker

View on GitHub
check_docker/check_swarm.py

Summary

Maintainability
B
4 hrs
Test Coverage
A
94%
#!/usr/bin/env python3
import argparse
import json
import logging
import os
import re
import socket
import stat
import traceback
from functools import lru_cache
from http.client import HTTPConnection
from sys import argv
from urllib.request import AbstractHTTPHandler, HTTPHandler, HTTPSHandler, OpenerDirector

logger = logging.getLogger()
__author__ = 'Tim Laurence'
__copyright__ = "Copyright 2019"
__credits__ = ['Tim Laurence']
__license__ = "GPL"
__version__ = "2.2.2"

'''
nrpe compatible check for docker swarm

Requires Python 3

Note: I really would have preferred to have used requests for all the network connections but that would have added a
dependency.
'''

DEFAULT_SOCKET = '/var/run/docker.sock'
DEFAULT_TIMEOUT = 10.0
DEFAULT_PORT = 2375
DEFAULT_HEADERS = [('Accept', 'application/vnd.docker.distribution.manifest.v2+json')]
OK_RC = 0
WARNING_RC = 1
CRITICAL_RC = 2
UNKNOWN_RC = 3

HTTP_GOOD_CODES = range(200, 299)

# These hold the final results
rc = -1
messages = []


# Hacked up urllib to handle sockets
#############################################################################################
# Docker runs a http connection over a socket. http.client is knows how to deal with these
# but lacks some niceties. Urllib wraps that and makes up for some of the deficiencies but
# cannot fix the fact http.client can't read from socket files. In order to take advantage of
# urllib and http.client's  capabilities the class below tweaks HttpConnection and passes it
# to urllib registering for socket:// connections

# This is all side effect so excluding coverage
class SocketFileHandler(AbstractHTTPHandler):  # pragma: no cover
    class SocketFileToHttpConnectionAdaptor(HTTPConnection):
        def __init__(self, socket_file, timeout=DEFAULT_TIMEOUT):
            super().__init__(host='', port=0, timeout=timeout)
            self.socket_file = socket_file

        def connect(self):
            self.sock = socket.socket(family=socket.AF_UNIX, type=socket.SOCK_STREAM, proto=0, fileno=None)
            self.sock.settimeout(self.timeout)
            self.sock.connect(self.socket_file)

    def socket_open(self, req):
        socket_file, path = req.selector.split(':', 1)
        req.host = socket_file
        req.selector = path
        return self.do_open(self.SocketFileToHttpConnectionAdaptor, req)


better_urllib_get = OpenerDirector()
better_urllib_get.addheaders = DEFAULT_HEADERS.copy()
better_urllib_get.add_handler(HTTPHandler())
better_urllib_get.add_handler(HTTPSHandler())
better_urllib_get.add_handler(SocketFileHandler())


# Util functions
#############################################################################################


@lru_cache()
def get_url(url):
    response = better_urllib_get.open(url, timeout=timeout)
    return process_urllib_response(response), response.status


def process_urllib_response(response):
    response_bytes = response.read()
    body = response_bytes.decode('utf-8')
    logger.debug(body)
    return json.loads(body)


def get_swarm_status():
    content, status = get_url(daemon + '/swarm')
    return status


def get_service_info(name):
    return get_url(daemon + '/services/{service}'.format(service=name))


def get_service_tasks(name):
    tasks, status = get_url(daemon + '/tasks?filters={{"name":{{"{service}":true}}}}'.format(service=name))
    return tasks


def get_nodes():
    return get_url(daemon + '/nodes')


def get_services(names):
    services_list, status = get_url(daemon + '/services')
    if status == 406:
        critical("Error checking service status, node is not in swarm mode")
        return []
    elif status not in HTTP_GOOD_CODES:
        unknown("Could not retrieve service info")
        return []

    all_services_names = set(x['Spec']['Name'] for x in services_list)
    if 'all' in names:
        return all_services_names

    filtered = set()
    not_found = []
    for matcher in names:
        found = False
        for candidate in all_services_names:
            if re.match("^{}$".format(matcher), candidate):
                filtered.add(candidate)
                found = True
        # If we don't find a service that matches out regex
        if not found:
            not_found.append(matcher)
    if len(not_found) > 0:
        critical("No services match {}".format(','.join(not_found)))
    return filtered


def set_rc(new_rc):
    global rc
    rc = new_rc if new_rc > rc else rc


def ok(message):
    set_rc(OK_RC)
    messages.append('OK: ' + message)


def warning(message):
    set_rc(WARNING_RC)
    messages.append('WARNING: ' + message)


def critical(message):
    set_rc(CRITICAL_RC)
    messages.append('CRITICAL: ' + message)


def unknown(message):
    set_rc(UNKNOWN_RC)
    messages.append('UNKNOWN: ' + message)


# Checks
#############################################################################################
def check_swarm():
    status = get_swarm_status()
    process_url_status(status, ok_msg='Node is in a swarm',
                       critical_msg='Node is not in a swarm', unknown_msg='Error accessing swarm info')


def process_global_service(name, ignore_paused=False):
    bad_node_states = {'drain'}
    if ignore_paused:
        bad_node_states.add('paused')

    # Get all the nodes we care about based on their state
    node_list, status = get_nodes()
    node_index = set()
    for node in node_list:
        if node['Spec']['Availability'] in bad_node_states:
            continue
        node_index.add(node['ID'])

    # If a task is on a targeted node confirm it is running
    # Services that are not running are considered bad. This is to prevent services in crash loops from being ignored
    # Also note, this ignores conditions where services state they are running on a node not in the index.
    service_tasks = get_service_tasks(name)
    for task in service_tasks:
        if task['Status']['State'] != 'running':
            critical('Global service {service} has one or more tasks not running'.format(service=name))
            return
        node_index.discard(task['NodeID'])

    if len(node_index) > 0:
        critical('Global service {service} has {count} tasks not running'.format(service=name, count=len(node_list)))

    ok('Global service {service} OK'.format(service=name))


def process_replicated_service(name, replicas_desired):
    # Services that are not running are considered bad. This is to prevent services in crash loops from being ignored
    all_tasks = get_service_tasks(name)
    running_tasks = [task for task in all_tasks if task['Status']['State'] == 'running']
    num_tasks = len(running_tasks)
    if num_tasks != replicas_desired:
        critical('Replicated service {service} has {num_tasks} tasks, {replicas_desired} desired'.
                 format(service=name, num_tasks=num_tasks, replicas_desired=replicas_desired))
    else:
        ok('Replicated service {service} OK'.format(service=name))


def check_service(name, ignore_paused=False):
    # get service mode
    service_info, status = get_service_info(name)
    mode_info = service_info['Spec']['Mode']

    # if global ensure one per node
    if 'Global' in mode_info:
        process_global_service(name=name, ignore_paused=ignore_paused)
    # if replicated ensure sufficient number of replicas
    elif 'Replicated' in mode_info:
        process_replicated_service(name=name, replicas_desired=mode_info['Replicated']['Replicas'])


def process_url_status(status, ok_msg=None, critical_msg=None, unknown_msg=None):
    if status in HTTP_GOOD_CODES:
        return ok(ok_msg)
    elif status in [503, 404, 406]:
        return critical(critical_msg)
    else:
        return unknown(unknown_msg)


def process_args(args):
    parser = argparse.ArgumentParser(description='Check docker swarm.')

    # Connect to local socket or ip address
    connection_group = parser.add_mutually_exclusive_group()
    connection_group.add_argument('--connection',
                                  dest='connection',
                                  action='store',
                                  default=DEFAULT_SOCKET,
                                  type=str,
                                  metavar='[/<path to>/docker.socket|<ip/host address>:<port>]',
                                  help='Where to find docker daemon socket. (default: %(default)s)')

    connection_group.add_argument('--secure-connection',
                                  dest='secure_connection',
                                  action='store',
                                  type=str,
                                  metavar='[<ip/host address>:<port>]',
                                  help='Where to find TLS protected docker daemon socket.')

    # Connection timeout
    parser.add_argument('--timeout',
                        dest='timeout',
                        action='store',
                        type=float,
                        default=DEFAULT_TIMEOUT,
                        help='Connection timeout in seconds. (default: %(default)s)')

    swarm_group = parser.add_mutually_exclusive_group(required=True)

    # Swarm
    swarm_group.add_argument('--swarm',
                             dest='swarm',
                             default=None,
                             action='store_true',
                             help='Check swarm status')

    # Service
    swarm_group.add_argument('--service',
                             dest='service',
                             action='store',
                             type=str,
                             nargs='+',
                             default=[],
                             help='One or more RegEx that match the names of the services(s) to check.')

    swarm_group.add_argument('--ignore_paused',
                             dest='ignore_paused',
                             action='store_true',
                             help="Don't require global services to be running on paused nodes")

    parser.add_argument('-V', action='version', version='%(prog)s {}'.format(__version__))

    if len(args) == 0:
        parser.print_help()

    parsed_args = parser.parse_args(args=args)

    global timeout
    timeout = parsed_args.timeout

    global daemon
    global connection_type
    if parsed_args.secure_connection:
        daemon = 'https://' + parsed_args.secure_connection
        connection_type = 'https'
    elif parsed_args.connection:
        if parsed_args.connection[0] == '/':
            daemon = 'socket://' + parsed_args.connection + ':'
            connection_type = 'socket'
        else:
            daemon = 'http://' + parsed_args.connection
            connection_type = 'http'

    return parsed_args


def socketfile_permissions_failure(parsed_args):
    if connection_type == 'socket':
        return not (os.path.exists(parsed_args.connection)
                    and stat.S_ISSOCK(os.stat(parsed_args.connection).st_mode)
                    and os.access(parsed_args.connection, os.R_OK)
                    and os.access(parsed_args.connection, os.W_OK))
    else:
        return False


def print_results():
    print('; '.join(messages))


def perform_checks(raw_args):
    args = process_args(raw_args)
    if socketfile_permissions_failure(args):
        unknown("Cannot access docker socket file. User ID={}, socket file={}".format(os.getuid(), args.connection))
    else:
        # Here is where all the work happens
        #############################################################################################
        try:
            if args.swarm:
                check_swarm()
            elif args.service:
                services = get_services(args.service)

                # Status is set to critical by get_services() if nothing is found for a name
                for service in services:
                    check_service(name=service, ignore_paused=args.ignore_paused)

        except Exception as e:
            traceback.print_exc()
            unknown("Exception raised during check: {}".format(repr(e)))

    print_results()


def main():
    perform_checks(argv[1:])
    exit(rc)


if __name__ == '__main__':
    main()