Support multi-protocol LB messages.

This change allows for defining a load balancer supporting
multiple simultaneous protocols instead of one LB per protocol.

Redundant worker tests are also removed.

Change-Id: I9ef13771af7d0513997c675374fc171d515d4b43
This commit is contained in:
David Shrewsbury 2012-10-29 14:01:17 -04:00
parent 6f70deb0be
commit d54272bf57
6 changed files with 226 additions and 295 deletions

View File

@ -24,6 +24,7 @@ class LBaaSController(object):
RESPONSE_SUCCESS = "PASS" RESPONSE_SUCCESS = "PASS"
ACTION_FIELD = 'hpcs_action' ACTION_FIELD = 'hpcs_action'
RESPONSE_FIELD = 'hpcs_response' RESPONSE_FIELD = 'hpcs_response'
LBLIST_FIELD = 'loadbalancers'
def __init__(self, logger, driver, json_msg): def __init__(self, logger, driver, json_msg):
self.logger = logger self.logger = logger
@ -61,7 +62,15 @@ class LBaaSController(object):
return self.msg return self.msg
def _action_create(self): def _action_create(self):
""" Create a Load Balancer. """ """
Create a Load Balancer.
This is the only method (so far) that actually parses the contents
of the JSON message (other than the ACTION_FIELD field). Modifying
the JSON message structure likely means this method will need to
be modified, unless the change involves fields that are ignored.
"""
try: try:
self.driver.init() self.driver.init()
except NotImplementedError: except NotImplementedError:
@ -71,78 +80,96 @@ class LBaaSController(object):
self.msg[self.RESPONSE_FIELD] = self.RESPONSE_FAILURE self.msg[self.RESPONSE_FIELD] = self.RESPONSE_FAILURE
return self.msg return self.msg
if 'nodes' not in self.msg: if self.LBLIST_FIELD not in self.msg:
return BadRequest("Missing 'nodes' element").to_json() return BadRequest(
"Missing '%s' element" % self.LBLIST_FIELD
).to_json()
if 'protocol' in self.msg: lb_list = self.msg['loadbalancers']
port = None
if 'port' in self.msg:
port = self.msg['port']
try:
self.driver.set_protocol(self.msg['protocol'], port)
except NotImplementedError:
self.logger.error(
"Selected driver does not support setting protocol."
)
self.msg[self.RESPONSE_FIELD] = self.RESPONSE_FAILURE
return self.msg
except Exception as e:
self.logger.error("Failure trying to set protocol: %s, %s" %
(e.__class__, e))
self.msg[self.RESPONSE_FIELD] = self.RESPONSE_FAILURE
return self.msg
if 'algorithm' in self.msg: for current_lb in lb_list:
algo = self.msg['algorithm'].upper() if 'nodes' not in current_lb:
if algo == 'ROUND_ROBIN': return BadRequest("Missing 'nodes' element").to_json()
algo = LoadBalancerDriver.ROUNDROBIN
elif algo == 'LEAST_CONNECTIONS': if 'protocol' not in current_lb:
algo = LoadBalancerDriver.LEASTCONN return BadRequest(
"Missing required 'protocol' value."
).to_json()
else: else:
self.logger.error("Invalid algorithm: %s" % algo) port = None
self.msg[self.RESPONSE_FIELD] = self.RESPONSE_FAILURE if 'port' in current_lb:
return self.msg port = current_lb['port']
try:
self.driver.add_protocol(current_lb['protocol'], port)
except NotImplementedError:
self.logger.error(
"Selected driver does not support setting protocol."
)
self.msg[self.RESPONSE_FIELD] = self.RESPONSE_FAILURE
return self.msg
except Exception as e:
self.logger.error(
"Failure trying to set protocol: %s, %s" %
(e.__class__, e)
)
self.msg[self.RESPONSE_FIELD] = self.RESPONSE_FAILURE
return self.msg
try: if 'algorithm' in current_lb:
self.driver.set_algorithm(algo) algo = current_lb['algorithm'].upper()
except NotImplementedError: if algo == 'ROUND_ROBIN':
self.logger.error( algo = LoadBalancerDriver.ROUNDROBIN
"Selected driver does not support setting algorithm." elif algo == 'LEAST_CONNECTIONS':
) algo = LoadBalancerDriver.LEASTCONN
self.msg[self.RESPONSE_FIELD] = self.RESPONSE_FAILURE else:
return self.msg self.logger.error("Invalid algorithm: %s" % algo)
except Exception as e: self.msg[self.RESPONSE_FIELD] = self.RESPONSE_FAILURE
self.logger.error("Selected driver failed setting algorithm.") return self.msg
self.msg[self.RESPONSE_FIELD] = self.RESPONSE_FAILURE
return self.msg
for lb_node in self.msg['nodes']: try:
port, address = None, None self.driver.set_algorithm(current_lb['protocol'], algo)
except NotImplementedError:
self.logger.error(
"Selected driver does not support setting algorithm."
)
self.msg[self.RESPONSE_FIELD] = self.RESPONSE_FAILURE
return self.msg
except Exception as e:
self.logger.error(
"Selected driver failed setting algorithm."
)
self.msg[self.RESPONSE_FIELD] = self.RESPONSE_FAILURE
return self.msg
if 'port' in lb_node: for lb_node in current_lb['nodes']:
port = lb_node['port'] port, address = None, None
else:
return BadRequest("Missing 'port' element.").to_json()
if 'address' in lb_node: if 'port' in lb_node:
address = lb_node['address'] port = lb_node['port']
else: else:
return BadRequest("Missing 'address' element.").to_json() return BadRequest("Missing 'port' element.").to_json()
try: if 'address' in lb_node:
self.driver.add_server(address, port) address = lb_node['address']
except NotImplementedError: else:
self.logger.error( return BadRequest("Missing 'address' element.").to_json()
"Selected driver does not support adding a server."
) try:
lb_node['condition'] = self.NODE_ERR self.driver.add_server(current_lb['protocol'],
except Exception as e: address,
self.logger.error("Failure trying adding server: %s, %s" % port)
(e.__class__, e)) except NotImplementedError:
lb_node['condition'] = self.NODE_ERR self.logger.error(
else: "Selected driver does not support adding a server."
self.logger.debug("Added server: %s:%s" % (address, port)) )
lb_node['condition'] = self.NODE_OK lb_node['condition'] = self.NODE_ERR
except Exception as e:
self.logger.error("Failure trying adding server: %s, %s" %
(e.__class__, e))
lb_node['condition'] = self.NODE_ERR
else:
self.logger.debug("Added server: %s:%s" % (address, port))
lb_node['condition'] = self.NODE_OK
try: try:
self.driver.create() self.driver.create()
@ -150,8 +177,9 @@ class LBaaSController(object):
self.logger.error( self.logger.error(
"Selected driver does not support CREATE action." "Selected driver does not support CREATE action."
) )
for lb_node in self.msg['nodes']: for current_lb in lb_list:
lb_node['condition'] = self.NODE_ERR for lb_node in current_lb['nodes']:
lb_node['condition'] = self.NODE_ERR
self.msg[self.RESPONSE_FIELD] = self.RESPONSE_FAILURE self.msg[self.RESPONSE_FIELD] = self.RESPONSE_FAILURE
except Exception as e: except Exception as e:
self.logger.error("CREATE failed: %s, %s" % (e.__class__, e)) self.logger.error("CREATE failed: %s, %s" % (e.__class__, e))

