diff --git a/wsme/tests/test_types.py b/wsme/tests/test_types.py index 1c28495..a0a3cd1 100644 --- a/wsme/tests/test_types.py +++ b/wsme/tests/test_types.py @@ -523,3 +523,107 @@ Value: 'v3'. Value should be one of: v., v.", return 'from-file' f = types.File(content=six.b('from-content')) assert f.file.read() == six.b('from-content') + + def test_unregister(self): + class TempType(object): + pass + types.registry.register(TempType) + v = types.registry.lookup('TempType') + self.assertIs(v, TempType) + types.registry._unregister(TempType) + after = types.registry.lookup('TempType') + self.assertIs(after, None) + + def test_unregister_twice(self): + class TempType(object): + pass + types.registry.register(TempType) + v = types.registry.lookup('TempType') + self.assertIs(v, TempType) + types.registry._unregister(TempType) + # Second call should not raise an exception + types.registry._unregister(TempType) + after = types.registry.lookup('TempType') + self.assertIs(after, None) + + def test_unregister_array_type(self): + class TempType(object): + pass + t = [TempType] + types.registry.register(t) + self.assertNotEqual(types.registry.array_types, set()) + types.registry._unregister(t) + self.assertEqual(types.registry.array_types, set()) + + def test_unregister_array_type_twice(self): + class TempType(object): + pass + t = [TempType] + types.registry.register(t) + self.assertNotEqual(types.registry.array_types, set()) + types.registry._unregister(t) + # Second call should not raise an exception + types.registry._unregister(t) + self.assertEqual(types.registry.array_types, set()) + + def test_unregister_dict_type(self): + class TempType(object): + pass + t = {str: TempType} + types.registry.register(t) + self.assertNotEqual(types.registry.dict_types, set()) + types.registry._unregister(t) + self.assertEqual(types.registry.dict_types, set()) + + def test_unregister_dict_type_twice(self): + class TempType(object): + pass + t = {str: TempType} + types.registry.register(t) + self.assertNotEqual(types.registry.dict_types, set()) + types.registry._unregister(t) + # Second call should not raise an exception + types.registry._unregister(t) + self.assertEqual(types.registry.dict_types, set()) + + def test_reregister(self): + class TempType(object): + pass + types.registry.register(TempType) + v = types.registry.lookup('TempType') + self.assertIs(v, TempType) + types.registry.reregister(TempType) + after = types.registry.lookup('TempType') + self.assertIs(after, TempType) + + def test_reregister_and_add_attr(self): + class TempType(object): + pass + types.registry.register(TempType) + attrs = types.list_attributes(TempType) + self.assertEqual(attrs, []) + TempType.one = str + types.registry.reregister(TempType) + after = types.list_attributes(TempType) + self.assertNotEqual(after, []) + + def test_dynamicbase_add_attributes(self): + class TempType(types.DynamicBase): + pass + types.registry.register(TempType) + attrs = types.list_attributes(TempType) + self.assertEqual(attrs, []) + TempType.add_attributes(one=str) + after = types.list_attributes(TempType) + self.assertEqual(len(after), 1) + + def test_dynamicbase_add_attributes_second(self): + class TempType(types.DynamicBase): + pass + types.registry.register(TempType) + attrs = types.list_attributes(TempType) + self.assertEqual(attrs, []) + TempType.add_attributes(one=str) + TempType.add_attributes(two=int) + after = types.list_attributes(TempType) + self.assertEqual(len(after), 2) diff --git a/wsme/types.py b/wsme/types.py index 0902ce3..1193ee9 100644 --- a/wsme/types.py +++ b/wsme/types.py @@ -670,6 +670,41 @@ class Registry(object): self._complex_types.append(weakref.ref(class_)) return class_ + def reregister(self, class_): + """Register a type which may already have been registered. + """ + self._unregister(class_) + return self.register(class_) + + def _unregister(self, class_): + """Remove a previously registered type. + """ + # Clear the existing attribute reference so it is rebuilt if + # the class is registered again later. + if hasattr(class_, '_wsme_attributes'): + del class_._wsme_attributes + # FIXME(dhellmann): This method does not recurse through the + # types like register() does. Should it? + if isinstance(class_, list): + at = ArrayType(class_[0]) + try: + self.array_types.remove(at) + except KeyError: + pass + elif isinstance(class_, dict): + key_type, value_type = list(class_.items())[0] + self.dict_types = set( + dt for dt in self.dict_types + if (dt.key_type, dt.value_type) != (key_type, value_type) + ) + # We can't use remove() here because the items in + # _complex_types are weakref objects pointing to the classes, + # so we can't compare with them directly. + self._complex_types = [ + ct for ct in self._complex_types + if ct() is not class_ + ] + def lookup(self, typename): log.debug('Lookup %s' % typename) modname = None @@ -772,3 +807,26 @@ class File(Base): if self._file is None and self._content: self._file = six.BytesIO(self._content) return self._file + + +class DynamicBase(Base): + """Base type for complex types for which all attributes are not + defined when the class is constructed. + + This class is meant to be used as a base for types that have + properties added after the main class is created, such as by + loading plugins. + + """ + + @classmethod + def add_attributes(cls, **attrs): + """Add more attributes + + The arguments should be valid Python attribute names + associated with a type for the new attribute. + + """ + for n, t in attrs.items(): + setattr(cls, n, t) + cls.__registry__.reregister(cls)