diff --git a/migrate/versioning/api.py b/migrate/versioning/api.py index cd5467a..32ed33f 100644 --- a/migrate/versioning/api.py +++ b/migrate/versioning/api.py @@ -11,7 +11,7 @@ __all__=[ 'help', 'create', 'script', -'script_python_changes', +'make_update_script_for_model', 'commit', 'version', 'source', @@ -22,8 +22,9 @@ __all__=[ 'drop_version_control', 'manage', 'test', -'compare_db', -'db_schema_dump', +'compare_model_to_db', +'create_model', +'update_db_from_model', ] cls_repository = repository.Repository @@ -280,28 +281,28 @@ def manage(file,**opts): """ return repository.manage(file,**opts) -def compare_db(url,model,repository,**opts): - """%prog compare_db URL MODEL REPOSITORY_PATH +def compare_model_to_db(url,model,repository,**opts): + """%prog compare_model_to_db URL MODEL REPOSITORY_PATH Compare the current model (assumed to be a module level variable of type sqlalchemy.MetaData) against the current database. NOTE: This is EXPERIMENTAL. """ # TODO: get rid of EXPERIMENTAL label engine=create_engine(url) - print cls_schema.compare_db(engine,model,repository) + print cls_schema.compare_model_to_db(engine,model,repository) -def db_schema_dump(url,repository,**opts): - """%prog db_schema_dump URL REPOSITORY_PATH +def create_model(url,repository,**opts): + """%prog create_model URL REPOSITORY_PATH Dump the current database as a Python model to stdout. NOTE: This is EXPERIMENTAL. """ # TODO: get rid of EXPERIMENTAL label engine=create_engine(url) - print cls_schema.db_schema_dump(engine,repository) + print cls_schema.create_model(engine,repository) -def script_python_changes(path,url,model,repository,**opts): - """%prog script_python_changes PATH URL MODEL REPOSITORY_PATH +def make_update_script_for_model(path,url,model,repository,**opts): + """%prog make_update_script_for_model PATH URL MODEL REPOSITORY_PATH Create a script changing the current (old) database to the current (new) Python model. @@ -309,8 +310,18 @@ def script_python_changes(path,url,model,repository,**opts): """ # TODO: get rid of EXPERIMENTAL label engine=create_engine(url) try: - cls_script_python.script_python_changes(path,engine,model,repository,**opts) + cls_script_python.make_update_script_for_model(path,engine,model,repository,**opts) except exceptions.PathFoundError,e: raise exceptions.KnownError("The path %s already exists"%e.args[0]) +def update_db_from_model(url,model,repository,**opts): + """%prog update_db_from_model URL MODEL REPOSITORY_PATH + + Modify the database to match the structure of the current Python model. + + NOTE: This is EXPERIMENTAL. + """ # TODO: get rid of EXPERIMENTAL label + engine=create_engine(url) + cls_schema.update_db_from_model(engine,model,repository) + diff --git a/migrate/versioning/genmodel.py b/migrate/versioning/genmodel.py index 9274c23..bad96ff 100644 --- a/migrate/versioning/genmodel.py +++ b/migrate/versioning/genmodel.py @@ -3,7 +3,7 @@ # Some of this is borrowed heavily from the AutoCode project at: http://code.google.com/p/sqlautocode/ import sys -import sqlalchemy +import migrate, sqlalchemy HEADER = """ @@ -28,11 +28,15 @@ class ModelGenerator(object): kwarg.append('primary_key') if not col.nullable: kwarg.append('nullable') if col.onupdate: kwarg.append('onupdate') - if col.default: kwarg.append('default') + if col.default: + if col.primary_key: + # I found that Postgres automatically creates a default value for the sequence, but let's not show that. + pass + else: + kwarg.append('default') ks = ', '.join('%s=%r' % (k, getattr(col, k)) for k in kwarg ) name = col.name.encode('utf8') # crs: not sure if this is good idea, but it gets rid of extra u'' - #type = self.colTypeMappings[col.type.__class__]() type = self.colTypeMappings.get(col.type.__class__, None) if type: # Make the column type be an instance of this type. @@ -96,3 +100,60 @@ class ModelGenerator(object): def toDowngradePython(self, indent=' '): return ' pass #TODO DOWNGRADE' + def applyModel(self): + ''' Apply model to current database. ''' + + # Yuck! We have to import from changeset to apply the monkey-patch to allow column adding/dropping. + from migrate.changeset import schema + + def dbCanHandleThisChange(missingInDatabase, missingInModel, diffDecl): + if missingInDatabase and not missingInModel and not diffDecl: + # Even sqlite can handle this. + return True + else: + return not self.diff.conn.url.drivername.startswith('sqlite') + + meta = sqlalchemy.MetaData(self.diff.conn.engine) + + for table in self.diff.tablesMissingInModel: + table = table.tometadata(meta) + table.drop() + for table in self.diff.tablesMissingInDatabase: + table = table.tometadata(meta) + table.create() + for modelTable in self.diff.tablesWithDiff: + modelTable = modelTable.tometadata(meta) + dbTable = self.diff.reflected_model.tables[modelTable.name] + #print 'TODO DEBUG.cols1', [x.name for x in dbTable.columns] + #dbTable = dbTable.tometadata(meta) + #print 'TODO DEBUG.cols2', [x.name for x in dbTable.columns] + tableName = modelTable.name + missingInDatabase, missingInModel, diffDecl = self.diff.colDiffs[tableName] + if dbCanHandleThisChange(missingInDatabase, missingInModel, diffDecl): + for col in missingInDatabase: + modelTable.columns[col.name].create() + for col in missingInModel: + dbTable.columns[col.name].drop() + for modelCol, databaseCol, modelDecl, databaseDecl in diffDecl: + dbTable.columns[databaseCol.name].drop() + modelTable.columns[modelCol.name].create() + else: + # Sqlite doesn't support drop column, so you have to do more: + # create temp table, copy data to it, drop old table, create new table, copy data back. + + tempName = '_temp_%s' % modelTable.name # I wonder if this is guaranteed to be unique? + def getCopyStatement(): + preparer = self.diff.conn.engine.dialect.preparer + commonCols = [] + for modelCol in modelTable.columns: + if dbTable.columns.has_key(modelCol.name): + commonCols.append(modelCol.name) + commonColsStr = ', '.join(commonCols) + return 'INSERT INTO %s (%s) SELECT %s FROM %s' % (tableName, commonColsStr, commonColsStr, tempName) + + self.diff.conn.execute('CREATE TEMPORARY TABLE %s as SELECT * from %s' % (tempName, modelTable.name)) + modelTable.drop() + modelTable.create() + self.diff.conn.execute(getCopyStatement()) + self.diff.conn.execute('DROP TABLE %s' % tempName) + diff --git a/migrate/versioning/schema.py b/migrate/versioning/schema.py index 4cb0976..df94efa 100644 --- a/migrate/versioning/schema.py +++ b/migrate/versioning/schema.py @@ -93,7 +93,7 @@ class ControlledSchema(object): return table @classmethod - def compare_db(cls,engine,model,repository): + def compare_model_to_db(cls,engine,model,repository): """Compare the current model against the current database.""" if isinstance(repository, basestring): @@ -108,13 +108,28 @@ class ControlledSchema(object): return diff @classmethod - def db_schema_dump(cls,engine,repository): + def create_model(cls,engine,repository): """Dump the current database as a Python model.""" if isinstance(repository, basestring): repository=Repository(repository) diff = schemadiff.getDiffOfModelAgainstDatabase(MetaData(), engine, excludeTables=[repository.version_table]) return genmodel.ModelGenerator(diff).toPython() + + @classmethod + def update_db_from_model(cls,engine,model,repository): + """Modify the database to match the structure of the current Python model.""" + + if isinstance(repository, basestring): + repository=Repository(repository) + if isinstance(model, basestring): # TODO: centralize this code? + # Assume model is of form "mod1.mod2.varname". + varname = model.split('.')[-1] + modules = '.'.join(model.split('.')[:-1]) + module = __import__(modules, globals(), {}, ['dummy-not-used'], -1) + model = getattr(module, varname) + diff = schemadiff.getDiffOfModelAgainstDatabase(model, engine, excludeTables=[repository.version_table]) + return genmodel.ModelGenerator(diff).applyModel() def drop(self): """Remove version control from a database""" diff --git a/migrate/versioning/schemadiff.py b/migrate/versioning/schemadiff.py index 6b37412..1005986 100644 --- a/migrate/versioning/schemadiff.py +++ b/migrate/versioning/schemadiff.py @@ -71,7 +71,7 @@ class SchemaDiff(object): # Types and nullable are the same. pass else: - self.storeColumnDiff(modelTable, modelDecl, databaseDecl) + self.storeColumnDiff(modelTable, modelCol, databaseCol, modelDecl, databaseDecl) else: self.storeColumnMissingInModel(modelTable, databaseCol) else: @@ -126,9 +126,9 @@ class SchemaDiff(object): missingInDatabase, missingInModel, diffDecl = self.colDiffs.setdefault(table.name, ([], [], [])) missingInModel.append(col) - def storeColumnDiff(self, table, modelDecl, databaseDecl): + def storeColumnDiff(self, table, modelCol, databaseCol, modelDecl, databaseDecl): if table not in self.tablesWithDiff: self.tablesWithDiff.append(table) missingInDatabase, missingInModel, diffDecl = self.colDiffs.setdefault(table.name, ([], [], [])) - diffDecl.append( (modelDecl, databaseDecl) ) + diffDecl.append( (modelCol, databaseCol, modelDecl, databaseDecl) ) diff --git a/migrate/versioning/script/py.py b/migrate/versioning/script/py.py index f373131..6103827 100644 --- a/migrate/versioning/script/py.py +++ b/migrate/versioning/script/py.py @@ -20,7 +20,7 @@ class PythonScript(base.BaseScript): shutil.copy(src,path) @classmethod - def script_python_changes(cls,path,engine,model,repository,**opts): + def make_update_script_for_model(cls,path,engine,model,repository,**opts): """Create a migration script""" cls.require_notfound(path) diff --git a/test/versioning/test_schemadiff.py b/test/versioning/test_schemadiff.py index 7ef864f..53a85df 100644 --- a/test/versioning/test_schemadiff.py +++ b/test/versioning/test_schemadiff.py @@ -1,3 +1,4 @@ +import os import sqlalchemy from sqlalchemy import * from test import fixture @@ -26,20 +27,30 @@ class TestSchemaDiff(fixture.DB): self.meta = MetaData(self.engine) self.table = Table(self.table_name,self.meta, Column('id',Integer(),primary_key=True), + Column('name',UnicodeText()), Column('data',UnicodeText()), ) if self.table.exists(): self.table.drop() - #self.engine.echo = True + WANT_ENGINE_ECHO = os.environ.get('WANT_ENGINE_ECHO', 'F') # to get debugging: set this to T and run py.test with --pdb + if WANT_ENGINE_ECHO == 'T': + self.engine.echo = True def tearDown(self): if self.table.exists(): self.table.drop() fixture.DB.tearDown(self) + def _applyLatestModel(self): + diff = schemadiff.getDiffOfModelAgainstDatabase(self.meta, self.engine) + genmodel.ModelGenerator(diff).applyModel() + @fixture.usedb() def test_rundiffs(self): + # Yuck! We have to import from changeset to apply the monkey-patch to allow column adding/dropping. + from migrate.changeset import schema + def assertDiff(isDiff, tablesMissingInDatabase, tablesMissingInModel, tablesWithDiff): diff = schemadiff.getDiffOfModelAgainstDatabase(self.meta, self.engine) self.assertEquals(bool(diff), isDiff) @@ -56,13 +67,14 @@ class TestSchemaDiff(fixture.DB): meta = MetaData(migrate_engine) tmp_schemadiff = Table('tmp_schemadiff',meta, Column('id',Integer(),primary_key=True,nullable=False), + Column('name',UnicodeText(length=None)), Column('data',UnicodeText(length=None)), ) ''') self.assertEqualsIgnoreWhitespace(commands, '''tmp_schemadiff.create()''') # Create table in database, now model should match database. - self.table.create() + self._applyLatestModel() assertDiff(False, [], [], []) # Check Python code gen from database. @@ -72,47 +84,62 @@ class TestSchemaDiff(fixture.DB): self.assertEqualsIgnoreWhitespace(src, ''' tmp_schemadiff = Table('tmp_schemadiff',meta, Column('id',Integer(),primary_key=True,nullable=False), + Column('name',Text(length=None,convert_unicode=False,assert_unicode=None)), Column('data',Text(length=None,convert_unicode=False,assert_unicode=None)), ) ''') - # Modify table in model (by removing it and adding it back to model). + # Add data, later we'll make sure it's still present. + result = self.engine.execute(self.table.insert(), id=1, name=u'mydata') + dataId = result.last_inserted_ids()[0] + + # Modify table in model (by removing it and adding it back to model) -- drop column data and add column data2. self.meta.remove(self.table) self.table = Table(self.table_name,self.meta, Column('id',Integer(),primary_key=True), + Column('name',UnicodeText(length=None)), Column('data2',UnicodeText(),nullable=True), ) assertDiff(True, [], [], [self.table_name]) # Apply latest model changes and find no more diffs. - self.table.drop() - self.table.create() + self._applyLatestModel() assertDiff(False, [], [], []) + # Make sure data is still present. + result = self.engine.execute(self.table.select(), id=dataId) + rows = result.fetchall() + self.assertEquals(len(rows), 1) + self.assertEquals(rows[0].name, 'mydata') + # Change column type in model. self.meta.remove(self.table) self.table = Table(self.table_name,self.meta, Column('id',Integer(),primary_key=True), + Column('name',UnicodeText(length=None)), Column('data2',Integer(),nullable=True), ) assertDiff(True, [], [], [self.table_name]) # TODO test type diff # Apply latest model changes and find no more diffs. - self.table.drop() - self.table.create() + self._applyLatestModel() assertDiff(False, [], [], []) + # Delete data, since we're about to make a required column. + # Not even using sqlalchemy.PassiveDefault helps because we're doing explicit column select. + self.engine.execute(self.table.delete(), id=dataId) + # Change column nullable in model. self.meta.remove(self.table) self.table = Table(self.table_name,self.meta, Column('id',Integer(),primary_key=True), + Column('name',UnicodeText(length=None)), Column('data2',Integer(),nullable=False), ) assertDiff(True, [], [], [self.table_name]) # TODO test nullable diff # Apply latest model changes and find no more diffs. - self.table.drop() - self.table.create() + self._applyLatestModel() assertDiff(False, [], [], []) # Remove table from model.