diff --git a/cloudbaseinit/exception.py b/cloudbaseinit/exception.py index f954b451..79edfb84 100644 --- a/cloudbaseinit/exception.py +++ b/cloudbaseinit/exception.py @@ -71,3 +71,14 @@ class WindowsCloudbaseInitException(CloudbaseInitException): except TypeError: formatted_msg = msg super(WindowsCloudbaseInitException, self).__init__(formatted_msg) + + +class LoadUserProfileCloudbaseInitException(WindowsCloudbaseInitException): + """Windows cannot load the newly created user profile. + + The load user profile can fail if the Windows subsystems responsible for + the action are not ready. This usually happens on laggy systems and should + be retried. + """ + + pass diff --git a/cloudbaseinit/osutils/windows.py b/cloudbaseinit/osutils/windows.py index 290c2e5a..5630d212 100644 --- a/cloudbaseinit/osutils/windows.py +++ b/cloudbaseinit/osutils/windows.py @@ -642,6 +642,9 @@ class WindowsUtils(base.BaseOSUtils): # User not found pass + @retry_decorator.retry_decorator( + max_retry_count=3, + exceptions=exception.LoadUserProfileCloudbaseInitException) def create_user_logon_session(self, username, password, domain='.', load_profile=True, logon_type=LOGON32_LOGON_INTERACTIVE): @@ -666,7 +669,7 @@ class WindowsUtils(base.BaseOSUtils): ret_val = userenv.LoadUserProfileW(token, ctypes.byref(pi)) if not ret_val: kernel32.CloseHandle(token) - raise exception.WindowsCloudbaseInitException( + raise exception.LoadUserProfileCloudbaseInitException( "Cannot load user profile: %r") return token diff --git a/cloudbaseinit/tests/osutils/test_windows.py b/cloudbaseinit/tests/osutils/test_windows.py index ce3a0cf7..a9eec533 100644 --- a/cloudbaseinit/tests/osutils/test_windows.py +++ b/cloudbaseinit/tests/osutils/test_windows.py @@ -452,7 +452,9 @@ class TestWindowsUtils(testutils.CloudbaseInitTestBase): self._test_create_user(fail=True) @mock.patch('cloudbaseinit.osutils.windows.Win32_PROFILEINFO') - def _test_create_user_logon_session(self, mock_Win32_PROFILEINFO, logon, + @mock.patch('time.sleep') + def _test_create_user_logon_session(self, mock_time_sleep, + mock_Win32_PROFILEINFO, logon, loaduser, load_profile=True, last_error=None): self._wintypes_mock.HANDLE = mock.MagicMock() @@ -474,7 +476,9 @@ class TestWindowsUtils(testutils.CloudbaseInitTestBase): userenv.LoadUserProfileW.return_value = None kernel32.CloseHandle.return_value = None with self.assert_raises_windows_message( - "Cannot load user profile: %r", last_error): + "Cannot load user profile: %r", last_error, + get_last_error_called_times=4, + format_error_called_times=4): self._winutils.create_user_logon_session( self._USERNAME, self._PASSWORD, domain='.', load_profile=load_profile) diff --git a/cloudbaseinit/tests/testutils.py b/cloudbaseinit/tests/testutils.py index a3e3303b..fdea01c3 100644 --- a/cloudbaseinit/tests/testutils.py +++ b/cloudbaseinit/tests/testutils.py @@ -159,7 +159,9 @@ class CloudbaseInitTestBase(unittest.TestCase): @contextlib.contextmanager def assert_raises_windows_message( self, expected_msg, error_code, - exc=exception.WindowsCloudbaseInitException): + exc=exception.WindowsCloudbaseInitException, + get_last_error_called_times=1, + format_error_called_times=1): """Helper method for testing raised error messages This assert method is similar to :meth:`~assertRaises`, but @@ -188,11 +190,16 @@ class CloudbaseInitTestBase(unittest.TestCase): # This can be called when the error code is not given, # but we don't have control over that, so test that # it's actually called only once. - mock_get_last_error.assert_called_once_with() - mock_format_error.assert_called_once_with( + mock_get_last_error.assert_called() + self.assertEqual(mock_get_last_error.call_count, + get_last_error_called_times) + mock_format_error.assert_called_with( mock_get_last_error.return_value) else: - mock_format_error.assert_called_once_with(error_code) + mock_format_error.assert_called_with(error_code) + + self.assertEqual(mock_format_error.call_count, + format_error_called_times) expected_msg = expected_msg % mock_format_error.return_value self.assertEqual(expected_msg, cm.exception.args[0])