aports/testing/grommunio-admin-api/0006-enable-sqlalchemy2.patch

116 lines
5.4 KiB
Diff

Fixes for sqlalchemy version 2.x (reported upstream)
--- a/cli/dbtools.py
+++ b/cli/dbtools.py
@@ -10,6 +10,7 @@
import string
from datetime import datetime
from getpass import getpass
+from sqlalchemy import text
passwordChars = string.ascii_letters+string.digits+'!"#$%&()*+-/;<=>?[]_{|}~'
defaultPassLength = 16
@@ -87,7 +88,7 @@
user = Users.query.filter(Users.ID == 0).first()
if user is None:
cli.print("System admin user not found, creating...")
- DB.session.execute("SET sql_mode='NO_AUTO_VALUE_ON_ZERO';")
+ DB.session.execute(text("SET sql_mode='NO_AUTO_VALUE_ON_ZERO';"))
user, _ = createAdmin()
DB.session.add(user)
cli.print("Setting password for user '{}'".format(user.username))
--- a/orm/__init__.py
+++ b/orm/__init__.py
@@ -4,10 +4,10 @@
__all__ = ["domains", "misc", "users", "ext"]
+import sqlalchemy
from sqlalchemy import create_engine, event, select, text
from sqlalchemy.exc import OperationalError
-from sqlalchemy.ext.declarative import declarative_base
-from sqlalchemy.orm import scoped_session, sessionmaker, class_mapper, Query, column_property
+from sqlalchemy.orm import declarative_base, scoped_session, sessionmaker, class_mapper, Query, column_property
from tools.config import Config
@@ -44,7 +44,7 @@
def testConnection(self, verbose=False):
try:
- self.session.execute("SELECT 1 FROM DUAL")
+ self.session.execute(text("SELECT 1 FROM DUAL"))
except OperationalError as err:
self.session.remove()
return "Database connection failed with error {}: {}".format(err.orig.args[0], err.orig.args[1])
@@ -64,7 +64,7 @@
Version number or None on failure
"""
try:
- version = int(self.session.execute("SELECT `value` FROM `options` WHERE `key` = 'schemaversion'").fetchone()[0])
+ version = int(self.session.execute(text("SELECT `value` FROM `options` WHERE `key` = 'schemaversion'")).fetchone()[0])
if verbose:
logger.info("Detected database schema version n"+str(version))
return version
@@ -216,7 +216,10 @@
column : Any
Column definition to return if version check passes
"""
- return column if DB.minVersion(version) else column_property(select([text(default)]).as_scalar())
+ if sqlalchemy.__version__.split(".") >= ["1", "4"]:
+ return column if DB.minVersion(version) else column_property(select(text(default)).scalar_subquery())
+ else:
+ return column if DB.minVersion(version) else column_property(select([text(default)]).as_scalar())
class NotifyTable:
--- a/orm/domains.py
+++ b/orm/domains.py
@@ -418,18 +418,18 @@
if sqlalchemy.__version__.split(".") >= ["1", "4"]:
inspect(Domains).add_property("activeUsers",
- column_property(select([func.count(Users.ID)])
+ column_property(select(func.count(Users.ID))
.where(Users.domainID == Domains.ID, Users.addressStatus == Users.NORMAL,
Users.maildir != "")
.scalar_subquery()))
inspect(Domains).add_property("inactiveUsers",
- column_property(select([func.count(Users.ID)])
+ column_property(select(func.count(Users.ID))
.where(Users.domainID == Domains.ID,
Users.addressStatus.not_in((Users.NORMAL, Users.SHARED)),
Users.maildir != "")
.scalar_subquery()))
inspect(Domains).add_property("virtualUsers",
- column_property(select([func.count(Users.ID)])
+ column_property(select(func.count(Users.ID))
.where(Users.domainID == Domains.ID, (Users.addressStatus == Users.SHARED) |
(Users.maildir == ""))
.scalar_subquery()))
--- a/orm/misc.py
+++ b/orm/misc.py
@@ -164,7 +164,10 @@
@users.expression
def users(cls):
from .users import Users
- return select([func.count(Users.ID)]).where(Users.homeserverID == cls.ID).as_scalar()
+ if sqlalchemy.__version__.split(".") >= ["1", "4"]:
+ return select(func.count(Users.ID)).where(Users.homeserverID == cls.ID).scalar_subquery()
+ else:
+ return select([func.count(Users.ID)]).where(Users.homeserverID == cls.ID).as_scalar()
@hybrid_property
def domains(self):
@@ -174,7 +177,10 @@
@domains.expression
def domains(cls):
from .domains import Domains
- return select([func.count(Domains.ID)]).where(Domains.homeserverID == cls.ID).as_scalar()
+ if sqlalchemy.__version__.split(".") >= ["1", "4"]:
+ return select(func.count(Domains.ID)).where(Domains.homeserverID == cls.ID).scalar_subquery()
+ else:
+ return select([func.count(Domains.ID)]).where(Domains.homeserverID == cls.ID).as_scalar()
@staticmethod
def _getServer(objID, serverID=None, domain=False):