View File

@ -28,6 +28,12 @@ class LoadBalancerDriver(object):
Generally, an appliance driver should queue up any changes made Generally, an appliance driver should queue up any changes made
via these API calls until the create() method is called. via these API calls until the create() method is called.
This design allows for a single load balancer to support multiple
protocols simultaneously. Each protocol added via the add_protocol()
method is assumed to be unique, and one protocol per port. This same
protocol is then supplied to other methods (e.g., add_server() and
set_algorithm()) to make changes for that specific protocol.
""" """
# Load balancer algorithms # Load balancer algorithms
@ -38,16 +44,16 @@ class LoadBalancerDriver(object):
""" Allows the driver to do any initialization for a new config. """ """ Allows the driver to do any initialization for a new config. """
raise NotImplementedError() raise NotImplementedError()
def add_server(self, host, port): def add_protocol(self, protocol, port):
""" Add a server for which we will proxy. """ """ Add a supported protocol and listening port for the instance. """
raise NotImplementedError() raise NotImplementedError()
def set_protocol(self, protocol, port): def add_server(self, protocol, host, port):
""" Set the protocol of the instance. """ """ Add a server for the protocol for which we will proxy. """
raise NotImplementedError() raise NotImplementedError()
def set_algorithm(self, algo): def set_algorithm(self, protocol, algo):
""" Set the algorithm used by the load balancer. """ """ Set the algorithm used by the load balancer for this protocol. """
raise NotImplementedError() raise NotImplementedError()
def create(self): def create(self):

