diff --git a/random_survival_forest/RandomSurvivalForest.py b/random_survival_forest/RandomSurvivalForest.py index 87c2a75..152339a 100644 --- a/random_survival_forest/RandomSurvivalForest.py +++ b/random_survival_forest/RandomSurvivalForest.py @@ -32,7 +32,9 @@ def __init__(self, timeline, n_estimators=100, min_leaf=3, unique_deaths=3, n_jo self.trees = [] self.random_states = [] + def fit(self, x, y): + """ Build a forest of trees from the training set (X, y). :param x: The input samples. Should be a Dataframe with the shape [n_samples, n_features]. @@ -40,7 +42,6 @@ def fit(self, x, y): in the second with the shape [n_samples, 2] :return: self: object """ - if self.n_jobs == -1: self.n_jobs = multiprocessing.cpu_count() elif self.n_jobs is None: @@ -48,10 +49,11 @@ def fit(self, x, y): self.random_states = np.random.RandomState(seed=self.random_state).randint(0, 2**32-1, self.n_estimators) self.bootstrap_idxs = self.draw_bootstrap_samples(x) - trees = Parallel(n_jobs=self.n_jobs)(delayed(self.create_tree)(x, y, i) for i in range(self.n_estimators)) + trees = Parallel(n_jobs=self.n_jobs, backend="multiprocessing")(delayed(self.create_tree)(x, y, i) + for i in range(self.n_estimators)) for i in range(len(trees)): - if trees[i].prediction_possible is True: + if trees[i].prediction_possible: self.trees.append(trees[i]) self.bootstraps.append(self.bootstrap_idxs[i]) @@ -79,27 +81,14 @@ def create_tree(self, x, y, i): return tree - def compute_oob_ensembles(self, x): + def compute_oob_ensembles(self, xs): """ Compute OOB ensembles. :return: List of oob ensemble for each sample. """ - oob_ensemble_chfs = [] - for sample_idx in range(x.shape[0]): - denominator = 0 - numerator = 0 - for b in range(len(self.trees)): - if sample_idx not in self.bootstraps[b]: - sample = x.iloc[sample_idx].to_list() - chf = self.trees[b].predict(sample) - denominator = denominator + 1 - numerator = numerator + 1 * chf - - if denominator == 0: - continue - else: - ensemble_chf = numerator/denominator - oob_ensemble_chfs.append(ensemble_chf) + results = [compute_oob_ensemble_chf(sample_idx=sample_idx, xs=xs, trees=self.trees, + bootstraps=self.bootstraps) for sample_idx in range(xs.shape[0])] + oob_ensemble_chfs = [i for i in results if not i.empty] return oob_ensemble_chfs def compute_oob_score(self, x, y): @@ -117,18 +106,8 @@ def predict(self, xs): :param xs: The input samples :return: List of the predicted cumulative hazard functions. """ - ensemble_chfs = [] - for sample_idx in range(xs.shape[0]): - denominator = 0 - numerator = 0 - for b in range(len(self.trees)): - sample = xs.iloc[sample_idx].to_list() - chf = self.trees[b].predict(sample) - denominator = denominator + 1 - numerator = numerator + 1 * chf - - ensemble_chf = numerator / denominator - ensemble_chfs.append(ensemble_chf) + ensemble_chfs = [compute_ensemble_chf(sample_idx=sample_idx, xs=xs, trees=self.trees) + for sample_idx in range(xs.shape[0])] return ensemble_chfs def draw_bootstrap_samples(self, data): @@ -149,3 +128,31 @@ def draw_bootstrap_samples(self, data): bootstrap_idxs.append(bootstrap_idx) return bootstrap_idxs + + +def compute_ensemble_chf(sample_idx, xs, trees): + denominator = 0 + numerator = 0 + for b in range(len(trees)): + sample = xs.iloc[sample_idx].to_list() + chf = trees[b].predict(sample) + denominator = denominator + 1 + numerator = numerator + 1 * chf + ensemble_chf = numerator / denominator + return ensemble_chf + + +def compute_oob_ensemble_chf(sample_idx, xs, trees, bootstraps): + denominator = 0 + numerator = 0 + for b in range(len(trees)): + if sample_idx not in bootstraps[b]: + sample = xs.iloc[sample_idx].to_list() + chf = trees[b].predict(sample) + denominator = denominator + 1 + numerator = numerator + 1 * chf + if denominator != 0: + oob_ensemble_chf = numerator / denominator + else: + oob_ensemble_chf = pd.Series() + return oob_ensemble_chf diff --git a/random_survival_forest/scoring.py b/random_survival_forest/scoring.py index 0eda5fc..3f9cd46 100644 --- a/random_survival_forest/scoring.py +++ b/random_survival_forest/scoring.py @@ -14,10 +14,10 @@ def concordance_index(y_time, y_pred, y_event): concordance = 0 permissible = 0 for pair in possible_pairs: - t1 = y_time.iloc[pair[0]] - t2 = y_time.iloc[pair[1]] - e1 = y_event.iloc[pair[0]] - e2 = y_event.iloc[pair[1]] + t1 = y_time.iat[pair[0]] + t2 = y_time.iat[pair[1]] + e1 = y_event.iat[pair[0]] + e2 = y_event.iat[pair[1]] predicted_outcome_1 = oob_predicted_outcome[pair[0]] predicted_outcome_2 = oob_predicted_outcome[pair[1]] diff --git a/random_survival_forest/splitting.py b/random_survival_forest/splitting.py index 6d4d3ea..bc9ed59 100644 --- a/random_survival_forest/splitting.py +++ b/random_survival_forest/splitting.py @@ -7,20 +7,10 @@ def find_split(node): :param node: Node to find best split for. :return: score of best split, value of best split, variable to split, left indices, right indices. """ - score_opt = 0 - split_val_opt = None - lhs_idxs_opt = None - rhs_idxs_opt = None - split_var_opt = None - for i in node.f_idxs: - score, split_val, lhs_idxs, rhs_idxs = find_best_split_for_variable(node, i) - if score > score_opt: - score_opt = score - split_val_opt = split_val - lhs_idxs_opt = lhs_idxs - rhs_idxs_opt = rhs_idxs - split_var_opt = i - + results = [find_best_split_for_variable(node, i) for i in node.f_idxs] + scores = [item[0] for item in results] + max_idx = scores.index(max(scores)) + score_opt, split_val_opt, lhs_idxs_opt, rhs_idxs_opt, split_var_opt = results[max_idx] return score_opt, split_val_opt, split_var_opt, lhs_idxs_opt, rhs_idxs_opt @@ -30,12 +20,12 @@ def find_best_split_for_variable(node, var_idx): statistics. The logrank_test function of the lifelines package is used here. :param node: Node :param var_idx: Index of variable - :return: score, split value, left indices, right indices. + :return: score, split value, left indices, right indices, feature index. """ score, split_val, lhs_idxs, rhs_idxs = logrank_statistics(x=node.x, y=node.y, feature=var_idx, min_leaf=node.min_leaf) - return score, split_val, lhs_idxs, rhs_idxs + return score, split_val, lhs_idxs, rhs_idxs, var_idx def logrank_statistics(x, y, feature, min_leaf): @@ -48,16 +38,21 @@ def logrank_statistics(x, y, feature, min_leaf): :return: best score, best split value, left indices, right indices """ x_feature = x.reset_index(drop=True).iloc[:, feature] - score_opt = 0 - split_val_opt = None - lhs_idxs = None - rhs_idxs = None + sorted_values = x_feature.sort_values(ascending=True, kind="quicksort").unique() + results = [compute_score(x_feature, y, split_val, min_leaf) for split_val in sorted_values] + scores = [item[0] for item in results] + max_idx = scores.index(max(scores)) + score_opt, split_val_opt, lhs_idxs, rhs_idxs = results[max_idx] + + return score_opt, split_val_opt, lhs_idxs, rhs_idxs + - for split_val in x_feature.sort_values(ascending=True, kind="quicksort").unique(): - feature1 = list(x_feature[x_feature <= split_val].index) - feature2 = list(x_feature[x_feature > split_val].index) - if len(feature1) < min_leaf or len(feature2) < min_leaf: - continue +def compute_score(x_feature, y, split_val, min_leaf): + feature1 = list(x_feature[x_feature <= split_val].index) + feature2 = list(x_feature[x_feature > split_val].index) + if len(feature1) < min_leaf or len(feature2) < min_leaf: + score = 0 + else: durations_a = y.iloc[feature1, 0] event_observed_a = y.iloc[feature1, 1] durations_b = y.iloc[feature2, 0] @@ -65,11 +60,4 @@ def logrank_statistics(x, y, feature, min_leaf): results = logrank_test(durations_A=durations_a, durations_B=durations_b, event_observed_A=event_observed_a, event_observed_B=event_observed_b) score = results.test_statistic - - if score > score_opt: - score_opt = round(score, 3) - split_val_opt = round(split_val, 3) - lhs_idxs = feature1 - rhs_idxs = feature2 - - return score_opt, split_val_opt, lhs_idxs, rhs_idxs + return [score, split_val, feature1, feature2] diff --git a/setup.py b/setup.py index d8cb7f3..f366a2c 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ setup( name='random_survival_forest', # How you named your package folder (MyLib) packages=['random_survival_forest'], # Chose the same as "name" - version='0.7', # Start with a small number and increase it with every change you make + version='0.7.1', # Start with a small number and increase it with every change you make license="MIT License", # Chose a license from here: https://help.github.com/articles/licensing-a-repository long_description=readme, long_description_content_type="text/markdown",