Skip to content

Commit

Permalink
API change: Rework the observer interface
Browse files Browse the repository at this point in the history
  • Loading branch information
pierre-dejoue committed Aug 12, 2023
1 parent 66b4eb0 commit 9cc042a
Show file tree
Hide file tree
Showing 11 changed files with 119 additions and 82 deletions.
10 changes: 5 additions & 5 deletions src/gui/src/grid_window.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
#include <utility>


GridWindow::LineEvent::LineEvent(picross::ObserverEvent event, const picross::Line* line, unsigned int misc, const ObserverGrid& grid)
GridWindow::LineEvent::LineEvent(picross::ObserverEvent event, const picross::Line* line, const picross::ObserverData& data, const ObserverGrid& grid)
: m_event(event)
, m_line_id()
, m_misc(misc)
, m_data(data)
, m_grid(grid)
{
if (line) { m_line_id = std::make_optional<picross::LineId>(*line); }
Expand Down Expand Up @@ -310,7 +310,7 @@ void GridWindow::reset_solutions()
text_buffer->buffer.appendf("Grid %s\n", picross::str_input_grid_size(grid).c_str());
}

void GridWindow::observer_callback(picross::ObserverEvent event, const picross::Line* line, unsigned int, unsigned int misc, const ObserverGrid& l_grid)
void GridWindow::observer_callback(picross::ObserverEvent event, const picross::Line* line, const picross::ObserverData& data, const ObserverGrid& l_grid)
{
// Filter out events useless to the GUI
if (event != picross::ObserverEvent::DELTA_LINE && event != picross::ObserverEvent::SOLVED_GRID && event != picross::ObserverEvent::PROGRESS)
Expand All @@ -329,7 +329,7 @@ void GridWindow::observer_callback(picross::ObserverEvent event, const picross::
|| this->abort_solver_thread();
});
}
line_events.emplace_back(event, line, misc, l_grid);
line_events.emplace_back(event, line, data, l_grid);
}

unsigned int GridWindow::process_line_events(std::vector<LineEvent>& events)
Expand All @@ -343,7 +343,7 @@ unsigned int GridWindow::process_line_events(std::vector<LineEvent>& events)
if (event.m_event == picross::ObserverEvent::PROGRESS)
{
// Only indicative
solver_progress = reinterpret_cast<const float&>(static_cast<const std::uint32_t&>(event.m_misc));
solver_progress = event.m_data.m_misc_f;
continue;
}
assert(event.m_event == picross::ObserverEvent::DELTA_LINE || event.m_event == picross::ObserverEvent::SOLVED_GRID);
Expand Down
6 changes: 3 additions & 3 deletions src/gui/src/grid_window.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@ class GridWindow : public GridObserver
public:
struct LineEvent
{
LineEvent(picross::ObserverEvent event, const picross::Line* line, unsigned int misc, const ObserverGrid& grid);
LineEvent(picross::ObserverEvent event, const picross::Line* line, const picross::ObserverData& data, const ObserverGrid& grid);

picross::ObserverEvent m_event;
std::optional<picross::LineId> m_line_id;
unsigned int m_misc;
picross::ObserverData m_data;
ObserverGrid m_grid;
};
public:
Expand All @@ -42,7 +42,7 @@ class GridWindow : public GridObserver

private:
void reset_solutions();
void observer_callback(picross::ObserverEvent event, const picross::Line* line, unsigned int, unsigned int, const ObserverGrid& grid) override;
void observer_callback(picross::ObserverEvent event, const picross::Line* line, const picross::ObserverData& data, const ObserverGrid& grid) override;
unsigned int process_line_events(std::vector<LineEvent>& events);
void solve_picross_grid();
void save_grid();
Expand Down
32 changes: 19 additions & 13 deletions src/picross/include/picross/picross_observer.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
******************************************************************************/
#pragma once

