stephensolis/kameris

View on GitHub
kameris/subcommands/summarize.py

Summary

Maintainability
A
0 mins
Test Coverage
from __future__ import (
    absolute_import, division, print_function, unicode_literals)

import base64
import collections
import json
import os
import re
from six import iteritems
import subprocess
from tabulate import tabulate

from ..job_steps._classifiers import classifier_names
from ..utils import fs_utils


all_classifiers = set(classifier_names)


def natural_sort_key(string):
    return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string)]


def run(args):
    accuracy_key = 'top{}'.format(args.top_n)

    def accuracy_for_classifier(classifier_results):
        return classifier_results[accuracy_key]['accuracy'] * 100

    run_stats = {}

    for run_name in os.listdir(args.job_dir):
        curr_path = os.path.join(args.job_dir, run_name)
        if not os.path.isdir(curr_path):
            continue

        base_run_name = re.sub('-k=[0-9]+', '', run_name)
        run_k = re.search('-k=([0-9]+)', run_name)
        if run_k is None:
            raise RuntimeError('Run name {} does not have a parameter k, '
                               'which is currently unsupported'
                               .format(run_name))
        run_k = int(run_k.group(1))

        if base_run_name not in run_stats:
            with open(os.path.join(curr_path, 'metadata.json')) as f:
                metadata = json.load(f)
                groups = (x['group'] for x in metadata)

            dists = [filename[15:-5]
                     for filename in os.listdir(curr_path)
                     if filename.startswith('classification-')
                     and filename.endswith('.json')
                     and os.path.isfile(os.path.join(curr_path, filename))]

            run_stats[base_run_name] = {
                'classes': collections.Counter(groups),
                'dists': dists,
                'classifier_counts': collections.defaultdict(int),
                'ks': set(),
                'best_classifier': {'accuracy': 0},
                'best_classifier_by_k': {
                    dist: collections.defaultdict(lambda: {'accuracy': 0})
                    for dist in dists
                }
            }

        curr_stats = run_stats[base_run_name]
        curr_stats['ks'].add(run_k)

        dist_results = {}
        for dist_name in curr_stats['dists']:
            with open(os.path.join(curr_path, 'classification-{}.json'
                                              .format(dist_name))) as f:
                dist_results[dist_name] = json.load(f)

        for dist_name, results in iteritems(dist_results):
            for classifier, classifier_results in iteritems(results):
                if accuracy_key not in classifier_results:
                    continue

                curr_stats['classifier_counts'][classifier] += 1
                accuracy = accuracy_for_classifier(classifier_results)

                if (accuracy > curr_stats['best_classifier']['accuracy'] or
                    (accuracy == curr_stats['best_classifier']['accuracy']
                     and run_k < curr_stats['best_classifier']['k'])):
                    curr_stats['best_classifier'] = {
                        'accuracy': accuracy,
                        'confusion_matrix':
                            classifier_results['confusion_matrix'],

                        'class_order': classifier_results['classes'],
                        'dist': dist_name,
                        'k': run_k,
                        'classifier': classifier,

                        'metadata_file': os.path.join(
                            curr_path, 'metadata.json'),
                        'classification_file': os.path.join(
                            curr_path,
                            'classification-{}.json'.format(dist_name)),
                        'mds_file': os.path.join(
                            curr_path, 'mds10-{}.json'.format(dist_name))
                    }

                    curr_stats['best_k_classifiers'] = {
                        curr_dist: {
                            curr_classifier: accuracy_for_classifier(
                                curr_classifier_results
                            )
                            for curr_classifier, curr_classifier_results
                            in iteritems(curr_results)
                            if accuracy_key in curr_classifier_results
                        }
                        for curr_dist, curr_results
                        in iteritems(dist_results)
                    }

                best_by_k_stats = curr_stats['best_classifier_by_k'][dist_name]
                if accuracy > best_by_k_stats[run_k]['accuracy']:
                    best_by_k_stats[run_k] = {
                        'accuracy': accuracy,
                        'classifier': classifier
                    }

    exp_names = sorted(run_stats.keys(), key=natural_sort_key)
    for exp_name in exp_names:
        curr_stats = run_stats[exp_name]
        best_stats = curr_stats['best_classifier']
        if len(curr_stats['classes']) <= args.top_n:
            continue

        print()
        print('Experiment:', exp_name)
        print()

        exp_classifiers = set(curr_stats['classifier_counts'].keys())
        always_classifiers = {
            name for name, count in iteritems(curr_stats['classifier_counts'])
            if count == len(curr_stats['ks'])*len(curr_stats['dists'])
        }
        print('These classifiers ran every time: [{}]'
              .format(', '.join(always_classifiers)))
        print('These classifiers ran sometimes but not always: [{}]'
              .format(', '.join(exp_classifiers - always_classifiers)))
        print('These classifiers did not run: [{}]'
              .format(', '.join(all_classifiers - exp_classifiers)))
        print()

        print('Classes:')
        for class_name in best_stats['class_order']:
            print('{} ({})'
                  .format(class_name, curr_stats['classes'][class_name]))
        print()

        print('Best accuracy: {accuracy:.2f}% (k={k}, {dist}, {classifier})'
              .format(**best_stats))
        print('Confusion matrix:')
        print(tabulate(
            best_stats['confusion_matrix']
        ))
        print()

        best_by_k = curr_stats['best_classifier_by_k']
        print('Best classifier by k:')
        print(tabulate(
            [[k] + [val for dist_name in curr_stats['dists']
                    for val in ([best_by_k[dist_name][k]['accuracy'],
                                 best_by_k[dist_name][k]['classifier']]
                                if 'classifier' in best_by_k[dist_name][k]
                                else ['N/A', 'N/A'])]
             for k in curr_stats['ks']],
            ['k'] + [header for dist_name in curr_stats['dists']
                     for header in [dist_name + '-accuracy',
                                    dist_name + '-classifier']],
            floatfmt='.2f'
        ))
        print()

        best_for_k = curr_stats['best_k_classifiers']
        best_classifiers = set(c for classifier_results in best_for_k.values()
                               for c in classifier_results.keys())
        print('Classifiers for k={}:'
              .format(best_stats['k']))
        print(tabulate(
            [[c] + [best_for_k[dist_name][c]
                    if c in best_for_k[dist_name] else 'N/A'
                    for dist_name in curr_stats['dists']]
             for c in best_classifiers],
            ['classifier'] + curr_stats['dists'],
            floatfmt='.2f'
        ))
        print()

        if args.plot_output_dir is not None:
            num_classes = len(curr_stats['classes'])
            if num_classes > 10:
                print('Warning: skipping plot generation because there are '
                      'too many classes ({} > 10)'.format(num_classes))
            else:
                base_output_path = os.path.join(
                    args.plot_output_dir, os.path.basename(args.job_dir)
                )
                fs_utils.mkdir_p(base_output_path)
                base_output_filename = os.path.join(
                    base_output_path, '{}-k={k}-{dist}-{classifier}'
                                      .format(exp_name, **best_stats)
                )
                subprocess.call(
                    'wolframscript "{}" {}'.format(
                        os.path.normpath(os.path.join(
                            os.path.dirname(__file__), '..', 'scripts',
                            'make_plots.wls'
                        )),
                        base64.b64encode(json.dumps({
                            'accuracy_type': accuracy_key,
                            'classifier_name': best_stats['classifier'],
                            'metadata_file': best_stats['metadata_file'],
                            'classification_file':
                                best_stats['classification_file'],
                            'mds_file': best_stats['mds_file'],
                            'output_file': base_output_filename + '-plots.nb',
                            'svg_output_file':
                                base_output_filename + '-plot2d.svg',
                            'png_output_file':
                                base_output_filename + '-plot2d.png'
                        }))
                    ),
                    shell=True
                )

        print('='*80)

    print()
    print('Experiment summary:')
    print(tabulate(
        [[exp_name, run_stats[exp_name]['best_classifier']['accuracy'],
          'k={k}, {dist}, {classifier}'
          .format(**run_stats[exp_name]['best_classifier'])]
         for exp_name in exp_names
         if len(run_stats[exp_name]['classes']) > args.top_n],
        ['experiment', 'best accuracy', 'run info'],
        floatfmt='.2f'
    ))
    print()