NoNameItem/cool-django-auth-ldap

View on GitHub
cool_django_auth_ldap/config.py

Summary

Maintainability
A
1 hr
Test Coverage
B
86%
# Copyright (c) 2019, Artem Vasin
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# - Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# - Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

"""
This module contains classes that will be needed for configuration of LDAP
authentication. Unlike backend.py, this is safe to import into settings.py.
Please see the docstring on the backend module for more information, including
notes on naming conventions.
"""

import logging
import pprint

import ldap
import ldap.filter
from django.utils.tree import Node


# pylint: disable=missing-class-docstring
class ConfigurationWarning(UserWarning):
    pass


class _LDAPConfig:
    """
    A private class that loads and caches some global objects.
    """

    logger = None

    _ldap_configured = False

    @classmethod
    def get_ldap(cls, global_options=None):
        """
        Returns the configured ldap module.
        """
        # Apply global LDAP options once
        if not cls._ldap_configured and global_options is not None:
            for opt, value in global_options.items():
                ldap.set_option(opt, value)

            cls._ldap_configured = True

        return ldap

    @classmethod
    def get_logger(cls):
        """
        Initializes and returns our logger instance.
        """
        if cls.logger is None:
            cls.logger = logging.getLogger("cool_django_auth_ldap")
            cls.logger.addHandler(logging.NullHandler())

        return cls.logger


# Our global logger
logger = _LDAPConfig.get_logger()


class LDAPSearch:
    """
    Public class that holds a set of LDAP search parameters. Objects of this
    class should be considered immutable. Only the initialization method is
    documented for configuration purposes. Internal clients may use the other
    methods to refine and execute the search.
    """

    def __init__(self, base_dn, scope, filterstr="(objectClass=*)", attrlist=None):
        """
        These parameters are the same as the first three parameters to
        ldap.search_s.
        """
        self.base_dn = base_dn
        self.scope = scope
        self.filterstr = filterstr
        self.attrlist = attrlist
        self.ldap = _LDAPConfig.get_ldap()

    def __repr__(self):
        return "<{}: {}>".format(self.__class__.__name__, self.base_dn)

    def search_with_additional_terms(self, term_dict, escape=True):
        """
        Returns a new search object with additional search terms and-ed to the
        filter string. term_dict maps attribute names to assertion values. If
        you don't want the values escaped, pass escape=False.
        """
        term_strings = [self.filterstr]

        for name, value in term_dict.items():
            if escape:
                value = self.ldap.filter.escape_filter_chars(value)
            term_strings.append("({}={})".format(name, value))

        filterstr = "(&{})".format("".join(term_strings))

        return self.__class__(
            self.base_dn, self.scope, filterstr, attrlist=self.attrlist
        )

    def search_with_additional_term_string(self, filterstr):
        """
        Returns a new search object with filterstr and-ed to the original filter
        string. The caller is responsible for passing in a properly escaped
        string.
        """
        filterstr = "(&{}{})".format(self.filterstr, filterstr)

        return self.__class__(
            self.base_dn, self.scope, filterstr, attrlist=self.attrlist
        )

    def execute(self, connection, filterargs=(), escape=True):
        """
        Executes the search on the given connection (an LDAPObject). filterargs
        is an object that will be used for expansion of the filter string.
        If escape is True, values in filterargs will be escaped.

        The python-ldap library returns utf8-encoded strings. For the sake of
        sanity, this method will decode all result strings and return them as
        Unicode.
        """
        if escape:
            filterargs = self._escape_filterargs(filterargs)

        try:
            filterstr = self.filterstr % filterargs
            results = connection.search_s(
                self.base_dn, self.scope, filterstr, self.attrlist
            )
        except ldap.LDAPError as e:
            results = []
            logger.error(
                "search_s('{}', {}, '{}') raised {}".format(
                    self.base_dn, self.scope, filterstr, pprint.pformat(e)
                )
            )

        return self._process_results(results)

    def _begin(self, connection, filterargs=(), escape=True):
        """
        Begins an asynchronous search and returns the message id to retrieve
        the results.

        filterargs is an object that will be used for expansion of the filter
        string. If escape is True, values in filterargs will be escaped.

        """
        if escape:
            filterargs = self._escape_filterargs(filterargs)

        try:
            filterstr = self.filterstr % filterargs
            msgid = connection.search(
                self.base_dn, self.scope, filterstr, self.attrlist
            )
        except ldap.LDAPError as e:
            msgid = None
            logger.error(
                "search('{}', {}, '{}') raised {}".format(
                    self.base_dn, self.scope, filterstr, pprint.pformat(e)
                )
            )

        return msgid

    def _results(self, connection, msgid):
        """
        Returns the result of a previous asynchronous query.
        """
        try:
            kind, results = connection.result(msgid)
            if kind not in (ldap.RES_SEARCH_ENTRY, ldap.RES_SEARCH_RESULT):
                results = []
        except ldap.LDAPError as e:
            results = []
            logger.error("result({}) raised {}".format(msgid, pprint.pformat(e)))

        return self._process_results(results)

    def _escape_filterargs(self, filterargs):
        """
        Escapes values in filterargs.

        filterargs is a value suitable for Django's string formatting operator
        (%), which means it's either a tuple or a dict. This return a new tuple
        or dict with all values escaped for use in filter strings.

        """
        if isinstance(filterargs, tuple):
            filterargs = tuple(
                self.ldap.filter.escape_filter_chars(value) for value in filterargs
            )
        elif isinstance(filterargs, dict):
            filterargs = {
                key: self.ldap.filter.escape_filter_chars(value)
                for key, value in filterargs.items()
            }
        else:
            raise TypeError("filterargs must be a tuple or dict.")

        return filterargs

    def _process_results(self, results):
        """
        Returns a sanitized copy of raw LDAP results. This scrubs out
        references, decodes utf8, normalizes DNs, etc.
        """
        results = [r for r in results if r[0] is not None]
        results = _DeepStringCoder("utf-8").decode(results)

        # The normal form of a DN is lower case.
        results = [(r[0].lower(), r[1]) for r in results]

        result_dns = [result[0] for result in results]
        logger.debug(
            "search_s('{}', {}, '{}') returned {} objects: {}".format(
                self.base_dn,
                self.scope,
                self.filterstr,
                len(result_dns),
                "; ".join(result_dns),
            )
        )

        return results


