-
Notifications
You must be signed in to change notification settings - Fork 59
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add boundary test to better understand get_columns method for cross d…
…atabase refs
- Loading branch information
1 parent
8bbdbf2
commit a9af12d
Showing
4 changed files
with
144 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
from datetime import datetime | ||
import os | ||
import random | ||
from typing import Any, Dict | ||
|
||
import pytest | ||
import redshift_connector | ||
|
||
|
||
@pytest.fixture | ||
def connection(connection_config) -> redshift_connector.Connection: | ||
return redshift_connector.connect(**connection_config) | ||
|
||
|
||
@pytest.fixture | ||
def connection_alt(connection_config) -> redshift_connector.Connection: | ||
config = connection_config.copy() | ||
config.update(database=os.getenv("REDSHIFT_TEST_DBNAME_ALT")) | ||
return redshift_connector.connect(**config) | ||
|
||
|
||
@pytest.fixture | ||
def connection_config() -> Dict[str, Any]: | ||
return { | ||
"user": os.getenv("REDSHIFT_TEST_USER"), | ||
"password": os.getenv("REDSHIFT_TEST_PASS"), | ||
"host": os.getenv("REDSHIFT_TEST_HOST"), | ||
"port": int(os.getenv("REDSHIFT_TEST_PORT")), | ||
"database": os.getenv("REDSHIFT_TEST_DBNAME"), | ||
"region": os.getenv("REDSHIFT_TEST_REGION"), | ||
} | ||
|
||
|
||
@pytest.fixture | ||
def schema_name(request) -> str: | ||
runtime = datetime.utcnow() - datetime(1970, 1, 1, 0, 0, 0) | ||
runtime_s = int(runtime.total_seconds()) | ||
runtime_ms = runtime.microseconds | ||
random_int = random.randint(0, 9999) | ||
file_name = request.module.__name__.split(".")[-1] | ||
return f"test_{runtime_s}{runtime_ms}{random_int:04}_{file_name}" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
import os | ||
|
||
import pytest | ||
|
||
|
||
@pytest.fixture(autouse=True) | ||
def setup(connection, connection_alt, schema_name) -> str: | ||
# create the same table in two different databases | ||
with connection.cursor() as cursor: | ||
cursor.execute(f"CREATE SCHEMA IF NOT EXISTS {schema_name}") | ||
cursor.execute(f"CREATE TABLE {schema_name}.cross_db as select 3.14 as id") | ||
with connection_alt.cursor() as cursor: | ||
cursor.execute(f"CREATE SCHEMA IF NOT EXISTS {schema_name}") | ||
cursor.execute(f"CREATE TABLE {schema_name}.cross_db as select 3.14 as id") | ||
|
||
yield schema_name | ||
|
||
# drop both test schemas | ||
with connection_alt.cursor() as cursor: | ||
cursor.execute(f"DROP SCHEMA IF EXISTS {schema_name} CASCADE") | ||
with connection.cursor() as cursor: | ||
cursor.execute(f"DROP SCHEMA IF EXISTS {schema_name} CASCADE") | ||
|
||
|
||
def test_columns_in_relation(connection, schema_name): | ||
# we're specifically running this query from the default database | ||
# we're expecting to get both tables, the one in the default database and the one in the alt database | ||
with connection.cursor() as cursor: | ||
columns = cursor.get_columns(schema_pattern=schema_name, tablename_pattern="cross_db") | ||
|
||
# we should have the same table in both databases | ||
assert len(columns) == 2 | ||
|
||
databases = set() | ||
for column in columns: | ||
( | ||
database, | ||
schema, | ||
table, | ||
name, | ||
type_code, | ||
type_name, | ||
precision, | ||
_, | ||
scale, | ||
*_, | ||
) = column | ||
databases.add(database) | ||
assert schema_name == schema_name | ||
assert table == "cross_db" | ||
assert name == "id" | ||
assert type_code == 2 | ||
assert type_name == "numeric" | ||
assert precision == 3 | ||
assert scale == 2 | ||
|
||
# only the databases are different | ||
assert databases == { | ||
os.getenv("REDSHIFT_TEST_DBNAME"), | ||
os.getenv("REDSHIFT_TEST_DBNAME_ALT"), | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
from dbt.tests.util import get_connection, run_dbt | ||
import pytest | ||
|
||
|
||
MY_CROSS_DB_SOURCES = """ | ||
version: 2 | ||
sources: | ||
- name: ci | ||
schema: adapter | ||
tables: | ||
- name: cross_db | ||
- name: ci_alt | ||
database: ci_alt | ||
schema: adapter | ||
tables: | ||
- name: cross_db | ||
""" | ||
|
||
|
||
class TestCrossDatabase: | ||
""" | ||
This addresses https://github.com/dbt-labs/dbt-redshift/issues/736 | ||
""" | ||
|
||
@pytest.fixture(scope="class") | ||
def models(self): | ||
my_model = """ | ||
select '{{ adapter.get_columns_in_relation(source('ci', 'cross_db')) }}' as columns | ||
union all | ||
select '{{ adapter.get_columns_in_relation(source('ci_alt', 'cross_db')) }}' as columns | ||
""" | ||
return { | ||
"sources.yml": MY_CROSS_DB_SOURCES, | ||
"my_model.sql": my_model, | ||
} | ||
|
||
def test_columns_in_relation(self, project): | ||
run_dbt(["run"]) | ||
with get_connection(project.adapter, "_test"): | ||
records = project.run_sql(f"select * from {project.test_schema}.my_model", fetch=True) | ||
assert len(records) == 2 |