Skip to content

Commit

Permalink
Improve database URI handling
Browse files Browse the repository at this point in the history
This allows for using socket files for full_uri (or host) as well as
just generally improves all the ways URIs are used.
  • Loading branch information
bennybp committed Jul 9, 2023
1 parent c7e4a32 commit 1be6db8
Show file tree
Hide file tree
Showing 8 changed files with 447 additions and 236 deletions.
10 changes: 5 additions & 5 deletions qcarchivetesting/qcarchivetesting/testing_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def create_template(self):

self.sql_command(
f"CREATE DATABASE {self.template_name} TEMPLATE {self.db_name};",
database_name=self.config.existing_db,
use_maintenance_db=True,
returns=False,
)

Expand All @@ -59,7 +59,7 @@ def recreate_database(self):
self.delete_database()
self.sql_command(
f"CREATE DATABASE {self.db_name} TEMPLATE {self.template_name};",
database_name=self.config.existing_db,
use_maintenance_db=True,
returns=False,
)

Expand All @@ -73,17 +73,17 @@ def __init__(self, db_path: str):
self.logger = logging.getLogger(__name__)
self.tmp_pg = TemporaryPostgres(data_dir=db_path)
self.harness = self.tmp_pg._harness
self.logger.debug(f"Using database located at {db_path} with uri {self.harness.database_uri}")
self.logger.debug(f"Using database located at {db_path} with uri {self.harness.config.safe_uri}")

# Postgres process is up, but the database is not created
assert self.harness.is_alive(False) and not self.harness.is_alive(True)
assert self.harness.is_alive() and not self.harness.can_connect()

def get_new_harness(self, db_name: str) -> QCATestingPostgresHarness:
harness_config = deepcopy(self.harness.config.dict())
harness_config["database_name"] = db_name

new_harness = QCATestingPostgresHarness(DatabaseConfig(**harness_config))
new_harness.create_database()
new_harness.create_database(create_tables=True)
return new_harness


Expand Down
4 changes: 0 additions & 4 deletions qcfractal/qcfractal/alembic.ini
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,6 @@ file_template = %%(year)d-%%(month).2d-%%(day).2d-%%(rev)s_%%(slug)s
# are written from script.py.mako
# output_encoding = utf-8

# NOT used from here, will be read from FractalConfig
sqlalchemy.url = 'Overridden by the application'


# Logging configuration
[loggers]
keys = root,sqlalchemy,alembic
Expand Down
123 changes: 78 additions & 45 deletions qcfractal/qcfractal/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@

import logging
import os
import re
import secrets
import urllib.parse
from typing import Optional, Dict, Any
from typing import Optional, Dict, Union, Any

import yaml
from psycopg2.extensions import make_dsn, parse_dsn
from pydantic import BaseSettings, Field, validator, root_validator, ValidationError
from pydantic.env_settings import SettingsSourceCallable
from sqlalchemy.engine.url import URL, make_url

from qcfractal.port_util import find_open_port

