Found, fixed and tested simple circular connection

This commit is contained in:
Przemyslaw Kaminski 2015-04-24 12:39:05 +02:00
parent 87068e67ae
commit f9686f7c98
2 changed files with 31 additions and 5 deletions

View File

@ -19,9 +19,14 @@ class Connections(object):
if src not in emitter.args: if src not in emitter.args:
return return
# TODO: implement general circular detection, this one is simple
if [emitter.name, src] in CLIENTS.get(receiver.name, {}).get(dst, []):
raise Exception('Attempted to create cycle in dependencies. Not nice.')
CLIENTS.setdefault(emitter.name, {}) CLIENTS.setdefault(emitter.name, {})
CLIENTS[emitter.name].setdefault(src, []) CLIENTS[emitter.name].setdefault(src, [])
CLIENTS[emitter.name][src].append((receiver.name, dst)) if [receiver.name, dst] not in CLIENTS[emitter.name][src]:
CLIENTS[emitter.name][src].append([receiver.name, dst])
utils.save_to_config_file(CLIENTS_CONFIG_KEY, CLIENTS) utils.save_to_config_file(CLIENTS_CONFIG_KEY, CLIENTS)
@ -29,7 +34,7 @@ class Connections(object):
def remove(emitter, src, receiver, dst): def remove(emitter, src, receiver, dst):
CLIENTS[emitter.name][src] = [ CLIENTS[emitter.name][src] = [
destination for destination in CLIENTS[emitter.name][src] destination for destination in CLIENTS[emitter.name][src]
if destination != (receiver.name, dst) if destination != [receiver.name, dst]
] ]
utils.save_to_config_file(CLIENTS_CONFIG_KEY, CLIENTS) utils.save_to_config_file(CLIENTS_CONFIG_KEY, CLIENTS)
@ -45,8 +50,8 @@ class Connections(object):
for emitter_input, destinations in dest_dict.items(): for emitter_input, destinations in dest_dict.items():
for receiver_name, receiver_input in destinations: for receiver_name, receiver_input in destinations:
receiver = db.get_resource(receiver_name) receiver = db.get_resource(receiver_name)
receiver.args[receiver_input].subscribe( emitter.args[emitter_input].subscribe(
emitter.args[emitter_input]) receiver.args[receiver_input])
@staticmethod @staticmethod
def clear(): def clear():
@ -151,7 +156,7 @@ def assign_connections(receiver, connections):
mappings = defaultdict(list) mappings = defaultdict(list)
for key, dest in connections.iteritems(): for key, dest in connections.iteritems():
resource, r_key = dest.split('.') resource, r_key = dest.split('.')
mappings[resource].append((r_key, key)) mappings[resource].append([r_key, key])
for resource, r_mappings in mappings.iteritems(): for resource, r_mappings in mappings.iteritems():
connect(resource, receiver, r_mappings) connect(resource, receiver, r_mappings)

View File

@ -120,6 +120,27 @@ input:
sample1.update({'ip': '10.0.0.3'}) sample1.update({'ip': '10.0.0.3'})
self.assertEqual(sample2.args['ip'], sample.args['ip']) self.assertEqual(sample2.args['ip'], sample.args['ip'])
def test_circular_connection_prevention(self):
# TODO: more complex cases
sample_meta_dir = self.make_resource_meta("""
id: sample
handler: ansible
version: 1.0.0
input:
ip:
""")
sample1 = self.create_resource(
'sample1', sample_meta_dir, {'ip': '10.0.0.1'}
)
sample2 = self.create_resource(
'sample2', sample_meta_dir, {'ip': '10.0.0.2'}
)
xs.connect(sample1, sample2)
with self.assertRaises(Exception):
xs.connect(sample2, sample1)
class TestListInput(base.BaseResourceTest): class TestListInput(base.BaseResourceTest):
def test_list_input_single(self): def test_list_input_single(self):