Skip to content

Commit

Permalink
Fix baml validation imports (#1038)
Browse files Browse the repository at this point in the history
<!-- ELLIPSIS_HIDDEN -->



> [!IMPORTANT]
> Fix import and definition of `BamlValidationError` by moving it to
`internal_monkeypatch.py` and updating related imports and Rust code.
> 
>   - **Imports and Definitions**:
> - Move `BamlValidationError` definition to `internal_monkeypatch.py`.
>     - Update import of `BamlValidationError` in `errors.py`.
>   - **File Deletions**:
>     - Remove `errors.pyi` as it is no longer needed.
>   - **Rust Code Adjustments**:
> - Modify `raise_baml_validation_error` in `errors.rs` to import
`BamlValidationError` from `internal_monkeypatch.py`.
> - Remove manual definition of `BamlValidationError` in `errors.rs`.
> 
> <sup>This description was created by </sup>[<img alt="Ellipsis"
src="https://img.shields.io/badge/Ellipsis-blue?color=175173">](https://www.ellipsis.dev?ref=BoundaryML%2Fbaml&utm_source=github&utm_medium=referral)<sup>
for 17fd482. It will automatically
update as commits are pushed.</sup>

<!-- ELLIPSIS_HIDDEN -->
  • Loading branch information
aaronvg authored Oct 14, 2024
1 parent da9d182 commit 1c14e8a
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 95 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,31 +40,27 @@ def __ctx(self) -> RuntimeContextManager:
ctx[thread_id] = self.rt.create_context_manager()
return ctx[thread_id]


def allow_reset(self) -> bool:
ctx = self.ctx.get()

if len(ctx) > 1:
print("Too many ctxs!")
return False


thread_id = current_thread_id()
if thread_id not in ctx:
print("Thread not in ctx!")
return False
return False

for c in ctx.values():
if c.context_depth() > 0:
print("Context depth is greater than 0!")
return False


return True

def reset(self) -> None:
self.ctx.set({current_thread_id(): self.rt.create_context_manager()})


def upsert_tags(self, **tags: str) -> None:
mngr = self.__ctx()
Expand Down
4 changes: 1 addition & 3 deletions engine/language_client_python/python_src/baml_py/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@
BamlClientHttpError,
BamlInvalidArgumentError,
)

# hack to get the BamlValidationError class which is a custom error
from .baml_py.errors import BamlValidationError
from .internal_monkeypatch import BamlValidationError


__all__ = [
Expand Down
26 changes: 0 additions & 26 deletions engine/language_client_python/python_src/baml_py/errors.pyi

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from .baml_py import BamlError


# Define the BamlValidationError exception with additional fields
# note on custom exceptions https://github.com/PyO3/pyo3/issues/295
# can't use extends=PyException yet https://github.com/PyO3/pyo3/discussions/3838
class BamlValidationError(BamlError):
def __init__(self, prompt: str, message: str, raw_output: str):
super().__init__(message)
self.prompt = prompt
self.message = message
self.raw_output = raw_output

def __str__(self):
return f"BamlValidationError(message={self.message}, raw_output={self.raw_output}, prompt={self.prompt})"

def __repr__(self):
return f"BamlValidationError(message={self.message}, raw_output={self.raw_output}, prompt={self.prompt})"
66 changes: 6 additions & 60 deletions engine/language_client_python/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,76 +17,22 @@ create_exception!(baml_py, BamlClientHttpError, BamlClientError);

// Define the BamlValidationError exception with additional fields
// can't use extends=PyException yet https://github.com/PyO3/pyo3/discussions/3838
#[pyfunction]

#[allow(non_snake_case)]
fn raise_baml_validation_error(prompt: String, message: String, raw_output: String) -> PyErr {
Python::with_gil(|py| {
// Import the current module to access the BamlValidationError class
let module = PyModule::import(py, "baml_py.errors").unwrap();
let exception = module.getattr("BamlValidationError").unwrap();
let internal_monkeypatch = py.import("baml_py.internal_monkeypatch").unwrap();
let exception = internal_monkeypatch.getattr("BamlValidationError").unwrap();
let args = (prompt, message, raw_output);
let instance = exception.call1(args).unwrap();
PyErr::from_value(instance.into())
let inst = exception.call1(args).unwrap();
PyErr::from_value(inst)
})
}

/// Defines the errors module with the BamlValidationError exception.
/// IIRC the name of this function is the name of the module that pyo3 generates (errors.py)
#[pymodule]
pub fn errors(parent_module: &Bound<'_, PyModule>) -> PyResult<()> {
// Define the BamlValidationError Python exception class in a hacky way first, manually into that errors module.
let errors_module = PyModule::from_code_bound(
parent_module.py(),
r#"
class BamlValidationError(Exception):
def __init__(self, prompt, message, raw_output):
super().__init__(message)
self.prompt = prompt
self.message = message
self.raw_output = raw_output
def __str__(self):
return f"BamlValidationError(message={self.message}, raw_output={self.raw_output}, prompt={self.prompt})"
def __repr__(self):
return f"BamlValidationError(message={self.message}, raw_output={self.raw_output}, prompt={self.prompt})"
"#,
"errors.py",
"errors",
)?;

// Add the raise_baml_validation_error function to the module
parent_module.add_wrapped(wrap_pyfunction!(raise_baml_validation_error))?;

// add the other exceptions in
errors_module.add(
"BamlError",
errors_module.py().get_type_bound::<BamlError>(),
)?;
errors_module.add(
"BamlInvalidArgumentError",
errors_module
.py()
.get_type_bound::<BamlInvalidArgumentError>(),
)?;
errors_module.add(
"BamlClientError",
errors_module.py().get_type_bound::<BamlClientError>(),
)?;
errors_module.add(
"BamlClientHttpError",
errors_module.py().get_type_bound::<BamlClientHttpError>(),
)?;

parent_module.add_submodule(&errors_module)?;

// we have to do this hack or python will complain the baml_py.errors is not a package.
parent_module
.py()
.import("sys")?
.getattr("modules")?
.set_item("baml_py.errors", errors_module.clone())?;

// now add the other errors again to the parent module
parent_module.add(
"BamlError",
parent_module.py().get_type_bound::<BamlError>(),
Expand Down

0 comments on commit 1c14e8a

Please sign in to comment.