diff --git a/lale/lib/aif360/_mystic_util.py b/lale/lib/aif360/_mystic_util.py index 6a1e3131c..af10152b9 100644 --- a/lale/lib/aif360/_mystic_util.py +++ b/lale/lib/aif360/_mystic_util.py @@ -1,9 +1,36 @@ +# Copyright 2023 IBM Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from typing import Dict, Set import numpy as np -from mystic.coupler import and_ -from mystic.penalty import quadratic_equality -from mystic.solvers import diffev2 + +try: + from mystic.coupler import and_ + from mystic.penalty import quadratic_equality + from mystic.solvers import diffev2 + + mystic_installed = True +except ModuleNotFoundError: + mystic_installed = False + + +def _assert_mystic_installed(): + assert mystic_installed, """Your Python environment does not have mystic installed. You can install it with + pip install mystic +or with + pip install 'lale[fairness]'""" def parse_solver_soln(n_flat, group_mapping): @@ -122,6 +149,8 @@ def obtain_solver_info( def construct_ci_penalty(A, C, n_ci, i): + _assert_mystic_installed() + def condition(x): reshape_list = [] for _ in range(A): @@ -153,6 +182,8 @@ def create_ci_penalties(n_ci, n_di): def construct_di_penalty(A, C, n_di, F, i): + _assert_mystic_installed() + def condition(x): reshape_list = [] for _ in range(A): @@ -201,6 +232,7 @@ def create_di_penalties(n_ci, n_di, F): def calc_oversample_soln(o_flat, F, n_ci, n_di): + _assert_mystic_installed() # integer constraint ints = np.round @@ -235,6 +267,7 @@ def cost(x): def calc_undersample_soln(o_flat, F, n_ci, n_di): + _assert_mystic_installed() # integer constraint ints = np.round @@ -269,6 +302,7 @@ def cost(x): def calc_mixedsample_soln(o_flat, F, n_ci, n_di): + _assert_mystic_installed() # integer constraint ints = np.round diff --git a/setup.py b/setup.py index af66feed4..32f76c578 100644 --- a/setup.py +++ b/setup.py @@ -71,7 +71,7 @@ "lightgbm<4.0.0", "snapml>=1.7.0rc3,<1.12.0", "liac-arff>=2.4.0", - "tensorflow>=2.4.0", + "tensorflow>=2.4.0,<=2.13.0", "smac<=0.10.0", "numba", "aif360>=0.4.0",