class LDAPSearchUnion:
    """
    A compound search object that returns the union of the results. Instantiate
    it with one or more LDAPSearch objects.
    """

    def __init__(self, *args):
        self.searches = args
        self.ldap = _LDAPConfig.get_ldap()

    def search_with_additional_terms(self, term_dict, escape=True):
        searches = [
            s.search_with_additional_terms(term_dict, escape) for s in self.searches
        ]

        return self.__class__(*searches)

    def search_with_additional_term_string(self, filterstr):
        searches = [
            s.search_with_additional_term_string(filterstr) for s in self.searches
        ]

        return self.__class__(*searches)

    # pylint: disable=protected-access
    def execute(self, connection, filterargs=()):
        msgids = [search._begin(connection, filterargs) for search in self.searches]
        results = {}

        for search, msgid in zip(self.searches, msgids):
            if msgid is not None:
                result = search._results(connection, msgid)
                results.update(dict(result))

        return results.items()


class _DeepStringCoder:
    """
    Encodes and decodes strings in a nested structure of lists, tuples, and
    dicts. This is helpful when interacting with the Unicode-unaware
    python-ldap.
    """

    def __init__(self, encoding):
        self.encoding = encoding
        self.ldap = _LDAPConfig.get_ldap()

    def decode(self, value):
        """Core decode function. Calls other decode functions based on value type"""
        try:
            if isinstance(value, bytes):
                value = value.decode(self.encoding)
            elif isinstance(value, list):
                value = self._decode_list(value)
            elif isinstance(value, tuple):
                value = tuple(self._decode_list(value))
            elif isinstance(value, dict):
                value = self._decode_dict(value)
        except UnicodeDecodeError:
            pass

        return value

    def _decode_list(self, value):
        return [self.decode(v) for v in value]

    def _decode_dict(self, value):
        # Attribute dictionaries should be case-insensitive. python-ldap
        # defines this, although for some reason, it doesn't appear to use it
        # for search results.
        decoded = self.ldap.cidict.cidict()

        for k, v in value.items():
            decoded[self.decode(k)] = self.decode(v)

        return decoded


class LDAPGroupType:
    """
    This is an abstract base class for classes that determine LDAP group
    membership. A group can mean many different things in LDAP, so we will need
    a concrete subclass for each grouping mechanism. Clients may subclass this
    if they have a group mechanism that is not handled by a built-in
    implementation.

    name_attr is the name of the LDAP attribute from which we will take the
    Django group name.

    Subclasses in this file must use self.ldap to access the python-ldap module.
    This will be a mock object during unit tests.
    """

    def __init__(self, name_attr="cn"):
        self.name_attr = name_attr
        self.ldap = _LDAPConfig.get_ldap()

    # pylint: disable=unused-argument, no-self-use
    def user_groups(self, ldap_user, group_search):
        """
        Returns a list of group_info structures, each one a group to which
        ldap_user belongs. group_search is an LDAPSearch object that returns all
        of the groups that the user might belong to. Typical implementations
        will apply additional filters to group_search and return the results of
        the search. ldap_user represents the user and has the following three
        properties:

        dn: the distinguished name
        attrs: a dictionary of LDAP attributes (with lists of values)
        connection: an LDAPObject that has been bound with credentials

        This is the primitive method in the API and must be implemented.
        """
        return []

    def is_member(self, ldap_user, group_dn):
        """
        This method is an optimization for determining group membership without
        loading all of the user's groups. Subclasses that are able to do this
        may return True or False. ldap_user is as above. group_dn is the
        distinguished name of the group in question.

        The base implementation returns None, which means we don't have enough
        information. The caller will have to call user_groups() instead and look
        for group_dn in the results.
        """
        return None

    def group_name_from_info(self, group_info):
        """
        Given the (DN, attrs) 2-tuple of an LDAP group, this returns the name of
        the Django group. This may return None to indicate that a particular
        LDAP group has no corresponding Django group.

        The base implementation returns the value of the cn attribute, or
        whichever attribute was given to __init__ in the name_attr
        parameter.
        """
        try:
            name = group_info[1][self.name_attr][0]
        except (KeyError, IndexError):
            name = None

        return name


