From aedf967cf5ec0ed739d8f6a3fe02738b89b5082c Mon Sep 17 00:00:00 2001 From: Doug Hellmann Date: Wed, 5 Dec 2012 10:51:50 -0500 Subject: [PATCH] tighten up validate_value logic and allow string promotion to integers --- wsme/tests/test_types.py | 23 +++++++++---- wsme/types.py | 71 +++++++++++++++++++++++----------------- 2 files changed, 58 insertions(+), 36 deletions(-) diff --git a/wsme/tests/test_types.py b/wsme/tests/test_types.py index 494fb5e..cbfb45d 100644 --- a/wsme/tests/test_types.py +++ b/wsme/tests/test_types.py @@ -249,7 +249,7 @@ class TestTypes(unittest.TestCase): assert value.a is types.Unset def test_validate_dict(self): - types.validate_value({int: str}, {1: '1', 5: '5'}) + assert types.validate_value({int: str}, {1: '1', 5: '5'}) try: types.validate_value({int: str}, []) @@ -257,11 +257,7 @@ class TestTypes(unittest.TestCase): except ValueError: pass - try: - types.validate_value({int: str}, {'1': '1', 5: '5'}) - assert False, "No ValueError raised" - except ValueError: - pass + assert types.validate_value({int: str}, {'1': '1', 5: '5'}) try: types.validate_value({int: str}, {1: 1, 5: '5'}) @@ -273,6 +269,21 @@ class TestTypes(unittest.TestCase): self.assertEqual(types.validate_value(float, 1), 1.0) self.assertEqual(types.validate_value(float, '1'), 1.0) self.assertEqual(types.validate_value(float, 1.1), 1.1) + try: + types.validate_value(float, []) + assert False, "No ValueError raised" + except ValueError: + pass + + def test_validate_int(self): + self.assertEqual(types.validate_value(int, 1), 1) + self.assertEqual(types.validate_value(int, '1'), 1) + self.assertEqual(types.validate_value(int, six.u('1')), 1) + try: + types.validate_value(int, 1.1) + assert False, "No ValueError raised" + except ValueError: + pass def test_register_invalid_array(self): self.assertRaises(ValueError, types.register_type, []) diff --git a/wsme/types.py b/wsme/types.py index 7d0357d..c88acf9 100644 --- a/wsme/types.py +++ b/wsme/types.py @@ -178,6 +178,8 @@ pod_types = six.integer_types + ( dt_types = (datetime.date, datetime.time, datetime.datetime) extra_types = (binary, decimal.Decimal) native_types = pod_types + dt_types + extra_types +# The types for which we allow promotion to certain numbers. +_promotable_types = six.integer_types + (text, bytes) def iscomplex(datatype): @@ -194,40 +196,49 @@ def isdict(datatype): def validate_value(datatype, value): + if value in (Unset, None): + return value + + # Try to promote the data type to one of our complex types. + if isinstance(datatype, list): + datatype = ArrayType(datatype[0]) + elif isinstance(datatype, dict): + datatype = DictType(*list(datatype.items())[0]) + + # If the datatype has its own validator, use that. if hasattr(datatype, 'validate'): return datatype.validate(value) - else: - if value in (Unset, None): - return value - if isinstance(datatype, list): - datatype = ArrayType(datatype[0]) - if isinstance(datatype, dict): - datatype = DictType(*list(datatype.items())[0]) - if isarray(datatype): - datatype.validate(value) - elif isdict(datatype): - datatype.validate(value) - elif datatype in six.integer_types: - if not isinstance(value, six.integer_types): - raise ValueError( - "Wrong type. Expected an integer, got '%s'" % ( - type(value) - )) - elif datatype is text and isinstance(value, bytes): - value = value.decode() - elif datatype is bytes and isinstance(value, text): - value = value.encode() - elif datatype is float and (isinstance(value, int) - or isinstance(value, text) - or isinstance(value, bytes)): + # Do type promotion/conversion and data validation for builtin + # types. + v_type = type(value) + if datatype in six.integer_types: + if v_type in _promotable_types: + try: + # Try to turn the value into an int + value = datatype(value) + except ValueError: + # An error is raised at the end of the function + # when the types don't match. + pass + elif datatype is float and v_type in _promotable_types: + try: value = float(value) - elif not isinstance(value, datatype): - raise ValueError( - "Wrong type. Expected '%s', got '%s'" % ( - datatype, type(value) - )) - return value + except ValueError: + # An error is raised at the end of the function + # when the types don't match. + pass + elif datatype is text and isinstance(value, bytes): + value = value.decode() + elif datatype is bytes and isinstance(value, text): + value = value.encode() + + if not isinstance(value, datatype): + raise ValueError( + "Wrong type. Expected '%s', got '%s'" % ( + datatype, v_type + )) + return value class wsproperty(property):