melvin_ob/scheduling/score_grid.rs
1use std::fmt::Debug;
2
3/// A 2D grid structure to store integer scores, implemented as a flat array.
4#[derive(Debug, Clone)]
5pub struct ScoreGrid {
6 /// The length of the energy dimension (number of rows).
7 e_len: usize,
8 /// The length of the state dimension (number of columns).
9 s_len: usize,
10 /// A flattened array representing the grid's scores.
11 score: Box<[i32]>,
12}
13
14impl ScoreGrid {
15 /// The minimum score used to initialize unwanted final states
16 pub const MIN_SCORE: i32 = i32::MIN + 2;
17 /// Creates a new [`ScoreGrid`] with specified dimensions, initializing all values to `0`.
18 ///
19 /// # Arguments
20 /// * `e_len` - The length of the energy dimension (number of rows).
21 /// * `s_len` - The length of the state dimension (number of columns).
22 ///
23 /// # Returns
24 /// A [`ScoreGrid`] instance with all scores initialized to `0`.
25 pub fn new(e_len: usize, s_len: usize) -> Self {
26 Self { e_len, s_len, score: vec![0i32; e_len * s_len].into_boxed_slice() }
27 }
28
29 /// Creates a [`ScoreGrid`] and initializes scores based on the specified condition.
30 ///
31 /// # Arguments
32 /// * `e_len` - The length of the energy dimension (number of rows).
33 /// * `s_len` - The length of the state dimension (number of columns).
34 /// * `(end_s, end_min_e)` - A tuple indicating the start column (`end_s`)
35 /// and the minimum row (`end_min_e`) from which to initialize the scores.
36 ///
37 /// # Returns
38 /// A [`ScoreGrid`] with scores set to `i32::MIN` by default, and rows
39 /// starting at `(end_min_e, end_s)` initialized to `0`.
40 pub fn new_from_condition(
41 e_len: usize,
42 s_len: usize,
43 (end_s, end_min_e): (Option<usize>, usize),
44 ) -> Self {
45 let (end_st, step) = if let Some(s) = end_s { (s, s_len) } else { (0, 1) };
46 let mut min_score = vec![Self::MIN_SCORE; e_len * s_len].into_boxed_slice();
47 if end_st < s_len && end_min_e < e_len {
48 let start_ind = end_min_e * s_len + end_st;
49 let end_ind = s_len * e_len;
50 for i in (start_ind..end_ind).step_by(step) {
51 min_score[i] = 0;
52 }
53 }
54
55 Self { e_len, s_len, score: min_score }
56 }
57
58 /// Retrieves the score at a specific position in the grid.
59 ///
60 /// # Arguments
61 /// * `e` - The index along the energy dimension (row).
62 /// * `s` - The index along the state dimension (column).
63 ///
64 /// # Returns
65 /// The score at the specified position.
66 pub fn get(&self, e: usize, s: usize) -> i32 { self.score[e * self.s_len + s] }
67
68 /// Retrieves the state with the maximum score at a specific energy level.
69 ///
70 /// # Arguments
71 /// * `e` - The index along the energy dimension (row).
72 ///
73 /// # Returns
74 /// The score at the specified position.
75 pub fn get_max_s(&self, e: usize) -> usize {
76 (0..self.s_len).max_by_key(|&i| self.score[e * self.s_len + i]).unwrap()
77 }
78
79 /// Sets the score at a specific position in the grid.
80 ///
81 /// # Arguments
82 /// * `e` - The index along the energy dimension (row).
83 /// * `s` - The index along the state dimension (column).
84 /// * `score` - The value to set at the specified position.
85 pub fn set(&mut self, e: usize, s: usize, score: i32) {
86 self.score[e * self.s_len + s] = score;
87 }
88
89 /// Returns the length of the energy dimension (number of rows).
90 ///
91 /// # Returns
92 /// The length of the energy dimension (`e_len`).
93 pub fn e_len(&self) -> usize { self.e_len }
94
95 /// Returns the length of the state dimension (number of columns).
96 ///
97 /// # Returns
98 /// The length of the state dimension (`s_len`).
99 pub fn s_len(&self) -> usize { self.s_len }
100}