#include <cstdint>
#include <functional>
#include <ostream>
#include <string>
Expand All @@ -20,19 +21,19 @@ namespace picross
//
// The observer is a function object with the following signature:
//
// void observer(Event event, const Line* line, unsigned int depth, unsigned int misc);
// void observer(Event event, const Line* line, { uint32_t depth, uint32_t misc_i, float misc_f });
//
// * event = KNOWN_LINE line = known_line depth is set misc = nb_alternatives (before)
// * event = DELTA_LINE line = delta depth is set misc = nb_alternatives (after)
// * event = KNOWN_LINE line = known_line depth is set misc_i = nb_alternatives (before)
// * event = DELTA_LINE line = delta depth is set misc_i = nb_alternatives (after)
//
// A line of the output grid has been updated, the delta between the previous value of that line
// and the new one is given in event DELTA_LINE.
// The depth is set to zero initially, then it is the same value as that of the last BRANCHING event.
// The number of alternatives before (resp. after) the line solver are given in the event KNOWN_LINE
// (resp. DELTA_LINE).
//
// * event = BRANCHING line = known_line depth >= 0 misc = nb_alternatives (NODE)
// * event = BRANCHING line = nullptr depth > 0 misc = 0 (EDGE)
// * event = BRANCHING line = known_line depth >= 0 misc_i = nb_alternatives (NODE)
// * event = BRANCHING line = nullptr depth > 0 misc_i = 0 (EDGE)
//
// This event occurs when the algorithm is branching between several alternative solutions, or
// when it is going back to an earlier branch. Upon starting a new branch the depth is increased
Expand All @@ -43,31 +44,36 @@ namespace picross
// branching line and the number of alternatives, and the EDGE event each time an alternative
// of the branching line is being tested.
//
// * event = SOLVED_GRID line = nullptr depth is set misc = 0
// * event = SOLVED_GRID line = nullptr depth is set misc_i = 0
//
// A solution grid has been found. The sum of all the delta lines up until that stage is the solved grid.
//
// * event = INTERNAL_STATE line = nullptr depth is set misc = state
// * event = INTERNAL_STATE line = nullptr depth is set misc_i = internal state
//
// Internal state of the solver given as an integer
//
// * event = PROGRESS line = nullptr deptb is set misc = reinterpret_cast<uint32_t&>(progress);
// * event = PROGRESS line = nullptr depth is set misc_f = progress_ratio
//
// Progress is a float between 0.f and 1.f, whose binary representation is copied in the 'misc' integer.
//
enum class ObserverEvent {
enum class ObserverEvent
{
KNOWN_LINE,
DELTA_LINE,
BRANCHING,
SOLVED_GRID,
INTERNAL_STATE,
PROGRESS
};
struct ObserverData
{
std::uint32_t m_depth = 0u;
std::uint32_t m_misc_i = 0u;
float m_misc_f = 0.f;
};
class Line;
using Observer = std::function<void(ObserverEvent,const Line*,unsigned int,unsigned int)>;
using Observer = std::function<void(ObserverEvent,const Line*,const ObserverData&)>;

std::ostream& operator<<(std::ostream& out, ObserverEvent event);

std::string str_solver_internal_state(unsigned int internal_state);
std::string str_solver_internal_state(std::uint32_t internal_state);

} // namespace picross
2 changes: 1 addition & 1 deletion src/picross/src/solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ std::ostream& operator<<(std::ostream& out, ObserverEvent event)
return out;
}

std::string str_solver_internal_state(unsigned int internal_state)
std::string str_solver_internal_state(std::uint32_t internal_state)
{
std::stringstream ss;
ss << static_cast<WorkGridState>(internal_state);
Expand Down
81 changes: 57 additions & 24 deletions src/picross/src/work_grid.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,18 +66,13 @@ void update_line_range(LineRange& range, const std::vector<bool>& source)
assert(range.m_begin <= range.m_end);
}

