Skip to content

Commit

Permalink
Enhance performance of splitting and fix parallelization issue
Browse files Browse the repository at this point in the history
  • Loading branch information
julianspaeth committed Oct 27, 2019
1 parent ae8cfc1 commit 095d603
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 70 deletions.
71 changes: 39 additions & 32 deletions random_survival_forest/RandomSurvivalForest.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,26 +32,28 @@ def __init__(self, timeline, n_estimators=100, min_leaf=3, unique_deaths=3, n_jo
self.trees = []
self.random_states = []


def fit(self, x, y):

"""
Build a forest of trees from the training set (X, y).
:param x: The input samples. Should be a Dataframe with the shape [n_samples, n_features].
:param y: The target values as a Dataframe with the survival time in the first column and the event
in the second with the shape [n_samples, 2]
:return: self: object
"""

if self.n_jobs == -1:
self.n_jobs = multiprocessing.cpu_count()
elif self.n_jobs is None:
self.n_jobs = 1
self.random_states = np.random.RandomState(seed=self.random_state).randint(0, 2**32-1, self.n_estimators)
self.bootstrap_idxs = self.draw_bootstrap_samples(x)

trees = Parallel(n_jobs=self.n_jobs)(delayed(self.create_tree)(x, y, i) for i in range(self.n_estimators))
trees = Parallel(n_jobs=self.n_jobs, backend="multiprocessing")(delayed(self.create_tree)(x, y, i)
for i in range(self.n_estimators))

for i in range(len(trees)):
if trees[i].prediction_possible is True:
if trees[i].prediction_possible:
self.trees.append(trees[i])
self.bootstraps.append(self.bootstrap_idxs[i])

Expand Down Expand Up @@ -79,27 +81,14 @@ def create_tree(self, x, y, i):

return tree

def compute_oob_ensembles(self, x):
def compute_oob_ensembles(self, xs):
"""
Compute OOB ensembles.
:return: List of oob ensemble for each sample.
"""
oob_ensemble_chfs = []
for sample_idx in range(x.shape[0]):
denominator = 0
numerator = 0
for b in range(len(self.trees)):
if sample_idx not in self.bootstraps[b]:
sample = x.iloc[sample_idx].to_list()
chf = self.trees[b].predict(sample)
denominator = denominator + 1
numerator = numerator + 1 * chf

if denominator == 0:
continue
else:
ensemble_chf = numerator/denominator
oob_ensemble_chfs.append(ensemble_chf)
results = [compute_oob_ensemble_chf(sample_idx=sample_idx, xs=xs, trees=self.trees,
bootstraps=self.bootstraps) for sample_idx in range(xs.shape[0])]
oob_ensemble_chfs = [i for i in results if not i.empty]
return oob_ensemble_chfs

def compute_oob_score(self, x, y):
Expand All @@ -117,18 +106,8 @@ def predict(self, xs):
:param xs: The input samples
:return: List of the predicted cumulative hazard functions.
"""
ensemble_chfs = []
for sample_idx in range(xs.shape[0]):
denominator = 0
numerator = 0
for b in range(len(self.trees)):
sample = xs.iloc[sample_idx].to_list()
chf = self.trees[b].predict(sample)
denominator = denominator + 1
numerator = numerator + 1 * chf

ensemble_chf = numerator / denominator
ensemble_chfs.append(ensemble_chf)
ensemble_chfs = [compute_ensemble_chf(sample_idx=sample_idx, xs=xs, trees=self.trees)
for sample_idx in range(xs.shape[0])]
return ensemble_chfs

def draw_bootstrap_samples(self, data):
Expand All @@ -149,3 +128,31 @@ def draw_bootstrap_samples(self, data):
bootstrap_idxs.append(bootstrap_idx)

return bootstrap_idxs


def compute_ensemble_chf(sample_idx, xs, trees):
denominator = 0
numerator = 0
for b in range(len(trees)):
sample = xs.iloc[sample_idx].to_list()
chf = trees[b].predict(sample)
denominator = denominator + 1
numerator = numerator + 1 * chf
ensemble_chf = numerator / denominator
return ensemble_chf


def compute_oob_ensemble_chf(sample_idx, xs, trees, bootstraps):
denominator = 0
numerator = 0
for b in range(len(trees)):
if sample_idx not in bootstraps[b]:
sample = xs.iloc[sample_idx].to_list()
chf = trees[b].predict(sample)
denominator = denominator + 1
numerator = numerator + 1 * chf
if denominator != 0:
oob_ensemble_chf = numerator / denominator
else:
oob_ensemble_chf = pd.Series()
return oob_ensemble_chf
8 changes: 4 additions & 4 deletions random_survival_forest/scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ def concordance_index(y_time, y_pred, y_event):
concordance = 0
permissible = 0
for pair in possible_pairs:
t1 = y_time.iloc[pair[0]]
t2 = y_time.iloc[pair[1]]
e1 = y_event.iloc[pair[0]]
e2 = y_event.iloc[pair[1]]
t1 = y_time.iat[pair[0]]
t2 = y_time.iat[pair[1]]
e1 = y_event.iat[pair[0]]
e2 = y_event.iat[pair[1]]
predicted_outcome_1 = oob_predicted_outcome[pair[0]]
predicted_outcome_2 = oob_predicted_outcome[pair[1]]

Expand Down
54 changes: 21 additions & 33 deletions random_survival_forest/splitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,10 @@ def find_split(node):
:param node: Node to find best split for.
:return: score of best split, value of best split, variable to split, left indices, right indices.
"""
score_opt = 0
split_val_opt = None
lhs_idxs_opt = None
rhs_idxs_opt = None
split_var_opt = None
for i in node.f_idxs:
score, split_val, lhs_idxs, rhs_idxs = find_best_split_for_variable(node, i)
if score > score_opt:
score_opt = score
split_val_opt = split_val
lhs_idxs_opt = lhs_idxs
rhs_idxs_opt = rhs_idxs
split_var_opt = i

results = [find_best_split_for_variable(node, i) for i in node.f_idxs]
scores = [item[0] for item in results]
max_idx = scores.index(max(scores))
score_opt, split_val_opt, lhs_idxs_opt, rhs_idxs_opt, split_var_opt = results[max_idx]
return score_opt, split_val_opt, split_var_opt, lhs_idxs_opt, rhs_idxs_opt


Expand All @@ -30,12 +20,12 @@ def find_best_split_for_variable(node, var_idx):
statistics. The logrank_test function of the lifelines package is used here.
:param node: Node
:param var_idx: Index of variable
:return: score, split value, left indices, right indices.
:return: score, split value, left indices, right indices, feature index.
"""
score, split_val, lhs_idxs, rhs_idxs = logrank_statistics(x=node.x, y=node.y,
feature=var_idx,
min_leaf=node.min_leaf)
return score, split_val, lhs_idxs, rhs_idxs
return score, split_val, lhs_idxs, rhs_idxs, var_idx


def logrank_statistics(x, y, feature, min_leaf):
Expand All @@ -48,28 +38,26 @@ def logrank_statistics(x, y, feature, min_leaf):
:return: best score, best split value, left indices, right indices
"""
x_feature = x.reset_index(drop=True).iloc[:, feature]
score_opt = 0
split_val_opt = None
lhs_idxs = None
rhs_idxs = None
sorted_values = x_feature.sort_values(ascending=True, kind="quicksort").unique()
results = [compute_score(x_feature, y, split_val, min_leaf) for split_val in sorted_values]
scores = [item[0] for item in results]
max_idx = scores.index(max(scores))
score_opt, split_val_opt, lhs_idxs, rhs_idxs = results[max_idx]

return score_opt, split_val_opt, lhs_idxs, rhs_idxs


for split_val in x_feature.sort_values(ascending=True, kind="quicksort").unique():
feature1 = list(x_feature[x_feature <= split_val].index)
feature2 = list(x_feature[x_feature > split_val].index)
if len(feature1) < min_leaf or len(feature2) < min_leaf:
continue
def compute_score(x_feature, y, split_val, min_leaf):
feature1 = list(x_feature[x_feature <= split_val].index)
feature2 = list(x_feature[x_feature > split_val].index)
if len(feature1) < min_leaf or len(feature2) < min_leaf:
score = 0
else:
durations_a = y.iloc[feature1, 0]
event_observed_a = y.iloc[feature1, 1]
durations_b = y.iloc[feature2, 0]
event_observed_b = y.iloc[feature2, 1]
results = logrank_test(durations_A=durations_a, durations_B=durations_b,
event_observed_A=event_observed_a, event_observed_B=event_observed_b)
score = results.test_statistic

if score > score_opt:
score_opt = round(score, 3)
split_val_opt = round(split_val, 3)
lhs_idxs = feature1
rhs_idxs = feature2

return score_opt, split_val_opt, lhs_idxs, rhs_idxs
return [score, split_val, feature1, feature2]
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
setup(
name='random_survival_forest', # How you named your package folder (MyLib)
packages=['random_survival_forest'], # Chose the same as "name"
version='0.7', # Start with a small number and increase it with every change you make
version='0.7.1', # Start with a small number and increase it with every change you make
license="MIT License", # Chose a license from here: https://help.github.com/articles/licensing-a-repository
long_description=readme,
long_description_content_type="text/markdown",
Expand Down

0 comments on commit 095d603

Please sign in to comment.