import importlib import logging from orm.common.orm_common.injector import injector import os from unittest import mock, TestCase logger = logging.getLogger(__name__) class TestInjector(TestCase): def setUp(self): pass @mock.patch.object(injector, '_import_file_by_name') def test_register_providers(self, mock_import_file_by_name): os.environ['CMS_ENV'] = 'test' injector.register_providers('CMS_ENV', 'a/b/c', logger) @mock.patch.object(injector, '_import_file_by_name') def test_register_providers_env_not_exist(self, mock_import_file_by_name): injector.register_providers('CMS_ENV1', 'a/b/c', logger) @mock.patch.object(injector, '_import_file_by_name') def test_register_providers_env_test(self, mock_import_file_by_name): os.environ['CMS_ENV2'] = '__TEST__' injector.register_providers('CMS_ENV2', 'a/b/c', logger) @mock.patch.object(injector, '_import_file_by_name') def test_register_providers_with_existing_provider(self, mock_import_file_by_name): mock_import_file_by_name.return_value = type('module', (object,), {'providers': ['a1', 'b2']})() os.environ['c3'] = 'test' injector.register_providers('c3', 'a/b/c', logger) def test_get_di(self): injector.get_di() def test_import_file_by_name_file_not_found_error(self): # Calling it with ('', '.') should raise a FileNotFoundError # (no such file or directory) self.assertRaises(FileNotFoundError, injector._import_file_by_name, '', '.') @mock.patch.object(injector.os.path, 'join') @mock.patch.object(injector.importlib.util, 'module_from_spec') @mock.patch.object(injector.importlib.util, 'spec_from_file_location') def test_import_file_by_name_sanity(self, mock_spec, mock_module, mock_os): mock_os.return_value = ['mock_providers.py'] mock_spec.return_value = mock.Mock() mock_module.return_value = type('mock_providers', (object,), {'providers': ['a1']})() mock_provider = injector._import_file_by_name('mock', '') test_module = importlib.util.module_from_spec(type('mock_providers', (object,), {'providers': ['a1']})()) self.assertIsInstance(mock_provider, type(test_module)) @mock.patch.object(injector._di.providers, 'register_instance') def test_override_injected_dependency(self, mock_di): injector.override_injected_dependency((1, 2,)) self.assertTrue(mock_di.called)