OpenJij/OpenJij

View on GitHub
openjij/utils/benchmark.py

Summary

Maintainability
D
2 days
Test Coverage
B
84%
# Copyright 2023 Jij Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#     http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations
import inspect

from logging import getLogger

import numpy as np

logger = getLogger(__name__)


def solver_benchmark(
    solver,
    time_list,
    solutions=[],
    args={},
    p_r=0.99,
    ref_energy=0,
    measure_with_energy=False,
    time_name="execution_time",
):
    """Calculate 'success probability', 'TTS', 'Residual energy','Standard Error' with computation time

    Args:
        solver (callable): returns openjij.Response, and solver has arguments 'time' and '**args'
        time_list (list):
        solutions (list(list(int)), list(int)): true solution or list of solution (if solutions are degenerated).
        args (dict): Arguments for solver.
        p_r (float): Thereshold probability for time to solutions.
        ref_energy (float): The ground (reference to calculate success probability and the residual energy) energy.
        measure_with_energy (bool): use a energy as measure for success
    Returns:
        dict: dictionary which has the following keys:

        * **time**: list of compuation time
        * **success_prob** list of success probability at each computation time
        * **tts**: list of time to solusion at each computation time
        * **residual_energy**: list of residual energy at each computation time
        * **se_lower_tts**: list of tts's lower standard error at each computation time
        * **se_upper_tts**: list of tts's upper standard error at each computation time
        * **se_success_prob**: list of success probability's standard error at each computation time
        * **se_residual_energy**: list of residual_energy's standard error at each computation time
        * **info** (dict): Parameter information for the benchmark
    """

    if not measure_with_energy:
        if solutions == []:
            raise ValueError("need input 'solutions': (list(list))")

    logger.info("function " + inspect.currentframe().f_code.co_name + " start")

    computation_times = []
    success_probabilities = []
    tts_list = []
    residual_energies = []

    se_lower_tts_list = []
    se_upper_tts_list = []
    se_success_prob_list = []
    se_residual_energy_list = []

    for time in time_list:
        response = solver(time, **args)

        comp_time = response.info[time_name]
        computation_times.append(comp_time)

        ps = success_probability(response, solutions, ref_energy, measure_with_energy)
        tts = time_to_solution(ps, comp_time, p_r)

        success_probabilities.append(ps)
        tts_list.append(tts)
        residual_energies.append(residual_energy(response, ref_energy))

        se_ps = se_success_probability(
            response, solutions, ref_energy, measure_with_energy
        )

        se_success_prob_list.append(se_ps)
        se_lower_tts_list.append(se_lower_tts(tts, ps, comp_time, p_r, se_ps))
        se_upper_tts_list.append(se_upper_tts(tts, ps, comp_time, p_r, se_ps))
        se_residual_energy_list.append(se_residual_energy(response, ref_energy))

    return {
        "time": computation_times,
        "success_prob": success_probabilities,
        "tts": tts_list,
        "residual_energy": residual_energies,
        "se_lower_tts": se_lower_tts_list,
        "se_upper_tts": se_upper_tts_list,
        "se_success_prob": se_success_prob_list,
        "se_residual_energy": se_residual_energy_list,
        "info": {
            "tts_threshold_prob": p_r,
            "ref_energy": ref_energy,
            "measure_with_energy": measure_with_energy,
        },
    }


def residual_energy(response, ref_energy):
    """Calculate redisual energy from measure energy

    Args:
        response (openjij.Response): response from solver (or sampler).
        ref_energy (float): the reference energy (usually use the ground energy)
    Returns:
        float: Residual energy which is defined as :math:`\\langle E \\rangle - E_0` (:math:`\\langle...\\rangle` represents average, :math:`E_0` is the reference energy (usually use the ground energy)).
    """
    return np.mean(response.energies) - ref_energy


def se_residual_energy(response, ref_energy):
    """Calculate redisual energy's standard error from measure energy

    Args:
        response (openjij.Response): response from solver (or sampler).
        ref_energy (float): the reference energy (usually use the ground energy)
    Returns:
        float: redisual energy's standard error from measure energy
    """
    return np.std(response.energies, ddof=1)