Expand Down Expand Up @@ -44,6 +44,29 @@ def _make_abs_path(path: Optional[str], base_folder: str, default_filename: Opti
return os.path.abspath(path)


def make_uri_string(
host: Optional[str],
port: Optional[Union[int, str]],
username: Optional[str],
password: Optional[str],
dbname: Optional[str],
query: Optional[Dict[str, str]],
) -> str:

username = username if username is not None else ""
password = ":" + password if password is not None else ""
sep = "@" if username != "" or password != "" else ""
query_str = "" if query is None else "&".join(f"{k}={v}" for k, v in query.items())

# If this is a socket file, move the host to the query params
if host.startswith("/"):
query_str = "&" + query_str if query_str != "" else ""
return f"postgresql://{username}{password}{sep}:{port}/{dbname}?host={host}{query_str}"
else:
query_str = "?" + query_str if query_str != "" else ""
return f"postgresql://{username}{password}{sep}{host}:{port}/{dbname}{query_str}"


class ConfigCommon:
case_sensitive = False
extra = "forbid"
Expand Down Expand Up @@ -104,7 +127,7 @@ class DatabaseConfig(ConfigBase):

host: str = Field(
"localhost",
description="The hostname and ip address the database is running on. If own = True, this must be localhost",
description="The hostname or ip address the database is running on. If own = True, this must be localhost. May also be a path to a directory containing the database socket file",
)
port: int = Field(
5432,
Expand All @@ -113,7 +136,7 @@ class DatabaseConfig(ConfigBase):
database_name: str = Field("qcfractal_default", description="The database name to connect to.")
username: Optional[str] = Field(None, description="The database username to connect with")
password: Optional[str] = Field(None, description="The database password to connect with")
query: Optional[str] = Field(None, description="Extra connection query parameters at the end of the URL string")
query: Dict[str, str] = Field({}, description="Extra connection query parameters at the end of the URL string")

own: bool = Field(
True,
Expand All @@ -140,7 +163,7 @@ class DatabaseConfig(ConfigBase):
description="[ADVANCED] set the size of the connection pool to use in SQLAlchemy. Set to zero to disable pooling",
)

existing_db: str = Field(
maintenance_db: str = Field(
"postgres",
description="[ADVANCED] An existing database (not the one you want to use/create). This is used for database management",
)
Expand All @@ -159,54 +182,64 @@ def _check_data_directory(cls, v, values):
def _check_logfile(cls, v, values):
return _make_abs_path(v, values["base_folder"], "qcfractal_database.log")

@root_validator(pre=True)
def _root_validator(cls, values):
@property
def database_uri(self) -> str:
"""
If full uri is specified, decompose it into the other fields
Returns the real database URI as a string
It does not hide the password, so is not suitable for logging
"""
if self.full_uri is not None:
return self.full_uri
else:
return make_uri_string(
host=self.host,
port=self.port,
username=self.username,
password=self.password,
dbname=self.database_name,
query=self.query,
)

full_uri = values.get("full_uri")
if full_uri:
parsed = urllib.parse.urlparse(full_uri)
values["host"] = parsed.hostname
values["port"] = parsed.port
values["username"] = parsed.username
values["password"] = parsed.password
values["query"] = "?" + parsed.query
values["database_name"] = parsed.path.strip("/")
@property
def sqlalchemy_url(self) -> URL:
"""Returns the SQLAlchemy URL for this database"""

return values
url = make_url(self.database_uri)
return url.set(drivername="postgresql+psycopg2")

@property
def uri(self):
if self.full_uri is not None:
return self.full_uri
else:
# Hostname can be a directory (unix sockets). But we need to escape some stuff
host = urllib.parse.quote(self.host, safe="")
username = self.username if self.username is not None else ""
password = f":{self.password}" if self.password is not None else ""
sep = "@" if username != "" or password != "" else ""
query = "" if self.query is None else self.query
return f"postgresql://{username}{password}{sep}{host}:{self.port}/{self.database_name}{query}"
def psycopg2_dsn(self) -> str:
"""
Returns a string suitable for use as a psycopg2 connection string
"""
dsn_dict = parse_dsn(self.database_uri)
return make_dsn(**dsn_dict)

@property
def safe_uri(self):
if self.full_uri is not None:
parsed = urllib.parse.urlparse(self.full_uri)
if parsed.password is None:
return self.full_uri
def psycopg2_maintenance_dsn(self) -> str:
dsn_dict = parse_dsn(self.database_uri)
dsn_dict["dbname"] = self.maintenance_db
return make_dsn(**dsn_dict)

new_netloc = re.sub(":.*@", ":********@", parsed.netloc)
parsed = parsed._replace(netloc=new_netloc)
return parsed.geturl()
else:
host = urllib.parse.quote(self.host, safe="")
username = self.username if self.username is not None else ""
password = ":********" if self.password is not None else ""
sep = "@" if username != "" or password != "" else ""
query = "" if self.query is None else self.query
return f"postgresql://{username}{password}{sep}{host}:{self.port}/{self.database_name}{query}"
@property
def safe_uri(self) -> str:
"""
Returns a user-readable version of the URI for logging, etc.
"""

dsn = parse_dsn(self.database_uri)

host = dsn.pop("host")
port = dsn.pop("port", None)
user = dsn.pop("user", None)
password = dsn.pop("password", None)
dbname = dsn.pop("dbname")

# SQLAlchemy render_string has some problems sometimes, so use our own
return make_uri_string(
host=host, port=port, username=user, password="********" if password else None, dbname=dbname, query=dsn
) # everything left over


class AutoResetConfig(ConfigBase):
Expand Down
32 changes: 21 additions & 11 deletions qcfractal/qcfractal/db_socket/socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ def __init__(self, qcf_config: FractalConfig):

# Logging data
self.logger = logging.getLogger("SQLAlchemySocket")
self.uri = qcf_config.database.uri

self.logger.info(f"SQLAlchemy attempt to connect to {qcf_config.database.safe_uri}.")

Expand All @@ -47,10 +46,18 @@ def __init__(self, qcf_config: FractalConfig):
# If pool_size in the config is non-zero, then set the pool class to None (meaning use
# SQLAlchemy default)
if qcf_config.database.pool_size == 0:
self.engine = create_engine(self.uri, echo=qcf_config.database.echo_sql, poolclass=NullPool, future=True)
self.engine = create_engine(
self.qcf_config.database.sqlalchemy_url,
echo=qcf_config.database.echo_sql,
poolclass=NullPool,
future=True,
)
else:
self.engine = create_engine(
self.uri, echo=qcf_config.database.echo_sql, pool_size=qcf_config.database.pool_size, future=True
self.qcf_config.database.sqlalchemy_url,
echo=qcf_config.database.echo_sql,
pool_size=qcf_config.database.pool_size,
future=True,
)

self.logger.info(
Expand Down Expand Up @@ -113,7 +120,7 @@ def checkout(dbapi_connection, connection_record, connection_proxy):
self.auth = AuthSocket(self)

def __str__(self) -> str:
return f"<SQLAlchemySocket: address='{self.uri}`>"
return f"<SQLAlchemySocket: address='{self.qcf_config.database.safe_uri}`>"

def post_fork_cleanup(self):
"""
Expand Down Expand Up @@ -142,15 +149,19 @@ def alembic_commands(db_config: DatabaseConfig) -> List[str]:
Components of an alembic command line as a list of strings
"""

db_uri = db_config.uri

# Find the path to the almebic ini
alembic_ini = os.path.join(qcfractal.qcfractal_topdir, "alembic.ini")
alembic_path = shutil.which("alembic")

if alembic_path is None:
raise RuntimeError("Cannot find the 'alembic' command. Is it installed?")
return [alembic_path, "-c", alembic_ini, "-x", "uri=" + db_uri]
return [
alembic_path,
"-c",
alembic_ini,
"-x",
"uri=" + db_config.database_uri,
]

@staticmethod
def get_alembic_config(db_config: DatabaseConfig):
Expand All @@ -165,7 +176,7 @@ def get_alembic_config(db_config: DatabaseConfig):

# Tell alembic to not set up logging. We already did that
alembic_cfg.set_main_option("skip_logging", "True")
alembic_cfg.set_main_option("sqlalchemy.url", db_config.uri)
alembic_cfg.set_main_option("sqlalchemy.url", db_config.database_uri)

return alembic_cfg

Expand All @@ -183,9 +194,8 @@ def create_database_tables(db_config: DatabaseConfig):
importlib.import_module("qcfractal.components.register_all")

# create the tables via sqlalchemy
uri = db_config.uri
logger.info(f"Creating tables for database: {uri}")
engine = create_engine(uri, echo=False, poolclass=NullPool, future=True)
logger.info(f"Creating tables for database: {db_config.safe_uri}")
engine = create_engine(db_config.sqlalchemy_url, echo=False, poolclass=NullPool, future=True)

from qcfractal.db_socket.base_orm import BaseORM

Expand Down
Loading

0 comments on commit 1be6db8

Please sign in to comment.