Skip to content

Commit

Permalink
fix int(final_cfs_sparse.at[cf_ix, feature]) in do_posthoc_sparsity_e…
Browse files Browse the repository at this point in the history
…nhancement, do_linear_search and do_binary_search (#343)

Signed-off-by: An <[email protected]>

Signed-off-by: An <[email protected]>
Co-authored-by: An <>
  • Loading branch information
azz147 authored Nov 16, 2022
1 parent 1c55f7b commit 8f17cfd
Showing 1 changed file with 6 additions and 7 deletions.
13 changes: 6 additions & 7 deletions dice_ml/explainer_interfaces/explainer_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,7 @@ def do_posthoc_sparsity_enhancement(self, final_cfs_sparse, query_instance, post
for feature in features_sorted:
# current_pred = self.predict_fn_for_sparsity(final_cfs_sparse.iat[[cf_ix]][self.data_interface.feature_names])
# feat_ix = self.data_interface.continuous_feature_names.index(feature)
diff = query_instance[feature].iat[0] - int(final_cfs_sparse.at[cf_ix, feature])
diff = query_instance[feature].iat[0] - final_cfs_sparse.at[cf_ix, feature]
if(abs(diff) <= quantiles[feature]):
if posthoc_sparsity_algorithm == "linear":
final_cfs_sparse = self.do_linear_search(diff, decimal_prec, query_instance, cf_ix,
Expand Down Expand Up @@ -561,17 +561,16 @@ def do_linear_search(self, diff, decimal_prec, query_instance, cf_ix, feature, f
while((abs(diff) > 10e-4) and (np.sign(diff*old_diff) > 0) and
self.is_cf_valid(current_pred)) and (count_steps < limit_steps_ls):

old_val = int(final_cfs_sparse.at[cf_ix, feature])
old_val = final_cfs_sparse.at[cf_ix, feature]
final_cfs_sparse.at[cf_ix, feature] += np.sign(diff)*change
current_pred = self.predict_fn_for_sparsity(final_cfs_sparse.loc[[cf_ix]][self.data_interface.feature_names])
old_diff = diff

if not self.is_cf_valid(current_pred):
final_cfs_sparse.at[cf_ix, feature] = old_val
diff = query_instance[feature].iat[0] - int(final_cfs_sparse.at[cf_ix, feature])
return final_cfs_sparse

diff = query_instance[feature].iat[0] - int(final_cfs_sparse.at[cf_ix, feature])
diff = query_instance[feature].iat[0] - final_cfs_sparse.at[cf_ix, feature]

count_steps += 1

Expand All @@ -581,7 +580,7 @@ def do_binary_search(self, diff, decimal_prec, query_instance, cf_ix, feature, f
"""Performs a binary search between continuous features of a CF and corresponding values
in query_instance until the prediction class changes."""

old_val = int(final_cfs_sparse.at[cf_ix, feature])
old_val = final_cfs_sparse.at[cf_ix, feature]
final_cfs_sparse.at[cf_ix, feature] = query_instance[feature].iat[0]
# Prediction of the query instance
current_pred = self.predict_fn_for_sparsity(final_cfs_sparse.loc[[cf_ix]][self.data_interface.feature_names])
Expand All @@ -594,7 +593,7 @@ def do_binary_search(self, diff, decimal_prec, query_instance, cf_ix, feature, f

# move the CF values towards the query_instance
if diff > 0:
left = int(final_cfs_sparse.at[cf_ix, feature])
left = final_cfs_sparse.at[cf_ix, feature]
right = query_instance[feature].iat[0]

while left <= right:
Expand All @@ -614,7 +613,7 @@ def do_binary_search(self, diff, decimal_prec, query_instance, cf_ix, feature, f

else:
left = query_instance[feature].iat[0]
right = int(final_cfs_sparse.at[cf_ix, feature])
right = final_cfs_sparse.at[cf_ix, feature]

while right >= left:
current_val = right - ((right - left)/2)
Expand Down

0 comments on commit 8f17cfd

Please sign in to comment.