diff --git a/inst/pydevil/pydevil/interface.py b/inst/pydevil/pydevil/interface.py index c95fa01..f49868d 100644 --- a/inst/pydevil/pydevil/interface.py +++ b/inst/pydevil/pydevil/interface.py @@ -1,4 +1,5 @@ import numpy as np +import gc import torch import pyro @@ -207,8 +208,16 @@ def run_SVDE( ret["params"]["lengthscale_kernel"] = pyro.param("lengthscale_param").cpu().detach().numpy() else: ret["params"]["lengthscale_kernel"] = pyro.param("lengthscale_param").detach().numpy() - - + + if cuda and torch.cuda.is_available(): + del elbo_list, beta_list, overdisp_list + del overdispersion, lk, coeff, eta, loc, variance + del input_matrix, model_matrix, group_matrix, beta_estimate_matrix, UMI, gene_specific_model_tensor, kernel_input + del input_matrix_batch, model_matrix_batch, group_matrix_batch, UMI_batch, gene_specific_model_tensor_batch, kernel_input_batch + del loss + torch.cuda.empty_cache() + gc.collect() + return ret