ionelmc/pytest-benchmark

View on GitHub
src/pytest_benchmark/utils.py

Summary

Maintainability
C
1 day
Test Coverage
import argparse
import json
import netrc
import os
import platform
import re
import subprocess
import sys
import types
from datetime import datetime
from datetime import timezone
from decimal import Decimal
from functools import partial
from os.path import basename
from os.path import dirname
from os.path import exists
from os.path import join
from os.path import split
from subprocess import CalledProcessError
from subprocess import check_output
from urllib.parse import parse_qs
from urllib.parse import urlparse

from .compat import PY38
from .compat import PY311

# This is here (in the utils module) because it might be used by
# various other modules.
try:
    from pathlib2 import Path
except ImportError:
    from pathlib import Path  # noqa: F401


TIME_UNITS = {'': 'Seconds', 'm': 'Milliseconds (ms)', 'u': 'Microseconds (us)', 'n': 'Nanoseconds (ns)'}
ALLOWED_COLUMNS = ['min', 'max', 'mean', 'stddev', 'median', 'iqr', 'ops', 'outliers', 'rounds', 'iterations']


class SecondsDecimal(Decimal):
    def __float__(self):
        return float(super().__str__())

    def __str__(self):
        return f'{format_time(float(super().__str__()))}s'

    @property
    def as_string(self):
        return super().__str__()


class NameWrapper:
    def __init__(self, target):
        self.target = target

    def __str__(self):
        name = self.target.__module__ + '.' if hasattr(self.target, '__module__') else ''
        name += self.target.__name__ if hasattr(self.target, '__name__') else repr(self.target)
        return name

    def __repr__(self):
        return 'NameWrapper(%s)' % repr(self.target)


def get_tag(project_name=None):
    info = get_commit_info(project_name)
    parts = [info['id'], get_current_time()]
    if info['dirty']:
        parts.append('uncommited-changes')
    return '_'.join(parts)


def get_machine_id():
    return '{}-{}-{}-{}'.format(
        platform.system(),
        platform.python_implementation(),
        '.'.join(platform.python_version_tuple()[:2]),
        platform.architecture()[0],
    )


class Fallback:
    def __init__(self, fallback, exceptions):
        self.fallback = fallback
        self.functions = []
        self.exceptions = exceptions

    def __call__(self, *args, **kwargs):
        for func in self.functions:
            try:
                value = func(*args, **kwargs)
            except self.exceptions:
                continue
            else:
                if value:
                    return value
        else:
            return self.fallback(*args, **kwargs)

    def register(self, other):
        self.functions.append(other)
        return self


@partial(Fallback, exceptions=(IndexError, CalledProcessError, OSError))
def get_project_name():
    return basename(os.getcwd())


@get_project_name.register
def get_project_name_git():
    is_git = check_output(['git', 'rev-parse', '--git-dir'], stderr=subprocess.STDOUT)
    if is_git:
        project_address = check_output(['git', 'config', '--local', 'remote.origin.url'])
        if isinstance(project_address, bytes) and str != bytes:
            project_address = project_address.decode()
        project_name = [i for i in re.split(r'[/:\s\\]|\.git', project_address) if i][-1]
        return project_name.strip()


@get_project_name.register
def get_project_name_hg():
    with open(os.devnull, 'w') as devnull:
        project_address = check_output(['hg', 'path', 'default'], stderr=devnull)
    project_address = project_address.decode()
    project_name = project_address.split('/')[-1]
    return project_name.strip()


def in_any_parent(name, path=None):
    prev = None
    if not path:
        path = os.getcwd()
    while path and prev != path and not exists(join(path, name)):
        prev = path
        path = dirname(path)
    return exists(join(path, name))


def subprocess_output(cmd):
    return check_output(cmd.split(), stderr=subprocess.STDOUT, universal_newlines=True).strip()