View File

@ -29,12 +29,10 @@ class HAProxyDriver(LoadBalancerDriver):
def _init_config(self): def _init_config(self):
self._config = dict() self._config = dict()
self.set_protocol('HTTP', 80)
self.set_algorithm(self.ROUNDROBIN)
def _bind(self, address, port): def _bind(self, protocol, address, port):
self._config['bind_address'] = address self._config[protocol]['bind_address'] = address
self._config['bind_port'] = port self._config[protocol]['bind_port'] = port
def _config_to_string(self): def _config_to_string(self):
""" """
@ -54,7 +52,6 @@ class HAProxyDriver(LoadBalancerDriver):
) )
output.append('defaults') output.append('defaults')
output.append(' log global') output.append(' log global')
output.append(' mode %s' % self._config['mode'])
output.append(' option httplog') output.append(' option httplog')
output.append(' option dontlognull') output.append(' option dontlognull')
output.append(' option redispatch') output.append(' option redispatch')
@ -63,18 +60,24 @@ class HAProxyDriver(LoadBalancerDriver):
output.append(' timeout connect 5000ms') output.append(' timeout connect 5000ms')
output.append(' timeout client 50000ms') output.append(' timeout client 50000ms')
output.append(' timeout server 5000ms') output.append(' timeout server 5000ms')
output.append(' balance %s' % self._config['algorithm'])
output.append(' cookie SERVERID rewrite') output.append(' cookie SERVERID rewrite')
output.append('frontend http-in')
output.append(' bind %s:%s' % (self._config['bind_address'],
self._config['bind_port']))
output.append(' default_backend servers')
output.append('backend servers')
serv_num = 1 serv_num = 1
for (addr, port) in self._config['servers']:
output.append(' server server%d %s:%s' % (serv_num, addr, port)) for proto in self._config:
serv_num += 1 protocfg = self._config[proto]
output.append('frontend %s-in' % proto)
output.append(' mode %s' % proto)
output.append(' bind %s:%s' % (protocfg['bind_address'],
protocfg['bind_port']))
output.append(' default_backend %s-servers' % proto)
output.append('backend %s-servers' % proto)
output.append(' balance %s' % protocfg['algorithm'])
for (addr, port) in protocfg['servers']:
output.append(' server server%d %s:%s' %
(serv_num, addr, port))
serv_num += 1
return '\n'.join(output) + '\n' return '\n'.join(output) + '\n'
@ -85,32 +88,37 @@ class HAProxyDriver(LoadBalancerDriver):
def init(self): def init(self):
self._init_config() self._init_config()
def add_server(self, host, port): def add_protocol(self, protocol, port=None):
if 'servers' not in self._config:
self._config['servers'] = []
self._config['servers'].append((host, port))
def set_protocol(self, protocol, port=None):
proto = protocol.lower() proto = protocol.lower()
if proto not in ('tcp', 'http', 'health'): if proto not in ('tcp', 'http', 'health'):
raise Exception("Invalid protocol: %s" % protocol) raise Exception("Unsupported protocol: %s" % protocol)
self._config['mode'] = proto if proto in self._config:
raise Exception("Protocol '%s' is already defined." % protocol)
else:
self._config[proto] = dict()
if port is None: if port is None:
if proto == 'tcp': if proto == 'tcp':
raise Exception('Port is required for TCP protocol.') raise Exception('Port is required for TCP protocol.')
elif proto == 'http': elif proto == 'http':
self._bind('0.0.0.0', 80) self._bind(proto, '0.0.0.0', 80)
else: else:
self._bind('0.0.0.0', port) self._bind(proto, '0.0.0.0', port)
def set_algorithm(self, algo): def add_server(self, protocol, host, port):
proto = protocol.lower()
if 'servers' not in self._config[proto]:
self._config[proto]['servers'] = []
self._config[proto]['servers'].append((host, port))
def set_algorithm(self, protocol, algo):
proto = protocol.lower()
if algo == self.ROUNDROBIN: if algo == self.ROUNDROBIN:
self._config['algorithm'] = 'roundrobin' self._config[proto]['algorithm'] = 'roundrobin'
elif algo == self.LEASTCONN: elif algo == self.LEASTCONN:
self._config['algorithm'] = 'leastconn' self._config[proto]['algorithm'] = 'leastconn'
else: else:
raise Exception('Invalid algorithm') raise Exception('Invalid algorithm: %s' % protocol)
def create(self): def create(self):
self.ossvc.write_config() self.ossvc.write_config()

