Refactor the selinux guard to aid in mocking
1. Adjust the test_util after this mocking to be cleaner
This commit is contained in:
parent
62e027b172
commit
a353240c09
@ -46,19 +46,13 @@ import urlparse
|
||||
|
||||
import yaml
|
||||
|
||||
from cloudinit import importer
|
||||
from cloudinit import log as logging
|
||||
from cloudinit import url_helper as uhelp
|
||||
|
||||
from cloudinit.settings import (CFG_BUILTIN, CLOUD_CONFIG)
|
||||
|
||||
|
||||
try:
|
||||
import selinux
|
||||
HAVE_LIBSELINUX = True
|
||||
except ImportError:
|
||||
HAVE_LIBSELINUX = False
|
||||
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
# Helps cleanup filenames to ensure they aren't FS incompatible
|
||||
@ -126,31 +120,37 @@ class ProcessExecutionError(IOError):
|
||||
|
||||
class SeLinuxGuard(object):
|
||||
def __init__(self, path, recursive=False):
|
||||
# Late import since it might not always
|
||||
# be possible to use this
|
||||
try:
|
||||
self.selinux = importer.import_module('selinux')
|
||||
except ImportError:
|
||||
self.selinux = None
|
||||
self.path = path
|
||||
self.recursive = recursive
|
||||
self.enabled = False
|
||||
if HAVE_LIBSELINUX and selinux.is_selinux_enabled():
|
||||
self.enabled = True
|
||||
|
||||
def __enter__(self):
|
||||
return self.enabled
|
||||
if self.selinux:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def __exit__(self, excp_type, excp_value, excp_traceback):
|
||||
if self.enabled:
|
||||
if self.selinux:
|
||||
path = os.path.realpath(os.path.expanduser(self.path))
|
||||
do_restore = False
|
||||
try:
|
||||
# See if even worth restoring??
|
||||
stats = os.lstat(path)
|
||||
if stat.ST_MODE in stats:
|
||||
selinux.matchpathcon(path, stats[stat.ST_MODE])
|
||||
self.selinux.matchpathcon(path, stats[stat.ST_MODE])
|
||||
do_restore = True
|
||||
except OSError:
|
||||
pass
|
||||
if do_restore:
|
||||
LOG.debug("Restoring selinux mode for %s (recursive=%s)",
|
||||
path, self.recursive)
|
||||
selinux.restorecon(path, recursive=self.recursive)
|
||||
self.selinux.restorecon(path, recursive=self.recursive)
|
||||
|
||||
|
||||
class MountFailedError(Exception):
|
||||
|
@ -5,6 +5,26 @@ from unittest import TestCase
|
||||
from mocker import MockerTestCase
|
||||
|
||||
from cloudinit import util
|
||||
from cloudinit import importer
|
||||
|
||||
|
||||
class FakeSelinux(object):
|
||||
|
||||
def __init__(self, match_what):
|
||||
self.match_what = match_what
|
||||
self.restored = []
|
||||
|
||||
def matchpathcon(self, path, mode):
|
||||
if path == self.match_what:
|
||||
return
|
||||
else:
|
||||
raise OSError("No match!")
|
||||
|
||||
def is_selinux_enabled(self):
|
||||
return True
|
||||
|
||||
def restorecon(self, path, recursive):
|
||||
self.restored.append(path)
|
||||
|
||||
|
||||
class TestMergeDict(MockerTestCase):
|
||||
@ -159,22 +179,16 @@ class TestWriteFile(MockerTestCase):
|
||||
|
||||
def test_restorecon_if_possible_is_called(self):
|
||||
"""Make sure the selinux guard is called correctly."""
|
||||
try:
|
||||
# We can only mock these out if selinux actually
|
||||
# exists, so thats why we catch the import
|
||||
mock_restorecon = self.mocker.replace(
|
||||
"selinux.restorecon", passthrough=False)
|
||||
mock_is_selinux_enabled = self.mocker.replace(
|
||||
"selinux.is_selinux_enabled", passthrough=False)
|
||||
mock_is_selinux_enabled()
|
||||
self.mocker.result(True)
|
||||
mock_restorecon("/etc/hosts", recursive=False)
|
||||
self.mocker.result(True)
|
||||
self.mocker.replay()
|
||||
with util.SeLinuxGuard("/etc/hosts") as is_on:
|
||||
self.assertTrue(is_on)
|
||||
except ImportError:
|
||||
pass
|
||||
import_mock = self.mocker.replace(importer.import_module,
|
||||
passthrough=False)
|
||||
import_mock('selinux')
|
||||
fake_se = FakeSelinux('/etc/hosts')
|
||||
self.mocker.result(fake_se)
|
||||
self.mocker.replay()
|
||||
with util.SeLinuxGuard("/etc/hosts") as is_on:
|
||||
self.assertTrue(is_on)
|
||||
self.assertEqual(1, len(fake_se.restored))
|
||||
self.assertEqual('/etc/hosts', fake_se.restored[0])
|
||||
|
||||
|
||||
class TestDeleteDirContents(MockerTestCase):
|
||||
|
Loading…
x
Reference in New Issue
Block a user