diff --git a/libra/worker/controller.py b/libra/worker/controller.py index ffdcc14c..7589e924 100644 --- a/libra/worker/controller.py +++ b/libra/worker/controller.py @@ -24,6 +24,7 @@ class LBaaSController(object): RESPONSE_SUCCESS = "PASS" ACTION_FIELD = 'hpcs_action' RESPONSE_FIELD = 'hpcs_response' + LBLIST_FIELD = 'loadbalancers' def __init__(self, logger, driver, json_msg): self.logger = logger @@ -61,7 +62,15 @@ class LBaaSController(object): return self.msg def _action_create(self): - """ Create a Load Balancer. """ + """ + Create a Load Balancer. + + This is the only method (so far) that actually parses the contents + of the JSON message (other than the ACTION_FIELD field). Modifying + the JSON message structure likely means this method will need to + be modified, unless the change involves fields that are ignored. + """ + try: self.driver.init() except NotImplementedError: @@ -71,78 +80,96 @@ class LBaaSController(object): self.msg[self.RESPONSE_FIELD] = self.RESPONSE_FAILURE return self.msg - if 'nodes' not in self.msg: - return BadRequest("Missing 'nodes' element").to_json() + if self.LBLIST_FIELD not in self.msg: + return BadRequest( + "Missing '%s' element" % self.LBLIST_FIELD + ).to_json() - if 'protocol' in self.msg: - port = None - if 'port' in self.msg: - port = self.msg['port'] - try: - self.driver.set_protocol(self.msg['protocol'], port) - except NotImplementedError: - self.logger.error( - "Selected driver does not support setting protocol." - ) - self.msg[self.RESPONSE_FIELD] = self.RESPONSE_FAILURE - return self.msg - except Exception as e: - self.logger.error("Failure trying to set protocol: %s, %s" % - (e.__class__, e)) - self.msg[self.RESPONSE_FIELD] = self.RESPONSE_FAILURE - return self.msg + lb_list = self.msg['loadbalancers'] - if 'algorithm' in self.msg: - algo = self.msg['algorithm'].upper() - if algo == 'ROUND_ROBIN': - algo = LoadBalancerDriver.ROUNDROBIN - elif algo == 'LEAST_CONNECTIONS': - algo = LoadBalancerDriver.LEASTCONN + for current_lb in lb_list: + if 'nodes' not in current_lb: + return BadRequest("Missing 'nodes' element").to_json() + + if 'protocol' not in current_lb: + return BadRequest( + "Missing required 'protocol' value." + ).to_json() else: - self.logger.error("Invalid algorithm: %s" % algo) - self.msg[self.RESPONSE_FIELD] = self.RESPONSE_FAILURE - return self.msg + port = None + if 'port' in current_lb: + port = current_lb['port'] + try: + self.driver.add_protocol(current_lb['protocol'], port) + except NotImplementedError: + self.logger.error( + "Selected driver does not support setting protocol." + ) + self.msg[self.RESPONSE_FIELD] = self.RESPONSE_FAILURE + return self.msg + except Exception as e: + self.logger.error( + "Failure trying to set protocol: %s, %s" % + (e.__class__, e) + ) + self.msg[self.RESPONSE_FIELD] = self.RESPONSE_FAILURE + return self.msg - try: - self.driver.set_algorithm(algo) - except NotImplementedError: - self.logger.error( - "Selected driver does not support setting algorithm." - ) - self.msg[self.RESPONSE_FIELD] = self.RESPONSE_FAILURE - return self.msg - except Exception as e: - self.logger.error("Selected driver failed setting algorithm.") - self.msg[self.RESPONSE_FIELD] = self.RESPONSE_FAILURE - return self.msg + if 'algorithm' in current_lb: + algo = current_lb['algorithm'].upper() + if algo == 'ROUND_ROBIN': + algo = LoadBalancerDriver.ROUNDROBIN + elif algo == 'LEAST_CONNECTIONS': + algo = LoadBalancerDriver.LEASTCONN + else: + self.logger.error("Invalid algorithm: %s" % algo) + self.msg[self.RESPONSE_FIELD] = self.RESPONSE_FAILURE + return self.msg - for lb_node in self.msg['nodes']: - port, address = None, None + try: + self.driver.set_algorithm(current_lb['protocol'], algo) + except NotImplementedError: + self.logger.error( + "Selected driver does not support setting algorithm." + ) + self.msg[self.RESPONSE_FIELD] = self.RESPONSE_FAILURE + return self.msg + except Exception as e: + self.logger.error( + "Selected driver failed setting algorithm." + ) + self.msg[self.RESPONSE_FIELD] = self.RESPONSE_FAILURE + return self.msg - if 'port' in lb_node: - port = lb_node['port'] - else: - return BadRequest("Missing 'port' element.").to_json() + for lb_node in current_lb['nodes']: + port, address = None, None - if 'address' in lb_node: - address = lb_node['address'] - else: - return BadRequest("Missing 'address' element.").to_json() + if 'port' in lb_node: + port = lb_node['port'] + else: + return BadRequest("Missing 'port' element.").to_json() - try: - self.driver.add_server(address, port) - except NotImplementedError: - self.logger.error( - "Selected driver does not support adding a server." - ) - lb_node['condition'] = self.NODE_ERR - except Exception as e: - self.logger.error("Failure trying adding server: %s, %s" % - (e.__class__, e)) - lb_node['condition'] = self.NODE_ERR - else: - self.logger.debug("Added server: %s:%s" % (address, port)) - lb_node['condition'] = self.NODE_OK + if 'address' in lb_node: + address = lb_node['address'] + else: + return BadRequest("Missing 'address' element.").to_json() + + try: + self.driver.add_server(current_lb['protocol'], + address, + port) + except NotImplementedError: + self.logger.error( + "Selected driver does not support adding a server." + ) + lb_node['condition'] = self.NODE_ERR + except Exception as e: + self.logger.error("Failure trying adding server: %s, %s" % + (e.__class__, e)) + lb_node['condition'] = self.NODE_ERR + else: + self.logger.debug("Added server: %s:%s" % (address, port)) + lb_node['condition'] = self.NODE_OK try: self.driver.create() @@ -150,8 +177,9 @@ class LBaaSController(object): self.logger.error( "Selected driver does not support CREATE action." ) - for lb_node in self.msg['nodes']: - lb_node['condition'] = self.NODE_ERR + for current_lb in lb_list: + for lb_node in current_lb['nodes']: + lb_node['condition'] = self.NODE_ERR self.msg[self.RESPONSE_FIELD] = self.RESPONSE_FAILURE except Exception as e: self.logger.error("CREATE failed: %s, %s" % (e.__class__, e)) diff --git a/libra/worker/drivers/base.py b/libra/worker/drivers/base.py index 21642b3e..9e85671b 100644 --- a/libra/worker/drivers/base.py +++ b/libra/worker/drivers/base.py @@ -28,6 +28,12 @@ class LoadBalancerDriver(object): Generally, an appliance driver should queue up any changes made via these API calls until the create() method is called. + + This design allows for a single load balancer to support multiple + protocols simultaneously. Each protocol added via the add_protocol() + method is assumed to be unique, and one protocol per port. This same + protocol is then supplied to other methods (e.g., add_server() and + set_algorithm()) to make changes for that specific protocol. """ # Load balancer algorithms @@ -38,16 +44,16 @@ class LoadBalancerDriver(object): """ Allows the driver to do any initialization for a new config. """ raise NotImplementedError() - def add_server(self, host, port): - """ Add a server for which we will proxy. """ + def add_protocol(self, protocol, port): + """ Add a supported protocol and listening port for the instance. """ raise NotImplementedError() - def set_protocol(self, protocol, port): - """ Set the protocol of the instance. """ + def add_server(self, protocol, host, port): + """ Add a server for the protocol for which we will proxy. """ raise NotImplementedError() - def set_algorithm(self, algo): - """ Set the algorithm used by the load balancer. """ + def set_algorithm(self, protocol, algo): + """ Set the algorithm used by the load balancer for this protocol. """ raise NotImplementedError() def create(self): diff --git a/libra/worker/drivers/haproxy/driver.py b/libra/worker/drivers/haproxy/driver.py index 61b195ad..ea650dff 100644 --- a/libra/worker/drivers/haproxy/driver.py +++ b/libra/worker/drivers/haproxy/driver.py @@ -29,12 +29,10 @@ class HAProxyDriver(LoadBalancerDriver): def _init_config(self): self._config = dict() - self.set_protocol('HTTP', 80) - self.set_algorithm(self.ROUNDROBIN) - def _bind(self, address, port): - self._config['bind_address'] = address - self._config['bind_port'] = port + def _bind(self, protocol, address, port): + self._config[protocol]['bind_address'] = address + self._config[protocol]['bind_port'] = port def _config_to_string(self): """ @@ -54,7 +52,6 @@ class HAProxyDriver(LoadBalancerDriver): ) output.append('defaults') output.append(' log global') - output.append(' mode %s' % self._config['mode']) output.append(' option httplog') output.append(' option dontlognull') output.append(' option redispatch') @@ -63,18 +60,24 @@ class HAProxyDriver(LoadBalancerDriver): output.append(' timeout connect 5000ms') output.append(' timeout client 50000ms') output.append(' timeout server 5000ms') - output.append(' balance %s' % self._config['algorithm']) output.append(' cookie SERVERID rewrite') - output.append('frontend http-in') - output.append(' bind %s:%s' % (self._config['bind_address'], - self._config['bind_port'])) - output.append(' default_backend servers') - output.append('backend servers') serv_num = 1 - for (addr, port) in self._config['servers']: - output.append(' server server%d %s:%s' % (serv_num, addr, port)) - serv_num += 1 + + for proto in self._config: + protocfg = self._config[proto] + output.append('frontend %s-in' % proto) + output.append(' mode %s' % proto) + output.append(' bind %s:%s' % (protocfg['bind_address'], + protocfg['bind_port'])) + output.append(' default_backend %s-servers' % proto) + output.append('backend %s-servers' % proto) + output.append(' balance %s' % protocfg['algorithm']) + + for (addr, port) in protocfg['servers']: + output.append(' server server%d %s:%s' % + (serv_num, addr, port)) + serv_num += 1 return '\n'.join(output) + '\n' @@ -85,32 +88,37 @@ class HAProxyDriver(LoadBalancerDriver): def init(self): self._init_config() - def add_server(self, host, port): - if 'servers' not in self._config: - self._config['servers'] = [] - self._config['servers'].append((host, port)) - - def set_protocol(self, protocol, port=None): + def add_protocol(self, protocol, port=None): proto = protocol.lower() if proto not in ('tcp', 'http', 'health'): - raise Exception("Invalid protocol: %s" % protocol) - self._config['mode'] = proto + raise Exception("Unsupported protocol: %s" % protocol) + if proto in self._config: + raise Exception("Protocol '%s' is already defined." % protocol) + else: + self._config[proto] = dict() if port is None: if proto == 'tcp': raise Exception('Port is required for TCP protocol.') elif proto == 'http': - self._bind('0.0.0.0', 80) + self._bind(proto, '0.0.0.0', 80) else: - self._bind('0.0.0.0', port) + self._bind(proto, '0.0.0.0', port) - def set_algorithm(self, algo): + def add_server(self, protocol, host, port): + proto = protocol.lower() + if 'servers' not in self._config[proto]: + self._config[proto]['servers'] = [] + self._config[proto]['servers'].append((host, port)) + + def set_algorithm(self, protocol, algo): + proto = protocol.lower() if algo == self.ROUNDROBIN: - self._config['algorithm'] = 'roundrobin' + self._config[proto]['algorithm'] = 'roundrobin' elif algo == self.LEASTCONN: - self._config['algorithm'] = 'leastconn' + self._config[proto]['algorithm'] = 'leastconn' else: - raise Exception('Invalid algorithm') + raise Exception('Invalid algorithm: %s' % protocol) def create(self): self.ossvc.write_config() diff --git a/tests/test_haproxy_driver.py b/tests/test_haproxy_driver.py index adb10b46..58e842a3 100644 --- a/tests/test_haproxy_driver.py +++ b/tests/test_haproxy_driver.py @@ -13,45 +13,47 @@ class TestHAProxyDriver(unittest.TestCase): """ Test the HAProxy init() method """ self.driver.init() self.assertIsInstance(self.driver._config, dict) - self.assertEqual(self.driver._config['mode'], 'http') - self.assertEqual(self.driver._config['bind_address'], '0.0.0.0') - self.assertEqual(self.driver._config['bind_port'], 80) - def testSetProtocol(self): + def testAddProtocol(self): """ Test the HAProxy set_protocol() method """ - self.driver.set_protocol('http', None) - self.assertEqual(self.driver._config['bind_address'], '0.0.0.0') - self.assertEqual(self.driver._config['bind_port'], 80) - self.assertEqual(self.driver._config['mode'], 'http') + proto = 'http' + self.driver.add_protocol(proto, None) + self.assertIn(proto, self.driver._config) + self.assertEqual(self.driver._config[proto]['bind_address'], '0.0.0.0') + self.assertEqual(self.driver._config[proto]['bind_port'], 80) - self.driver.set_protocol('http', 8080) - self.assertEqual(self.driver._config['bind_address'], '0.0.0.0') - self.assertEqual(self.driver._config['bind_port'], 8080) - self.assertEqual(self.driver._config['mode'], 'http') - - self.driver.set_protocol('tcp', 443) - self.assertEqual(self.driver._config['bind_address'], '0.0.0.0') - self.assertEqual(self.driver._config['bind_port'], 443) - self.assertEqual(self.driver._config['mode'], 'tcp') + proto = 'tcp' + self.driver.add_protocol(proto, 443) + self.assertIn(proto, self.driver._config) + self.assertEqual(self.driver._config[proto]['bind_address'], '0.0.0.0') + self.assertEqual(self.driver._config[proto]['bind_port'], 443) + def testAddTCPRequiresPort(self): with self.assertRaises(Exception): - self.driver.set_protocol('tcp', None) + self.driver.add_protocol('tcp', None) def testAddServer(self): """ Test the HAProxy add_server() method """ - self.driver.add_server('1.2.3.4', 7777) - self.driver.add_server('5.6.7.8', 8888) - self.assertIn('servers', self.driver._config) - servers = self.driver._config['servers'] + proto = 'http' + self.driver.add_protocol(proto, None) + self.driver.add_server(proto, '1.2.3.4', 7777) + self.driver.add_server(proto, '5.6.7.8', 8888) + self.assertIn(proto, self.driver._config) + self.assertIn('servers', self.driver._config[proto]) + servers = self.driver._config[proto]['servers'] self.assertEqual(len(servers), 2) self.assertEqual(servers[0], ('1.2.3.4', 7777)) self.assertEqual(servers[1], ('5.6.7.8', 8888)) def testSetAlgorithm(self): """ Test the HAProxy set_algorithm() method """ - self.driver.set_algorithm(self.driver.ROUNDROBIN) - self.assertEqual(self.driver._config['algorithm'], 'roundrobin') - self.driver.set_algorithm(self.driver.LEASTCONN) - self.assertEqual(self.driver._config['algorithm'], 'leastconn') + proto = 'http' + self.driver.add_protocol(proto, None) + self.driver.set_algorithm(proto, self.driver.ROUNDROBIN) + self.assertIn(proto, self.driver._config) + self.assertIn('algorithm', self.driver._config[proto]) + self.assertEqual(self.driver._config[proto]['algorithm'], 'roundrobin') + self.driver.set_algorithm(proto, self.driver.LEASTCONN) + self.assertEqual(self.driver._config[proto]['algorithm'], 'leastconn') with self.assertRaises(Exception): - self.driver.set_protocol(99) + self.driver.set_algorithm(proto, 99) diff --git a/tests/test_lbaas_worker.py b/tests/test_lbaas_worker.py deleted file mode 100644 index f8fc2aa8..00000000 --- a/tests/test_lbaas_worker.py +++ /dev/null @@ -1,155 +0,0 @@ -import json -import logging -import unittest -import mock -from libra.worker.worker import lbaas_task -from libra.worker.drivers.base import LoadBalancerDriver - - -class FakeDriver(LoadBalancerDriver): - pass - - -class FakeJob(object): - def __init__(self, data): - """ - data: JSON object to convert to a string - """ - self.data = data - - -class FakeWorker(object): - def __init__(self): - self.logger = logging.getLogger('lbaas_worker_test') - self.driver = FakeDriver() - - -class TestLBaaSTask(unittest.TestCase): - def setUp(self): - pass - - def tearDown(self): - pass - - def testLBaaSTask(self): - """ Test the lbaas_task() function """ - - worker = FakeWorker() - data = { - "hpcs_action": "create", - "name": "a-new-loadbalancer", - "nodes": [ - { - "address": "10.1.1.1", - "port": "80" - }, - { - "address": "10.1.1.2", - "port": "81" - } - ] - } - - job = FakeJob(data) - r = lbaas_task(worker, job) - - self.assertEqual(r["name"], data["name"]) - self.assertEqual(len(r["nodes"]), 2) - self.assertEqual(r["nodes"][0]["address"], data["nodes"][0]["address"]) - self.assertEqual(r["nodes"][0]["port"], data["nodes"][0]["port"]) - self.assertIn("condition", r["nodes"][0]) - self.assertEqual(r["nodes"][1]["address"], data["nodes"][1]["address"]) - self.assertEqual(r["nodes"][1]["port"], data["nodes"][1]["port"]) - self.assertIn("condition", r["nodes"][1]) - - def testMissingAction(self): - """ Test invalid messages: missing hpcs_action """ - worker = FakeWorker() - data = { - "name": "a-new-loadbalancer", - "nodes": [ - { - "address": "10.1.1.1", - "port": "80" - }, - { - "address": "10.1.1.2", - "port": "81" - } - ] - } - job = FakeJob(data) - r = lbaas_task(worker, job) - self.assertIn("hpcs_response", r) - self.assertEqual("FAIL", r["hpcs_response"]) - - def testInvalidAction(self): - """ Test invalid messages: invalid hpcs_action """ - worker = FakeWorker() - data = { - "action": "invalid", - "name": "a-new-loadbalancer", - "nodes": [ - { - "address": "10.1.1.1", - "port": "80" - }, - { - "address": "10.1.1.2", - "port": "81" - } - ] - } - job = FakeJob(data) - r = lbaas_task(worker, job) - self.assertIn("hpcs_response", r) - self.assertEqual("FAIL", r["hpcs_response"]) - - def testMissingNodes(self): - """ Test invalid messages: missing nodes """ - - worker = FakeWorker() - data = { - "hpcs_action": "create", - "name": "a-new-loadbalancer" - } - job = FakeJob(data) - r = lbaas_task(worker, job) - self.assertIn("badRequest", r) - self.assertIn("validationErrors", r["badRequest"]) - - def testMissingPort(self): - """ Test invalid messages: missing port """ - - worker = FakeWorker() - data = { - "hpcs_action": "create", - "name": "a-new-loadbalancer", - "nodes": [ - { - "address": "10.1.1.1" - } - ] - } - job = FakeJob(data) - r = lbaas_task(worker, job) - self.assertIn("badRequest", r) - self.assertIn("validationErrors", r["badRequest"]) - - def testMissingAddress(self): - """ Test invalid messages: missing address """ - - worker = FakeWorker() - data = { - "hpcs_action": "create", - "name": "a-new-loadbalancer", - "nodes": [ - { - "port": "80" - } - ] - } - job = FakeJob(data) - r = lbaas_task(worker, job) - self.assertIn("badRequest", r) - self.assertIn("validationErrors", r["badRequest"]) diff --git a/tests/test_worker_controller.py b/tests/test_worker_controller.py index f2ca79c8..41d40553 100644 --- a/tests/test_worker_controller.py +++ b/tests/test_worker_controller.py @@ -28,10 +28,15 @@ class TestWorkerController(unittest.TestCase): def testCreate(self): msg = { c.ACTION_FIELD: 'CREATE', - 'nodes': [ + 'loadbalancers': [ { - 'address': '10.0.0.1', - 'port': 80 + 'protocol': 'http', + 'nodes': [ + { + 'address': '10.0.0.1', + 'port': 80 + } + ] } ] } @@ -43,10 +48,15 @@ class TestWorkerController(unittest.TestCase): def testUpdate(self): msg = { c.ACTION_FIELD: 'CREATE', - 'nodes': [ + 'loadbalancers': [ { - 'address': '10.0.0.1', - 'port': 80 + 'protocol': 'http', + 'nodes': [ + { + 'address': '10.0.0.1', + 'port': 80 + } + ] } ] } @@ -82,7 +92,7 @@ class TestWorkerController(unittest.TestCase): self.assertIn(c.RESPONSE_FIELD, response) self.assertEquals(response[c.RESPONSE_FIELD], c.RESPONSE_SUCCESS) - def testCreateMissingNodes(self): + def testCreateMissingLBs(self): msg = { c.ACTION_FIELD: 'CREATE' } @@ -90,14 +100,46 @@ class TestWorkerController(unittest.TestCase): response = controller.run() self.assertIn('badRequest', response) + def testCreateMissingNodes(self): + msg = { + c.ACTION_FIELD: 'CREATE', + 'loadbalancers': [ { 'protocol': 'http' } ] + } + controller = c(self.logger, self.driver, msg) + response = controller.run() + self.assertIn('badRequest', response) + + def testCreateMissingProto(self): + msg = { + c.ACTION_FIELD: 'CREATE', + 'loadbalancers': [ + { + 'nodes': [ + { + 'address': '10.0.0.1', + 'port': 80 + } + ] + } + ] + } + controller = c(self.logger, self.driver, msg) + response = controller.run() + self.assertIn('badRequest', response) + def testBadAlgorithm(self): msg = { c.ACTION_FIELD: 'CREATE', - 'algorithm': 'BOGUS', - 'nodes': [ + 'loadbalancers': [ { - 'address': '10.0.0.1', - 'port': 80 + 'protocol': 'http', + 'algorithm': 'BOGUS', + 'nodes': [ + { + 'address': '10.0.0.1', + 'port': 80 + } + ] } ] }