template <typename RET_UINT>
RET_UINT progress_bar(const std::pair<float, float>& progress_bar, LineAlternatives::NbAlt progress, LineAlternatives::NbAlt nb_alternatives)
float progress_bar(const std::pair<float, float>& progress_bar, LineAlternatives::NbAlt progress, LineAlternatives::NbAlt nb_alternatives)
{
static_assert(std::is_integral_v<RET_UINT> && std::is_unsigned_v<RET_UINT>);
static_assert(std::numeric_limits<RET_UINT>::digits >= std::numeric_limits<std::uint32_t>::digits);
assert(progress <= nb_alternatives);
const float ratio_f = static_cast<float>(progress) / static_cast<float>(nb_alternatives);
assert(0.f <= ratio_f && ratio_f <= 1.f);
const auto progress_f = std::make_unique<float>(progress_bar.first + (progress_bar.second - progress_bar.first) * ratio_f);
// reinterpret as integer to go through the observer interface
const std::uint32_t ratio = *reinterpret_cast<const std::uint32_t*>(progress_f.get());
return static_cast<RET_UINT>(ratio);
const float ratio = static_cast<float>(progress) / static_cast<float>(nb_alternatives);
assert(0.f <= ratio && ratio <= 1.f);
const float progress_f = progress_bar.first + (progress_bar.second - progress_bar.first) * ratio;
return progress_f;
}

