diff --git a/solar/core/resource/repository.py b/solar/core/resource/repository.py index 8cac3c29..d1475d38 100644 --- a/solar/core/resource/repository.py +++ b/solar/core/resource/repository.py @@ -14,15 +14,15 @@ # under the License. from collections import defaultdict +import tempfile +from enum import Enum import errno import os import semantic_version import shutil import yaml - -from enum import Enum from solar import utils @@ -63,6 +63,7 @@ class Repository(object): db_obj = None _REPOS_LOCATION = '/var/lib/solar/repositories' + _TMP_DIRNAME = '.tmp' def __init__(self, name): self.name = name @@ -146,15 +147,23 @@ class Repository(object): os.mkdir(self.fpath) return if not link_only: + if os.path.isdir(self.fpath): + raise RepositoryExists("Repository %s " + "already exists" % self.name) + + if not os.path.isdir(self.tmp_dir): + os.makedirs(self.tmp_dir) + + old_fpath = self.fpath + self.fpath = tempfile.mkdtemp(dir=self.tmp_dir) try: - os.mkdir(self.fpath) - except OSError as e: - if e.errno == errno.EEXIST: - raise RepositoryExists("Repository %s " - "already exists" % self.name) - else: - raise - self._add_contents(source) + self._add_contents(source) + os.rename(self.fpath, old_fpath) + except Exception as e: + shutil.rmtree(self.fpath) + raise + finally: + self.fpath = old_fpath else: try: os.symlink(source, self.fpath) @@ -244,7 +253,7 @@ class Repository(object): return if resource_name is None: - for single in os.listdir(self.fpath): + for single in self.list_repos(): for gen in _single(single): yield gen else: @@ -335,6 +344,10 @@ class Repository(object): spec = self._parse_spec(spec) return self._make_version_path(spec) + @property + def tmp_dir(self): + return os.path.join(self._REPOS_LOCATION, self._TMP_DIRNAME) + @classmethod def get_metadata(cls, spec): spec = cls._parse_spec(spec) @@ -365,10 +378,11 @@ class Repository(object): @classmethod def list_repos(cls): - return filter(lambda x: - os.path.isdir(os.path.join(cls._REPOS_LOCATION, - x)), - os.listdir(cls._REPOS_LOCATION)) + return filter( + lambda x: (os.path.isdir(os.path.join(cls._REPOS_LOCATION, x)) + and x != cls._TMP_DIRNAME), + os.listdir(cls._REPOS_LOCATION) + ) @classmethod def parse(cls, spec): diff --git a/solar/test/conftest.py b/solar/test/conftest.py index 7b8394da..cb8ea358 100644 --- a/solar/test/conftest.py +++ b/solar/test/conftest.py @@ -42,9 +42,8 @@ def resources(): @pytest.fixture(scope='session', autouse=True) def repos_path(tmpdir_factory): Repository._REPOS_LOCATION = str(tmpdir_factory.mktemp('repositories')) - path = Repository._REPOS_LOCATION repo = Repository('resources') - repo.create(path) + repo.create() def plan_from_fixture(name): diff --git a/solar/test/test_resource_repository.py b/solar/test/test_resource_repository.py index bf3e103c..39b14e15 100644 --- a/solar/test/test_resource_repository.py +++ b/solar/test/test_resource_repository.py @@ -14,9 +14,12 @@ # under the License. import itertools +import mock import os -import pytest import shutil + +import pytest + from solar.core.resource.repository import Repository from solar.core.resource.repository import RES_TYPE @@ -243,3 +246,36 @@ def test_create_empty(): repo = Repository('empty') repo.create() assert 'empty' in Repository.list_repos() + + +@mock.patch('solar.core.resource.repository.Repository._add_contents') +@mock.patch('tempfile.mkdtemp') +@mock.patch('shutil.rmtree') +def test_create_from_src_failed(mock_rmtree, mock_mkdtemp, mock_add_contents): + tmp_dir = '/tmp/dir' + mock_mkdtemp.return_value = tmp_dir + + mock_add_contents.side_effect = Exception() + repo = Repository('fail_create') + real_path = repo.fpath + with pytest.raises(Exception): + repo.create(source='source') + + mock_rmtree.assert_called_with(tmp_dir) + assert repo.fpath == real_path + + +@mock.patch('solar.core.resource.repository.Repository._add_contents') +@mock.patch('os.rename') +@mock.patch('tempfile.mkdtemp') +def test_create_from_src(mock_mkdtemp, mock_rename, _): + tmp_dir = '/tmp/dir' + mock_mkdtemp.return_value = tmp_dir + + repo = Repository('create_from_src') + real_path = repo.fpath + + repo.create(source='source') + + mock_rename.assert_called_with(tmp_dir, real_path) + assert repo.fpath == real_path