Beam Search
(heuristic/beam-search.hpp)
テンプレート
struct State {
using score_t = double;
private:
public:
score_t score = 0;
int root_action = -1;
State() {}
bool is_finished() const {}
void eval_score() { this->score = 0; }
void advance(int action) {}
vector<int> next_actions() const {}
friend bool operator<(const State& lhs, const State& rhs) {
return lhs.score < rhs.score;
}
};
Depends on
Code
#pragma once
#include "heuristic/time-keeper.hpp"
template <class State>
int BeamSearch(const State& state, const int beam_width, int time_threshold) {
auto time_keeper = TimeKeeper(time_threshold);
priority_queue<State> now_beam;
State best_state;
now_beam.push(state);
for (int t = 0;; t++) {
priority_queue<State> next_beam;
for (int i = 0; i < beam_width; i++) {
if (time_keeper.is_time_over()) return best_state.root_action;
if (now_beam.empty()) break;
State now_state = now_beam.top();
now_beam.pop();
auto actions = now_state.next_actions();
for (const auto& action : actions) {
State next_state = now_state;
next_state.advance(action);
next_state.eval_score();
if (t == 0) next_state.root_action = action;
next_beam.push(next_state);
}
}
now_beam = next_beam;
best_state = now_beam.top();
if (best_state.is_finished()) break;
}
return best_state.root_action;
}
/**
* @brief Beam Search
* @docs docs/heuristic/beam-search.md
*/
#line 2 "heuristic/time-keeper.hpp"
class TimeKeeper {
private:
chrono::high_resolution_clock::time_point start_time_;
int time_threshold_;
public:
TimeKeeper(int time_threshold) : start_time_(chrono::high_resolution_clock::now()), time_threshold_(time_threshold) {}
bool is_time_over() const {
auto diff = chrono::high_resolution_clock::now() - this->start_time_;
return chrono::duration_cast<chrono::milliseconds>(diff).count() >= time_threshold_;
}
};
#line 3 "heuristic/beam-search.hpp"
template <class State>
int BeamSearch(const State& state, const int beam_width, int time_threshold) {
auto time_keeper = TimeKeeper(time_threshold);
priority_queue<State> now_beam;
State best_state;
now_beam.push(state);
for (int t = 0;; t++) {
priority_queue<State> next_beam;
for (int i = 0; i < beam_width; i++) {
if (time_keeper.is_time_over()) return best_state.root_action;
if (now_beam.empty()) break;
State now_state = now_beam.top();
now_beam.pop();
auto actions = now_state.next_actions();
for (const auto& action : actions) {
State next_state = now_state;
next_state.advance(action);
next_state.eval_score();
if (t == 0) next_state.root_action = action;
next_beam.push(next_state);
}
}
now_beam = next_beam;
best_state = now_beam.top();
if (best_state.is_finished()) break;
}
return best_state.root_action;
}
/**
* @brief Beam Search
* @docs docs/heuristic/beam-search.md
*/
Back to top page