View File

@ -13,45 +13,47 @@ class TestHAProxyDriver(unittest.TestCase):
""" Test the HAProxy init() method """ """ Test the HAProxy init() method """
self.driver.init() self.driver.init()
self.assertIsInstance(self.driver._config, dict) self.assertIsInstance(self.driver._config, dict)
self.assertEqual(self.driver._config['mode'], 'http')
self.assertEqual(self.driver._config['bind_address'], '0.0.0.0')
self.assertEqual(self.driver._config['bind_port'], 80)
def testSetProtocol(self): def testAddProtocol(self):
""" Test the HAProxy set_protocol() method """ """ Test the HAProxy set_protocol() method """
self.driver.set_protocol('http', None) proto = 'http'
self.assertEqual(self.driver._config['bind_address'], '0.0.0.0') self.driver.add_protocol(proto, None)
self.assertEqual(self.driver._config['bind_port'], 80) self.assertIn(proto, self.driver._config)
self.assertEqual(self.driver._config['mode'], 'http') self.assertEqual(self.driver._config[proto]['bind_address'], '0.0.0.0')
self.assertEqual(self.driver._config[proto]['bind_port'], 80)
self.driver.set_protocol('http', 8080) proto = 'tcp'
self.assertEqual(self.driver._config['bind_address'], '0.0.0.0') self.driver.add_protocol(proto, 443)
self.assertEqual(self.driver._config['bind_port'], 8080) self.assertIn(proto, self.driver._config)
self.assertEqual(self.driver._config['mode'], 'http') self.assertEqual(self.driver._config[proto]['bind_address'], '0.0.0.0')
self.assertEqual(self.driver._config[proto]['bind_port'], 443)
self.driver.set_protocol('tcp', 443)
self.assertEqual(self.driver._config['bind_address'], '0.0.0.0')
self.assertEqual(self.driver._config['bind_port'], 443)
self.assertEqual(self.driver._config['mode'], 'tcp')
def testAddTCPRequiresPort(self):
with self.assertRaises(Exception): with self.assertRaises(Exception):
self.driver.set_protocol('tcp', None) self.driver.add_protocol('tcp', None)
def testAddServer(self): def testAddServer(self):
""" Test the HAProxy add_server() method """ """ Test the HAProxy add_server() method """
self.driver.add_server('1.2.3.4', 7777) proto = 'http'
self.driver.add_server('5.6.7.8', 8888) self.driver.add_protocol(proto, None)
self.assertIn('servers', self.driver._config) self.driver.add_server(proto, '1.2.3.4', 7777)
servers = self.driver._config['servers'] self.driver.add_server(proto, '5.6.7.8', 8888)
self.assertIn(proto, self.driver._config)
self.assertIn('servers', self.driver._config[proto])
servers = self.driver._config[proto]['servers']
self.assertEqual(len(servers), 2) self.assertEqual(len(servers), 2)
self.assertEqual(servers[0], ('1.2.3.4', 7777)) self.assertEqual(servers[0], ('1.2.3.4', 7777))
self.assertEqual(servers[1], ('5.6.7.8', 8888)) self.assertEqual(servers[1], ('5.6.7.8', 8888))
def testSetAlgorithm(self): def testSetAlgorithm(self):
""" Test the HAProxy set_algorithm() method """ """ Test the HAProxy set_algorithm() method """
self.driver.set_algorithm(self.driver.ROUNDROBIN) proto = 'http'
self.assertEqual(self.driver._config['algorithm'], 'roundrobin') self.driver.add_protocol(proto, None)
self.driver.set_algorithm(self.driver.LEASTCONN) self.driver.set_algorithm(proto, self.driver.ROUNDROBIN)
self.assertEqual(self.driver._config['algorithm'], 'leastconn') self.assertIn(proto, self.driver._config)
self.assertIn('algorithm', self.driver._config[proto])
self.assertEqual(self.driver._config[proto]['algorithm'], 'roundrobin')
self.driver.set_algorithm(proto, self.driver.LEASTCONN)
self.assertEqual(self.driver._config[proto]['algorithm'], 'leastconn')
with self.assertRaises(Exception): with self.assertRaises(Exception):
self.driver.set_protocol(99) self.driver.set_algorithm(proto, 99)