def get_commit_info(project_name=None):
    dirty = False
    commit = 'unversioned'
    commit_time = None
    author_time = None
    project_name = project_name or get_project_name()
    branch = '(unknown)'
    try:
        if in_any_parent('.git'):
            desc = subprocess_output('git describe --dirty --always --long --abbrev=40')
            desc = desc.split('-')
            if desc[-1].strip() == 'dirty':
                dirty = True
                desc.pop()
            commit = desc[-1].strip('g')
            commit_time = subprocess_output('git show -s --pretty=format:"%cI"').strip('"')
            author_time = subprocess_output('git show -s --pretty=format:"%aI"').strip('"')
            branch = subprocess_output('git rev-parse --abbrev-ref HEAD')
            if branch == 'HEAD':
                branch = '(detached head)'
        elif in_any_parent('.hg'):
            desc = subprocess_output('hg id --id --debug')
            if desc[-1] == '+':
                dirty = True
            commit = desc.strip('+')
            commit_time = subprocess_output('hg tip --template "{date|rfc3339date}"').strip('"')
            branch = subprocess_output('hg branch')
        return {
            'id': commit,
            'time': commit_time,
            'author_time': author_time,
            'dirty': dirty,
            'project': project_name,
            'branch': branch,
        }
    except Exception as exc:
        return {
            'id': 'unknown',
            'time': None,
            'author_time': None,
            'dirty': dirty,
            'error': f'CalledProcessError({exc.returncode}, {exc.output!r})' if isinstance(exc, CalledProcessError) else repr(exc),
            'project': project_name,
            'branch': branch,
        }


def get_current_time():
    return datetime.now(timezone.utc).strftime('%Y%m%d_%H%M%S')


def first_or_value(obj, value):
    if obj:
        (value,) = obj

    return value


def short_filename(path, machine_id=None):
    parts = []
    try:
        last = len(path.parts) - 1
    except AttributeError:
        return str(path)
    for pos, part in enumerate(path.parts):
        if not pos and part == machine_id:
            continue
        if pos == last:
            part = part.rsplit('.', 1)[0]
            # if len(part) > 16:
            #     part = "%.13s..." % part
        parts.append(part)
    return '/'.join(parts)


def load_timer(string):
    if '.' not in string:
        raise argparse.ArgumentTypeError("Value for --benchmark-timer must be in dotted form. Eg: 'module.attr'.")
    mod, attr = string.rsplit('.', 1)
    if mod == 'pep418':
        import time

        return NameWrapper(getattr(time, attr))
    else:
        __import__(mod)
        mod = sys.modules[mod]
        return NameWrapper(getattr(mod, attr))


class RegressionCheck:
    def __init__(self, field, threshold):
        self.field = field
        self.threshold = threshold

    def fails(self, current, compared):
        val = self.compute(current, compared)
        if val > self.threshold:
            return f'Field {self.field!r} has failed {self.__class__.__name__}: {val:.9f} > {self.threshold:.9f}'


class PercentageRegressionCheck(RegressionCheck):
    def compute(self, current, compared):
        val = compared[self.field]
        if not val:
            return float('inf')
        return current[self.field] / val * 100 - 100


class DifferenceRegressionCheck(RegressionCheck):
    def compute(self, current, compared):
        return current[self.field] - compared[self.field]


def parse_compare_fail(
    string,
    rex=re.compile(
        r'^(?P<field>min|max|mean|median|stddev|iqr):'
        r'((?P<percentage>[0-9]?[0-9])%|(?P<difference>[0-9]*\.?[0-9]+([eE][-+]?['
        r'0-9]+)?))$'
    ),
):
    m = rex.match(string)
    if m:
        g = m.groupdict()
        if g['percentage']:
            return PercentageRegressionCheck(g['field'], int(g['percentage']))
        elif g['difference']:
            return DifferenceRegressionCheck(g['field'], float(g['difference']))

    raise argparse.ArgumentTypeError('Could not parse value: %r.' % string)


def parse_warmup(string):
    string = string.lower().strip()
    if string == 'auto':
        return platform.python_implementation() == 'PyPy'
    elif string in ['off', 'false', 'no']:
        return False
    elif string in ['on', 'true', 'yes', '']:
        return True
    else:
        raise argparse.ArgumentTypeError('Could not parse value: %r.' % string)


def name_formatter_short(bench):
    name = bench['name']
    if bench['source']:
        name = '{} ({:.4})'.format(name, split(bench['source'])[-1])
    if name.startswith('test_'):
        name = name[5:]
    return name


def name_formatter_normal(bench):
    name = bench['name']
    if bench['source']:
        parts = bench['source'].split('/')
        parts[-1] = parts[-1][:12]
        name = '{} ({})'.format(name, '/'.join(parts))
    return name


def name_formatter_long(bench):
    if bench['source']:
        return '{fullname} ({source})'.format(**bench)
    else:
        return bench['fullname']


def name_formatter_trial(bench):
    if bench['source']:
        return '%.4s' % split(bench['source'])[-1]
    else:
        return '????'


