Skip to content

Commit

Permalink
FIX NeuralNetBinaryClassifier with torch.compile (#1058)
Browse files Browse the repository at this point in the history
* FIX NeuralNetBinaryClassifier with torch.compile

Fixes #1057

NeuralNetBinaryClassifier was not working with torch.compile because the
non-linearity was not correctly inferred. This inference depends on the
instance type of the criterion. However, when using torch.compile, the
criterion is wrapped, resulting in the isinstance check to miss. Now, we
unwrap the criterion before checking the instance type.

* Add entry to CHANGES.md
  • Loading branch information
BenjaminBossan authored May 30, 2024
1 parent 2e8f052 commit 346d705
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 0 deletions.
2 changes: 2 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Changed
### Fixed

- Fix an issue with using `NeuralNetBinaryClassifier` with `torch.compile` (#1058)

## [1.0.0] - 2024-05-27

The 1.0.0 release of skorch is here. We think that skorch is at a very stable point, which is why a 1.0.0 release is appropriate. There are no plans to add any breaking changes or major revisions in the future. Instead, our focus now is to keep skorch up-to-date with the latest versions of PyTorch and scikit-learn, and to fix any bugs that may arise.
Expand Down
33 changes: 33 additions & 0 deletions skorch/tests/test_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -4159,3 +4159,36 @@ def test_fit_and_predict_with_compile(self, net_cls, module_cls, data):
# compiled, we rely here on torch keeping this public attribute
assert hasattr(net.module_, 'dynamo_ctx')
assert hasattr(net.criterion_, 'dynamo_ctx')

def test_binary_classifier_with_compile(self, data):
# issue 1057 the problem was that compile would wrap the optimizer,
# resulting in _infer_predict_nonlinearity to return the wrong result
# because of a failing isinstance check
from skorch import NeuralNetBinaryClassifier

X, y = data[0], data[1].astype(np.float32)

class MyNet(nn.Module):
def __init__(self):
super(MyNet, self).__init__()
self.linear = nn.Linear(20, 10)
self.output = nn.Linear(10, 1)

def forward(self, input):
out = self.linear(input)
out = nn.functional.relu(out)
out = self.output(out)
return out.squeeze(-1)

net = NeuralNetBinaryClassifier(
MyNet,
max_epochs=3,
compile=True,
)
# check that no error is raised
net.fit(X, y)

y_proba = net.predict_proba(X)
y_pred = net.predict(X)
assert y_proba.shape == (X.shape[0], 2)
assert y_pred.shape == (X.shape[0],)
2 changes: 2 additions & 0 deletions skorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,6 +660,8 @@ def _infer_predict_nonlinearity(net):
return _identity

criterion = getattr(net, net._criteria[0] + '_')
# unwrap optimizer in case of torch.compile being used
criterion = getattr(criterion, '_orig_mod', criterion)

if isinstance(criterion, CrossEntropyLoss):
return partial(torch.softmax, dim=-1)
Expand Down

0 comments on commit 346d705

Please sign in to comment.