Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Lgbm linear reader #151

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/omlt/linear_tree/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,4 @@
LinearTreeGDPFormulation,
LinearTreeHybridBigMFormulation,
)
from omlt.linear_tree.gblt_model import EnsembleDefinition
356 changes: 356 additions & 0 deletions src/omlt/linear_tree/gblt_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,356 @@
import lightgbm as lgb
import numpy as np


class EnsembleDefinition:
def __init__(
self,
gblt_model,
scaling_object=None,
unscaled_input_bounds=None,
scaled_input_bounds=None
):
"""
Create a network definition object used to create the gradient-boosted trees
formulation in Pyomo

Args:
onnx_model : ONNX Model
An ONNX model that is generated by the ONNX convert function for
lightgbm.
scaling_object : ScalingInterface or None
A scaling object to specify the scaling parameters for the
tree ensemble inputs and outputs. If None, then no
scaling is performed.
scaled_input_bounds : dict or None
A dict that contains the bounds on the scaled variables (the
direct inputs to the tree ensemble). If None, then no bounds
are specified or they are generated using unscaled bounds.
"""
self.__model = gblt_model
self.__scaling_object = scaling_object

# Process input bounds to insure scaled input bounds exist for formulations
if scaled_input_bounds is None:
if unscaled_input_bounds is not None and scaling_object is not None:
lbs = scaling_object.get_scaled_input_expressions(
{k: t[0] for k, t in unscaled_input_bounds.items()}
)
ubs = scaling_object.get_scaled_input_expressions(
{k: t[1] for k, t in unscaled_input_bounds.items()}
)

scaled_input_bounds = {
k: (lbs[k], ubs[k]) for k in unscaled_input_bounds.keys()
}

# If unscaled input bounds provided and no scaler provided, scaled
# input bounds = unscaled input bounds
elif unscaled_input_bounds is not None and scaling_object is None:
scaled_input_bounds = unscaled_input_bounds
elif unscaled_input_bounds is None:
raise ValueError(
"Input Bounds needed to represent linear trees as MIPs"
)

self.__unscaled_input_bounds = unscaled_input_bounds
self.__scaled_input_bounds = scaled_input_bounds

n_inputs = _find_n_inputs(gblt_model)
self.__n_inputs = n_inputs
self.__n_outputs = 1
self.__splits, self.__leaves, self.__thresholds =\
_parse_model(gblt_model, scaled_input_bounds, n_inputs)

@property
def scaling_object(self):
"""Returns scaling object"""
return self.__scaling_object

@property
def scaled_input_bounds(self):
"""Returns dict containing scaled input bounds"""
return self.__scaled_input_bounds

@property
def splits(self):
"""Returns dict containing split information"""
return self.__splits

@property
def leaves(self):
"""Returns dict containing leaf information"""
return self.__leaves

@property
def thresholds(self):
"""Returns dict containing threshold information"""
return self.__thresholds

@property
def n_inputs(self):
"""Returns number of inputs to the linear tree"""
return self.__n_inputs

@property
def n_outputs(self):
"""Returns number of outputs to the linear tree"""
return self.__n_outputs


def _find_all_children_splits(split, splits_dict):
"""
This helper function finds all multigeneration children splits for an
argument split.

Arguments:
split --The split for which you are trying to find children splits
splits_dict -- A dictionary of all the splits in the tree

Returns:
A list containing the Node IDs of all children splits
"""
all_splits = []

# Check if the immediate left child of the argument split is also a split.
# If so append to the list then use recursion to generate the remainder
left_child = splits_dict[split]['children'][0]
if left_child in splits_dict:
all_splits.append(left_child)
all_splits.extend(_find_all_children_splits(left_child, splits_dict))

# Same as above but with right child
right_child = splits_dict[split]['children'][1]
if right_child in splits_dict:
all_splits.append(right_child)
all_splits.extend(_find_all_children_splits(right_child, splits_dict))

return all_splits

def _reassign_none_bounds(leaves, input_bounds, n_inputs):
"""
This helper function reassigns bounds that are None to the bounds
input by the user

Arguments:
leaves -- The dictionary of leaf information. Attribute of the
LinearTreeDefinition object
input_bounds -- The nested dictionary

Returns:
The modified leaves dict without any bounds that are listed as None
"""
leaf_indices = np.array(list(leaves.keys()))
features = np.arange(0, n_inputs)

for leaf in leaf_indices:
slopes = leaves[leaf]["slope"]
if len(slopes) == 0:
leaves[leaf]["slope"] = list(np.zeros(n_inputs))
for feat in features:
if leaves[leaf]["bounds"][feat][0] is None:
leaves[leaf]["bounds"][feat][0] = input_bounds[feat][0]
if leaves[leaf]["bounds"][feat][1] is None:
leaves[leaf]["bounds"][feat][1] = input_bounds[feat][1]

return leaves

def _find_all_children_leaves(split, splits_dict, leaves_dict):
"""
This helper function finds all multigeneration children leaves for an
argument split.

Arguments:
split -- The split for which you are trying to find children leaves
splits_dict -- A dictionary of all the split info in the tree
leaves_dict -- A dictionary of all the leaf info in the tree

Returns:
A list containing all the Node IDs of all children leaves
"""
all_leaves = []

# Find all the splits that are children of the relevant split
all_splits = _find_all_children_splits(split, splits_dict)

# Ensure the current split is included
if split not in all_splits:
all_splits.append(split)

# For each leaf, check if the parents appear in the list of children
# splits (all_splits). If so, it must be a leaf of the argument split
for leaf in leaves_dict:
if leaves_dict[leaf]['parent'] in all_splits:
all_leaves.append(leaf)

