diff --git a/README.md b/README.md index a912598..454da39 100644 --- a/README.md +++ b/README.md @@ -205,6 +205,19 @@ imports = [ ] ``` +This also works with excludes and includes. + +```toml +[tool.paracelsus] +base = "example.base:Base" +imports = [ + "example.models" +] +exclude_tables = [ + "comments" +] +``` + ## Sponsorship This project is developed by [Robert Hafner](https://blog.tedivm.com) If you find this project useful please consider sponsoring me using Github! diff --git a/paracelsus/cli.py b/paracelsus/cli.py index 8c5b3b7..88e891f 100644 --- a/paracelsus/cli.py +++ b/paracelsus/cli.py @@ -67,15 +67,15 @@ def graph( settings = get_pyproject_settings() base_class = get_base_class(base_class_path, settings) - if settings and "imports" in settings: + if "imports" in settings: import_module.extend(settings["imports"]) typer.echo( get_graph_string( base_class_path=base_class, import_module=import_module, - include_tables=set(include_tables), - exclude_tables=set(exclude_tables), + include_tables=set(include_tables + settings.get("include_tables", [])), + exclude_tables=set(exclude_tables + settings.get("exclude_tables", [])), python_dir=python_dir, format=format.value, ) @@ -141,12 +141,16 @@ def inject( ), ] = False, ): + settings = get_pyproject_settings() + if "imports" in settings: + import_module.extend(settings["imports"]) + # Generate Graph graph = get_graph_string( base_class_path=base_class_path, import_module=import_module, - include_tables=set(include_tables), - exclude_tables=set(exclude_tables), + include_tables=set(include_tables + settings.get("include_tables", [])), + exclude_tables=set(exclude_tables + settings.get("exclude_tables", [])), python_dir=python_dir, format=format.value, ) diff --git a/paracelsus/graph.py b/paracelsus/graph.py index 936924f..ed2d048 100644 --- a/paracelsus/graph.py +++ b/paracelsus/graph.py @@ -106,5 +106,11 @@ def filter_metadata( filtered_metadata = MetaData() for tablename, table in metadata.tables.items(): if tablename in include_tables: - table.tometadata(filtered_metadata) + if hasattr(table, "to_metadata"): + # to_metadata is the new way to do this, but it's only available in newer versions of SQLAlchemy. + table = table.to_metadata(filtered_metadata) + else: + # tometadata is deprecated, but we still need to support it for older versions of SQLAlchemy. + table = table.tometadata(filtered_metadata) + return filtered_metadata diff --git a/paracelsus/pyproject.py b/paracelsus/pyproject.py index 02df80d..c28584d 100644 --- a/paracelsus/pyproject.py +++ b/paracelsus/pyproject.py @@ -8,13 +8,13 @@ import toml as tomllib # type: ignore -def get_pyproject_settings(dir: Path = Path(os.getcwd())) -> Dict[str, Any] | None: +def get_pyproject_settings(dir: Path = Path(os.getcwd())) -> Dict[str, Any]: pyproject = dir / "pyproject.toml" if not pyproject.exists(): - return None + return {} with open(pyproject, "rb") as f: data = tomllib.loads(f.read().decode()) - return data.get("tool", {}).get("paracelsus", None) + return data.get("tool", {}).get("paracelsus", {}) diff --git a/tests/test_pyproject.py b/tests/test_pyproject.py index 4a6a223..6122619 100644 --- a/tests/test_pyproject.py +++ b/tests/test_pyproject.py @@ -11,4 +11,4 @@ def test_pyproject(package_path): def test_pyproject_none(): settings = get_pyproject_settings() - assert settings is None + assert settings == {}