From e999218ad36823f507a143b21e71a5d7ffab88f5 Mon Sep 17 00:00:00 2001 From: jovoni Date: Tue, 24 Oct 2023 11:48:58 +0200 Subject: [PATCH] Added cleaning of cuda memory --- inst/pydevil/pydevil/interface.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/inst/pydevil/pydevil/interface.py b/inst/pydevil/pydevil/interface.py index c95fa01..f49868d 100644 --- a/inst/pydevil/pydevil/interface.py +++ b/inst/pydevil/pydevil/interface.py @@ -1,4 +1,5 @@ import numpy as np +import gc import torch import pyro @@ -207,8 +208,16 @@ def run_SVDE( ret["params"]["lengthscale_kernel"] = pyro.param("lengthscale_param").cpu().detach().numpy() else: ret["params"]["lengthscale_kernel"] = pyro.param("lengthscale_param").detach().numpy() - - + + if cuda and torch.cuda.is_available(): + del elbo_list, beta_list, overdisp_list + del overdispersion, lk, coeff, eta, loc, variance + del input_matrix, model_matrix, group_matrix, beta_estimate_matrix, UMI, gene_specific_model_tensor, kernel_input + del input_matrix_batch, model_matrix_batch, group_matrix_batch, UMI_batch, gene_specific_model_tensor_batch, kernel_input_batch + del loss + torch.cuda.empty_cache() + gc.collect() + return ret