Skip to content

Commit

Permalink
Add initialize function to fsapi (#1501)
Browse files Browse the repository at this point in the history
* add do

* add initialize function to fsapi

* add *args

* run action when action is initialize; export values when action is iniitalize

* black reformat

* update comment for initialize

* add try/except around run initialize
  • Loading branch information
MichaelPesce authored Oct 24, 2024
1 parent 74c6518 commit 8c651de
Showing 1 changed file with 29 additions and 3 deletions.
32 changes: 29 additions & 3 deletions watertap/ui/fsapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,6 +562,7 @@ class Actions(str, Enum):
solve = "solve"
export = "_export"
diagram = "diagram"
initialize = "initialize"


class FlowsheetCategory(str, Enum):
Expand Down Expand Up @@ -607,6 +608,7 @@ def __init__(
do_build: Callable = None,
do_export: Callable = None,
do_solve: Callable = None,
do_initialize: Callable = None,
get_diagram: Callable = None,
category: FlowsheetCategory = None,
custom_do_param_sweep_kwargs: Dict = None,
Expand Down Expand Up @@ -639,11 +641,14 @@ def __init__(
(do_export, "export"),
(do_build, "build"),
(do_solve, "solve"),
(do_initialize, "initialize"),
):
if arg:
if not callable(arg):
raise TypeError(f"'do_{name}' argument must be callable")
self.add_action(getattr(Actions, name), arg)
elif name == "initialize":
self.add_action(getattr(Actions, name), arg)
else:
raise ValueError(f"'do_{name}' argument is required")
if callable(get_diagram):
Expand Down Expand Up @@ -703,6 +708,24 @@ def get_diagram(self, **kwargs):
else:
return None

def initialize(self, *args, **kwargs):
"""Run initialize function.
Args:
**kwargs: User-defined values
Returns:
Return value of the underlying initialization function. Otherwise, return none
"""
if self.get_action(Actions.initialize) is not None:
try:
result = self.run_action(Actions.initialize, *args, **kwargs)
except Exception as err:
raise RuntimeError(f"Initializing flowsheet: {err}") from err
return result
else:
return None

def dict(self) -> Dict:
"""Serialize.
Expand Down Expand Up @@ -857,6 +880,9 @@ def action_wrapper(*args, **kwargs):
elif action_name == Actions.diagram:
self._actions[action_name] = action_func
return
elif action_name == Actions.initialize:
_log.debug(f"initializing")
result = action_func(self.fs_exp.m)
elif self.fs_exp.obj is None:
raise RuntimeError(
f"Cannot run any flowsheet action (except "
Expand All @@ -872,7 +898,7 @@ def action_wrapper(*args, **kwargs):
if not pyo.check_optimal_termination(result):
raise RuntimeError(f"Solve failed: {result}")
# Sync model with exported values
if action_name in (Actions.build, Actions.solve):
if action_name in (Actions.build, Actions.solve, Actions.initialize):
self.export_values()
return result

Expand All @@ -892,15 +918,15 @@ def get_action(self, name: str) -> Union[Callable, None]:
"""
return self._actions[name]

def run_action(self, name, **kwargs):
def run_action(self, name, *args, **kwargs):
"""Run the named action."""
func = self.get_action(name)
if name.startswith("_"):
raise ValueError(
f"Refusing to call '{name}' action directly since its "
f"name begins with an underscore"
)
return func(**kwargs)
return func(*args, **kwargs)

def export_values(self):
"""Copy current values in underlying Pyomo model into exported model.
Expand Down

0 comments on commit 8c651de

Please sign in to comment.