return all_leaves


def _find_n_inputs(model):
if str(type(model)) == "<class 'lightgbm.basic.Booster'>":
n_inputs = model.num_feature()
else:
n_inputs = len(model["feature_names"])

return n_inputs

def _parse_model(model, input_bounds, n_inputs):
if str(type(model)) == "<class 'lightgbm.basic.Booster'>":
whole_model = model.dump_model()
else:
whole_model=model

tree = {}
for i in range(whole_model['tree_info'][-1]['tree_index']+1):

node = whole_model['tree_info'][i]["tree_structure"]

queue = [node]
splits = {}

# the very first node
splits["split"+str(queue[0]["split_index"])] = {'th': queue[0]["threshold"],
'col': queue[0]["split_feature"] }

# flow though the tree
while queue:

# left child
if "left_child" in queue[0].keys():
queue.append(queue[0]["left_child"])
# child is a split
if "split_index" in queue[0]["left_child"].keys():
splits["split"+str(queue[0]["left_child"]["split_index"])] = {'parent': "split"+str(queue[0]["split_index"]),
'direction': 'left',
'th': queue[0]["left_child"]["threshold"],
'col': queue[0]["left_child"]["split_feature"]}
# child is a leaf
else:
splits["leaf"+str(queue[0]["left_child"]["leaf_index"])] = {'parent': "split"+str(queue[0]["split_index"]),
'direction': 'left',
'intercept': queue[0]["left_child"]["leaf_const"],
'slope': list(np.zeros(n_inputs))}

for idx, val in zip(queue[0]["left_child"]["leaf_features"], queue[0]["left_child"]["leaf_coeff"]):
splits["leaf"+str(queue[0]["left_child"]["leaf_index"])]["slope"][idx] = val

# right child
if "right_child" in queue[0].keys():
queue.append(queue[0]["right_child"])
# child is a split
if "split_index" in queue[0]["right_child"].keys():
splits["split"+str(queue[0]["right_child"]["split_index"])] = {'parent': "split"+str(queue[0]["split_index"]),
'direction': 'right',
'th': queue[0]["right_child"]["threshold"],
'col': queue[0]["right_child"]["split_feature"]}
# child is a leaf
else:
splits["leaf"+str(queue[0]["right_child"]["leaf_index"])] = {'parent': "split"+str(queue[0]["split_index"]),
'direction': 'right',
'intercept': queue[0]["right_child"]["leaf_const"],
'slope': list(np.zeros(n_inputs))}

for idx, val in zip(queue[0]["right_child"]["leaf_features"], queue[0]["right_child"]["leaf_coeff"]):
splits["leaf"+str(queue[0]["right_child"]["leaf_index"])]["slope"][idx] = val
# delet the first node
queue.pop(0)

tree['tree'+str(i)] = splits

nested_splits = {}
nested_leaves = {}
nested_thresholds = {}

for index in tree:

splits = tree[index]
for i in splits:
# print(i)
if 'parent' in splits[i].keys():
splits[splits[i]['parent']]['children'] = []

for i in splits:
# print(i)
if 'parent' in splits[i].keys():
if splits[i]['direction'] == 'left':
splits[splits[i]['parent']]['children'].insert(0,i)
if splits[i]['direction'] == 'right':
splits[splits[i]['parent']]['children'].insert(11,i)

leaves = {}
for i in splits.keys():
if i[0] == 'l':
leaves[i] = splits[i]

for leaf in leaves:
del splits[leaf]

for split in splits:
# print("split:" + str(split))
left_child = splits[split]['children'][0]
right_child = splits[split]['children'][1]

if left_child in splits:
# means left_child is split
splits[split]['left_leaves'] = _find_all_children_leaves(
left_child, splits, leaves
)
else:
# means left_child is leaf
splits[split]['left_leaves'] = [left_child]
# print("left_child" + str(left_child))

if right_child in splits:
splits[split]['right_leaves'] = _find_all_children_leaves(
right_child, splits, leaves
)
else:
splits[split]['right_leaves'] = [right_child]
# print("right_child" + str(right_child))

splitting_thresholds = {}
for split in splits:
var = splits[split]['col']
splitting_thresholds[var] = {}
for split in splits:
var = splits[split]['col']
splitting_thresholds[var][split] = splits[split]['th']

for var in splitting_thresholds:
splitting_thresholds[var] = dict(sorted(splitting_thresholds[var].items(), key=lambda x: x[1]))

for split in splits:
var = splits[split]['col']
splits[split]['y_index'] = []
splits[split]['y_index'].append(var)
splits[split]['y_index'].append(
list(splitting_thresholds[var]).index(split)
)

features = np.arange(0,n_inputs)

for leaf in leaves:
leaves[leaf]['bounds'] = {}
for th in features:
for leaf in leaves:
leaves[leaf]['bounds'][th] = [None, None]

# import pprint
# pp = pprint.PrettyPrinter(indent=4)
# pp.pprint(splits)
# pp.pprint(leaves)
for split in splits:
var = splits[split]['col']
for leaf in splits[split]['left_leaves']:
leaves[leaf]['bounds'][var][1] = splits[split]['th']

for leaf in splits[split]['right_leaves']:
leaves[leaf]['bounds'][var][0] = splits[split]['th']

leaves = _reassign_none_bounds(leaves, input_bounds, n_inputs)

nested_splits[str(index)] = splits
nested_leaves[str(index)] = leaves
nested_thresholds[str(index)] = splitting_thresholds

return nested_splits, nested_leaves, nested_thresholds
Loading