std::pair<float, float> nested_progress_bar(const std::pair<float, float>& progress_bar, LineAlternatives::NbAlt progress, LineAlternatives::NbAlt nb_alternatives)
Expand Down Expand Up @@ -323,7 +318,10 @@ Solver::Status WorkGrid<SolverPolicy>::line_solve(const Solver::SolutionFound& s
{
if (m_observer)
{
m_observer(ObserverEvent::INTERNAL_STATE, nullptr, m_branching_depth, static_cast<unsigned int>(m_state));
ObserverData data;
data.m_depth = m_branching_depth;
data.m_misc_i = static_cast<std::uint32_t>(m_state);
m_observer(ObserverEvent::INTERNAL_STATE, nullptr, data);
}

switch (m_state)
Expand Down Expand Up @@ -430,7 +428,10 @@ Solver::Status WorkGrid<SolverPolicy>::solve(const Solver::SolutionFound& soluti
assert(m_solver_policy.m_branching_allowed);
if (m_observer)
{
m_observer(ObserverEvent::INTERNAL_STATE, nullptr, m_branching_depth, static_cast<unsigned int>(m_state));
ObserverData data;
data.m_depth = m_branching_depth;
data.m_misc_i = static_cast<std::uint32_t>(m_state);
m_observer(ObserverEvent::INTERNAL_STATE, nullptr, data);
}

// Make a guess (branch search)
Expand Down Expand Up @@ -513,15 +514,21 @@ bool WorkGrid<SolverPolicy>::update_line(const LineSpan& line, unsigned int nb_a

if (m_observer)
{
ObserverData data;
data.m_depth = m_branching_depth;

if (line_changed)
{
m_observer(ObserverEvent::KNOWN_LINE, &observer_original_line, m_branching_depth, observer_original_nb_alt);
data.m_misc_i = static_cast<std::uint32_t>(observer_original_nb_alt);
m_observer(ObserverEvent::KNOWN_LINE, &observer_original_line, data);
const Line delta = get_line(line_type, line_index) - observer_original_line;
m_observer(ObserverEvent::DELTA_LINE, &delta, m_branching_depth, nb_alt);
data.m_misc_i = static_cast<std::uint32_t>(nb_alt);
m_observer(ObserverEvent::DELTA_LINE, &delta, data);
}
else
{
m_observer(ObserverEvent::KNOWN_LINE, &observer_original_line, m_branching_depth, nb_alt);
data.m_misc_i = static_cast<std::uint32_t>(nb_alt);
m_observer(ObserverEvent::KNOWN_LINE, &observer_original_line, data);
}
}

Expand Down Expand Up @@ -764,8 +771,10 @@ typename WorkGrid<SolverPolicy>::PassStatus WorkGrid<SolverPolicy>::full_grid_pa
{
if (m_observer)
{
ObserverData data;
data.m_depth = m_branching_depth;
const Line contradictory_line = line_from_line_span(get_line(*it));
m_observer(ObserverEvent::KNOWN_LINE, &contradictory_line, m_branching_depth, 0);
m_observer(ObserverEvent::KNOWN_LINE, &contradictory_line, data);
}
break;
}
Expand Down Expand Up @@ -867,7 +876,10 @@ typename WorkGrid<SolverPolicy>::ProbingResult WorkGrid<SolverPolicy>::probe(Lin
if (m_observer)
{
const auto line_known_tiles = line_from_line_span(known_tiles);
m_observer(ObserverEvent::BRANCHING, &line_known_tiles, m_branching_depth, nb_alt);
ObserverData data;
data.m_depth = m_branching_depth;
data.m_misc_i = static_cast<std::uint32_t>(nb_alt);
m_observer(ObserverEvent::BRANCHING, &line_known_tiles, data);
}

if (m_grid_stats != nullptr)
Expand All @@ -891,7 +903,9 @@ typename WorkGrid<SolverPolicy>::ProbingResult WorkGrid<SolverPolicy>::probe(Lin
probing_work_grid.configure(nested_solver_policy, WorkGridState::LINEAR_REDUCTION, nested_stats.get(), nested_progress.first, nested_progress.second);
if (m_observer)
{
m_observer(ObserverEvent::BRANCHING, nullptr, probing_work_grid.m_branching_depth, 0);
ObserverData data;
data.m_depth = probing_work_grid.m_branching_depth;
m_observer(ObserverEvent::BRANCHING, nullptr, data);
}

// Set one line in the new_grid according to the hypothesis we made. That line is then complete
Expand Down Expand Up @@ -940,7 +954,10 @@ typename WorkGrid<SolverPolicy>::ProbingResult WorkGrid<SolverPolicy>::probe(Lin
if (m_observer)
{
const auto line_known_tiles = line_from_line_span(known_tiles);
m_observer(ObserverEvent::BRANCHING, &line_known_tiles, m_branching_depth, nb_alt);
ObserverData data;
data.m_depth = m_branching_depth;
data.m_misc_i = static_cast<std::uint32_t>(nb_alt);
m_observer(ObserverEvent::BRANCHING, &line_known_tiles, data);
}

if (!reduced_grid)
Expand Down Expand Up @@ -990,8 +1007,11 @@ Solver::Status WorkGrid<SolverPolicy>::branch(const Solver::SolutionFound& solut

if (m_observer)
{
ObserverData data;
data.m_depth = m_branching_depth;
data.m_misc_i = static_cast<std::uint32_t>(nb_alt);
const auto line_known_tiles = line_from_line_span(known_tiles);
m_observer(ObserverEvent::BRANCHING, &line_known_tiles, m_branching_depth, nb_alt);
m_observer(ObserverEvent::BRANCHING, &line_known_tiles, data);
}

if (m_grid_stats != nullptr)
Expand Down Expand Up @@ -1019,7 +1039,9 @@ Solver::Status WorkGrid<SolverPolicy>::branch(const Solver::SolutionFound& solut
branching_work_grid.configure(nested_solver_policy, WorkGridState::LINEAR_REDUCTION, nested_stats.get(), nested_progress.first, nested_progress.second);
if (m_observer)
{
m_observer(ObserverEvent::BRANCHING, nullptr, branching_work_grid.m_branching_depth, 0);
ObserverData data;
data.m_depth = branching_work_grid.m_branching_depth;
m_observer(ObserverEvent::BRANCHING, nullptr, data);
}

// Set one line in the new_grid according to the hypothesis we made. That line is then complete
Expand Down Expand Up @@ -1050,7 +1072,10 @@ Solver::Status WorkGrid<SolverPolicy>::branch(const Solver::SolutionFound& solut
progress++;
if (m_observer)
{
m_observer(ObserverEvent::PROGRESS, nullptr, branching_work_grid.m_branching_depth, progress_bar<unsigned int>(m_progress_bar, progress, nb_alt));
ObserverData data;
data.m_depth = branching_work_grid.m_branching_depth;
data.m_misc_f = progress_bar(m_progress_bar, progress, nb_alt);
m_observer(ObserverEvent::PROGRESS, nullptr, data);
}

if (status == Solver::Status::ABORTED)
Expand All @@ -1063,8 +1088,11 @@ Solver::Status WorkGrid<SolverPolicy>::branch(const Solver::SolutionFound& solut
// Repeat start branching message
if (m_observer)
{
ObserverData data;
data.m_depth = m_branching_depth;
data.m_misc_i = static_cast<std::uint32_t>(nb_alt);
const auto line_known_tiles = line_from_line_span(known_tiles);
m_observer(ObserverEvent::BRANCHING, &line_known_tiles, m_branching_depth, nb_alt);
m_observer(ObserverEvent::BRANCHING, &line_known_tiles, data);
}

return flag_solution_found ? Solver::Status::OK : Solver::Status::CONTRADICTORY_GRID;
Expand All @@ -1088,7 +1116,12 @@ bool WorkGrid<SolverPolicy>::found_solution(const Solver::SolutionFound& solutio
assert(is_valid_solution());
const auto adjusted_branching_depth = m_branching_depth + m_probing_depth_incr;
if (m_grid_stats != nullptr) { m_grid_stats->nb_solutions++; }
if (m_observer) { m_observer(ObserverEvent::SOLVED_GRID, nullptr, adjusted_branching_depth, 0); }
if (m_observer)
{
ObserverData data;
data.m_depth = adjusted_branching_depth;
m_observer(ObserverEvent::SOLVED_GRID, nullptr, data);
}

// Shallow copy of only the grid data
return solution_found(Solver::Solution{ OutputGrid(*this), adjusted_branching_depth, FULL_SOLUTION });
Expand Down
2 changes: 1 addition & 1 deletion src/utils/include/utils/console_observer.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class ConsoleObserver final : public GridObserver
void verify_against_goal(const picross::OutputGrid& goal);

private:
void observer_callback(picross::ObserverEvent event, const picross::Line* line, unsigned int depth, unsigned int misc, const ObserverGrid& grid) override;
void observer_callback(picross::ObserverEvent event, const picross::Line* line, const picross::ObserverData& data, const ObserverGrid& grid) override;

private:
std::ostream& m_ostream;
Expand Down
2 changes: 1 addition & 1 deletion src/utils/include/utils/console_progress_observer.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class ConsoleProgressObserver final
public:
explicit ConsoleProgressObserver(std::ostream& ostream);

void operator()(picross::ObserverEvent event, const picross::Line* line, unsigned int depth, unsigned int misc);
void operator()(picross::ObserverEvent event, const picross::Line* line, const picross::ObserverData& data);

private:
std::ostream& m_ostream;
Expand Down
4 changes: 2 additions & 2 deletions src/utils/include/utils/grid_observer.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,13 @@ class GridObserver
explicit GridObserver(const picross::InputGrid& grid);
virtual ~GridObserver() = default;

void operator()(picross::ObserverEvent event, const picross::Line* line, unsigned int depth, unsigned int misc);
void operator()(picross::ObserverEvent event, const picross::Line* line, const picross::ObserverData& data);

protected:
void observer_clear();

private:
virtual void observer_callback(picross::ObserverEvent event, const picross::Line* line, unsigned int depth, unsigned int misc, const ObserverGrid& grid) = 0;
virtual void observer_callback(picross::ObserverEvent event, const picross::Line* line, const picross::ObserverData& data, const ObserverGrid& grid) = 0;

private:
std::vector<ObserverGrid> m_grids;
Expand Down
Loading

0 comments on commit 9cc042a

Please sign in to comment.