Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gpu batch #35

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open

Gpu batch #35

wants to merge 8 commits into from

Conversation

Jeanselme
Copy link
Contributor

Allow to put on GPU only batch to limit memory use

@Jeanselme
Copy link
Contributor Author

This could benefit of two levels of cuda: one fully on gpu (all train and test) and one batch ?

@chiragnagpal
Copy link
Collaborator

@Jeanselme I think this still requires the model to be moved to CPU before predict_survival can be used. Can you also change https://github.com/autonlab/DeepSurvivalMachines/blob/e88c88556bc603ac58ff83abdbe606e1c29c839b/dsm/losses.py#L307 and https://github.com/autonlab/DeepSurvivalMachines/blob/e88c88556bc603ac58ff83abdbe606e1c29c839b/dsm/losses.py#L341 do move to the same device as 'x'. So that one can run the predict_risk or predict_survival function without any hiccups ?

risk=str(r+1)).detach().numpy())
loss += float(losses.conditional_loss(self.torch_model,
x_val, t_val, e_val, elbo=False,
risk=str(r+1)).detach().cpu().numpy())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont think this needs to be detached and put to CPU. torch doesnt track this variable.

One has to be careful of the test set's size as it will not easily fit on gpu (either batch it or put model on cpu)
More elegant way for obtaining value in unit length tensor
@codecov-io
Copy link

codecov-io commented Feb 9, 2021

Codecov Report

Merging #35 (1c9850a) into master (e88c885) will decrease coverage by 0.42%.
The diff coverage is 41.37%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master      #35      +/-   ##
==========================================
- Coverage   53.18%   52.76%   -0.43%     
==========================================
  Files           7        7              
  Lines         831      851      +20     
==========================================
+ Hits          442      449       +7     
- Misses        389      402      +13     
Impacted Files Coverage Δ
dsm/losses.py 33.95% <33.33%> (ø)
dsm/dsm_api.py 53.12% <35.00%> (-2.12%) ⬇️
dsm/utilities.py 82.45% <66.66%> (-1.33%) ⬇️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update e88c885...1c9850a. Read the comment docs.

@@ -179,8 +186,7 @@ def train_dsm(model,
elbo=False,
risk=str(r+1))

valid_loss = valid_loss.detach().cpu().numpy()
costs.append(float(valid_loss))
costs.append(valid_loss.item())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets use float()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.item() automatically puts on cpu if necessary and cast it

Copy link
Collaborator

@chiragnagpal chiragnagpal left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure why you use .item() ?

@@ -74,9 +74,9 @@ def pretrain_dsm(model, t_train, e_train, t_valid, e_valid,
valid_loss = 0
for r in range(model.risks):
valid_loss += unconditional_loss(premodel, t_valid, e_valid, str(r+1))
valid_loss = valid_loss.detach().cpu().numpy()
valid_loss = valid_loss.item()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets use float

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants