include/openjij/sampler/sa_sampler.hpp
// Copyright 2023 Jij Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "openjij/graph/all.hpp"
#include "openjij/updater/all.hpp"
#include "openjij/system/all.hpp"
namespace openjij {
namespace sampler {
//! @brief Class for executing simulated annealing.
//! @tparam ModelType The type of models.
template<class ModelType>
class SASampler {
//! @brief The value type.
using ValueType = typename ModelType::ValueType;
//! @brief The variable type
using VariableType = typename ModelType::VariableType;
public:
//! @brief Constructor for SASampler class.
//! @param model The model.
SASampler(const ModelType &model): model_(model) {}
//! @brief Set the number of sweeps.
//! @param num_sweeps The number of sweeps, which must be larger than zero.
void SetNumSweeps(const std::int32_t num_sweeps) {
if (num_sweeps <= 0) {
throw std::runtime_error("num_sweeps must be larger than zero.");
}
num_sweeps_ = num_sweeps;
}
//! @brief Set the number of samples.
//! @param num_reads The number of samples, which must be larger than zero.
void SetNumReads(const std::int32_t num_reads) {
if (num_reads <= 0) {
throw std::runtime_error("num_reads must be larger than zero.");
}
num_reads_ = num_reads;
}
//! @brief Set the number of threads in the calculation.
//! @param num_threads The number of threads in the calculation, which must be larger than zero.
void SetNumThreads(const std::int32_t num_threads) {
if (num_threads <= 0) {
throw std::runtime_error("num_threads must be non-negative integer.");
}
num_threads_ = num_threads;
}
//! @brief Set the minimum inverse temperature.
//! @param beta_min The minimum inverse temperature, which must be larger than zero.
void SetBetaMin(const ValueType beta_min) {
if (beta_min < 0) {
throw std::runtime_error("beta_min must be positive number");
}
beta_min_ = beta_min;
}
//! @brief Set the minimum inverse temperature.
//! @param beta_max The minimum inverse temperature, which must be larger than zero.
void SetBetaMax(const ValueType beta_max) {
if (beta_max < 0) {
throw std::runtime_error("beta_max must be positive number");
}
beta_max_ = beta_max;
}
//! @brief Set the minimum inverse temperature automatically.
void SetBetaMinAuto() {
beta_min_ = std::log(2.0)/model_.GetEstimatedMaxEnergyDifference();
}
//! @brief Set the maximum inverse temperature automatically.
void SetBetaMaxAuto() {
beta_max_ = std::log(100.0)/model_.GetEstimatedMinEnergyDifference();
}
//! @brief Set update method used in the state update.
//! @param update_method The update method.
void SetUpdateMethod(const algorithm::UpdateMethod update_method) {
update_method_ = update_method;
}
//! @brief Set random number engine for updating initializing state.
//! @param random_number_engine The random number engine.
void SetRandomNumberEngine(const algorithm::RandomNumberEngine random_number_engine) {
random_number_engine_ = random_number_engine;
}
//! @brief Set the cooling schedule.
//! @param schedule The cooling schedule.
void SetTemperatureSchedule(const utility::TemperatureSchedule schedule) {
schedule_ = schedule;
}
//! @brief Get the model.
//! @return The model.
const ModelType &GetModel() const {
return model_;
}
//! @brief Get the number of sweeps.
//! @return The number of sweeps.
std::int32_t GetNumSweeps() const {
return num_sweeps_;
}
//! @brief Get the number of reads.
//! @return The number of reads.
std::int32_t GetNumReads() const {
return num_reads_;
}
//! @brief Get the number of threads.
//! @return The number of threads.
std::int32_t GetNumThreads() const {
return num_threads_;
}
//! @brief Get the minimum inverse temperature.
//! @return The minimum inverse temperature.
ValueType GetBetaMin() const {
return beta_min_;
}
//! @brief Get the maximum inverse temperature.
//! @return The maximum inverse temperature.
ValueType GetBetaMax() const {
return beta_max_;
}
//! @brief Get the update method used in the state update.
//! @return The update method used in the state update.
algorithm::UpdateMethod GetUpdateMethod() const {
return update_method_;
}
//! @brief Get the random number engine for updating and initializing state.
//! @return The random number engine for updating and initializing state.
algorithm::RandomNumberEngine GetRandomNumberEngine() const {
return random_number_engine_;
}
//! @brief Get the temperature schedule.
//! @return The temperature schedule.
utility::TemperatureSchedule GetTemperatureSchedule() const {
return schedule_;
}
//! @brief Get the seed to be used in the calculation.
//! @return The seed.
std::uint64_t GetSeed() const {
return seed_;
}
const std::vector<typename ModelType::IndexType> &GetIndexList() const {
return model_.GetIndexList();
}
//! @brief Get the samples.
//! @return The samples.
const std::vector<std::vector<VariableType>> &GetSamples() const {
return samples_;
}
std::vector<ValueType> CalculateEnergies() const {
if (samples_.size() == 0) {
throw std::runtime_error("The sample size is zero. It seems that sampling has not been carried out.");
}
std::vector<ValueType> energies(num_reads_);
try {
#pragma omp parallel for schedule(guided) num_threads(num_threads_)
for (std::int32_t i = 0; i < num_reads_; ++i) {
energies[i] = model_.CalculateEnergy(samples_[i]);
}
}
catch (const std::exception &e) {
std::cerr << e.what() << std::endl;
}
return energies;
}
//! @brief Execute sampling.
//! Seed to be used in the calculation will be set automatically.
void Sample() {
Sample(std::random_device()());
}
//! @brief Execute sampling.
//! @param seed The seed to be used in the calculation.
void Sample(const std::uint64_t seed) {
seed_ = seed;
samples_.clear();
samples_.shrink_to_fit();
samples_.resize(num_reads_);
if (random_number_engine_ == algorithm::RandomNumberEngine::XORSHIFT) {
TemplateSampler<system::SASystem<ModelType, utility::Xorshift>, utility::Xorshift>();
}
else if (random_number_engine_ == algorithm::RandomNumberEngine::MT) {
TemplateSampler<system::SASystem<ModelType, std::mt19937>, std::mt19937>();
}
else if (random_number_engine_ == algorithm::RandomNumberEngine::MT_64) {
TemplateSampler<system::SASystem<ModelType, std::mt19937_64>, std::mt19937_64>();
}
else {
throw std::runtime_error("Unknown RandomNumberEngine");
}
}
private:
//! @brief The model.
const ModelType model_;
//! @brief The number of sweeps.
std::int32_t num_sweeps_ = 1000;
//! @brief The number of reads (samples).
std::int32_t num_reads_ = 1;
//! @brief The number of threads in the calculation.
std::int32_t num_threads_ = 1;
//! @brief The start inverse temperature.
ValueType beta_min_ = 1;
//! @brief The end inverse temperature.
ValueType beta_max_ = 1;
//! @brief The update method used in the state update.
algorithm::UpdateMethod update_method_ = algorithm::UpdateMethod::METROPOLIS;
//! @brief Random number engine for updating and initializing state.
algorithm::RandomNumberEngine random_number_engine_ = algorithm::RandomNumberEngine::XORSHIFT;
//! @brief Cooling schedule.
utility::TemperatureSchedule schedule_ = utility::TemperatureSchedule::GEOMETRIC;
//! @brief The seed to be used in the calculation.
std::uint64_t seed_ = std::random_device()();
//! @brief The samples.
std::vector<std::vector<VariableType>> samples_;
template<typename RandType>
std::vector<std::pair<typename RandType::result_type, typename RandType::result_type>>
GenerateSeedPairList(const typename RandType::result_type seed, const std::int32_t num_reads) const {
RandType random_number_engine(seed);
std::vector<std::pair<typename RandType::result_type, typename RandType::result_type>> seed_pair_list(num_reads);
for (std::int32_t i = 0; i < num_reads; ++i) {
seed_pair_list[i].first = random_number_engine();
seed_pair_list[i].second = random_number_engine();
}
return seed_pair_list;
}
template<class SystemType, class RandType>
void TemplateSampler() {
const auto seed_pair_list = GenerateSeedPairList<RandType>(static_cast<typename RandType::result_type>(seed_), num_reads_);
std::vector<ValueType> beta_list = utility::GenerateBetaList(schedule_, beta_min_, beta_max_, num_sweeps_);
#pragma omp parallel for schedule(guided) num_threads(num_threads_)
for (std::int32_t i = 0; i < num_reads_; ++i) {
auto system = SystemType{model_, seed_pair_list[i].first};
updater::SingleFlipUpdater<SystemType, RandType>(&system, num_sweeps_, beta_list, seed_pair_list[i].second, update_method_);
samples_[i] = system.ExtractSample();
}
}
};
template<class ModelType>
auto make_sa_sampler(const ModelType &model) {
return SASampler<ModelType>{model};
};
} //sampler
} //openjij