diff --git a/neat/locals/overload/mhod/core.py b/neat/locals/overload/mhod/core.py index 58f02cb..be0e2c7 100644 --- a/neat/locals/overload/mhod/core.py +++ b/neat/locals/overload/mhod/core.py @@ -118,7 +118,7 @@ def mhod(state_config, otf, window_sizes, bruteforce_step, learning_steps, total_time = len(utilization) max_window_size = max(window_sizes) state_vector = build_state_vector(state_config, utilization) - state = current_state(state_vector) + current_state = get_current_state(state_vector) selected_windows = estimation.select_window( state['variances'], state['acceptable_variances'], @@ -144,13 +144,13 @@ def mhod(state_config, otf, window_sizes, bruteforce_step, learning_steps, state['acceptable_variances'], state['estimate_windows'], state['previous_state']) - state['previous_state'] = state + state['previous_state'] = current_state if len(utilization) >= learning_steps: state_history = utilization_to_states(state_config, utilization) time_in_states = total_time time_in_state_n = get_time_in_state_n(state_config, state_history) - tmp = set(p[state]) + tmp = set(p[current_state]) if len(tmp) != 1 or tmp[0] != 0: policy = bruteforce.optimize( step, 1.0, otf, (migration_time / time_step), ls, @@ -199,7 +199,7 @@ def utilization_to_state(state_config, utilization): @contract -def current_state(state_vector): +def get_current_state(state_vector): """ Get the current state corresponding to the state probability vector. :param state_vector: The state PMF vector. diff --git a/tests/locals/overload/mhod/test_core.py b/tests/locals/overload/mhod/test_core.py index 4095ec7..2dfb477 100644 --- a/tests/locals/overload/mhod/test_core.py +++ b/tests/locals/overload/mhod/test_core.py @@ -82,10 +82,10 @@ class Core(TestCase): self.assertEqual(c.build_state_vector(state_config, [0.0, 1.0]), [0, 0, 1]) - def test_current_state(self): - self.assertEqual(c.current_state([1, 0, 0]), 0) - self.assertEqual(c.current_state([0, 1, 0]), 1) - self.assertEqual(c.current_state([0, 0, 1]), 2) + def test_get_current_state(self): + self.assertEqual(c.get_current_state([1, 0, 0]), 0) + self.assertEqual(c.get_current_state([0, 1, 0]), 1) + self.assertEqual(c.get_current_state([0, 0, 1]), 2) def test_utilization_to_states(self): state_config = [0.4, 0.7]