Refactor the selinux guard to aid in mocking

1. Adjust the test_util after this mocking to be cleaner
This commit is contained in:
harlowja 2012-06-29 20:37:46 -07:00
parent 62e027b172
commit a353240c09
2 changed files with 44 additions and 30 deletions

View File

@ -46,19 +46,13 @@ import urlparse
import yaml
from cloudinit import importer
from cloudinit import log as logging
from cloudinit import url_helper as uhelp
from cloudinit.settings import (CFG_BUILTIN, CLOUD_CONFIG)
try:
import selinux
HAVE_LIBSELINUX = True
except ImportError:
HAVE_LIBSELINUX = False
LOG = logging.getLogger(__name__)
# Helps cleanup filenames to ensure they aren't FS incompatible
@ -126,31 +120,37 @@ class ProcessExecutionError(IOError):
class SeLinuxGuard(object):
def __init__(self, path, recursive=False):
# Late import since it might not always
# be possible to use this
try:
self.selinux = importer.import_module('selinux')
except ImportError:
self.selinux = None
self.path = path
self.recursive = recursive
self.enabled = False
if HAVE_LIBSELINUX and selinux.is_selinux_enabled():
self.enabled = True
def __enter__(self):
return self.enabled
if self.selinux:
return True
else:
return False
def __exit__(self, excp_type, excp_value, excp_traceback):
if self.enabled:
if self.selinux:
path = os.path.realpath(os.path.expanduser(self.path))
do_restore = False
try:
# See if even worth restoring??
stats = os.lstat(path)
if stat.ST_MODE in stats:
selinux.matchpathcon(path, stats[stat.ST_MODE])
self.selinux.matchpathcon(path, stats[stat.ST_MODE])
do_restore = True
except OSError:
pass
if do_restore:
LOG.debug("Restoring selinux mode for %s (recursive=%s)",
path, self.recursive)
selinux.restorecon(path, recursive=self.recursive)
self.selinux.restorecon(path, recursive=self.recursive)
class MountFailedError(Exception):

View File

@ -5,6 +5,26 @@ from unittest import TestCase
from mocker import MockerTestCase
from cloudinit import util
from cloudinit import importer
class FakeSelinux(object):
def __init__(self, match_what):
self.match_what = match_what
self.restored = []
def matchpathcon(self, path, mode):
if path == self.match_what:
return
else:
raise OSError("No match!")
def is_selinux_enabled(self):
return True
def restorecon(self, path, recursive):
self.restored.append(path)
class TestMergeDict(MockerTestCase):
@ -159,22 +179,16 @@ class TestWriteFile(MockerTestCase):
def test_restorecon_if_possible_is_called(self):
"""Make sure the selinux guard is called correctly."""
try:
# We can only mock these out if selinux actually
# exists, so thats why we catch the import
mock_restorecon = self.mocker.replace(
"selinux.restorecon", passthrough=False)
mock_is_selinux_enabled = self.mocker.replace(
"selinux.is_selinux_enabled", passthrough=False)
mock_is_selinux_enabled()
self.mocker.result(True)
mock_restorecon("/etc/hosts", recursive=False)
self.mocker.result(True)
self.mocker.replay()
with util.SeLinuxGuard("/etc/hosts") as is_on:
self.assertTrue(is_on)
except ImportError:
pass
import_mock = self.mocker.replace(importer.import_module,
passthrough=False)
import_mock('selinux')
fake_se = FakeSelinux('/etc/hosts')
self.mocker.result(fake_se)
self.mocker.replay()
with util.SeLinuxGuard("/etc/hosts") as is_on:
self.assertTrue(is_on)
self.assertEqual(1, len(fake_se.restored))
self.assertEqual('/etc/hosts', fake_se.restored[0])
class TestDeleteDirContents(MockerTestCase):