Skip to the content.

:warning: Beam Search
(heuristic/beam-search.hpp)

テンプレート

struct State {
  using score_t = double;

 private:
 public:
  score_t score = 0;
  int root_action = -1;
  State() {}

  bool is_finished() const {}
  void eval_score() { this->score = 0; }
  void advance(int action) {}
  vector<int> next_actions() const {}

  friend bool operator<(const State& lhs, const State& rhs) {
    return lhs.score < rhs.score;
  }
};

Depends on

Code

#pragma once
#include "heuristic/time-keeper.hpp"

template <class State>
int BeamSearch(const State& state, const int beam_width, int time_threshold) {
  auto time_keeper = TimeKeeper(time_threshold);
  priority_queue<State> now_beam;
  State best_state;
  now_beam.push(state);
  for (int t = 0;; t++) {
    priority_queue<State> next_beam;
    for (int i = 0; i < beam_width; i++) {
      if (time_keeper.is_time_over()) return best_state.root_action;
      if (now_beam.empty()) break;
      State now_state = now_beam.top();
      now_beam.pop();
      auto actions = now_state.next_actions();
      for (const auto& action : actions) {
        State next_state = now_state;
        next_state.advance(action);
        next_state.eval_score();
        if (t == 0) next_state.root_action = action;
        next_beam.push(next_state);
      }
    }
    now_beam = next_beam;
    best_state = now_beam.top();
    if (best_state.is_finished()) break;
  }
  return best_state.root_action;
}
/**
 * @brief Beam Search
 * @docs docs/heuristic/beam-search.md
 */
#line 2 "heuristic/time-keeper.hpp"

class TimeKeeper {
 private:
  chrono::high_resolution_clock::time_point start_time_;
  int time_threshold_;

 public:
  TimeKeeper(int time_threshold) : start_time_(chrono::high_resolution_clock::now()), time_threshold_(time_threshold) {}
  bool is_time_over() const {
    auto diff = chrono::high_resolution_clock::now() - this->start_time_;
    return chrono::duration_cast<chrono::milliseconds>(diff).count() >= time_threshold_;
  }
};
#line 3 "heuristic/beam-search.hpp"

template <class State>
int BeamSearch(const State& state, const int beam_width, int time_threshold) {
  auto time_keeper = TimeKeeper(time_threshold);
  priority_queue<State> now_beam;
  State best_state;
  now_beam.push(state);
  for (int t = 0;; t++) {
    priority_queue<State> next_beam;
    for (int i = 0; i < beam_width; i++) {
      if (time_keeper.is_time_over()) return best_state.root_action;
      if (now_beam.empty()) break;
      State now_state = now_beam.top();
      now_beam.pop();
      auto actions = now_state.next_actions();
      for (const auto& action : actions) {
        State next_state = now_state;
        next_state.advance(action);
        next_state.eval_score();
        if (t == 0) next_state.root_action = action;
        next_beam.push(next_state);
      }
    }
    now_beam = next_beam;
    best_state = now_beam.top();
    if (best_state.is_finished()) break;
  }
  return best_state.root_action;
}
/**
 * @brief Beam Search
 * @docs docs/heuristic/beam-search.md
 */
Back to top page