View File

@ -1,155 +0,0 @@
import json
import logging
import unittest
import mock
from libra.worker.worker import lbaas_task
from libra.worker.drivers.base import LoadBalancerDriver
class FakeDriver(LoadBalancerDriver):
pass
class FakeJob(object):
def __init__(self, data):
"""
data: JSON object to convert to a string
"""
self.data = data
class FakeWorker(object):
def __init__(self):
self.logger = logging.getLogger('lbaas_worker_test')
self.driver = FakeDriver()
class TestLBaaSTask(unittest.TestCase):
def setUp(self):
pass
def tearDown(self):
pass
def testLBaaSTask(self):
""" Test the lbaas_task() function """
worker = FakeWorker()
data = {
"hpcs_action": "create",
"name": "a-new-loadbalancer",
"nodes": [
{
"address": "10.1.1.1",
"port": "80"
},
{
"address": "10.1.1.2",
"port": "81"
}
]
}
job = FakeJob(data)
r = lbaas_task(worker, job)
self.assertEqual(r["name"], data["name"])
self.assertEqual(len(r["nodes"]), 2)
self.assertEqual(r["nodes"][0]["address"], data["nodes"][0]["address"])
self.assertEqual(r["nodes"][0]["port"], data["nodes"][0]["port"])
self.assertIn("condition", r["nodes"][0])
self.assertEqual(r["nodes"][1]["address"], data["nodes"][1]["address"])
self.assertEqual(r["nodes"][1]["port"], data["nodes"][1]["port"])
self.assertIn("condition", r["nodes"][1])
def testMissingAction(self):
""" Test invalid messages: missing hpcs_action """
worker = FakeWorker()
data = {
"name": "a-new-loadbalancer",
"nodes": [
{
"address": "10.1.1.1",
"port": "80"
},
{
"address": "10.1.1.2",
"port": "81"
}
]
}
job = FakeJob(data)
r = lbaas_task(worker, job)
self.assertIn("hpcs_response", r)
self.assertEqual("FAIL", r["hpcs_response"])
def testInvalidAction(self):
""" Test invalid messages: invalid hpcs_action """
worker = FakeWorker()
data = {
"action": "invalid",
"name": "a-new-loadbalancer",
"nodes": [
{
"address": "10.1.1.1",
"port": "80"
},
{
"address": "10.1.1.2",
"port": "81"
}
]
}
job = FakeJob(data)
r = lbaas_task(worker, job)
self.assertIn("hpcs_response", r)
self.assertEqual("FAIL", r["hpcs_response"])
def testMissingNodes(self):
""" Test invalid messages: missing nodes """
worker = FakeWorker()
data = {
"hpcs_action": "create",
"name": "a-new-loadbalancer"
}
job = FakeJob(data)
r = lbaas_task(worker, job)
self.assertIn("badRequest", r)
self.assertIn("validationErrors", r["badRequest"])
def testMissingPort(self):
""" Test invalid messages: missing port """
worker = FakeWorker()
data = {
"hpcs_action": "create",
"name": "a-new-loadbalancer",
"nodes": [
{
"address": "10.1.1.1"
}
]
}
job = FakeJob(data)
r = lbaas_task(worker, job)
self.assertIn("badRequest", r)
self.assertIn("validationErrors", r["badRequest"])
def testMissingAddress(self):
""" Test invalid messages: missing address """
worker = FakeWorker()
data = {
"hpcs_action": "create",
"name": "a-new-loadbalancer",
"nodes": [
{
"port": "80"
}
]
}
job = FakeJob(data)
r = lbaas_task(worker, job)
self.assertIn("badRequest", r)
self.assertIn("validationErrors", r["badRequest"])

