Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor out ProblemConfig #46

Merged
merged 10 commits into from
Jul 4, 2024
34 changes: 15 additions & 19 deletions entmoot/constraints.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Callable
from typing import Callable

import pyomo.environ as pyo

from entmoot.problem_config import FeatureType
from entmoot.problem_config import AnyFeatureT
from entmoot.typing.optimizer_stubs import PyomoModelT

if TYPE_CHECKING:
from problem_config import FeatureType

ConstraintFunctionType = Callable[[pyo.ConcreteModel, int], pyo.Expression]
ConstraintFunctionType = Callable[[PyomoModelT, int], pyo.Expression]


class Constraint(ABC):
Expand All @@ -23,7 +21,7 @@ def __init__(self, feature_keys: list[str]):
self.feature_keys = feature_keys

def _get_feature_vars(
self, model: pyo.ConcreteModel, feat_list: list["FeatureType"]
self, model: PyomoModelT, feat_list: list[AnyFeatureT]
) -> list[pyo.Var]:
"""Return a list of all the pyo.Vars, in the order of the constraint definition"""
all_keys = [feat.name for feat in feat_list]
Expand All @@ -33,7 +31,7 @@ def _get_feature_vars(

@abstractmethod
def as_pyomo_constraint(
self, model: pyo.ConcreteModel, feat_list: list["FeatureType"]
self, model: PyomoModelT, feat_list: list[AnyFeatureT]
) -> pyo.Constraint:
"""Convert to a pyomo.Constraint object.

Expand All @@ -53,8 +51,8 @@ def add(self, constraint: Constraint):

def apply_pyomo_constraints(
self,
model: pyo.ConcreteModel,
feat_list: list[FeatureType],
model: PyomoModelT,
feat_list: list[AnyFeatureT],
pyo_constraint_list: pyo.ConstraintList,
) -> None:
"""Add constraints to a pyo.ConstraintList object.
Expand All @@ -79,14 +77,14 @@ class ExpressionConstraint(Constraint):
For constraints that can be simply defined by an expression of variables.
"""

def as_pyomo_constraint(
self, model: pyo.ConcreteModel, feat_list: list["FeatureType"]
) -> pyo.Constraint:
def as_pyomo_constraint(self, model, feat_list):
features = self._get_feature_vars(model, feat_list)
return pyo.Constraint(expr=self._get_expr(model, features))

@abstractmethod
def _get_expr(self, model, features) -> pyo.Expression:
def _get_expr(
self, model: PyomoModelT, features: list[AnyFeatureT]
) -> pyo.Expression:
pass


Expand All @@ -96,15 +94,13 @@ class FunctionalConstraint(Constraint):
For constraints that require creating intermediate variables and access to the model.
"""

def as_pyomo_constraint(
self, model: pyo.ConcreteModel, feat_list: list["FeatureType"]
) -> pyo.Constraint:
def as_pyomo_constraint(self, model, feat_list):
features = self._get_feature_vars(model, feat_list)
return pyo.Constraint(rule=self._get_function(model, features))

@abstractmethod
def _get_function(
self, model: pyo.ConcreteModel, features: list["FeatureType"]
self, model: PyomoModelT, features: list[AnyFeatureT]
) -> ConstraintFunctionType:
pass

Expand All @@ -118,7 +114,7 @@ def __init__(self, feature_keys: list[str], coefficients: list[float], rhs: floa
self.rhs = rhs
super().__init__(feature_keys)

def _get_lhs(self, features: pyo.ConcreteModel) -> pyo.Expression:
def _get_lhs(self, features) -> pyo.Expression:
"""Get the left-hand side of the linear constraint"""
return sum(f * c for f, c in zip(features, self.coefficients))

Expand Down
4 changes: 2 additions & 2 deletions entmoot/models/enting.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def __init__(
problem_config=problem_config, params=unc_params
)

def fit(self, X: np.ndarray, y: np.ndarray) -> None:
def fit(self, X: list | np.ndarray, y: np.ndarray) -> None:
"""
Performs the training of you tree model using training data and labels
"""
Expand Down Expand Up @@ -113,7 +113,7 @@ def leaf_bnd_predict(self, obj_name, leaf_enc):
bnds = self._problem_config.get_enc_bnd()
return self.mean_model.meta_tree_dict[obj_name].prune_var_bnds(leaf_enc, bnds)

def predict(self, X: np.ndarray, is_enc=False) -> list:
def predict(self, X: list | np.ndarray, is_enc=False) -> list:
"""
Computes prediction value of tree model for X.
"""
Expand Down
3 changes: 2 additions & 1 deletion entmoot/optimizers/gurobi_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from entmoot.models.enting import Enting
from entmoot.problem_config import Categorical, ProblemConfig
from entmoot.typing.optimizer_stubs import GurobiModelT
from entmoot.utils import OptResult

ActiveLeavesT = list[list[tuple[int, str]]]
Expand Down Expand Up @@ -72,7 +73,7 @@ def get_active_leaf_sol(self) -> ActiveLeavesT:
def solve(
self,
tree_model: Enting,
model_core: Optional[gur.Model] = None,
model_core: Optional[GurobiModelT] = None,
weights: Optional[tuple[float, ...]] = None,
use_env: bool = False,
) -> OptResult:
Expand Down
Loading