class PosixGroupType(LDAPGroupType):
    """
    An LDAPGroupType subclass that handles groups of class posixGroup.
    """

    def user_groups(self, ldap_user, group_search):
        """
        Searches for any group that is either the user's primary or contains the
        user as a member.
        """
        groups = []

        try:
            user_uid = ldap_user.attrs["uid"][0]

            if "gidNumber" in ldap_user.attrs:
                user_gid = ldap_user.attrs["gidNumber"][0]
                filterstr = "(|(gidNumber={})(memberUid={}))".format(
                    self.ldap.filter.escape_filter_chars(user_gid),
                    self.ldap.filter.escape_filter_chars(user_uid),
                )
            else:
                filterstr = "(memberUid={})".format(
                    self.ldap.filter.escape_filter_chars(user_uid)
                )

            search = group_search.search_with_additional_term_string(filterstr)
            groups = search.execute(ldap_user.connection)
        except (KeyError, IndexError):
            pass

        return groups

    def is_member(self, ldap_user, group_dn):
        """
        Returns True if the group is the user's primary group or if the user is
        listed in the group's memberUid attribute.
        """
        try:
            user_uid = ldap_user.attrs["uid"][0]

            try:
                is_member = ldap_user.connection.compare_s(
                    group_dn, "memberUid", user_uid.encode()
                )
            except (ldap.UNDEFINED_TYPE, ldap.NO_SUCH_ATTRIBUTE):
                is_member = False

            if not is_member:
                try:
                    user_gid = ldap_user.attrs["gidNumber"][0]
                    is_member = ldap_user.connection.compare_s(
                        group_dn, "gidNumber", user_gid.encode()
                    )
                except (ldap.UNDEFINED_TYPE, ldap.NO_SUCH_ATTRIBUTE):
                    is_member = False
        except (KeyError, IndexError):
            is_member = False

        return is_member


class MemberDNGroupType(LDAPGroupType):
    """
    A group type that stores lists of members as distinguished names.
    """

    def __init__(self, member_attr, name_attr="cn"):
        """
        member_attr is the attribute on the group object that holds the list of
        member DNs.
        """
        self.member_attr = member_attr

        super().__init__(name_attr)

    def __repr__(self):
        return "<{}: {}>".format(self.__class__.__name__, self.member_attr)

    def user_groups(self, ldap_user, group_search):
        search = group_search.search_with_additional_terms(
            {self.member_attr: ldap_user.dn}
        )
        return search.execute(ldap_user.connection)

    def is_member(self, ldap_user, group_dn):
        try:
            result = ldap_user.connection.compare_s(
                group_dn, self.member_attr, ldap_user.dn.encode()
            )
        except (ldap.UNDEFINED_TYPE, ldap.NO_SUCH_ATTRIBUTE):
            result = 0

        return result


class NestedMemberDNGroupType(LDAPGroupType):
    """
    A group type that stores lists of members as distinguished names and
    supports nested groups. There is no shortcut for is_member in this case, so
    it's left unimplemented.
    """

    def __init__(self, member_attr, name_attr="cn"):
        """
        member_attr is the attribute on the group object that holds the list of
        member DNs.
        """
        self.member_attr = member_attr

        super().__init__(name_attr)

    def user_groups(self, ldap_user, group_search):
        """
        This searches for all of a user's groups from the bottom up. In other
        words, it returns the groups that the user belongs to, the groups that
        those groups belong to, etc. Circular references will be detected and
        pruned.
        """
        group_info_map = {}  # Maps group_dn to group_info of groups we've found
        member_dn_set = {ldap_user.dn}  # Member DNs to search with next
        handled_dn_set = set()  # Member DNs that we've already searched with

        while len(member_dn_set) > 0:
            group_infos = self.find_groups_with_any_member(
                member_dn_set, group_search, ldap_user.connection
            )
            new_group_info_map = {info[0]: info for info in group_infos}
            group_info_map.update(new_group_info_map)
            handled_dn_set.update(member_dn_set)

            # Get ready for the next iteration. To avoid cycles, we make sure
            # never to search with the same member DN twice.
            member_dn_set = set(new_group_info_map.keys()) - handled_dn_set

        return group_info_map.values()

    def find_groups_with_any_member(self, member_dn_set, group_search, connection):
        terms = [
            "({}={})".format(self.member_attr, self.ldap.filter.escape_filter_chars(dn))
            for dn in member_dn_set
        ]

        filterstr = "(|{})".format("".join(terms))
        search = group_search.search_with_additional_term_string(filterstr)

        return search.execute(connection)


