Skip to content

Commit

Permalink
Better unloading of data
Browse files Browse the repository at this point in the history
  • Loading branch information
jovoni committed Feb 6, 2024
1 parent 83b9cda commit bdbd152
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 9 deletions.
10 changes: 9 additions & 1 deletion inst/pydevil/pydevil/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,14 @@ def run_SVDE(
#lk = dist.NegativeBinomial(logits = eta - torch.log(overdispersion) ,
# total_count= torch.clamp(overdispersion, 1e-9,1e9)).log_prob(input_matrix).sum(dim = 0)

ret['input_matrix'] = unload_tensor(ret['input_matrix'])
ret['model_matrix'] = unload_tensor(ret['model_matrix'])
ret['group_matrix'] = unload_tensor(ret['group_matrix'])
ret['sf'] = unload_tensor(ret['sf'])
ret['offset_matrix'] = unload_tensor(ret['offset_matrix'])
ret['beta_estimate_matrix'] = unload_tensor(ret['beta_estimate_matrix'])
ret['dispersion_priors'] = unload_tensor(ret['dispersion_priors'])
ret['cluster'] = unload_tensor(ret['cluster'])
input_matrix = unload_tensor(input_matrix)
model_matrix = unload_tensor(model_matrix)
overdispersion = unload_tensor(overdispersion)
Expand Down Expand Up @@ -149,7 +157,7 @@ def run_SVDE(
"theta" : overdispersion,
#"lk" : lk,
"beta" : coeff,
"eta" : eta,
#"eta" : eta,
"variance" : loc,
"size_factors" : UMI
},
Expand Down
7 changes: 3 additions & 4 deletions inst/pydevil/pydevil/utils_hessian.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
from tqdm import trange
from pydevil.utils_input import unload_tensor

def compute_hessian(obs, model_matrix, coeff, overdispersion, size_factors):
beta = torch.tensor(coeff)
Expand Down Expand Up @@ -33,8 +34,7 @@ def compute_hessians(input_matrix, model_matrix, coeff, overdispersion, size_fac

t.set_description('Variance estimation: {:.2f} '.format(gene_idx / n_genes))
t.refresh()
if torch.cuda.is_available():
solved_hessian = solved_hessian.detach().cpu()
solved_hessian = unload_tensor(solved_hessian)

if full_cov:
loc[gene_idx, :, :] = solved_hessian
Expand Down Expand Up @@ -94,8 +94,7 @@ def compute_sandwiches(input_matrix, model_matrix, coeff, overdispersion, size_f
t.set_description('Clustered variance estimation: {:.2f} '.format(gene_idx / n_genes))
t.refresh()

if torch.cuda.is_available():
s = s.detach().cpu()
s = unload_tensor(s)

loc[gene_idx, :, :] = s
del s
Expand Down
10 changes: 6 additions & 4 deletions inst/pydevil/pydevil/utils_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,12 @@ def unload_tensor(obj):
"""
Unload the tensor from the GPU.
"""
if obj.get_device() == 0:
return obj.cpu().detach().numpy()
else:
return obj.detach().numpy()
if isinstance(obj, torch.Tensor):
if obj.get_device() == 0:
return obj.cpu().detach().numpy()
else:
return obj.detach().numpy()
return obj

def validate_boolean(parameter, parameter_name):
"""
Expand Down

0 comments on commit bdbd152

Please sign in to comment.