diff --git a/ostack_validator/schema.py b/ostack_validator/schema.py index c2e9207..c62ec2c 100644 --- a/ostack_validator/schema.py +++ b/ostack_validator/schema.py @@ -165,9 +165,13 @@ class TypeValidatorRegistry: return self.__validators[name] +class SchemaError(Issue): + def __init__(self, message): + super(SchemaError, self).__init__(Issue.ERROR, message) + class InvalidValueError(MarkedIssue): def __init__(self, message, mark=Mark('', 1, 1)): - super(InvalidValueError, self).__init__(Issue.ERROR, message, mark) + super(InvalidValueError, self).__init__(Issue.ERROR, 'Invalid value: '+message, mark) class TypeValidator(object): def __init__(self, f): @@ -187,6 +191,9 @@ def type_validator(name, **kwargs): return fn return wrap +def isissue(o): + return isinstance(o, Issue) + @type_validator('boolean') def validate_boolean(s): s = s.lower() @@ -195,42 +202,163 @@ def validate_boolean(s): elif s == 'false': return False else: - return InvalidValueError('Invalid value: value should be "true" or "false"') + return InvalidValueError('Value should be "true" or "false"') def validate_enum(s, values=[]): if s in values: return None if len(values) == 0: - message = 'there should be no value' + message = 'There should be no value' elif len(values) == 1: - message = 'the only valid value is %s' % values[0] + message = 'The only valid value is %s' % values[0] else: - message = 'valid values are %s and %s' % (', '.join(values[:-1]), values[-1]) - return InvalidValueError('Invalid value: %s' % message) + message = 'Valid values are %s and %s' % (', '.join(values[:-1]), values[-1]) + return InvalidValueError('%s' % message) + +def validate_ipv4_address(s): + s = s.strip() + parts = s.split('.') + if len(parts) == 4: + if all([all([c.isdigit() for c in part]) for part in parts]): + parts = [int(part) for part in parts] + if all([part < 256 for part in parts]): + return '.'.join([str(part) for part in parts]) + + return InvalidValueError('Value should be ipv4 address') + +def validate_ipv4_network(s): + s = s.strip() + parts = s.split('/') + if len(parts) != 2: + return InvalidValueError('Should have "/" character separating address and prefix length') + + address, prefix = parts + prefix = prefix.strip() + + if prefix.strip() == '': + return InvalidValueError('Prefix length is required') + + address = validate_ipv4_address(address) + if isissue(address): + return address + + if not all([c.isdigit() for c in prefix]): + return InvalidValueError('Prefix length should be an integer') + + prefix = int(prefix) + if prefix > 32: + return InvalidValueError('Prefix length should be less than or equal to 32') + + return '%s/%d' % (address, prefix) + +@type_validator('host_address') +def validate_host_address(s): + return validate_ipv4_address(s) + +@type_validator('network_address') +def validate_network_address(s): + return validate_ipv4_network(s) + +@type_validator('host_and_port') +def validate_host_and_port(s, default_port=None): + parts = s.strip().split(':', 2) + + host_address = validate_host_address(parts[0]) + if isissue(host_address): + return host_address + + if len(parts) == 2: + port = validate_port(parts[1]) + if isissue(port): + return port + elif default_port: + port = default_port + else: + return InvalidValueError('No port specified') + + return (host_address, port) -@type_validator('host') @type_validator('string') -@type_validator('stringlist') def validate_string(s): return s @type_validator('integer') -@type_validator('port', min=1, max=65535) def validate_integer(s, min=None, max=None): leading_whitespace_len = 0 - while s[leading_whitespace_len].isspace(): leading_whitespace_len += 1 + while leading_whitespace_len < len(s) and s[leading_whitespace_len].isspace(): leading_whitespace_len += 1 s = s.strip() + if s == '': + return InvalidValueError('Should not be empty') + for i, c in enumerate(s): if not c.isdigit() and not ((c == '-') and (i == 0)): - return InvalidValueError('Invalid value: only digits are allowed, but found char "%s"' % c, Mark('', 1, i+1+leading_whitespace_len)) + return InvalidValueError('Only digits are allowed, but found char "%s"' % c, Mark('', 1, i+1+leading_whitespace_len)) v = int(s) if min and v < min: - return InvalidValueError('Invalid value: should be greater than or equal to %d' % min, Mark('', 1, leading_whitespace_len)) + return InvalidValueError('Should be greater than or equal to %d' % min, Mark('', 1, leading_whitespace_len)) if max and v > max: - return InvalidValueError('Invalid value: should be less than or equal to %d' % max, Mark('', 1, leading_whitespace_len)) + return InvalidValueError('Should be less than or equal to %d' % max, Mark('', 1, leading_whitespace_len)) return v +@type_validator('port') +def validate_port(s, min=1, max=65535): + return validate_integer(s, min=min, max=max) + +@type_validator('string_list') +def validate_list(s, element_type='string'): + element_type_validator = TypeValidatorRegistry.get_validator(element_type) + if not element_type_validator: + return SchemaError('Invalid element type "%s"' % element_type) + + result = [] + s = s.strip() + + if s == '': + return result + + values = s.split(',') + for value in values: + validated_value = element_type_validator.validate(value.strip()) + if isinstance(validated_value, Issue): + # TODO: provide better position reporting + return validated_value + result.append(validated_value) + + return result + +@type_validator('string_dict') +def validate_dict(s, element_type='string'): + element_type_validator = TypeValidatorRegistry.get_validator(element_type) + if not element_type_validator: + return SchemaError('Invalid element type "%s"' % element_type) + + result = {} + s = s.strip() + + if s == '': + return result + + pairs = s.split(',') + for pair in pairs: + key_value = pair.split(':', 2) + if len(key_value) < 2: + return InvalidValueError('Value should be NAME:VALUE pairs separated by ","') + + key, value = key_value + key = key.strip() + value = value.strip() + + if key == '': + # TODO: provide better position reporting + return InvalidValueError('Key name should not be empty') + + validated_value = element_type_validator.validate(value) + if isinstance(validated_value, Issue): + # TODO: provide better position reporting + return validated_value + result[key] = validated_value + return result diff --git a/ostack_validator/test_type_validators.py b/ostack_validator/test_type_validators.py index 7552203..417b655 100644 --- a/ostack_validator/test_type_validators.py +++ b/ostack_validator/test_type_validators.py @@ -76,9 +76,57 @@ class IntegerTypeValidatorTests(TypeValidatorTestHelper, unittest.TestCase): v = self.validator.validate('123') self.assertEqual(123, v) +class HostAddressTypeValidatorTests(TypeValidatorTestHelper, unittest.TestCase): + type_name = 'host_address' + + def test_ipv4_address(self): + self.assertValid('127.0.0.1') + + def test_returns_address(self): + s = '10.0.0.1' + v = self.validator.validate(s) + self.assertEqual(s, v) + + def test_value_with_less_than_4_numbers_separated_by_dots(self): + self.assertInvalid('10.0.0') + + def test_ipv4_like_string_with_numbers_greater_than_255(self): + self.assertInvalid('10.0.256.1') + + +class NetworkAddressTypeValidatorTests(TypeValidatorTestHelper, unittest.TestCase): + type_name = 'network_address' + + def test_ipv4_network(self): + self.assertValid('127.0.0.1/24') + + def test_returns_address(self): + s = '10.0.0.1/32' + v = self.validator.validate(s) + self.assertEqual(s, v) + + def test_value_with_less_than_4_numbers_separated_by_dots(self): + self.assertInvalid('10.0.0/24') + + def test_ipv4_like_string_with_numbers_greater_than_255(self): + self.assertInvalid('10.0.256.1/24') + + def test_no_prefix_length(self): + self.assertInvalid('10.0.0.0') + self.assertInvalid('10.0.0.0/') + + def test_non_integer_prefix_length(self): + self.assertInvalid('10.0.0.0/1a') + + def test_prefix_greater_than_32(self): + self.assertInvalid('10.0.0.0/33') + class PortTypeValidatorTests(TypeValidatorTestHelper, unittest.TestCase): type_name = 'port' + def test_empty(self): + self.assertInvalid('') + def test_positive_integer(self): self.assertValid('123') @@ -109,6 +157,78 @@ class PortTypeValidatorTests(TypeValidatorTestHelper, unittest.TestCase): v = self.validator.validate('123') self.assertEqual(123, v) +class HostAndPortTypeValidatorTests(TypeValidatorTestHelper, unittest.TestCase): + type_name = 'host_and_port' + + def test_ipv4_address(self): + self.assertValid('127.0.0.1:80') + + def test_returns_address(self): + s = '10.0.0.1:80' + v = self.validator.validate(s) + self.assertEqual(('10.0.0.1', 80), v) + + def test_value_with_less_than_4_numbers_separated_by_dots(self): + self.assertInvalid('10.0.0:1234') + + def test_ipv4_like_string_with_numbers_greater_than_255(self): + self.assertInvalid('10.0.256.1:1234') + + def test_no_port(self): + self.assertInvalid('10.0.0.1') + self.assertInvalid('10.0.0.1:') + + def test_port_is_not_an_integer(self): + self.assertInvalid('10.0.0.1:abc') + + def test_port_is_greater_than_65535(self): + self.assertInvalid('10.0.0.1:65536') + +class StringListTypeValidatorTests(TypeValidatorTestHelper, unittest.TestCase): + type_name = 'string_list' + + def test_empty_value(self): + v = self.validator.validate('') + self.assertEqual([], v) + + def test_single_value(self): + v = self.validator.validate(' foo bar ') + + self.assertIsInstance(v, list) + self.assertEqual('foo bar', v[0]) + self.assertEqual(1, len(v)) + + def test_list_of_values(self): + v = self.validator.validate(' foo bar, baz ') + + self.assertIsInstance(v, list) + self.assertEqual('foo bar', v[0]) + self.assertEqual('baz', v[1]) + self.assertEqual(2, len(v)) + +class StringDictTypeValidatorTests(TypeValidatorTestHelper, unittest.TestCase): + type_name = 'string_dict' + + def test_empty_value(self): + v = self.validator.validate('') + self.assertEqual({}, v) + + def test_single_value(self): + v = self.validator.validate(' foo: bar ') + + self.assertIsInstance(v, dict) + self.assertEqual('bar', v['foo']) + self.assertEqual(1, len(v)) + + def test_list_of_values(self): + v = self.validator.validate(' foo: bar, baz: 123 ') + + self.assertIsInstance(v, dict) + self.assertEqual('bar', v['foo']) + self.assertEqual('123', v['baz']) + self.assertEqual(2, len(v)) + + if __name__ == '__main__': unittest.main()