NAME_FORMATTERS = {
    'short': name_formatter_short,
    'normal': name_formatter_normal,
    'long': name_formatter_long,
    'trial': name_formatter_trial,
}


def parse_name_format(string):
    string = string.lower().strip()
    if string in NAME_FORMATTERS:
        return string
    else:
        raise argparse.ArgumentTypeError('Could not parse value: %r.' % string)


def parse_timer(string):
    return str(load_timer(string))


def parse_sort(string):
    string = string.lower().strip()
    if string not in ('min', 'max', 'mean', 'stddev', 'name', 'fullname'):
        raise argparse.ArgumentTypeError(
            "Unacceptable value: %r. "
            "Value for --benchmark-sort must be one of: 'min', 'max', 'mean', "
            "'stddev', 'name', 'fullname'." % string
        )
    return string


def parse_columns(string):
    columns = [str.strip(s) for s in string.lower().split(',')]
    invalid = set(columns) - set(ALLOWED_COLUMNS)
    if invalid:
        # there are extra items in columns!
        msg = 'Invalid column name(s): %s. ' % ', '.join(invalid)
        msg += 'The only valid column names are: %s' % ', '.join(ALLOWED_COLUMNS)
        raise argparse.ArgumentTypeError(msg)
    return columns


def parse_rounds(string):
    try:
        value = int(string)
    except ValueError as exc:
        raise argparse.ArgumentTypeError(exc) from None
    else:
        if value < 1:
            raise argparse.ArgumentTypeError('Value for --benchmark-rounds must be at least 1.')
        return value


def parse_seconds(string):
    try:
        return SecondsDecimal(string).as_string
    except Exception as exc:
        raise argparse.ArgumentTypeError(f'Invalid decimal value {string!r}: {exc!r}') from None


def parse_save(string):
    if not string:
        raise argparse.ArgumentTypeError("Can't be empty.")
    illegal = ''.join(c for c in r'\/:*?<>|' if c in string)
    if illegal:
        raise argparse.ArgumentTypeError('Must not contain any of these characters: /:*?<>|\\ (it has %r)' % illegal)
    return string


def _parse_hosts(storage_url, netrc_file):
    # load creds from netrc file
    path = os.path.expanduser(netrc_file)
    creds = None
    if netrc_file and os.path.isfile(path):
        creds = netrc.netrc(path)

    # add creds to urls
    urls = []
    for netloc in storage_url.netloc.split(','):
        auth = ''
        if creds and '@' not in netloc:
            host = netloc.split(':').pop(0)
            res = creds.authenticators(host)
            if res:
                user, _, secret = res
                auth = f'{user}:{secret}@'
        url = f'{storage_url.scheme}://{auth}{netloc}'
        urls.append(url)
    return urls


def parse_elasticsearch_storage(string, default_index='benchmark', default_doctype='benchmark', netrc_file=''):
    storage_url = urlparse(string)
    hosts = _parse_hosts(storage_url, netrc_file)
    index = default_index
    doctype = default_doctype
    if storage_url.path and storage_url.path != '/':
        splitted = storage_url.path.strip('/').split('/')
        index = splitted[0]
        if len(splitted) >= 2:
            doctype = splitted[1]
    query = parse_qs(storage_url.query)
    try:
        project_name = query['project_name'][0]
    except KeyError:
        project_name = get_project_name()
    return hosts, index, doctype, project_name


def load_storage(storage, **kwargs):
    if '://' not in storage:
        storage = 'file://' + storage
    netrc_file = kwargs.pop('netrc')  # only used by elasticsearch storage
    if storage.startswith('file://'):
        from .storage.file import FileStorage

        return FileStorage(storage[len('file://') :], **kwargs)
    elif storage.startswith('elasticsearch+'):
        from .storage.elasticsearch import ElasticsearchStorage

        # TODO update benchmark_autosave
        args = parse_elasticsearch_storage(storage[len('elasticsearch+') :], netrc_file=netrc_file)
        return ElasticsearchStorage(*args, **kwargs)
    else:
        raise argparse.ArgumentTypeError('Storage must be in form of file://path or ' 'elasticsearch+http[s]://host1,host2/index/doctype')


def time_unit(value):
    if value < 1e-6:
        return 'n', 1e9
    elif value < 1e-3:
        return 'u', 1e6
    elif value < 1:
        return 'm', 1e3
    else:
        return '', 1.0


