Fixes for sqlalchemy version 2.x (reported upstream) --- a/cli/dbtools.py +++ b/cli/dbtools.py @@ -10,6 +10,7 @@ import string from datetime import datetime from getpass import getpass +from sqlalchemy import text passwordChars = string.ascii_letters+string.digits+'!"#$%&()*+-/;<=>?[]_{|}~' defaultPassLength = 16 @@ -87,7 +88,7 @@ user = Users.query.filter(Users.ID == 0).first() if user is None: cli.print("System admin user not found, creating...") - DB.session.execute("SET sql_mode='NO_AUTO_VALUE_ON_ZERO';") + DB.session.execute(text("SET sql_mode='NO_AUTO_VALUE_ON_ZERO';")) user, _ = createAdmin() DB.session.add(user) cli.print("Setting password for user '{}'".format(user.username)) --- a/orm/__init__.py +++ b/orm/__init__.py @@ -4,10 +4,10 @@ __all__ = ["domains", "misc", "users", "ext"] +import sqlalchemy from sqlalchemy import create_engine, event, select, text from sqlalchemy.exc import OperationalError -from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import scoped_session, sessionmaker, class_mapper, Query, column_property +from sqlalchemy.orm import declarative_base, scoped_session, sessionmaker, class_mapper, Query, column_property from tools.config import Config @@ -44,7 +44,7 @@ def testConnection(self, verbose=False): try: - self.session.execute("SELECT 1 FROM DUAL") + self.session.execute(text("SELECT 1 FROM DUAL")) except OperationalError as err: self.session.remove() return "Database connection failed with error {}: {}".format(err.orig.args[0], err.orig.args[1]) @@ -64,7 +64,7 @@ Version number or None on failure """ try: - version = int(self.session.execute("SELECT `value` FROM `options` WHERE `key` = 'schemaversion'").fetchone()[0]) + version = int(self.session.execute(text("SELECT `value` FROM `options` WHERE `key` = 'schemaversion'")).fetchone()[0]) if verbose: logger.info("Detected database schema version n"+str(version)) return version @@ -216,7 +216,10 @@ column : Any Column definition to return if version check passes """ - return column if DB.minVersion(version) else column_property(select([text(default)]).as_scalar()) + if sqlalchemy.__version__.split(".") >= ["1", "4"]: + return column if DB.minVersion(version) else column_property(select(text(default)).scalar_subquery()) + else: + return column if DB.minVersion(version) else column_property(select([text(default)]).as_scalar()) class NotifyTable: --- a/orm/domains.py +++ b/orm/domains.py @@ -418,18 +418,18 @@ if sqlalchemy.__version__.split(".") >= ["1", "4"]: inspect(Domains).add_property("activeUsers", - column_property(select([func.count(Users.ID)]) + column_property(select(func.count(Users.ID)) .where(Users.domainID == Domains.ID, Users.addressStatus == Users.NORMAL, Users.maildir != "") .scalar_subquery())) inspect(Domains).add_property("inactiveUsers", - column_property(select([func.count(Users.ID)]) + column_property(select(func.count(Users.ID)) .where(Users.domainID == Domains.ID, Users.addressStatus.not_in((Users.NORMAL, Users.SHARED)), Users.maildir != "") .scalar_subquery())) inspect(Domains).add_property("virtualUsers", - column_property(select([func.count(Users.ID)]) + column_property(select(func.count(Users.ID)) .where(Users.domainID == Domains.ID, (Users.addressStatus == Users.SHARED) | (Users.maildir == "")) .scalar_subquery())) --- a/orm/misc.py +++ b/orm/misc.py @@ -164,7 +164,10 @@ @users.expression def users(cls): from .users import Users - return select([func.count(Users.ID)]).where(Users.homeserverID == cls.ID).as_scalar() + if sqlalchemy.__version__.split(".") >= ["1", "4"]: + return select(func.count(Users.ID)).where(Users.homeserverID == cls.ID).scalar_subquery() + else: + return select([func.count(Users.ID)]).where(Users.homeserverID == cls.ID).as_scalar() @hybrid_property def domains(self): @@ -174,7 +177,10 @@ @domains.expression def domains(cls): from .domains import Domains - return select([func.count(Domains.ID)]).where(Domains.homeserverID == cls.ID).as_scalar() + if sqlalchemy.__version__.split(".") >= ["1", "4"]: + return select(func.count(Domains.ID)).where(Domains.homeserverID == cls.ID).scalar_subquery() + else: + return select([func.count(Domains.ID)]).where(Domains.homeserverID == cls.ID).as_scalar() @staticmethod def _getServer(objID, serverID=None, domain=False):