Current Path : /usr/share/pyshared/landscape/ |
Current File : //usr/share/pyshared/landscape/patch.py |
import logging class UpgraderConflict(Exception): """Two upgraders with the same version have been registered.""" class UpgradeManagerBase(object): """A simple upgrade system.""" def __init__(self): self._upgraders = {} def register_upgrader(self, version, function): """ @param version: The version number that this upgrader is upgrading the database to. This defines the order that upgraders are run. @param function: The function to call when applying upgraders. It must take a single object, the database that is being upgraded. """ if version in self._upgraders: raise UpgraderConflict("%s is already registered as %s; " "not adding %s" % ( version, self._upgraders[version], function)) self._upgraders[version] = function def get_version(self): """ Get the 'current' version of any database that this UpgradeManager will be applied to. """ keys = self._upgraders.keys() if keys: return max(keys) return 0 def upgrader(self, version): """ A decorator for specifying that a function is an upgrader for this upgrade manager. @param version: The version number that the function will be upgrading to. """ def inner(function): self.register_upgrader(version, function) return function return inner class UpgradeManager(UpgradeManagerBase): def apply(self, persist): """Bring the database up-to-date. @param persist: The database to upgrade. It will be passed to all upgrade functions. """ if not persist.has("system-version"): persist.set("system-version", 0) for version, upgrader in sorted(self._upgraders.items()): if version > persist.get("system-version"): persist.set("system-version", version) upgrader(persist) logging.info("Successfully applied patch %s" % version) def initialize(self, persist): """ Mark the database as being up-to-date; use this when initializing a new database. """ persist.set("system-version", self.get_version()) class SQLiteUpgradeManager(UpgradeManagerBase): """An upgrade manager backed by sqlite.""" def get_database_versions(self, cursor): cursor.execute("SELECT version FROM patch") result = cursor.fetchall() return set([row[0] for row in result]) def get_database_version(self, cursor): cursor.execute("SELECT MAX(version) FROM patch") version = cursor.fetchone()[0] if version: return version return 0 def apply(self, cursor): """Bring the database up-to-date.""" versions = self.get_database_versions(cursor) for version, upgrader in sorted(self._upgraders.items()): if version not in versions: self.apply_one(version, cursor) def apply_one(self, version, cursor): upgrader = self._upgraders[version] upgrader(cursor) cursor.execute("INSERT INTO patch VALUES (?)", (version,)) def initialize(self, cursor): """ Mark the database as being up-to-date; use this when initializing a new SQLite database. """ cursor.execute("CREATE TABLE patch (version INTEGER)") for version, upgrader in sorted(self._upgraders.items()): cursor.execute("INSERT INTO patch VALUES (?)", (version,))