cool_django_auth_ldap/config.py
# Copyright (c) 2019, Artem Vasin
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# - Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# - Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""
This module contains classes that will be needed for configuration of LDAP
authentication. Unlike backend.py, this is safe to import into settings.py.
Please see the docstring on the backend module for more information, including
notes on naming conventions.
"""
import logging
import pprint
import ldap
import ldap.filter
from django.utils.tree import Node
# pylint: disable=missing-class-docstring
class ConfigurationWarning(UserWarning):
pass
class _LDAPConfig:
"""
A private class that loads and caches some global objects.
"""
logger = None
_ldap_configured = False
@classmethod
def get_ldap(cls, global_options=None):
"""
Returns the configured ldap module.
"""
# Apply global LDAP options once
if not cls._ldap_configured and global_options is not None:
for opt, value in global_options.items():
ldap.set_option(opt, value)
cls._ldap_configured = True
return ldap
@classmethod
def get_logger(cls):
"""
Initializes and returns our logger instance.
"""
if cls.logger is None:
cls.logger = logging.getLogger("cool_django_auth_ldap")
cls.logger.addHandler(logging.NullHandler())
return cls.logger
# Our global logger
logger = _LDAPConfig.get_logger()
class LDAPSearch:
"""
Public class that holds a set of LDAP search parameters. Objects of this
class should be considered immutable. Only the initialization method is
documented for configuration purposes. Internal clients may use the other
methods to refine and execute the search.
"""
def __init__(self, base_dn, scope, filterstr="(objectClass=*)", attrlist=None):
"""
These parameters are the same as the first three parameters to
ldap.search_s.
"""
self.base_dn = base_dn
self.scope = scope
self.filterstr = filterstr
self.attrlist = attrlist
self.ldap = _LDAPConfig.get_ldap()
def __repr__(self):
return "<{}: {}>".format(self.__class__.__name__, self.base_dn)
def search_with_additional_terms(self, term_dict, escape=True):
"""
Returns a new search object with additional search terms and-ed to the
filter string. term_dict maps attribute names to assertion values. If
you don't want the values escaped, pass escape=False.
"""
term_strings = [self.filterstr]
for name, value in term_dict.items():
if escape:
value = self.ldap.filter.escape_filter_chars(value)
term_strings.append("({}={})".format(name, value))
filterstr = "(&{})".format("".join(term_strings))
return self.__class__(
self.base_dn, self.scope, filterstr, attrlist=self.attrlist
)
def search_with_additional_term_string(self, filterstr):
"""
Returns a new search object with filterstr and-ed to the original filter
string. The caller is responsible for passing in a properly escaped
string.
"""
filterstr = "(&{}{})".format(self.filterstr, filterstr)
return self.__class__(
self.base_dn, self.scope, filterstr, attrlist=self.attrlist
)
def execute(self, connection, filterargs=(), escape=True):
"""
Executes the search on the given connection (an LDAPObject). filterargs
is an object that will be used for expansion of the filter string.
If escape is True, values in filterargs will be escaped.
The python-ldap library returns utf8-encoded strings. For the sake of
sanity, this method will decode all result strings and return them as
Unicode.
"""
if escape:
filterargs = self._escape_filterargs(filterargs)
try:
filterstr = self.filterstr % filterargs
results = connection.search_s(
self.base_dn, self.scope, filterstr, self.attrlist
)
except ldap.LDAPError as e:
results = []
logger.error(
"search_s('{}', {}, '{}') raised {}".format(
self.base_dn, self.scope, filterstr, pprint.pformat(e)
)
)
return self._process_results(results)
def _begin(self, connection, filterargs=(), escape=True):
"""
Begins an asynchronous search and returns the message id to retrieve
the results.
filterargs is an object that will be used for expansion of the filter
string. If escape is True, values in filterargs will be escaped.
"""
if escape:
filterargs = self._escape_filterargs(filterargs)
try:
filterstr = self.filterstr % filterargs
msgid = connection.search(
self.base_dn, self.scope, filterstr, self.attrlist
)
except ldap.LDAPError as e:
msgid = None
logger.error(
"search('{}', {}, '{}') raised {}".format(
self.base_dn, self.scope, filterstr, pprint.pformat(e)
)
)
return msgid
def _results(self, connection, msgid):
"""
Returns the result of a previous asynchronous query.
"""
try:
kind, results = connection.result(msgid)
if kind not in (ldap.RES_SEARCH_ENTRY, ldap.RES_SEARCH_RESULT):
results = []
except ldap.LDAPError as e:
results = []
logger.error("result({}) raised {}".format(msgid, pprint.pformat(e)))
return self._process_results(results)
def _escape_filterargs(self, filterargs):
"""
Escapes values in filterargs.
filterargs is a value suitable for Django's string formatting operator
(%), which means it's either a tuple or a dict. This return a new tuple
or dict with all values escaped for use in filter strings.
"""
if isinstance(filterargs, tuple):
filterargs = tuple(
self.ldap.filter.escape_filter_chars(value) for value in filterargs
)
elif isinstance(filterargs, dict):
filterargs = {
key: self.ldap.filter.escape_filter_chars(value)
for key, value in filterargs.items()
}
else:
raise TypeError("filterargs must be a tuple or dict.")
return filterargs
def _process_results(self, results):
"""
Returns a sanitized copy of raw LDAP results. This scrubs out
references, decodes utf8, normalizes DNs, etc.
"""
results = [r for r in results if r[0] is not None]
results = _DeepStringCoder("utf-8").decode(results)
# The normal form of a DN is lower case.
results = [(r[0].lower(), r[1]) for r in results]
result_dns = [result[0] for result in results]
logger.debug(
"search_s('{}', {}, '{}') returned {} objects: {}".format(
self.base_dn,
self.scope,
self.filterstr,
len(result_dns),
"; ".join(result_dns),
)
)
return results
class LDAPSearchUnion:
"""
A compound search object that returns the union of the results. Instantiate
it with one or more LDAPSearch objects.
"""
def __init__(self, *args):
self.searches = args
self.ldap = _LDAPConfig.get_ldap()
def search_with_additional_terms(self, term_dict, escape=True):
searches = [
s.search_with_additional_terms(term_dict, escape) for s in self.searches
]
return self.__class__(*searches)
def search_with_additional_term_string(self, filterstr):
searches = [
s.search_with_additional_term_string(filterstr) for s in self.searches
]
return self.__class__(*searches)
# pylint: disable=protected-access
def execute(self, connection, filterargs=()):
msgids = [search._begin(connection, filterargs) for search in self.searches]
results = {}
for search, msgid in zip(self.searches, msgids):
if msgid is not None:
result = search._results(connection, msgid)
results.update(dict(result))
return results.items()
class _DeepStringCoder:
"""
Encodes and decodes strings in a nested structure of lists, tuples, and
dicts. This is helpful when interacting with the Unicode-unaware
python-ldap.
"""
def __init__(self, encoding):
self.encoding = encoding
self.ldap = _LDAPConfig.get_ldap()
def decode(self, value):
"""Core decode function. Calls other decode functions based on value type"""
try:
if isinstance(value, bytes):
value = value.decode(self.encoding)
elif isinstance(value, list):
value = self._decode_list(value)
elif isinstance(value, tuple):
value = tuple(self._decode_list(value))
elif isinstance(value, dict):
value = self._decode_dict(value)
except UnicodeDecodeError:
pass
return value
def _decode_list(self, value):
return [self.decode(v) for v in value]
def _decode_dict(self, value):
# Attribute dictionaries should be case-insensitive. python-ldap
# defines this, although for some reason, it doesn't appear to use it
# for search results.
decoded = self.ldap.cidict.cidict()
for k, v in value.items():
decoded[self.decode(k)] = self.decode(v)
return decoded
class LDAPGroupType:
"""
This is an abstract base class for classes that determine LDAP group
membership. A group can mean many different things in LDAP, so we will need
a concrete subclass for each grouping mechanism. Clients may subclass this
if they have a group mechanism that is not handled by a built-in
implementation.
name_attr is the name of the LDAP attribute from which we will take the
Django group name.
Subclasses in this file must use self.ldap to access the python-ldap module.
This will be a mock object during unit tests.
"""
def __init__(self, name_attr="cn"):
self.name_attr = name_attr
self.ldap = _LDAPConfig.get_ldap()
# pylint: disable=unused-argument, no-self-use
def user_groups(self, ldap_user, group_search):
"""
Returns a list of group_info structures, each one a group to which
ldap_user belongs. group_search is an LDAPSearch object that returns all
of the groups that the user might belong to. Typical implementations
will apply additional filters to group_search and return the results of
the search. ldap_user represents the user and has the following three
properties:
dn: the distinguished name
attrs: a dictionary of LDAP attributes (with lists of values)
connection: an LDAPObject that has been bound with credentials
This is the primitive method in the API and must be implemented.
"""
return []
def is_member(self, ldap_user, group_dn):
"""
This method is an optimization for determining group membership without
loading all of the user's groups. Subclasses that are able to do this
may return True or False. ldap_user is as above. group_dn is the
distinguished name of the group in question.
The base implementation returns None, which means we don't have enough
information. The caller will have to call user_groups() instead and look
for group_dn in the results.
"""
return None
def group_name_from_info(self, group_info):
"""
Given the (DN, attrs) 2-tuple of an LDAP group, this returns the name of
the Django group. This may return None to indicate that a particular
LDAP group has no corresponding Django group.
The base implementation returns the value of the cn attribute, or
whichever attribute was given to __init__ in the name_attr
parameter.
"""
try:
name = group_info[1][self.name_attr][0]
except (KeyError, IndexError):
name = None
return name
class PosixGroupType(LDAPGroupType):
"""
An LDAPGroupType subclass that handles groups of class posixGroup.
"""
def user_groups(self, ldap_user, group_search):
"""
Searches for any group that is either the user's primary or contains the
user as a member.
"""
groups = []
try:
user_uid = ldap_user.attrs["uid"][0]
if "gidNumber" in ldap_user.attrs:
user_gid = ldap_user.attrs["gidNumber"][0]
filterstr = "(|(gidNumber={})(memberUid={}))".format(
self.ldap.filter.escape_filter_chars(user_gid),
self.ldap.filter.escape_filter_chars(user_uid),
)
else:
filterstr = "(memberUid={})".format(
self.ldap.filter.escape_filter_chars(user_uid)
)
search = group_search.search_with_additional_term_string(filterstr)
groups = search.execute(ldap_user.connection)
except (KeyError, IndexError):
pass
return groups
def is_member(self, ldap_user, group_dn):
"""
Returns True if the group is the user's primary group or if the user is
listed in the group's memberUid attribute.
"""
try:
user_uid = ldap_user.attrs["uid"][0]
try:
is_member = ldap_user.connection.compare_s(
group_dn, "memberUid", user_uid.encode()
)
except (ldap.UNDEFINED_TYPE, ldap.NO_SUCH_ATTRIBUTE):
is_member = False
if not is_member:
try:
user_gid = ldap_user.attrs["gidNumber"][0]
is_member = ldap_user.connection.compare_s(
group_dn, "gidNumber", user_gid.encode()
)
except (ldap.UNDEFINED_TYPE, ldap.NO_SUCH_ATTRIBUTE):
is_member = False
except (KeyError, IndexError):
is_member = False
return is_member
class MemberDNGroupType(LDAPGroupType):
"""
A group type that stores lists of members as distinguished names.
"""
def __init__(self, member_attr, name_attr="cn"):
"""
member_attr is the attribute on the group object that holds the list of
member DNs.
"""
self.member_attr = member_attr
super().__init__(name_attr)
def __repr__(self):
return "<{}: {}>".format(self.__class__.__name__, self.member_attr)
def user_groups(self, ldap_user, group_search):
search = group_search.search_with_additional_terms(
{self.member_attr: ldap_user.dn}
)
return search.execute(ldap_user.connection)
def is_member(self, ldap_user, group_dn):
try:
result = ldap_user.connection.compare_s(
group_dn, self.member_attr, ldap_user.dn.encode()
)
except (ldap.UNDEFINED_TYPE, ldap.NO_SUCH_ATTRIBUTE):
result = 0
return result
class NestedMemberDNGroupType(LDAPGroupType):
"""
A group type that stores lists of members as distinguished names and
supports nested groups. There is no shortcut for is_member in this case, so
it's left unimplemented.
"""
def __init__(self, member_attr, name_attr="cn"):
"""
member_attr is the attribute on the group object that holds the list of
member DNs.
"""
self.member_attr = member_attr
super().__init__(name_attr)
def user_groups(self, ldap_user, group_search):
"""
This searches for all of a user's groups from the bottom up. In other
words, it returns the groups that the user belongs to, the groups that
those groups belong to, etc. Circular references will be detected and
pruned.
"""
group_info_map = {} # Maps group_dn to group_info of groups we've found
member_dn_set = {ldap_user.dn} # Member DNs to search with next
handled_dn_set = set() # Member DNs that we've already searched with
while len(member_dn_set) > 0:
group_infos = self.find_groups_with_any_member(
member_dn_set, group_search, ldap_user.connection
)
new_group_info_map = {info[0]: info for info in group_infos}
group_info_map.update(new_group_info_map)
handled_dn_set.update(member_dn_set)
# Get ready for the next iteration. To avoid cycles, we make sure
# never to search with the same member DN twice.
member_dn_set = set(new_group_info_map.keys()) - handled_dn_set
return group_info_map.values()
def find_groups_with_any_member(self, member_dn_set, group_search, connection):
terms = [
"({}={})".format(self.member_attr, self.ldap.filter.escape_filter_chars(dn))
for dn in member_dn_set
]
filterstr = "(|{})".format("".join(terms))
search = group_search.search_with_additional_term_string(filterstr)
return search.execute(connection)
class GroupOfNamesType(MemberDNGroupType):
"""
An LDAPGroupType subclass that handles groups of class groupOfNames.
"""
def __init__(self, name_attr="cn"):
super().__init__("member", name_attr)
class NestedGroupOfNamesType(NestedMemberDNGroupType):
"""
An LDAPGroupType subclass that handles groups of class groupOfNames with
nested group references.
"""
def __init__(self, name_attr="cn"):
super().__init__("member", name_attr)
class GroupOfUniqueNamesType(MemberDNGroupType):
"""
An LDAPGroupType subclass that handles groups of class groupOfUniqueNames.
"""
def __init__(self, name_attr="cn"):
super().__init__("uniqueMember", name_attr)
class NestedGroupOfUniqueNamesType(NestedMemberDNGroupType):
"""
An LDAPGroupType subclass that handles groups of class groupOfUniqueNames
with nested group references.
"""
def __init__(self, name_attr="cn"):
super().__init__("uniqueMember", name_attr)
class ActiveDirectoryGroupType(MemberDNGroupType):
"""
An LDAPGroupType subclass that handles Active Directory groups.
"""
def __init__(self, name_attr="cn"):
super().__init__("member", name_attr)
class NestedActiveDirectoryGroupType(NestedMemberDNGroupType):
"""
An LDAPGroupType subclass that handles Active Directory groups with nested
group references.
"""
def __init__(self, name_attr="cn"):
super().__init__("member", name_attr)
class OrganizationalRoleGroupType(MemberDNGroupType):
"""
An LDAPGroupType subclass that handles groups of class organizationalRole.
"""
def __init__(self, name_attr="cn"):
super().__init__("roleOccupant", name_attr)
class NestedOrganizationalRoleGroupType(NestedMemberDNGroupType):
"""
An LDAPGroupType subclass that handles groups of class OrganizationalRoleGroupType
with nested group references.
"""
def __init__(self, name_attr="cn"):
super().__init__("roleOccupant", name_attr)
class LDAPGroupQuery(Node):
"""
Represents a compound query for group membership.
This can be used to construct an arbitrarily complex group membership query
with AND, OR, and NOT logical operators. Construct primitive queries with a
group DN as the only argument. These queries can then be combined with the
``&``, ``|``, and ``~`` operators.
:param str group_dn: The DN of a group to test for membership.
"""
# Connection types
AND = "AND"
OR = "OR"
default = AND
_CONNECTORS = [AND, OR]
def __init__(self, *args, **kwargs):
super().__init__(children=list(args) + list(kwargs.items()))
def __and__(self, other):
return self._combine(other, self.AND)
def __or__(self, other):
return self._combine(other, self.OR)
def __invert__(self):
obj = type(self)()
obj.add(self, self.AND)
obj.negate()
return obj
def _combine(self, other, conn):
if not isinstance(other, LDAPGroupQuery):
raise TypeError(other)
if conn not in self._CONNECTORS:
raise ValueError(conn)
obj = type(self)()
obj.connector = conn
obj.add(self, conn)
obj.add(other, conn)
return obj
def resolve(self, ldap_user, groups=None):
if groups is None:
groups = ldap_user._get_groups() # pylint: disable=protected-access
result = self.aggregator(self._resolve_children(ldap_user, groups))
if self.negated:
result = not result
return result
@property
def aggregator(self):
"""
Returns a function for aggregating a sequence of sub-results.
"""
if self.connector == self.AND:
aggregator = all
elif self.connector == self.OR:
aggregator = any
else:
raise ValueError(self.connector)
return aggregator
def _resolve_children(self, ldap_user, groups):
"""
Generates the query result for each child.
"""
for child in self.children:
if isinstance(child, LDAPGroupQuery):
yield child.resolve(ldap_user, groups)
else:
yield groups.is_member_of(child)