class GroupOfNamesType(MemberDNGroupType):
    """
    An LDAPGroupType subclass that handles groups of class groupOfNames.
    """

    def __init__(self, name_attr="cn"):
        super().__init__("member", name_attr)


class NestedGroupOfNamesType(NestedMemberDNGroupType):
    """
    An LDAPGroupType subclass that handles groups of class groupOfNames with
    nested group references.
    """

    def __init__(self, name_attr="cn"):
        super().__init__("member", name_attr)


class GroupOfUniqueNamesType(MemberDNGroupType):
    """
    An LDAPGroupType subclass that handles groups of class groupOfUniqueNames.
    """

    def __init__(self, name_attr="cn"):
        super().__init__("uniqueMember", name_attr)


class NestedGroupOfUniqueNamesType(NestedMemberDNGroupType):
    """
    An LDAPGroupType subclass that handles groups of class groupOfUniqueNames
    with nested group references.
    """

    def __init__(self, name_attr="cn"):
        super().__init__("uniqueMember", name_attr)


class ActiveDirectoryGroupType(MemberDNGroupType):
    """
    An LDAPGroupType subclass that handles Active Directory groups.
    """

    def __init__(self, name_attr="cn"):
        super().__init__("member", name_attr)


class NestedActiveDirectoryGroupType(NestedMemberDNGroupType):
    """
    An LDAPGroupType subclass that handles Active Directory groups with nested
    group references.
    """

    def __init__(self, name_attr="cn"):
        super().__init__("member", name_attr)


class OrganizationalRoleGroupType(MemberDNGroupType):
    """
    An LDAPGroupType subclass that handles groups of class organizationalRole.
    """

    def __init__(self, name_attr="cn"):
        super().__init__("roleOccupant", name_attr)


class NestedOrganizationalRoleGroupType(NestedMemberDNGroupType):
    """
    An LDAPGroupType subclass that handles groups of class OrganizationalRoleGroupType
    with nested group references.
    """

    def __init__(self, name_attr="cn"):
        super().__init__("roleOccupant", name_attr)


class LDAPGroupQuery(Node):
    """
    Represents a compound query for group membership.

    This can be used to construct an arbitrarily complex group membership query
    with AND, OR, and NOT logical operators. Construct primitive queries with a
    group DN as the only argument. These queries can then be combined with the
    ``&``, ``|``, and ``~`` operators.

    :param str group_dn: The DN of a group to test for membership.

    """

    # Connection types
    AND = "AND"
    OR = "OR"
    default = AND

    _CONNECTORS = [AND, OR]

    def __init__(self, *args, **kwargs):
        super().__init__(children=list(args) + list(kwargs.items()))

    def __and__(self, other):
        return self._combine(other, self.AND)

    def __or__(self, other):
        return self._combine(other, self.OR)

    def __invert__(self):
        obj = type(self)()
        obj.add(self, self.AND)
        obj.negate()

        return obj

    def _combine(self, other, conn):
        if not isinstance(other, LDAPGroupQuery):
            raise TypeError(other)
        if conn not in self._CONNECTORS:
            raise ValueError(conn)

        obj = type(self)()
        obj.connector = conn
        obj.add(self, conn)
        obj.add(other, conn)

        return obj

    def resolve(self, ldap_user, groups=None):
        if groups is None:
            groups = ldap_user._get_groups()  # pylint: disable=protected-access

        result = self.aggregator(self._resolve_children(ldap_user, groups))
        if self.negated:
            result = not result

        return result

    @property
    def aggregator(self):
        """
        Returns a function for aggregating a sequence of sub-results.
        """
        if self.connector == self.AND:
            aggregator = all
        elif self.connector == self.OR:
            aggregator = any
        else:
            raise ValueError(self.connector)

        return aggregator

    def _resolve_children(self, ldap_user, groups):
        """
        Generates the query result for each child.
        """
        for child in self.children:
            if isinstance(child, LDAPGroupQuery):
                yield child.resolve(ldap_user, groups)
            else:
                yield groups.is_member_of(child)