diff --git a/neat/db.py b/neat/db.py index 170624f..56df6b5 100644 --- a/neat/db.py +++ b/neat/db.py @@ -189,6 +189,20 @@ class Database(object): hosts_ram[hostname] = int(x[4]) return hosts_cpu_mhz, hosts_cpu_cores, hosts_ram + @contract + def select_host_id(self, hostname): + """ Select the ID of a host. + + :param hostname: A host name. + :type hostname: str + + :return: The ID of the host. + :rtype: int + """ + sel = select([self.hosts.c.id]). \ + where(self.hosts.c.hostname == hostname) + return int(self.connection.execute(sel).fetchone()['id']) + @contract def select_host_ids(self): """ Select the IDs of all the hosts. @@ -269,3 +283,17 @@ class Database(object): return [host for host, state in self.select_host_states().items() if state == 0] + + @contract + def insert_host_overload(self, hostname, overload): + """ Insert whether a host is overloaded. + + :param hostname: A host name. + :type hostname: str + + :param overload: Whether the host is overloaded. + :type overload: bool + """ + self.host_overload.insert().execute( + host_id=self.select_host_id(hostname), + overload=int(overload)) diff --git a/tests/test_db.py b/tests/test_db.py index e00206c..fe84d30 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -142,6 +142,22 @@ class Db(TestCase): {'host1': 4, 'host2': 8}, {'host1': 4000, 'host2': 8000}) + @qc(1) + def select_host_id(): + db = db_utils.init_db('sqlite:///:memory:') + host1_id = db.hosts.insert().execute( + hostname='host1', + cpu_mhz=1, + cpu_cores=1, + ram=1).inserted_primary_key[0] + host2_id = db.hosts.insert().execute( + hostname='host2', + cpu_mhz=1, + cpu_cores=1, + ram=1).inserted_primary_key[0] + assert db.select_host_id('host1') == host1_id + assert db.select_host_id('host2') == host2_id + @qc(1) def select_host_ids(): db = db_utils.init_db('sqlite:///:memory:') @@ -184,7 +200,7 @@ class Db(TestCase): host2 = [x[3] for x in sorted(filter( lambda x: x[1] == hosts['host2'], result), key=lambda x: x[0])] - self.assertEqual(host1, [0, 0, 1]) + self.assertEqual(host2, [1, 0, 1]) @qc(10) def select_host_states( @@ -244,3 +260,22 @@ class Db(TestCase): if data and data[-1] == 0: res.append(host) assert db.select_inactive_hosts() == res + + def test_insert_host_overload(self): + db = db_utils.init_db('sqlite:///:memory:') + hosts = {} + hosts['host1'] = db.update_host('host1', 1, 1, 1) + hosts['host2'] = db.update_host('host2', 1, 1, 1) + db.insert_host_overload('host2', False) + db.insert_host_overload('host1', True) + db.insert_host_overload('host1', False) + db.insert_host_overload('host2', True) + result = db.host_overload.select().execute().fetchall() + host1 = [x[3] for x in sorted(filter( + lambda x: x[1] == hosts['host1'], + result), key=lambda x: x[0])] + self.assertEqual(host1, [1, 0]) + host2 = [x[3] for x in sorted(filter( + lambda x: x[1] == hosts['host2'], + result), key=lambda x: x[0])] + self.assertEqual(host2, [0, 1])