Issue 34; preview_sql now correctly displays SQL on python and SQL scripts. (tests added, docs still missing)

This commit is contained in:
iElectric 2009-06-06 10:34:22 +00:00
parent 4356e8b582
commit 938bbf9bf3
10 changed files with 225 additions and 132 deletions

View File

@ -1,7 +1,9 @@
0.5.4 0.5.4
- fixed preview_sql parameter for downgrade/upgrade. Now it prints SQL if the step is SQL script
and runs step with mocked engine to only print SQL statements if ORM is used. [Domen Kozar]
- use entrypoints terminology to specify dotted model names (module.model.User) [Domen Kozar] - use entrypoints terminology to specify dotted model names (module.model.User) [Domen Kozar]
- added engine_dict and engine_arg_* parameters to all api functions [Domen Kozar] - added engine_dict and engine_arg_* parameters to all api functions (deprecated echo) [Domen Kozar]
- make --echo parameter a bit more forgivable (better Python API support) [Domen Kozar] - make --echo parameter a bit more forgivable (better Python API support) [Domen Kozar]
- apply patch to refactor cmd line parsing for Issue 54 by Domen Kozar - apply patch to refactor cmd line parsing for Issue 54 by Domen Kozar
0.5.3 0.5.3

View File

@ -16,11 +16,10 @@ import sys
import inspect import inspect
import warnings import warnings
from sqlalchemy import create_engine
from migrate.versioning import (exceptions, repository, schema, version, from migrate.versioning import (exceptions, repository, schema, version,
script as script_) # command name conflict script as script_) # command name conflict
from migrate.versioning.util import asbool, catch_known_errors, guess_obj_type from migrate.versioning.util import catch_known_errors, construct_engine
__all__ = [ __all__ = [
'help', 'help',
@ -46,6 +45,7 @@ Repository = repository.Repository
ControlledSchema = schema.ControlledSchema ControlledSchema = schema.ControlledSchema
VerNum = version.VerNum VerNum = version.VerNum
PythonScript = script_.PythonScript PythonScript = script_.PythonScript
SqlScript = script_.SqlScript
# deprecated # deprecated
@ -117,7 +117,7 @@ def test(repository, url=None, **opts):
bad state. You should therefore better run the test on a copy of bad state. You should therefore better run the test on a copy of
your database. your database.
""" """
engine = _construct_engine(url, **opts) engine = construct_engine(url, **opts)
repos = Repository(repository) repos = Repository(repository)
script = repos.version(None).script() script = repos.version(None).script()
@ -179,7 +179,7 @@ def version_control(url, repository, version=None, **opts):
identical to what it would be if the database were created from identical to what it would be if the database were created from
scratch. scratch.
""" """
engine = _construct_engine(url, **opts) engine = construct_engine(url, **opts)
ControlledSchema.create(engine, repository, version) ControlledSchema.create(engine, repository, version)
@ -192,7 +192,7 @@ def db_version(url, repository, **opts):
The url should be any valid SQLAlchemy connection string. The url should be any valid SQLAlchemy connection string.
""" """
engine = _construct_engine(url, **opts) engine = construct_engine(url, **opts)
schema = ControlledSchema(engine, repository) schema = ControlledSchema(engine, repository)
return schema.version return schema.version
@ -236,7 +236,7 @@ def drop_version_control(url, repository, **opts):
Removes version control from a database. Removes version control from a database.
""" """
engine = _construct_engine(url, **opts) engine = construct_engine(url, **opts)
schema = ControlledSchema(engine, repository) schema = ControlledSchema(engine, repository)
schema.drop() schema.drop()
@ -268,7 +268,7 @@ def compare_model_to_db(url, model, repository, **opts):
NOTE: This is EXPERIMENTAL. NOTE: This is EXPERIMENTAL.
""" # TODO: get rid of EXPERIMENTAL label """ # TODO: get rid of EXPERIMENTAL label
engine = _construct_engine(url, **opts) engine = construct_engine(url, **opts)
print ControlledSchema.compare_model_to_db(engine, model, repository) print ControlledSchema.compare_model_to_db(engine, model, repository)
@ -279,7 +279,7 @@ def create_model(url, repository, **opts):
NOTE: This is EXPERIMENTAL. NOTE: This is EXPERIMENTAL.
""" # TODO: get rid of EXPERIMENTAL label """ # TODO: get rid of EXPERIMENTAL label
engine = _construct_engine(url, **opts) engine = construct_engine(url, **opts)
declarative = opts.get('declarative', False) declarative = opts.get('declarative', False)
print ControlledSchema.create_model(engine, repository, declarative) print ControlledSchema.create_model(engine, repository, declarative)
@ -294,7 +294,7 @@ def make_update_script_for_model(url, oldmodel, model, repository, **opts):
NOTE: This is EXPERIMENTAL. NOTE: This is EXPERIMENTAL.
""" # TODO: get rid of EXPERIMENTAL label """ # TODO: get rid of EXPERIMENTAL label
engine = _construct_engine(url, **opts) engine = construct_engine(url, **opts)
print PythonScript.make_update_script_for_model( print PythonScript.make_update_script_for_model(
engine, oldmodel, model, repository, **opts) engine, oldmodel, model, repository, **opts)
@ -308,30 +308,37 @@ def update_db_from_model(url, model, repository, **opts):
NOTE: This is EXPERIMENTAL. NOTE: This is EXPERIMENTAL.
""" # TODO: get rid of EXPERIMENTAL label """ # TODO: get rid of EXPERIMENTAL label
engine = _construct_engine(url, **opts) engine = construct_engine(url, **opts)
schema = ControlledSchema(engine, repository) schema = ControlledSchema(engine, repository)
schema.update_db_from_model(model) schema.update_db_from_model(model)
def _migrate(url, repository, version, upgrade, err, **opts): def _migrate(url, repository, version, upgrade, err, **opts):
engine = _construct_engine(url, **opts) engine = construct_engine(url, **opts)
schema = ControlledSchema(engine, repository) schema = ControlledSchema(engine, repository)
version = _migrate_version(schema, version, upgrade, err) version = _migrate_version(schema, version, upgrade, err)
changeset = schema.changeset(version) changeset = schema.changeset(version)
for ver, change in changeset: for ver, change in changeset:
nextver = ver + changeset.step nextver = ver + changeset.step
print '%s -> %s... ' % (ver, nextver), print '%s -> %s... ' % (ver, nextver)
if opts.get('preview_sql'): if opts.get('preview_sql'):
print if isinstance(change, PythonScript):
print change.log print change.preview_sql(url, changeset.step, **opts)
elif isinstance(change, SqlScript):
print change.source()
elif opts.get('preview_py'): elif opts.get('preview_py'):
source_ver = max(ver, nextver) source_ver = max(ver, nextver)
module = schema.repository.version(source_ver).script().module module = schema.repository.version(source_ver).script().module
funcname = upgrade and "upgrade" or "downgrade" funcname = upgrade and "upgrade" or "downgrade"
func = getattr(module, funcname) func = getattr(module, funcname)
print if isinstance(change, PythonScript):
print inspect.getsource(module.upgrade) print inspect.getsource(func)
else:
raise UsageError("Python source can be only displayed"
" for python migration files")
else: else:
schema.runchange(ver, change, changeset.step) schema.runchange(ver, change, changeset.step)
print 'done' print 'done'
@ -352,39 +359,3 @@ def _migrate_version(schema, version, upgrade, err):
if not direction: if not direction:
raise exceptions.KnownError(err % (cur, version)) raise exceptions.KnownError(err % (cur, version))
return version return version
def _construct_engine(url, **opts):
"""Constructs and returns SQLAlchemy engine.
Currently, there are 2 ways to pass create_engine options to api functions:
* keyword parameters (starting with `engine_arg_*`)
* python dictionary of options (`engine_dict`)
NOTE: keyword parameters override `engine_dict` values.
.. versionadded:: 0.5.4
"""
# TODO: include docs
# get options for create_engine
if opts.get('engine_dict') and isinstance(opts['engine_dict'], dict):
kwargs = opts['engine_dict']
else:
kwargs = dict()
# DEPRECATED: handle echo the old way
echo = asbool(opts.get('echo', False))
if echo:
warnings.warn('echo=True parameter is deprecated, pass '
'engine_arg_echo=True or engine_dict={"echo": True}',
DeprecationWarning)
kwargs['echo'] = echo
# parse keyword arguments
for key, value in opts.iteritems():
if key.startswith('engine_arg_'):
kwargs[key[11:]] = guess_obj_type(value)
return create_engine(url, **kwargs)

View File

@ -1,6 +1,9 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from migrate.versioning.base import log,operations from migrate.versioning.base import log,operations
from migrate.versioning import pathed,exceptions from migrate.versioning import pathed,exceptions
# import migrate.run
class BaseScript(pathed.Pathed): class BaseScript(pathed.Pathed):
"""Base class for other types of scripts """Base class for other types of scripts
@ -17,10 +20,10 @@ class BaseScript(pathed.Pathed):
""" """
def __init__(self,path): def __init__(self,path):
log.info('Loading script %s...'%path) log.info('Loading script %s...' % path)
self.verify(path) self.verify(path)
super(BaseScript,self).__init__(path) super(BaseScript, self).__init__(path)
log.info('Script %s loaded successfully'%path) log.info('Script %s loaded successfully' % path)
@classmethod @classmethod
def verify(cls,path): def verify(cls,path):
@ -33,10 +36,10 @@ class BaseScript(pathed.Pathed):
raise exceptions.InvalidScriptError(path) raise exceptions.InvalidScriptError(path)
def source(self): def source(self):
fd=open(self.path) fd = open(self.path)
ret=fd.read() ret = fd.read()
fd.close() fd.close()
return ret return ret
def run(self,engine): def run(self, engine):
raise NotImplementedError() raise NotImplementedError()

View File

@ -1,15 +1,20 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import shutil import shutil
#import migrate.run from StringIO import StringIO
import migrate
from migrate.versioning import exceptions, genmodel, schemadiff from migrate.versioning import exceptions, genmodel, schemadiff
from migrate.versioning.base import operations from migrate.versioning.base import operations
from migrate.versioning.template import template from migrate.versioning.template import template
from migrate.versioning.script import base from migrate.versioning.script import base
from migrate.versioning.util import import_path, loadModel from migrate.versioning.util import import_path, loadModel, construct_engine
import migrate
class PythonScript(base.BaseScript): class PythonScript(base.BaseScript):
@classmethod @classmethod
def create(cls,path,**opts): def create(cls, path, **opts):
"""Create an empty migration script""" """Create an empty migration script"""
cls.require_notfound(path) cls.require_notfound(path)
@ -18,29 +23,38 @@ class PythonScript(base.BaseScript):
# different one later. # different one later.
template_file = None template_file = None
src = template.get_script(template_file) src = template.get_script(template_file)
shutil.copy(src,path) shutil.copy(src, path)
@classmethod @classmethod
def make_update_script_for_model(cls,engine,oldmodel,model,repository,**opts): def make_update_script_for_model(cls, engine, oldmodel,
model, repository, **opts):
"""Create a migration script""" """Create a migration script"""
# Compute differences. # Compute differences.
if isinstance(repository, basestring): if isinstance(repository, basestring):
from migrate.versioning.repository import Repository # oh dear, an import cycle! # oh dear, an import cycle!
repository=Repository(repository) from migrate.versioning.repository import Repository
repository = Repository(repository)
oldmodel = loadModel(oldmodel) oldmodel = loadModel(oldmodel)
model = loadModel(model) model = loadModel(model)
diff = schemadiff.getDiffOfModelAgainstModel(oldmodel, model, engine, excludeTables=[repository.version_table]) diff = schemadiff.getDiffOfModelAgainstModel(
decls, upgradeCommands, downgradeCommands = genmodel.ModelGenerator(diff).toUpgradeDowngradePython() oldmodel,
model,
engine,
excludeTables=[repository.version_table])
decls, upgradeCommands, downgradeCommands = \
genmodel.ModelGenerator(diff).toUpgradeDowngradePython()
# Store differences into file. # Store differences into file.
template_file = None template_file = None
src = template.get_script(template_file) src = template.get_script(template_file)
contents = open(src).read() contents = open(src).read()
search = 'def upgrade():' search = 'def upgrade():'
contents = contents.replace(search, decls + '\n\n' + search, 1) contents = contents.replace(search, '\n\n'.join((decls, search)), 1)
if upgradeCommands: contents = contents.replace(' pass', upgradeCommands, 1) if upgradeCommands:
if downgradeCommands: contents = contents.replace(' pass', downgradeCommands, 1) contents = contents.replace(' pass', upgradeCommands, 1)
if downgradeCommands:
contents = contents.replace(' pass', downgradeCommands, 1)
return contents return contents
@classmethod @classmethod
@ -54,31 +68,31 @@ class PythonScript(base.BaseScript):
raise raise
try: try:
assert callable(module.upgrade) assert callable(module.upgrade)
except Exception,e: except Exception, e:
raise exceptions.InvalidScriptError(path+': %s'%str(e)) raise exceptions.InvalidScriptError(path + ': %s' % str(e))
return module return module
def _get_module(self): def preview_sql(self, url, step, **args):
if not hasattr(self,'_module'): """Mock engine to store all executable calls in a string \
self._module = self.verify_module(self.path) and execute the step"""
return self._module buf = StringIO()
module = property(_get_module) args['engine_arg_strategy'] = 'mock'
args['engine_arg_executor'] = lambda s, p='': buf.write(s + p)
engine = construct_engine(url, **args)
self.run(engine, step)
def _func(self,funcname): return buf.getvalue()
fn = getattr(self.module, funcname, None)
if not fn:
msg = "The function %s is not defined in this script"
raise exceptions.ScriptError(msg%funcname)
return fn
def run(self,engine,step): def run(self, engine, step):
"""Core method of Script file. \
Exectues update() or downgrade() function"""
if step > 0: if step > 0:
op = 'upgrade' op = 'upgrade'
elif step < 0: elif step < 0:
op = 'downgrade' op = 'downgrade'
else: else:
raise exceptions.ScriptError("%d is not a valid step"%step) raise exceptions.ScriptError("%d is not a valid step" % step)
funcname = base.operations[op] funcname = base.operations[op]
migrate.migrate_engine = engine migrate.migrate_engine = engine
@ -87,3 +101,17 @@ class PythonScript(base.BaseScript):
func() func()
migrate.migrate_engine = None migrate.migrate_engine = None
#migrate.run.migrate_engine = migrate.migrate_engine = None #migrate.run.migrate_engine = migrate.migrate_engine = None
def _get_module(self):
if not hasattr(self,'_module'):
self._module = self.verify_module(self.path)
return self._module
module = property(_get_module)
def _func(self, funcname):
fn = getattr(self.module, funcname, None)
if not fn:
msg = "The function %s is not defined in this script"
raise exceptions.ScriptError(msg%funcname)
return fn

View File

@ -5,6 +5,8 @@ import warnings
from decorator import decorator from decorator import decorator
from pkg_resources import EntryPoint from pkg_resources import EntryPoint
from sqlalchemy import create_engine
from migrate.versioning import exceptions from migrate.versioning import exceptions
from migrate.versioning.util.keyedinstance import KeyedInstance from migrate.versioning.util.keyedinstance import KeyedInstance
from migrate.versioning.util.importpath import import_path from migrate.versioning.util.importpath import import_path
@ -33,7 +35,10 @@ def asbool(obj):
return False return False
else: else:
raise ValueError("String is not true/false: %r" % obj) raise ValueError("String is not true/false: %r" % obj)
return bool(obj) if obj in (True, False):
return bool(obj)
else:
raise ValueError("String is not true/false: %r" % obj)
def guess_obj_type(obj): def guess_obj_type(obj):
"""Do everything to guess object type from string""" """Do everything to guess object type from string"""
@ -63,3 +68,38 @@ def catch_known_errors(f, *a, **kw):
f(*a, **kw) f(*a, **kw)
except exceptions.PathFoundError, e: except exceptions.PathFoundError, e:
raise exceptions.KnownError("The path %s already exists" % e.args[0]) raise exceptions.KnownError("The path %s already exists" % e.args[0])
def construct_engine(url, **opts):
"""Constructs and returns SQLAlchemy engine.
Currently, there are 2 ways to pass create_engine options to api functions:
* keyword parameters (starting with `engine_arg_*`)
* python dictionary of options (`engine_dict`)
NOTE: keyword parameters override `engine_dict` values.
.. versionadded:: 0.5.4
"""
# TODO: include docs
# get options for create_engine
if opts.get('engine_dict') and isinstance(opts['engine_dict'], dict):
kwargs = opts['engine_dict']
else:
kwargs = dict()
# DEPRECATED: handle echo the old way
echo = asbool(opts.get('echo', False))
if echo:
warnings.warn('echo=True parameter is deprecated, pass '
'engine_arg_echo=True or engine_dict={"echo": True}',
DeprecationWarning)
kwargs['echo'] = echo
# parse keyword arguments
for key, value in opts.iteritems():
if key.startswith('engine_arg_'):
kwargs[key[11:]] = guess_obj_type(value)
return create_engine(url, **kwargs)

View File

@ -21,7 +21,7 @@ class Base(unittest.TestCase):
def createLines(s): def createLines(s):
s = s.replace(' ', '') s = s.replace(' ', '')
lines = s.split('\n') lines = s.split('\n')
return [line for line in lines if line] return filter(None, lines)
lines1 = createLines(v1) lines1 = createLines(v1)
lines2 = createLines(v2) lines2 = createLines(v2)
self.assertEquals(len(lines1), len(lines2)) self.assertEquals(len(lines1), len(lines2))

View File

@ -2,6 +2,7 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import os import os
from decorator import decorator
from sqlalchemy import create_engine, Table, MetaData from sqlalchemy import create_engine, Table, MetaData
from sqlalchemy.orm import create_session from sqlalchemy.orm import create_session
@ -74,6 +75,7 @@ def usedb(supported=None, not_supported=None):
yield func, self yield func, self
self._teardown() self._teardown()
entangle.__name__ = func.__name__ entangle.__name__ = func.__name__
entangle.__doc__ = func.__doc__
return entangle return entangle
return dec return dec

View File

@ -1,49 +1,53 @@
import os,shutil,tempfile #!/usr/bin/env python
import base # -*- coding: utf-8 -*-
import os
import shutil
import tempfile
from test.fixture import base
class Pathed(base.Base): class Pathed(base.Base):
# Temporary files # Temporary files
#repos='/tmp/test_repos_091x10'
#config=repos+'/migrate.cfg'
#script='/tmp/test_migration_script.py'
_tmpdir=tempfile.mkdtemp() _tmpdir = tempfile.mkdtemp()
@classmethod @classmethod
def _tmp(cls,prefix='',suffix=''): def _tmp(cls, prefix='', suffix=''):
"""Generate a temporary file name that doesn't exist """Generate a temporary file name that doesn't exist
All filenames are generated inside a temporary directory created by All filenames are generated inside a temporary directory created by
tempfile.mkdtemp(); only the creating user has access to this directory. tempfile.mkdtemp(); only the creating user has access to this directory.
It should be secure to return a nonexistant temp filename in this It should be secure to return a nonexistant temp filename in this
directory, unless the user is messing with their own files. directory, unless the user is messing with their own files.
""" """
file,ret = tempfile.mkstemp(suffix,prefix,cls._tmpdir) file, ret = tempfile.mkstemp(suffix,prefix,cls._tmpdir)
os.close(file) os.close(file)
os.remove(ret) os.remove(ret)
return ret return ret
@classmethod @classmethod
def tmp(cls,*p,**k): def tmp(cls, *p, **k):
return cls._tmp(*p,**k) return cls._tmp(*p, **k)
@classmethod @classmethod
def tmp_py(cls,*p,**k): def tmp_py(cls, *p, **k):
return cls._tmp(suffix='.py',*p,**k) return cls._tmp(suffix='.py', *p, **k)
@classmethod @classmethod
def tmp_sql(cls,*p,**k): def tmp_sql(cls, *p, **k):
return cls._tmp(suffix='.sql',*p,**k) return cls._tmp(suffix='.sql', *p, **k)
@classmethod @classmethod
def tmp_named(cls,name): def tmp_named(cls, name):
return os.path.join(cls._tmpdir,name) return os.path.join(cls._tmpdir, name)
@classmethod @classmethod
def tmp_repos(cls,*p,**k): def tmp_repos(cls, *p, **k):
return cls._tmp(*p,**k) return cls._tmp(*p, **k)
@classmethod @classmethod
def purge(cls,path): def purge(cls, path):
"""Removes this path if it exists, in preparation for tests """Removes this path if it exists, in preparation for tests
Careful - all tests should take place in /tmp. Careful - all tests should take place in /tmp.
We don't want to accidentally wipe stuff out... We don't want to accidentally wipe stuff out...
@ -54,6 +58,6 @@ class Pathed(base.Base):
else: else:
os.remove(path) os.remove(path)
if path.endswith('.py'): if path.endswith('.py'):
pyc = path+'c' pyc = path + 'c'
if os.path.exists(pyc): if os.path.exists(pyc):
os.remove(pyc) os.remove(pyc)

View File

@ -1,13 +1,19 @@
from test import fixture #!/usr/bin/env python
# -*- coding: utf-8 -*-
import os
import shutil
from migrate.versioning.script import * from migrate.versioning.script import *
from migrate.versioning import exceptions, version from migrate.versioning import exceptions, version
import os,shutil from test import fixture
class TestPyScript(fixture.Pathed): class TestPyScript(fixture.Pathed):
cls = PythonScript cls = PythonScript
def test_create(self): def test_create(self):
"""We can create a migration script""" """We can create a migration script"""
path=self.tmp_py() path = self.tmp_py()
# Creating a file that doesn't exist should succeed # Creating a file that doesn't exist should succeed
self.cls.create(path) self.cls.create(path)
self.assert_(os.path.exists(path)) self.assert_(os.path.exists(path))
@ -18,8 +24,8 @@ class TestPyScript(fixture.Pathed):
def test_verify_notfound(self): def test_verify_notfound(self):
"""Correctly verify a python migration script: nonexistant file""" """Correctly verify a python migration script: nonexistant file"""
path=self.tmp_py() path = self.tmp_py()
self.assert_(not os.path.exists(path)) self.assertFalse(os.path.exists(path))
# Fails on empty path # Fails on empty path
self.assertRaises(exceptions.InvalidScriptError,self.cls.verify,path) self.assertRaises(exceptions.InvalidScriptError,self.cls.verify,path)
self.assertRaises(exceptions.InvalidScriptError,self.cls,path) self.assertRaises(exceptions.InvalidScriptError,self.cls,path)
@ -38,19 +44,52 @@ class TestPyScript(fixture.Pathed):
def test_verify_nofuncs(self): def test_verify_nofuncs(self):
"""Correctly verify a python migration script: valid python file; no upgrade func""" """Correctly verify a python migration script: valid python file; no upgrade func"""
path=self.tmp_py() path = self.tmp_py()
# Create empty file # Create empty file
f=open(path,'w') f = open(path, 'w')
f.write("def zergling():\n\tprint 'rush'") f.write("def zergling():\n\tprint 'rush'")
f.close() f.close()
self.assertRaises(exceptions.InvalidScriptError,self.cls.verify_module,path) self.assertRaises(exceptions.InvalidScriptError, self.cls.verify_module, path)
# script isn't verified on creation, but on module reference # script isn't verified on creation, but on module reference
py = self.cls(path) py = self.cls(path)
self.assertRaises(exceptions.InvalidScriptError,(lambda x: x.module),py) self.assertRaises(exceptions.InvalidScriptError,(lambda x: x.module),py)
@fixture.usedb(supported='sqlite')
def test_preview_sql(self):
"""Preview SQL abstract from ORM layer (sqlite)"""
path = self.tmp_py()
f = open(path, 'w')
content = """
from migrate import *
from sqlalchemy import *
metadata = MetaData(migrate_engine)
UserGroup = Table('Link', metadata,
Column('link1ID', Integer),
Column('link2ID', Integer),
UniqueConstraint('link1ID', 'link2ID'))
def upgrade():
metadata.create_all()
"""
f.write(content)
f.close()
pyscript = self.cls(path)
SQL = pyscript.preview_sql(self.url, 1)
self.assertEqualsIgnoreWhitespace("""
CREATE TABLE "Link"
("link1ID" INTEGER,
"link2ID" INTEGER,
UNIQUE ("link1ID", "link2ID"))
""", SQL)
# TODO: test: No SQL should be executed!
def test_verify_success(self): def test_verify_success(self):
"""Correctly verify a python migration script: success""" """Correctly verify a python migration script: success"""
path=self.tmp_py() path = self.tmp_py()
# Succeeds after creating # Succeeds after creating
self.cls.create(path) self.cls.create(path)
self.cls.verify(path) self.cls.verify(path)
@ -66,8 +105,8 @@ class TestSqlScript(fixture.Pathed):
# Create files -- files must be present or you'll get an exception later. # Create files -- files must be present or you'll get an exception later.
sqlite_upgrade_file = '001_sqlite_upgrade.sql' sqlite_upgrade_file = '001_sqlite_upgrade.sql'
default_upgrade_file = '001_default_upgrade.sql' default_upgrade_file = '001_default_upgrade.sql'
for file in [sqlite_upgrade_file, default_upgrade_file]: for file_ in [sqlite_upgrade_file, default_upgrade_file]:
filepath = '%s/%s' % (path, file) filepath = '%s/%s' % (path, file_)
open(filepath, 'w').close() open(filepath, 'w').close()
ver = version.Version(1, path, [sqlite_upgrade_file]) ver = version.Version(1, path, [sqlite_upgrade_file])

View File

@ -163,21 +163,21 @@ class TestShellCommands(Shell):
"""Construct engine the smart way""" """Construct engine the smart way"""
url = 'sqlite://' url = 'sqlite://'
engine = api._construct_engine(url) engine = api.construct_engine(url)
self.assert_(engine.name == 'sqlite') self.assert_(engine.name == 'sqlite')
# keyword arg # keyword arg
engine = api._construct_engine(url, engine_arg_assert_unicode=True) engine = api.construct_engine(url, engine_arg_assert_unicode=True)
self.assert_(engine.dialect.assert_unicode) self.assertTrue(engine.dialect.assert_unicode)
# dict # dict
engine = api._construct_engine(url, engine_dict={'assert_unicode': True}) engine = api.construct_engine(url, engine_dict={'assert_unicode': True})
self.assert_(engine.dialect.assert_unicode) self.assertTrue(engine.dialect.assert_unicode)
# test precedance # test precedance
engine = api._construct_engine(url, engine_dict={'assert_unicode': False}, engine = api.construct_engine(url, engine_dict={'assert_unicode': False},
engine_arg_assert_unicode=True) engine_arg_assert_unicode=True)
self.assert_(engine.dialect.assert_unicode) self.assertTrue(engine.dialect.assert_unicode)
def test_manage(self): def test_manage(self):
"""Create a project management script""" """Create a project management script"""
@ -327,8 +327,12 @@ class TestShellDatabase(Shell, fixture.DB):
# Add a script to the repository; upgrade the db # Add a script to the repository; upgrade the db
self.assertSuccess(self.cmd('script', '--repository=%s' % repos_path, 'Desc')) self.assertSuccess(self.cmd('script', '--repository=%s' % repos_path, 'Desc'))
self.assertEquals(self.cmd_version(repos_path), 1) self.assertEquals(self.cmd_version(repos_path), 1)
self.assertEquals(self.cmd_db_version(self.url, repos_path), 0) self.assertEquals(self.cmd_db_version(self.url, repos_path), 0)
# Test preview
self.assertSuccess(self.cmd('upgrade', self.url, repos_path, 0, "--preview_sql"))
self.assertSuccess(self.cmd('upgrade', self.url, repos_path, 0, "--preview_py"))
self.assertSuccess(self.cmd('upgrade', self.url, repos_path)) self.assertSuccess(self.cmd('upgrade', self.url, repos_path))
self.assertEquals(self.cmd_db_version(self.url, repos_path), 1) self.assertEquals(self.cmd_db_version(self.url, repos_path), 1)