Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

V1.10.0 multidb #1

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion tenant_schemas/cache.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from django.db import connection
from django.db import connections, router


def make_key(key, key_prefix, version):
Expand All @@ -8,6 +8,8 @@ def make_key(key, key_prefix, version):
Constructs the key used by all other methods. Prepends the tenant
`schema_name` and `key_prefix'.
"""
db = router.db_for_read(None)
connection = connections[db]
return '%s:%s:%s:%s' % (connection.schema_name, key_prefix, version, key)


Expand Down
4 changes: 3 additions & 1 deletion tenant_schemas/log.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging

from django.db import connection
from django.db import connections, router


class TenantContextFilter(logging.Filter):
Expand All @@ -10,6 +10,8 @@ class TenantContextFilter(logging.Filter):
Thanks to @regolith for the snippet on #248
"""
def filter(self, record):
db = router.db_for_read(None)
connection = connections[db]
record.schema_name = connection.tenant.schema_name
record.domain_url = getattr(connection.tenant, 'domain_url', '')
return True
5 changes: 3 additions & 2 deletions tenant_schemas/management/commands/tenant_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from django.core.management.base import BaseCommand
from django.db import connection
from tenant_schemas.management.commands import InteractiveTenantOption
from tenant_schemas.utils import tenant_context


class Command(InteractiveTenantOption, BaseCommand):
Expand All @@ -11,5 +12,5 @@ def handle(self, command, schema_name, *args, **options):
tenant = self.get_tenant_from_options_or_interactive(
schema_name=schema_name, **options
)
connection.set_tenant(tenant)
call_command(command, *args, **options)
with tenant_context(tenant):
call_command(command, *args, **options)
31 changes: 28 additions & 3 deletions tenant_schemas/middleware.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import django
from django.conf import settings
from django.core.exceptions import DisallowedHost
from django.db import connection
from django.db import connection, connections
from django.http import Http404
from django.core.urlresolvers import set_urlconf

from tenant_schemas.utils import (
get_public_schema_name,
get_tenant_model,
remove_www,
get_db_alias,
)


Expand Down Expand Up @@ -39,10 +42,16 @@ def hostname_from_request(self, request):
"""
return remove_www(request.get_host().split(":")[0]).lower()

def set_connection_to_public(self):
connection.set_schema_to_public()

def set_connection_to_tenant(self, tenant):
connection.set_tenant(tenant)

def process_request(self, request):
# Connection needs first to be at the public schema, as this is where
# the tenant metadata is stored.
connection.set_schema_to_public()
self.set_connection_to_public()

hostname = self.hostname_from_request(request)
TenantModel = get_tenant_model()
Expand All @@ -61,14 +70,16 @@ def process_request(self, request):
)

request.tenant = tenant
connection.set_tenant(request.tenant)

self.set_connection_to_tenant(request.tenant)

# Do we have a public-specific urlconf?
if (
hasattr(settings, "PUBLIC_SCHEMA_URLCONF")
and request.tenant.schema_name == get_public_schema_name()
):
request.urlconf = settings.PUBLIC_SCHEMA_URLCONF
set_urlconf(request.urlconf)


class TenantMiddleware(BaseTenantMiddleware):
Expand Down Expand Up @@ -120,3 +131,17 @@ def get_tenant(self, model, hostname, request):
schema_name = get_public_schema_name()

return model.objects.get(schema_name=schema_name)


class MultiDBTenantMiddleware(SuspiciousTenantMiddleware):
def set_connection_to_public(self):
for db in get_db_alias():
connections[db].set_schema_to_public()

def set_connection_to_tenant(self, tenant):
for db in get_db_alias():
connections[db].set_tenant(tenant)

# TODO remove this - just for local testing
# def get_tenant(self, model, hostname, request):
# return model.objects.get(domain_url=hostname.replace('public.localhost', 'dev.etailpet.com'))
19 changes: 13 additions & 6 deletions tenant_schemas/postgresql_backend/introspection.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,13 +210,20 @@ def get_table_description(self, cursor, table_name):
})
field_map = {line[0]: line[1:] for line in cursor.fetchall()}
cursor.execute('SELECT * FROM %s LIMIT 1' % self.connection.ops.quote_name(table_name))

# To fix issues with django-cms
# https://github.com/divio/django-cms/issues/6666
# https://code.djangoproject.com/ticket/30331
return [
FieldInfo(*(
(force_text(line[0]),) +
line[1:6] +
(field_map[force_text(line[0])][0] == 'YES', field_map[force_text(line[0])][1])
)) for line in cursor.description
FieldInfo(
line.name,
line.type_code,
line.display_size,
line.internal_size,
line.precision,
line.scale,
*field_map[line.name],
)
for line in cursor.description
]

def get_relations(self, cursor, table_name):
Expand Down
1 change: 1 addition & 0 deletions tenant_schemas/routers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ def allow_migrate(self, db, app_label, model_name=None, **hints):
# the imports below need to be done here else django <1.5 goes crazy
# https://code.djangoproject.com/ticket/20704
from django.db import connection
# TODO may be need to get the connection from connections to support multidb
from tenant_schemas.utils import get_public_schema_name, app_labels
from tenant_schemas.postgresql_backend.base import DatabaseWrapper as TenantDbWrapper

Expand Down
46 changes: 45 additions & 1 deletion tenant_schemas/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from contextlib import contextmanager

from django.conf import settings
from django.db import connection
from django.db import connection, connections

try:
from django.apps import apps, AppConfig
Expand All @@ -11,6 +11,12 @@
AppConfig = None
from django.core import mail

MULTI_DB_ENABLED = True if len(settings.DATABASES.keys()) > 1 else False


def get_db_alias():
return settings.DATABASES.keys()


@contextmanager
def schema_context(schema_name):
Expand Down Expand Up @@ -38,6 +44,43 @@ def tenant_context(tenant):
connection.set_tenant(previous_tenant)


# Changes to schema_context and tenant_context when multi db enabled
if MULTI_DB_ENABLED:
def get_previous_tenant_dict():
previous_tenant_dict = dict()
for db in get_db_alias():
previous_tenant_dict[db] = connections[db].tenant
return previous_tenant_dict

def apply_previous_tenant_dict(previous_tenant_dict):
if not previous_tenant_dict:
for db in get_db_alias():
connections[db].set_schema_to_public()
else:
for db in get_db_alias():
connections[db].set_tenant(previous_tenant_dict[db])

@contextmanager
def schema_context(schema_name):
previous_tenant_dict = get_previous_tenant_dict()
try:
for db in get_db_alias():
connections[db].set_schema(schema_name)
yield
finally:
apply_previous_tenant_dict(previous_tenant_dict)

@contextmanager
def tenant_context(tenant):
previous_tenant_dict = get_previous_tenant_dict()
try:
for db in get_db_alias():
connections[db].set_tenant(tenant)
yield
finally:
apply_previous_tenant_dict(previous_tenant_dict)


def get_tenant_model():
return get_model(*settings.TENANT_MODEL.split("."))

Expand Down Expand Up @@ -89,6 +132,7 @@ def django_is_in_test_mode():


def schema_exists(schema_name):
# TODO may be need to get the connection based on the default database to support multidb
cursor = connection.cursor()

# check if this schema already exists in the db
Expand Down