diff --git a/quark/db/api.py b/quark/db/api.py index 52c364e..56a0328 100644 --- a/quark/db/api.py +++ b/quark/db/api.py @@ -966,6 +966,16 @@ def security_group_rule_create(context, **rule_dict): return new_rule +def security_group_rule_update(context, rule, **kwargs): + '''Updates a security group rule. + + NOTE(alexm) this is non-standard functionality. + ''' + rule.update(kwargs) + context.session.add(rule) + return rule + + def security_group_rule_delete(context, rule): context.session.delete(rule) diff --git a/quark/plugin.py b/quark/plugin.py index 411da1d..90e4f25 100644 --- a/quark/plugin.py +++ b/quark/plugin.py @@ -226,6 +226,11 @@ class Plugin(neutron_plugin_base_v2.NeutronPluginBaseV2, return security_groups.update_security_group(context, id, security_group) + @sessioned + def update_security_group_rule(self, context, id, security_group_rule): + return security_groups.update_security_group_rule(context, id, + security_group_rule) + @sessioned def create_ip_policy(self, context, ip_policy): self._fix_missing_tenant_id(context, ip_policy, "ip_policy") diff --git a/quark/plugin_modules/security_groups.py b/quark/plugin_modules/security_groups.py index b81242c..c718328 100644 --- a/quark/plugin_modules/security_groups.py +++ b/quark/plugin_modules/security_groups.py @@ -35,6 +35,7 @@ DEFAULT_SG_UUID = "00000000-0000-0000-0000-000000000000" GROUP_NAME_MAX_LENGTH = 255 GROUP_DESCRIPTION_MAX_LENGTH = 255 RULE_CREATE = 'create' +RULE_UPDATE = 'update' RULE_DELETE = 'delete' @@ -76,6 +77,21 @@ def _validate_security_group_rule(context, rule): return rule +def _filter_update_security_group_rule(rule): + '''Only two fields are allowed for modification: + + external_service and external_service_id + ''' + allowed = ['external_service', 'external_service_id'] + filtered = {} + for k, val in rule.iteritems(): + if k in allowed: + if isinstance(val, basestring) and \ + len(val) <= GROUP_NAME_MAX_LENGTH: + filtered[k] = val + return filtered + + def _validate_security_group(security_group): if "name" in security_group: if len(security_group["name"]) > GROUP_NAME_MAX_LENGTH: @@ -201,6 +217,35 @@ def create_security_group_rule(context, security_group_rule): return v._make_security_group_rule_dict(new_rule) +def update_security_group_rule(context, id, security_group_rule): + '''Updates a rule and updates the ports''' + LOG.info("update_security_group_rule for tenant %s" % + (context.tenant_id)) + new_rule = security_group_rule["security_group_rule"] + # Only allow updatable fields + new_rule = _filter_update_security_group_rule(new_rule) + + with context.session.begin(): + rule = db_api.security_group_rule_find(context, id=id, + scope=db_api.ONE) + if not rule: + raise sg_ext.SecurityGroupRuleNotFound(id=id) + + db_rule = db_api.security_group_rule_update(context, rule, **new_rule) + + group_id = db_rule.group_id + group = db_api.security_group_find(context, id=group_id, + scope=db_api.ONE) + if not group: + raise sg_ext.SecurityGroupNotFound(id=group_id) + + if group: + _perform_async_update_rule(context, group_id, group, rule.id, + RULE_UPDATE) + + return v._make_security_group_rule_dict(db_rule) + + def get_security_group_rule(context, id, fields=None): LOG.info("get_security_group_rule %s for tenant %s" % (id, context.tenant_id)) diff --git a/quark/tests/plugin_modules/test_security_groups.py b/quark/tests/plugin_modules/test_security_groups.py index 42ea597..eb3bad9 100644 --- a/quark/tests/plugin_modules/test_security_groups.py +++ b/quark/tests/plugin_modules/test_security_groups.py @@ -140,6 +140,44 @@ class TestQuarkGetSecurityGroupRules(test_quark_plugin.TestQuarkPlugin): self.plugin.get_security_group_rule(self.context, 1) +class TestQuarkUpdateSecurityGroupRule(test_quark_plugin.TestQuarkPlugin): + def test_update_security_group_rule(self): + sg_dict = {"id": 0} + sg = models.SecurityGroup() + sg.update(sg_dict) + + rule_dict = {"id": 1, "direction": "ingress", + "port_range_min": 80, "port_range_max": 100, + "remote_ip_prefix": None, + "ethertype": protocols.translate_ethertype("IPv4"), + "tenant_id": "foo", "protocol": "UDP", "group_id": 1, + "external_service_id": None, "external_service": None} + rule = models.SecurityGroupRule() + rule.update(rule_dict) + + updated_part = {'external_service_id': 'aaa', + 'external_service': 'bbb'} + + with mock.patch('quark.db.api.security_group_rule_find') as db_find, \ + mock.patch('quark.db.api.security_group_rule_update') \ + as db_update, \ + mock.patch('quark.db.api.security_group_find') as sg_find: + db_find.return_value = rule + rule_dict.update(updated_part) + new_rule = models.SecurityGroupRule() + new_rule.update(rule_dict) + db_update.return_value = new_rule + sg_find.return_value = sg + + update = dict(security_group_rule=updated_part) + resp = self.plugin.update_security_group_rule(self.context, 1, + update) + self.assertEqual(resp['external_service_id'], + updated_part['external_service_id']) + self.assertEqual(resp['external_service'], + updated_part['external_service']) + + class TestQuarkUpdateSecurityGroup(test_quark_plugin.TestQuarkPlugin): def test_update_security_group(self): v4_ethertype = protocols.ETHERTYPES["IPv4"]