Skip to content

Commit

Permalink
Merge pull request #415 from srigas/master
Browse files Browse the repository at this point in the history
Updated auto_symbolic to include a configurable threshold & simplicity weight
  • Loading branch information
KindXiaoming authored Aug 23, 2024
2 parents 5b2af5e + 3895043 commit 91fc24e
Showing 1 changed file with 12 additions and 7 deletions.
19 changes: 12 additions & 7 deletions kan/MultKAN.py
Original file line number Diff line number Diff line change
Expand Up @@ -2160,7 +2160,7 @@ def suggest_symbolic(self, l, i, j, a_range=(-10, 10), b_range=(-10, 10), lib=No

return best_name, best_fun, best_r2, best_c;

def auto_symbolic(self, a_range=(-10, 10), b_range=(-10, 10), lib=None, verbose=1):
def auto_symbolic(self, a_range=(-10, 10), b_range=(-10, 10), lib=None, verbose=1, weight_simple = 0.8, r2_threshold=0.0):
'''
automatic symbolic regression for all edges
Expand All @@ -2174,7 +2174,10 @@ def auto_symbolic(self, a_range=(-10, 10), b_range=(-10, 10), lib=None, verbose=
library of candidate symbolic functions
verbose : int
larger verbosity => more verbosity
weight_simple : float
a weight that prioritizies simplicity (low complexity) over performance (high r2) - set to 0.0 to ignore complexity
r2_threshold : float
If r2 is below this threshold, the edge will not be fixed with any symbolic function - set to 0.0 to ignore this threshold
Returns:
--------
None
Expand All @@ -2191,17 +2194,19 @@ def auto_symbolic(self, a_range=(-10, 10), b_range=(-10, 10), lib=None, verbose=
for l in range(len(self.width_in) - 1):
for i in range(self.width_in[l]):
for j in range(self.width_out[l + 1]):
#if self.symbolic_fun[l].mask[j, i] > 0. and self.act_fun[l].mask[i][j] == 0.:
if self.symbolic_fun[l].mask[j, i] > 0. and self.act_fun[l].mask[i][j] == 0.:
print(f'skipping ({l},{i},{j}) since already symbolic')
elif self.symbolic_fun[l].mask[j, i] == 0. and self.act_fun[l].mask[i][j] == 0.:
self.fix_symbolic(l, i, j, '0', verbose=verbose > 1, log_history=False)
print(f'fixing ({l},{i},{j}) with 0')
else:
name, fun, r2, c = self.suggest_symbolic(l, i, j, a_range=a_range, b_range=b_range, lib=lib, verbose=False)
self.fix_symbolic(l, i, j, name, verbose=verbose > 1, log_history=False)
if verbose >= 1:
print(f'fixing ({l},{i},{j}) with {name}, r2={r2}, c={c}')
name, fun, r2, c = self.suggest_symbolic(l, i, j, a_range=a_range, b_range=b_range, lib=lib, verbose=False, weight_simple=weight_simple)
if r2 >= r2_threshold:
self.fix_symbolic(l, i, j, name, verbose=verbose > 1, log_history=False)
if verbose >= 1:
print(f'fixing ({l},{i},{j}) with {name}, r2={r2}, c={c}')
else:
print(f'For ({l},{i},{j}) the best fit was {name}, but r^2 = {r2} and this is lower than {r2_threshold}. This edge was omitted, keep training or try a different threshold.')

self.log_history('auto_symbolic')

Expand Down

0 comments on commit 91fc24e

Please sign in to comment.