diff --git a/neat/db.py b/neat/db.py index 4823c82..718926f 100644 --- a/neat/db.py +++ b/neat/db.py @@ -220,3 +220,27 @@ class Database(object): for k, v in hosts.items()] self.connection.execute( self.host_states.insert(), to_insert) + + @contract + def select_host_states(self): + """ Select the current states of all the hosts. + + :return: A dict of host names to states. + :rtype: dict(str: int) + """ + hs1 = self.host_states + hs2 = self.host_states.alias() + sel = select([hs1.c.host_id, hs1.c.state], from_obj=[ + hs1.outerjoin(hs2, and_( + hs1.c.host_id == hs2.c.host_id, + hs1.c.id < hs2.c.id))]). \ + where(hs2.c.id == None) + data = dict(self.connection.execute(sel).fetchall()) + host_ids = self.select_host_ids() + host_states = {} + for host, id in host_ids.items(): + if id in data: + host_states[str(host)] = int(data[id]) + else: + host_states[str(host)] = 1 + return host_states diff --git a/tests/test_db.py b/tests/test_db.py index ad2dc1b..3574e8d 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -185,3 +185,22 @@ class Db(TestCase): lambda x: x[1] == hosts['host2'], result), key=lambda x: x[0])] self.assertEqual(host1, [0, 0, 1]) + + @qc(1) + def select_host_states( + hosts=dict_( + keys=str_(of='abc123', min_length=1, max_length=5), + values=list_(of=int_(min=0, max=1), + min_length=0, max_length=10), + min_length=0, max_length=3 + ) + ): + db = db_utils.init_db('sqlite:///:memory:') + res = {} + for host, data in hosts.items(): + db.update_host(host, 1, 1, 1) + for state in data: + db.insert_host_states({host: state}) + if data: + res[host] = data[-1] + assert db.select_host_states() == res