def operations_unit(value):
    if value > 1e6:
        return 'M', 1e-6
    if value > 1e3:
        return 'K', 1e-3
    return '', 1.0


def format_time(value):
    unit, adjustment = time_unit(value)
    return f'{value * adjustment:.2f}{unit:s}'


class cached_property:
    def __init__(self, func):
        self.__doc__ = func.__doc__
        self.func = func

    def __get__(self, obj, cls):
        if obj is None:
            return self
        value = obj.__dict__[self.func.__name__] = self.func(obj)
        return value


def funcname(f):
    try:
        if isinstance(f, partial):
            return f.func.__name__
        else:
            return f.__name__
    except AttributeError:
        return str(f)


# from: https://bitbucket.org/antocuni/pypytools/src/tip/pypytools/util.py?at=default
def clonefunc(f):
    """Deep clone the given function to create a new one.

    By default, the PyPy JIT specializes the assembler based on f.__code__:
    clonefunc makes sure that you will get a new function with a **different**
    __code__, so that PyPy will produce independent assembler. This is useful
    e.g. for benchmarks and microbenchmarks, so you can make sure to compare
    apples to apples.

    Use it with caution: if abused, this might easily produce an explosion of
    produced assembler.
    """
    # first of all, we clone the code object
    if not hasattr(f, '__code__'):
        return f
    co = f.__code__
    args = [
        co.co_argcount,
        co.co_kwonlyargcount,
        co.co_nlocals,
        co.co_stacksize,
        co.co_flags,
        co.co_code,
        co.co_consts,
        co.co_names,
        co.co_varnames,
        co.co_filename,
        co.co_name,
        co.co_firstlineno,
        co.co_lnotab,
        co.co_freevars,
        co.co_cellvars,
    ]
    if PY38:
        args.insert(1, co.co_posonlyargcount)
    if PY311:
        args.insert(12, co.co_qualname)
        args.insert(15, co.co_exceptiontable)
    co2 = types.CodeType(*args)
    #
    # then, we clone the function itself, using the new co2
    f2 = types.FunctionType(co2, f.__globals__, f.__name__, f.__defaults__, f.__closure__)
    return f2


def format_dict(obj):
    return '{%s}' % ', '.join(f'{k}: {json.dumps(v)}' for k, v in sorted(obj.items()))


class SafeJSONEncoder(json.JSONEncoder):
    def default(self, o):
        return 'UNSERIALIZABLE[%r]' % o


def safe_dumps(obj, **kwargs):
    return json.dumps(obj, cls=SafeJSONEncoder, **kwargs)


def report_progress(iterable, terminal_reporter, format_string, **kwargs):
    total = len(iterable)

    def progress_reporting_wrapper():
        for pos, item in enumerate(iterable):
            string = format_string.format(pos=pos + 1, total=total, value=item, **kwargs)
            terminal_reporter.rewrite(string, black=True, bold=True)
            yield string, item

    return progress_reporting_wrapper()


def report_noprogress(iterable, *args, **kwargs):
    for item in iterable:
        yield '', item


def report_online_progress(progress_reporter, tr, line):
    next(progress_reporter([line], tr, '{value}'))


def slugify(name):
    for c in r'\/:*?<>| ':
        name = name.replace(c, '_').replace('__', '_')
    return name


def get_cprofile_functions(stats):
    """
    Convert pstats structure to list of sorted dicts about each function.
    """
    result = []
    # this assumes that you run py.test from project root dir
    project_dir_parent = dirname(os.getcwd())

    for function_info, run_info in stats.stats.items():
        file_path = function_info[0]
        if file_path.startswith(project_dir_parent):
            file_path = file_path[len(project_dir_parent) :].lstrip('/')
        function_name = f'{file_path}:{function_info[1]}({function_info[2]})'

        # if the function is recursive write number of 'total calls/primitive calls'
        if run_info[0] == run_info[1]:
            calls = str(run_info[0])
        else:
            calls = f'{run_info[1]}/{run_info[0]}'

        result.append(
            {
                'ncalls_recursion': calls,
                'ncalls': run_info[1],
                'tottime': run_info[2],
                'tottime_per': run_info[2] / run_info[0] if run_info[0] > 0 else 0,
                'cumtime': run_info[3],
                'cumtime_per': run_info[3] / run_info[0] if run_info[0] > 0 else 0,
                'function_name': function_name,
            }
        )

    return result