Skip to content

Commit

Permalink
Merge pull request #10 from jettify/fix-regression-printer
Browse files Browse the repository at this point in the history
Fix regression printer
  • Loading branch information
jettify authored Oct 28, 2019
2 parents 8202332 + 1fb112d commit 5272e47
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 4 deletions.
1 change: 1 addition & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ Simple example
Please check full Titanic example here: https://github.com/jettify/ibreakdown/blob/master/examples/titanic.py

.. code::
+------------------------------------+-----------------+--------------------+--------------------+
| Feature Name | Feature Value | Contrib:Deceased | Contrib:Survived |
+------------------------------------+-----------------+--------------------+--------------------|
Expand Down
6 changes: 3 additions & 3 deletions ibreakdown/explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

from .explanation import ClassificationExplanation, RegressionExplanation
from .utils import (
feature_group_values,
features_groups,
magnituge,
normalize_array,
to_matrix,
magnituge,
features_groups,
feature_group_values,
)


Expand Down
4 changes: 3 additions & 1 deletion ibreakdown/explanation.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@ def _build_df(self):
]
feature_names = ['intercept'] + feature_names + ['PREDICTION']
feature_values = [None] + self.feature_values + [None]
contrib = [self.intercept] + self.contributions + [self.pred_value]
contrib = (
[self.intercept] + self.contributions.tolist() + [self.pred_value]
)

data = {
'Feature Name': feature_names,
Expand Down
10 changes: 10 additions & 0 deletions tests/test_explainer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import io

import pytest
import numpy as np

from sklearn.datasets import load_boston
from sklearn.datasets import load_iris
from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
Expand Down Expand Up @@ -30,6 +33,9 @@ def test_regression(seed):
assert np.sum(exp.contributions) + exp.intercept == pytest.approx(
pred[0]
)
with io.StringIO() as buf:
exp.print(file=buf, flush=True)
assert len(buf.getvalue()) > 0


def test_multiclass(seed):
Expand All @@ -53,3 +59,7 @@ def test_multiclass(seed):
# check invariant
invariant = np.sum(exp.contributions, axis=0) + exp.intercept
assert invariant == pytest.approx(pred[0])

with io.StringIO() as buf:
exp.print(file=buf, flush=True)
assert len(buf.getvalue()) > 0

0 comments on commit 5272e47

Please sign in to comment.