def success_probability(response, solutions, ref_energy=0, measure_with_energy=False):
    """Calculate success probability from openjij.response

    Args:
        response (openjij.Response): response from solver (or sampler).
        solutions (list[int]): true solutions.
    Returns:
        float: Success probability.

        * When measure_with_energy is False, success is defined as getting the same state as solutions.
        * When measure_with_energy is True, success is defined as getting a state which energy is below reference energy
    """

    if measure_with_energy:
        suc_prob = np.count_nonzero(np.array(response.energies) <= ref_energy) / len(
            response.energies
        )
    else:
        if isinstance(solutions[0], dict):
            sampled_states = response.samples()
            suc_prob = np.mean(
                [1 if dict(state) in solutions else 0 for state in sampled_states]
            )
        else:
            sampled_states = response.states
            suc_prob = np.mean(
                [1 if list(state) in solutions else 0 for state in sampled_states]
            )

    return suc_prob


def se_success_probability(
    response, solutions, ref_energy=0, measure_with_energy=False
):
    """Calculate success probability's standard error from openjij.response

    Args:
        response (openjij.Response): response from solver (or sampler).
        solutions (list[int]): true solutions.
    Returns:
        float: Success probability's standard error.

        * When measure_with_energy is False, success is defined as getting the same state as solutions.
        * When measure_with_energy is True, success is defined as getting a state which energy is below reference energy
    """

    if measure_with_energy:
        se_suc_prob = np.sqrt(
            np.count_nonzero(np.array(response.energies) <= ref_energy)
            / (len(response.energies) - 1)
        )
    else:
        if isinstance(solutions[0], dict):
            sampled_states = response.samples()
            se_suc_prob = np.std(
                [1 if dict(state) in solutions else 0 for state in sampled_states]
            )
        else:
            sampled_states = response.states
            se_suc_prob = np.std(
                [1 if list(state) in solutions else 0 for state in sampled_states]
            )

    return se_suc_prob


def time_to_solution(success_prob, computation_time, p_r):
    """
    Args:

        success_prob (float): success probability.
        computation_time (float):
        p_r (float): thereshold probability to calculate time to solution.
    Returns:
        float: time to solution :math:`\\tau * \\log(1-pr)/\\log(1-ps)` which pr is thereshold probability, ps is success probability and :math:`tau` is computation time.
    """

    if success_prob == 1.0:
        tts = 0.0
    elif success_prob == 0.0:
        tts = np.inf
    else:
        tts = computation_time * np.log(1 - p_r) / np.log(1 - success_prob)

    return tts


def se_lower_tts(tts, success_prob, computation_time, p_r, se_success_prob):
    """
    Args:

        success_prob (float): success probability.
        computation_time (float):
        p_r (float): thereshold probability to calculate time to solution.
    Returns:
        float: time to solution :math:`\\tau * \\log(1-pr)/\\log(1-ps)` 's standard error which pr is thereshold probability, ps is success probability and :math:`tau` is computation time.
    """

    if 1 - (success_prob + se_success_prob) <= 0.0:
        tts_low_error = 0.0
    elif success_prob == 0.0:
        tts_low_error = 0.0
    else:
        tts_low_error = (
            computation_time
            * np.log(1 - p_r)
            / np.log(1 - (success_prob + se_success_prob))
        )

    se_lower_tts = abs(tts_low_error - tts)

    return se_lower_tts


def se_upper_tts(tts, success_prob, computation_time, p_r, se_success_prob):
    """
    Args:

        success_prob (float): success probability.
        computation_time (float):
        p_r (float): thereshold probability to calculate time to solution.
    Returens:
        float: time to solution :math:`\\tau * \\log(1-pr)/\\log(1-ps)` 's standard error which pr is thereshold probability, ps is success probability and :math:`tau` is computation time.
    """

    if success_prob == 1.0:
        tts_up_error = 0.0
    elif success_prob == 0.0:
        tts_up_error = 0.0
    else:
        tts_up_error = (
            computation_time
            * np.log(1 - p_r)
            / np.log(1 - (success_prob - se_success_prob))
        )

    se_upper_tts = abs(tts_up_error - tts)

    return se_upper_tts