From bdbd152846f413c631aa89654e73afb127e58476 Mon Sep 17 00:00:00 2001 From: jovoni Date: Tue, 6 Feb 2024 16:33:04 +0100 Subject: [PATCH] Better unloading of data --- inst/pydevil/pydevil/interface.py | 10 +++++++++- inst/pydevil/pydevil/utils_hessian.py | 7 +++---- inst/pydevil/pydevil/utils_input.py | 10 ++++++---- 3 files changed, 18 insertions(+), 9 deletions(-) diff --git a/inst/pydevil/pydevil/interface.py b/inst/pydevil/pydevil/interface.py index 143c995..a53b696 100644 --- a/inst/pydevil/pydevil/interface.py +++ b/inst/pydevil/pydevil/interface.py @@ -115,6 +115,14 @@ def run_SVDE( #lk = dist.NegativeBinomial(logits = eta - torch.log(overdispersion) , # total_count= torch.clamp(overdispersion, 1e-9,1e9)).log_prob(input_matrix).sum(dim = 0) + ret['input_matrix'] = unload_tensor(ret['input_matrix']) + ret['model_matrix'] = unload_tensor(ret['model_matrix']) + ret['group_matrix'] = unload_tensor(ret['group_matrix']) + ret['sf'] = unload_tensor(ret['sf']) + ret['offset_matrix'] = unload_tensor(ret['offset_matrix']) + ret['beta_estimate_matrix'] = unload_tensor(ret['beta_estimate_matrix']) + ret['dispersion_priors'] = unload_tensor(ret['dispersion_priors']) + ret['cluster'] = unload_tensor(ret['cluster']) input_matrix = unload_tensor(input_matrix) model_matrix = unload_tensor(model_matrix) overdispersion = unload_tensor(overdispersion) @@ -149,7 +157,7 @@ def run_SVDE( "theta" : overdispersion, #"lk" : lk, "beta" : coeff, - "eta" : eta, + #"eta" : eta, "variance" : loc, "size_factors" : UMI }, diff --git a/inst/pydevil/pydevil/utils_hessian.py b/inst/pydevil/pydevil/utils_hessian.py index 28c01e1..8f1b057 100644 --- a/inst/pydevil/pydevil/utils_hessian.py +++ b/inst/pydevil/pydevil/utils_hessian.py @@ -1,5 +1,6 @@ import torch from tqdm import trange +from pydevil.utils_input import unload_tensor def compute_hessian(obs, model_matrix, coeff, overdispersion, size_factors): beta = torch.tensor(coeff) @@ -33,8 +34,7 @@ def compute_hessians(input_matrix, model_matrix, coeff, overdispersion, size_fac t.set_description('Variance estimation: {:.2f} '.format(gene_idx / n_genes)) t.refresh() - if torch.cuda.is_available(): - solved_hessian = solved_hessian.detach().cpu() + solved_hessian = unload_tensor(solved_hessian) if full_cov: loc[gene_idx, :, :] = solved_hessian @@ -94,8 +94,7 @@ def compute_sandwiches(input_matrix, model_matrix, coeff, overdispersion, size_f t.set_description('Clustered variance estimation: {:.2f} '.format(gene_idx / n_genes)) t.refresh() - if torch.cuda.is_available(): - s = s.detach().cpu() + s = unload_tensor(s) loc[gene_idx, :, :] = s del s diff --git a/inst/pydevil/pydevil/utils_input.py b/inst/pydevil/pydevil/utils_input.py index bbb64d5..c1d7acb 100644 --- a/inst/pydevil/pydevil/utils_input.py +++ b/inst/pydevil/pydevil/utils_input.py @@ -24,10 +24,12 @@ def unload_tensor(obj): """ Unload the tensor from the GPU. """ - if obj.get_device() == 0: - return obj.cpu().detach().numpy() - else: - return obj.detach().numpy() + if isinstance(obj, torch.Tensor): + if obj.get_device() == 0: + return obj.cpu().detach().numpy() + else: + return obj.detach().numpy() + return obj def validate_boolean(parameter, parameter_name): """