View File

@ -28,10 +28,15 @@ class TestWorkerController(unittest.TestCase):
def testCreate(self): def testCreate(self):
msg = { msg = {
c.ACTION_FIELD: 'CREATE', c.ACTION_FIELD: 'CREATE',
'nodes': [ 'loadbalancers': [
{ {
'address': '10.0.0.1', 'protocol': 'http',
'port': 80 'nodes': [
{
'address': '10.0.0.1',
'port': 80
}
]
} }
] ]
} }
@ -43,10 +48,15 @@ class TestWorkerController(unittest.TestCase):
def testUpdate(self): def testUpdate(self):
msg = { msg = {
c.ACTION_FIELD: 'CREATE', c.ACTION_FIELD: 'CREATE',
'nodes': [ 'loadbalancers': [
{ {
'address': '10.0.0.1', 'protocol': 'http',
'port': 80 'nodes': [
{
'address': '10.0.0.1',
'port': 80
}
]
} }
] ]
} }
@ -82,7 +92,7 @@ class TestWorkerController(unittest.TestCase):
self.assertIn(c.RESPONSE_FIELD, response) self.assertIn(c.RESPONSE_FIELD, response)
self.assertEquals(response[c.RESPONSE_FIELD], c.RESPONSE_SUCCESS) self.assertEquals(response[c.RESPONSE_FIELD], c.RESPONSE_SUCCESS)
def testCreateMissingNodes(self): def testCreateMissingLBs(self):
msg = { msg = {
c.ACTION_FIELD: 'CREATE' c.ACTION_FIELD: 'CREATE'
} }
@ -90,14 +100,46 @@ class TestWorkerController(unittest.TestCase):
response = controller.run() response = controller.run()
self.assertIn('badRequest', response) self.assertIn('badRequest', response)
def testCreateMissingNodes(self):
msg = {
c.ACTION_FIELD: 'CREATE',
'loadbalancers': [ { 'protocol': 'http' } ]
}
controller = c(self.logger, self.driver, msg)
response = controller.run()
self.assertIn('badRequest', response)
def testCreateMissingProto(self):
msg = {
c.ACTION_FIELD: 'CREATE',
'loadbalancers': [
{
'nodes': [
{
'address': '10.0.0.1',
'port': 80
}
]
}
]
}
controller = c(self.logger, self.driver, msg)
response = controller.run()
self.assertIn('badRequest', response)
def testBadAlgorithm(self): def testBadAlgorithm(self):
msg = { msg = {
c.ACTION_FIELD: 'CREATE', c.ACTION_FIELD: 'CREATE',
'algorithm': 'BOGUS', 'loadbalancers': [
'nodes': [
{ {
'address': '10.0.0.1', 'protocol': 'http',
'port': 80 'algorithm': 'BOGUS',
'nodes': [
{
'address': '10.0.0.1',
'port': 80
}
]
} }
] ]
} }