Refactor the selinux guard to aid in mocking

1. Adjust the test_util after this mocking to be cleaner
This commit is contained in:
harlowja 2012-06-29 20:37:46 -07:00
parent 62e027b172
commit a353240c09
2 changed files with 44 additions and 30 deletions

View File

@ -46,19 +46,13 @@ import urlparse
import yaml import yaml
from cloudinit import importer
from cloudinit import log as logging from cloudinit import log as logging
from cloudinit import url_helper as uhelp from cloudinit import url_helper as uhelp
from cloudinit.settings import (CFG_BUILTIN, CLOUD_CONFIG) from cloudinit.settings import (CFG_BUILTIN, CLOUD_CONFIG)
try:
import selinux
HAVE_LIBSELINUX = True
except ImportError:
HAVE_LIBSELINUX = False
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
# Helps cleanup filenames to ensure they aren't FS incompatible # Helps cleanup filenames to ensure they aren't FS incompatible
@ -126,31 +120,37 @@ class ProcessExecutionError(IOError):
class SeLinuxGuard(object): class SeLinuxGuard(object):
def __init__(self, path, recursive=False): def __init__(self, path, recursive=False):
# Late import since it might not always
# be possible to use this
try:
self.selinux = importer.import_module('selinux')
except ImportError:
self.selinux = None
self.path = path self.path = path
self.recursive = recursive self.recursive = recursive
self.enabled = False
if HAVE_LIBSELINUX and selinux.is_selinux_enabled():
self.enabled = True
def __enter__(self): def __enter__(self):
return self.enabled if self.selinux:
return True
else:
return False
def __exit__(self, excp_type, excp_value, excp_traceback): def __exit__(self, excp_type, excp_value, excp_traceback):
if self.enabled: if self.selinux:
path = os.path.realpath(os.path.expanduser(self.path)) path = os.path.realpath(os.path.expanduser(self.path))
do_restore = False do_restore = False
try: try:
# See if even worth restoring?? # See if even worth restoring??
stats = os.lstat(path) stats = os.lstat(path)
if stat.ST_MODE in stats: if stat.ST_MODE in stats:
selinux.matchpathcon(path, stats[stat.ST_MODE]) self.selinux.matchpathcon(path, stats[stat.ST_MODE])
do_restore = True do_restore = True
except OSError: except OSError:
pass pass
if do_restore: if do_restore:
LOG.debug("Restoring selinux mode for %s (recursive=%s)", LOG.debug("Restoring selinux mode for %s (recursive=%s)",
path, self.recursive) path, self.recursive)
selinux.restorecon(path, recursive=self.recursive) self.selinux.restorecon(path, recursive=self.recursive)
class MountFailedError(Exception): class MountFailedError(Exception):

View File

@ -5,6 +5,26 @@ from unittest import TestCase
from mocker import MockerTestCase from mocker import MockerTestCase
from cloudinit import util from cloudinit import util
from cloudinit import importer
class FakeSelinux(object):
def __init__(self, match_what):
self.match_what = match_what
self.restored = []
def matchpathcon(self, path, mode):
if path == self.match_what:
return
else:
raise OSError("No match!")
def is_selinux_enabled(self):
return True
def restorecon(self, path, recursive):
self.restored.append(path)
class TestMergeDict(MockerTestCase): class TestMergeDict(MockerTestCase):
@ -159,22 +179,16 @@ class TestWriteFile(MockerTestCase):
def test_restorecon_if_possible_is_called(self): def test_restorecon_if_possible_is_called(self):
"""Make sure the selinux guard is called correctly.""" """Make sure the selinux guard is called correctly."""
try: import_mock = self.mocker.replace(importer.import_module,
# We can only mock these out if selinux actually passthrough=False)
# exists, so thats why we catch the import import_mock('selinux')
mock_restorecon = self.mocker.replace( fake_se = FakeSelinux('/etc/hosts')
"selinux.restorecon", passthrough=False) self.mocker.result(fake_se)
mock_is_selinux_enabled = self.mocker.replace( self.mocker.replay()
"selinux.is_selinux_enabled", passthrough=False) with util.SeLinuxGuard("/etc/hosts") as is_on:
mock_is_selinux_enabled() self.assertTrue(is_on)
self.mocker.result(True) self.assertEqual(1, len(fake_se.restored))
mock_restorecon("/etc/hosts", recursive=False) self.assertEqual('/etc/hosts', fake_se.restored[0])
self.mocker.result(True)
self.mocker.replay()
with util.SeLinuxGuard("/etc/hosts") as is_on:
self.assertTrue(is_on)
except ImportError:
pass
class TestDeleteDirContents(MockerTestCase): class TestDeleteDirContents(MockerTestCase):