code reorg: create new utility method loadModel for db diffing
This commit is contained in:
parent
bc7f96cbb1
commit
1eadc362f6
@ -1,6 +1,7 @@
|
|||||||
from sqlalchemy import Table,Column,MetaData,String,Integer,create_engine
|
from sqlalchemy import Table,Column,MetaData,String,Integer,create_engine
|
||||||
from sqlalchemy import exceptions as sa_exceptions
|
from sqlalchemy import exceptions as sa_exceptions
|
||||||
from migrate.versioning.repository import Repository
|
from migrate.versioning.repository import Repository
|
||||||
|
from migrate.versioning.util import loadModel
|
||||||
from migrate.versioning.version import VerNum
|
from migrate.versioning.version import VerNum
|
||||||
from migrate.versioning import exceptions, genmodel, schemadiff
|
from migrate.versioning import exceptions, genmodel, schemadiff
|
||||||
|
|
||||||
@ -98,12 +99,7 @@ class ControlledSchema(object):
|
|||||||
|
|
||||||
if isinstance(repository, basestring):
|
if isinstance(repository, basestring):
|
||||||
repository=Repository(repository)
|
repository=Repository(repository)
|
||||||
if isinstance(model, basestring): # TODO: centralize this code?
|
model = loadModel(model)
|
||||||
# Assume model is of form "mod1.mod2.varname".
|
|
||||||
varname = model.split('.')[-1]
|
|
||||||
modules = '.'.join(model.split('.')[:-1])
|
|
||||||
module = __import__(modules, globals(), {}, ['dummy-not-used'], -1)
|
|
||||||
model = getattr(module, varname)
|
|
||||||
diff = schemadiff.getDiffOfModelAgainstDatabase(model, engine, excludeTables=[repository.version_table])
|
diff = schemadiff.getDiffOfModelAgainstDatabase(model, engine, excludeTables=[repository.version_table])
|
||||||
return diff
|
return diff
|
||||||
|
|
||||||
@ -122,12 +118,7 @@ class ControlledSchema(object):
|
|||||||
|
|
||||||
if isinstance(repository, basestring):
|
if isinstance(repository, basestring):
|
||||||
repository=Repository(repository)
|
repository=Repository(repository)
|
||||||
if isinstance(model, basestring): # TODO: centralize this code?
|
model = loadModel(model)
|
||||||
# Assume model is of form "mod1.mod2.varname".
|
|
||||||
varname = model.split('.')[-1]
|
|
||||||
modules = '.'.join(model.split('.')[:-1])
|
|
||||||
module = __import__(modules, globals(), {}, ['dummy-not-used'], -1)
|
|
||||||
model = getattr(module, varname)
|
|
||||||
diff = schemadiff.getDiffOfModelAgainstDatabase(model, engine, excludeTables=[repository.version_table])
|
diff = schemadiff.getDiffOfModelAgainstDatabase(model, engine, excludeTables=[repository.version_table])
|
||||||
return genmodel.ModelGenerator(diff).applyModel()
|
return genmodel.ModelGenerator(diff).applyModel()
|
||||||
|
|
||||||
|
@ -4,7 +4,7 @@ from migrate.versioning import exceptions, genmodel, schemadiff
|
|||||||
from migrate.versioning.base import operations
|
from migrate.versioning.base import operations
|
||||||
from migrate.versioning.template import template
|
from migrate.versioning.template import template
|
||||||
from migrate.versioning.script import base
|
from migrate.versioning.script import base
|
||||||
from migrate.versioning.util import import_path
|
from migrate.versioning.util import import_path, loadModel
|
||||||
|
|
||||||
class PythonScript(base.BaseScript):
|
class PythonScript(base.BaseScript):
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -28,12 +28,7 @@ class PythonScript(base.BaseScript):
|
|||||||
if isinstance(repository, basestring):
|
if isinstance(repository, basestring):
|
||||||
from migrate.versioning.repository import Repository # oh dear, an import cycle!
|
from migrate.versioning.repository import Repository # oh dear, an import cycle!
|
||||||
repository=Repository(repository)
|
repository=Repository(repository)
|
||||||
if isinstance(model, basestring): # TODO: centralize this code?
|
model = loadModel(model)
|
||||||
# Assume model is of form "mod1.mod2.varname".
|
|
||||||
varname = model.split('.')[-1]
|
|
||||||
modules = '.'.join(model.split('.')[:-1])
|
|
||||||
module = __import__(modules, globals(), {}, ['dummy-not-used'], -1)
|
|
||||||
model = getattr(module, varname)
|
|
||||||
diff = schemadiff.getDiffOfModelAgainstDatabase(model, engine, excludeTables=[repository.version_table])
|
diff = schemadiff.getDiffOfModelAgainstDatabase(model, engine, excludeTables=[repository.version_table])
|
||||||
upgradeDecls, upgradeCommands = genmodel.ModelGenerator(diff).toUpgradePython()
|
upgradeDecls, upgradeCommands = genmodel.ModelGenerator(diff).toUpgradePython()
|
||||||
#downgradeCommands = genmodel.ModelGenerator(diff).toDowngradePython()
|
#downgradeCommands = genmodel.ModelGenerator(diff).toDowngradePython()
|
||||||
|
@ -1,3 +1,14 @@
|
|||||||
from keyedinstance import KeyedInstance
|
from keyedinstance import KeyedInstance
|
||||||
from importpath import import_path
|
from importpath import import_path
|
||||||
|
|
||||||
|
def loadModel(model):
|
||||||
|
''' Import module and use module-level variable -- assume model is of form "mod1.mod2.varname". '''
|
||||||
|
if isinstance(model, basestring):
|
||||||
|
varname = model.split('.')[-1]
|
||||||
|
modules = '.'.join(model.split('.')[:-1])
|
||||||
|
module = __import__(modules, globals(), {}, ['dummy-not-used'], -1)
|
||||||
|
return getattr(module, varname)
|
||||||
|
else:
|
||||||
|
# Assume it's already loaded.
|
||||||
|
return model
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user