Skip to content

Commit

Permalink
adding unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
benc-db committed Oct 21, 2024
1 parent 96d0033 commit d608847
Show file tree
Hide file tree
Showing 8 changed files with 826 additions and 299 deletions.
11 changes: 4 additions & 7 deletions dbt/adapters/databricks/python_models/python_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,14 @@ class PythonJobConfig(BaseModel):
"""Pydantic model for config found in python_job_config."""

name: Optional[str] = None
email_notifications: Optional[Dict[str, Any]] = None
webhook_notifications: Optional[Dict[str, Any]] = None
notification_settings: Optional[Dict[str, Any]] = None
timeout_seconds: Optional[int] = Field(None, gt=0)
health: Optional[Dict[str, Any]] = None
environments: Optional[List[Dict[str, Any]]] = None
grants: Dict[str, List[Dict[str, str]]] = Field(exclude=True, default_factory=dict)
existing_job_id: str = Field("", exclude=True)
post_hook_tasks: List[Dict[str, Any]] = Field(exclude=True, default_factory=list)
additional_task_settings: Dict[str, Any] = Field(exclude=True, default_factory=dict)

class Config:
extra = "allow"


class PythonModelConfig(BaseModel):
"""
Expand Down Expand Up @@ -54,7 +51,7 @@ class ParsedPythonModel(BaseModel):

@property
def run_name(self) -> str:
return f"{self.catalog}-{self.schema_}-" f"{self.identifier}-{uuid.uuid4()}"
return f"{self.catalog}-{self.schema_}-{self.identifier}-{uuid.uuid4()}"

class Config:
allow_population_by_field_name = True
105 changes: 62 additions & 43 deletions dbt/adapters/databricks/python_models/python_submissions.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,10 @@ class PythonNotebookUploader:
"""Uploads a compiled Python model as a notebook to the Databricks workspace."""

def __init__(
self, api_client: DatabricksApiClient, database: str, schema: str, identifier: str
self, api_client: DatabricksApiClient, catalog: str, schema: str, identifier: str
) -> None:
self.api_client = api_client
self.database = database
self.catalog = catalog
self.schema = schema
self.identifier = identifier

Expand All @@ -116,7 +116,7 @@ def create(

def upload(self, compiled_code: str) -> str:
"""Upload the compiled code to the Databricks workspace."""
workdir = self.api_client.workspace.create_python_model_dir(self.database, self.schema)
workdir = self.api_client.workspace.create_python_model_dir(self.catalog, self.schema)
file_path = f"{workdir}{self.identifier}"
self.api_client.workspace.upload_notebook(file_path, compiled_code)
return file_path
Expand Down Expand Up @@ -202,7 +202,32 @@ def build_job_permissions(self) -> List[Dict[str, Any]]:
)
access_control_list.append(acl_grant)

return access_control_list + (self.acls or [])
return access_control_list + self.acls


class PythonLibraryConfigurer:
"""Configures the libraries component for a Python job."""

@staticmethod
def get_library_config(
packages: List[str],
index_url: Optional[str],
additional_libraries: List[Dict[str, Any]],
) -> Dict[str, Any]:
"""Update the job configuration with the required libraries."""

libraries = []

for package in packages:
if index_url:
libraries.append({"pypi": {"package": package, "repo": index_url}})
else:
libraries.append({"pypi": {"package": package}})

for library in additional_libraries:
libraries.append(library)

return {"libraries": libraries}


class PythonJobConfigCompiler:
Expand All @@ -212,38 +237,42 @@ def __init__(
self,
api_client: DatabricksApiClient,
permission_builder: PythonPermissionBuilder,
model: ParsedPythonModel,
run_name: str,
cluster_spec: Dict[str, Any],
additional_job_settings: Dict[str, Any],
) -> None:
self.api_client = api_client
self.permission_builder = permission_builder
self.run_name = model.run_name
self.packages = model.config.packages
self.index_url = model.config.index_url
self.additional_libraries = model.config.additional_libs
if model.config.python_job_config:
self.additional_job_settings = model.config.python_job_config.dict()
else:
self.additional_job_settings = {}
self.run_name = run_name
self.cluster_spec = cluster_spec
self.additional_job_settings = additional_job_settings

def _update_with_libraries(self, job_spec: Dict[str, Any]) -> Dict[str, Any]:
"""Update the job configuration with the required libraries."""

local = job_spec.copy()
libraries = []

for package in self.packages:
if self.index_url:
libraries.append({"pypi": {"package": package, "repo": self.index_url}})
else:
libraries.append({"pypi": {"package": package}})

for library in self.additional_libraries:
libraries.append(library)
@staticmethod
def create(
api_client: DatabricksApiClient,
parsed_model: ParsedPythonModel,
cluster_spec: Dict[str, Any],
) -> "PythonJobConfigCompiler":
permission_builder = PythonPermissionBuilder.create(api_client, parsed_model)
packages = parsed_model.config.packages
index_url = parsed_model.config.index_url
additional_libraries = parsed_model.config.additional_libs
library_config = PythonLibraryConfigurer.get_library_config(
packages, index_url, additional_libraries
)
cluster_spec.update(library_config)
if parsed_model.config.python_job_config:
additional_job_settings = parsed_model.config.python_job_config.dict()
else:
additional_job_settings = {}

local.update({"libraries": libraries})
return local
return PythonJobConfigCompiler(
api_client,
permission_builder,
parsed_model.run_name,
cluster_spec,
additional_job_settings,
)

def compile(self, path: str) -> PythonJobDetails:

Expand All @@ -255,8 +284,6 @@ def compile(self, path: str) -> PythonJobDetails:
}
job_spec.update(self.cluster_spec) # updates 'new_cluster' config

job_spec = self._update_with_libraries(job_spec)

additional_job_config = self.additional_job_settings
access_control_list = self.permission_builder.build_job_permissions()
if access_control_list:
Expand Down Expand Up @@ -288,9 +315,8 @@ def create(
cluster_spec: Dict[str, Any],
) -> "PythonNotebookSubmitter":
notebook_uploader = PythonNotebookUploader.create(api_client, parsed_model)
config_compiler = PythonJobConfigCompiler(
config_compiler = PythonJobConfigCompiler.create(
api_client,
PythonPermissionBuilder.create(api_client, parsed_model),
parsed_model,
cluster_spec,
)
Expand Down Expand Up @@ -375,12 +401,12 @@ class PythonWorkflowConfigCompiler:

def __init__(
self,
workflow_settings: Dict[str, Any],
task_settings: Dict[str, Any],
workflow_spec: Dict[str, Any],
existing_job_id: str,
post_hook_tasks: List[Dict[str, Any]],
) -> None:
self.workflow_settings = workflow_settings
self.task_settings = task_settings
self.existing_job_id = existing_job_id
self.workflow_spec = workflow_spec
self.post_hook_tasks = post_hook_tasks
Expand Down Expand Up @@ -431,7 +457,7 @@ def compile(self, path: str) -> Tuple[Dict[str, Any], str]:
"source": "WORKSPACE",
},
}
notebook_task.update(self.workflow_settings)
notebook_task.update(self.task_settings)

self.workflow_spec["tasks"] = [notebook_task] + self.post_hook_tasks
return self.workflow_spec, self.existing_job_id
Expand Down Expand Up @@ -524,13 +550,6 @@ def submit(self, compiled_code: str) -> None:
class WorkflowPythonJobHelper(BaseDatabricksHelper):
"""Top level helper for Python models using workflow jobs on Databricks."""

@override
def validate_config(self) -> None:
if not self.parsed_model.config.python_job_config:
raise ValueError(
"python_job_config is required for the `python_job_config` submission method."
)

@override
def build_submitter(self) -> PythonSubmitter:
return PythonNotebookWorkflowSubmitter.create(
Expand Down
131 changes: 131 additions & 0 deletions tests/unit/python/test_python_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
from pydantic import ValidationError
import pytest
from dbt.adapters.databricks.python_models.python_config import (
ParsedPythonModel,
PythonJobConfig,
PythonModelConfig,
)


class TestParsedPythonModel:
def test_parsed_model__default_database_schema(self):
parsed_model = {
"alias": "test",
"config": {},
}

model = ParsedPythonModel(**parsed_model)
assert model.catalog == "hive_metastore"
assert model.schema_ == "default"
assert model.identifier == "test"

def test_parsed_model__empty_model_config(self):
parsed_model = {
"database": "database",
"schema": "schema",
"alias": "test",
"config": {},
}

model = ParsedPythonModel(**parsed_model)
assert model.catalog == "database"
assert model.schema_ == "schema"
assert model.identifier == "test"
config = model.config
assert config.user_folder_for_python is False
assert config.timeout == 86400
assert config.job_cluster_config == {}
assert config.access_control_list == []
assert config.packages == []
assert config.index_url is None
assert config.additional_libs == []
assert config.python_job_config is None
assert config.cluster_id is None
assert config.http_path is None
assert config.create_notebook is False

def test_parsed_model__valid_model_config(self):
parsed_model = {
"alias": "test",
"config": {
"user_folder_for_python": True,
"timeout": 100,
"job_cluster_config": {"key": "value"},
"access_control_list": [{"key": "value"}],
"packages": ["package"],
"index_url": "index_url",
"additional_libs": [{"key": "value"}],
"python_job_config": {"name": "name"},
"cluster_id": "cluster_id",
"http_path": "http_path",
"create_notebook": True,
},
}

model = ParsedPythonModel(**parsed_model)
config = model.config
assert config.user_folder_for_python is True
assert config.timeout == 100
assert config.job_cluster_config == {"key": "value"}
assert config.access_control_list == [{"key": "value"}]
assert config.packages == ["package"]
assert config.index_url == "index_url"
assert config.additional_libs == [{"key": "value"}]
assert config.python_job_config.name == "name"
assert config.python_job_config.grants == {}
assert config.python_job_config.existing_job_id == ""
assert config.python_job_config.post_hook_tasks == []
assert config.python_job_config.additional_task_settings == {}
assert config.cluster_id == "cluster_id"
assert config.http_path == "http_path"
assert config.create_notebook is True

def test_parsed_model__extra_model_config(self):
parsed_model = {
"alias": "test",
"config": {
"python_job_config": {"foo": "bar"},
},
}

model = ParsedPythonModel(**parsed_model)
assert model.config.python_job_config.foo == "bar"

def test_parsed_model__run_name(self):
parsed_model = {"alias": "test", "config": {}}
model = ParsedPythonModel(**parsed_model)
assert model.run_name.startswith("hive_metastore-default-test-")

def test_parsed_model__invalid_config(self):
parsed_model = {"alias": "test", "config": []}
with pytest.raises(ValidationError):
ParsedPythonModel(**parsed_model)


class TestPythonModelConfig:
def test_python_model_config__invalid_timeout(self):
config = {"timeout": -1}
with pytest.raises(ValidationError):
PythonModelConfig(**config)


class TestPythonJobConfig:
def test_python_job_config__dict_excludes_expected_fields(self):
config = {
"name": "name",
"grants": {"view": [{"user": "user"}]},
"existing_job_id": "existing_job_id",
"post_hook_tasks": [{"task": "task"}],
"additional_task_settings": {"key": "value"},
}
job_config = PythonJobConfig(**config).dict()
assert job_config == {"name": "name"}

def test_python_job_config__extra_values(self):
config = {
"name": "name",
"existing_job_id": "existing_job_id",
"foo": "bar",
}
job_config = PythonJobConfig(**config).dict()
assert job_config == {"name": "name", "foo": "bar"}
Loading

0 comments on commit d608847

Please sign in to comment.