From c906c28d36b19c2c641be7d2030bdb733b8fb988 Mon Sep 17 00:00:00 2001 From: Sasha Petrenko Date: Thu, 5 Sep 2024 20:18:51 -0500 Subject: [PATCH 1/8] Init sort option workflow --- src/ART/distributed/modules/DDVFA.jl | 16 +++++++++++++++- src/ART/single/modules/FuzzyART.jl | 8 ++++++++ 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/src/ART/distributed/modules/DDVFA.jl b/src/ART/distributed/modules/DDVFA.jl index ac6f3d7..6f7a9b5 100644 --- a/src/ART/distributed/modules/DDVFA.jl +++ b/src/ART/distributed/modules/DDVFA.jl @@ -91,6 +91,14 @@ $(_OPTS_DOCSTRING) Selected weight update function. """ update::Symbol = :basic_update + + """ + Flag to sort the F2 nodes by activation before the match phase + + When true, the F2 nodes are sorted by activation before match. + When false, an iterative argmax and inhibition procedure is used to find the best-matching unit. + """ + sort::Bool = false end # ----------------------------------------------------------------------------- @@ -228,6 +236,7 @@ function DDVFA(opts::opts_DDVFA) activation=opts.activation, match=opts.match, update=opts.update, + sort=opts.sort, ) # Construct the DDVFA module @@ -292,7 +301,12 @@ function train!(art::DDVFA, x::RealVector ; y::Integer=0, preprocessed::Bool=fal end # Compute the match for each category in the order of greatest activation - index = sortperm(art.T, rev=true) + if art.opts.sort + index = sortperm(art.T, rev=true) + else + # index = 1:art.n_categories + end + # index = sortperm(art.T, rev=true) accommodate_vector!(art.M, art.n_categories) for jx = 1:art.n_categories # Best matching unit diff --git a/src/ART/single/modules/FuzzyART.jl b/src/ART/single/modules/FuzzyART.jl index 1526c20..ea1b797 100644 --- a/src/ART/single/modules/FuzzyART.jl +++ b/src/ART/single/modules/FuzzyART.jl @@ -86,6 +86,14 @@ $(_OPTS_DOCSTRING) Selected weight update function. """ update::Symbol = :basic_update + + """ + Flag to sort the F2 nodes by activation before the match phase + + When true, the F2 nodes are sorted by activation before match. + When false, an iterative argmax and inhibition procedure is used to find the best-matching unit. + """ + sort::Bool = false end """ From 7c1f794eb0a6183d84687bc0ca084a9f7dba84b7 Mon Sep 17 00:00:00 2001 From: Sasha Petrenko Date: Thu, 5 Sep 2024 20:43:59 -0500 Subject: [PATCH 2/8] Dev flag in ddvfa and fuzzyart --- src/ART/distributed/modules/DDVFA.jl | 18 ++++++++++++++---- src/ART/single/modules/FuzzyART.jl | 26 ++++++++++++++++++++++---- 2 files changed, 36 insertions(+), 8 deletions(-) diff --git a/src/ART/distributed/modules/DDVFA.jl b/src/ART/distributed/modules/DDVFA.jl index 6f7a9b5..9520894 100644 --- a/src/ART/distributed/modules/DDVFA.jl +++ b/src/ART/distributed/modules/DDVFA.jl @@ -304,13 +304,16 @@ function train!(art::DDVFA, x::RealVector ; y::Integer=0, preprocessed::Bool=fal if art.opts.sort index = sortperm(art.T, rev=true) else - # index = 1:art.n_categories + top_bmu = argmax(art.T) end - # index = sortperm(art.T, rev=true) accommodate_vector!(art.M, art.n_categories) for jx = 1:art.n_categories # Best matching unit - bmu = index[jx] + if art.opts.sort + bmu = index[jx] + else + bmu = argmax(art.T) + end # Compute the match with the similarity linkage method art.M[bmu] = similarity(art.opts.similarity, art.F2[bmu], sample, false) # If we got a match, then learn (update the category) @@ -326,13 +329,20 @@ function train!(art::DDVFA, x::RealVector ; y::Integer=0, preprocessed::Bool=fal # No mismatch mismatch_flag = false break + elseif !art.opts.sort + # Remove the top activation + art.T[bmu] = 0.0 end end # If we triggered a mismatch if mismatch_flag # Keep the bmu as the top activation despite creating a new category - bmu = index[1] + if art.opts.sort + bmu = index[1] + else + bmu = top_bmu + end # Get the correct label y_hat = supervised ? y : art.n_categories + 1 # Create a new category diff --git a/src/ART/single/modules/FuzzyART.jl b/src/ART/single/modules/FuzzyART.jl index ea1b797..370758e 100644 --- a/src/ART/single/modules/FuzzyART.jl +++ b/src/ART/single/modules/FuzzyART.jl @@ -324,15 +324,25 @@ function train!(art::FuzzyART, x::RealVector ; y::Integer=0, preprocessed::Bool= activation_match!(art, sample) # Sort activation function values in descending order - index = sortperm(art.T, rev=true) + if art.opts.sort + index = sortperm(art.T, rev=true) + else + top_bmu = argmax(art.T) + end + # index = sortperm(art.T, rev=true) # Initialize mismatch as true mismatch_flag = true # Loop over all categories - for j = 1:art.n_categories + for jx = 1:art.n_categories # Best matching unit - bmu = index[j] + # bmu = index[jx] + if art.opts.sort + bmu = index[jx] + else + bmu = argmax(art.T) + end # Vigilance check - pass if art.M[bmu] >= art.threshold # If supervised and the label differed, force mismatch @@ -352,13 +362,21 @@ function train!(art::FuzzyART, x::RealVector ; y::Integer=0, preprocessed::Bool= # No mismatch mismatch_flag = false break + elseif !art.opts.sort + # Remove the top activation + art.T[bmu] = 0.0 end end # If there was no resonant category, make a new one if mismatch_flag # Keep the bmu as the top activation despite creating a new category - bmu = index[1] + # bmu = index[1] + if art.opts.sort + bmu = index[1] + else + bmu = top_bmu + end # Get the correct label for the new category y_hat = supervised ? y : art.n_categories + 1 From 74e9232c822ddf58cf8b9cda6a1716a47aceb712 Mon Sep 17 00:00:00 2001 From: Sasha Petrenko Date: Mon, 9 Sep 2024 11:30:17 -0500 Subject: [PATCH 3/8] Add sort opt to DDVFA, FuzzyART, SFAM --- src/ART/distributed/modules/DDVFA.jl | 31 +++++++++++----- src/ART/single/modules/FuzzyART.jl | 33 ++++++++++------- src/ARTMAP/SFAM.jl | 53 +++++++++++++++++++++++----- 3 files changed, 89 insertions(+), 28 deletions(-) diff --git a/src/ART/distributed/modules/DDVFA.jl b/src/ART/distributed/modules/DDVFA.jl index 9520894..f4d11ff 100644 --- a/src/ART/distributed/modules/DDVFA.jl +++ b/src/ART/distributed/modules/DDVFA.jl @@ -303,9 +303,11 @@ function train!(art::DDVFA, x::RealVector ; y::Integer=0, preprocessed::Bool=fal # Compute the match for each category in the order of greatest activation if art.opts.sort index = sortperm(art.T, rev=true) + top_bmu = index[1] else top_bmu = argmax(art.T) end + accommodate_vector!(art.M, art.n_categories) for jx = 1:art.n_categories # Best matching unit @@ -314,6 +316,7 @@ function train!(art::DDVFA, x::RealVector ; y::Integer=0, preprocessed::Bool=fal else bmu = argmax(art.T) end + # Compute the match with the similarity linkage method art.M[bmu] = similarity(art.opts.similarity, art.F2[bmu], sample, false) # If we got a match, then learn (update the category) @@ -338,13 +341,11 @@ function train!(art::DDVFA, x::RealVector ; y::Integer=0, preprocessed::Bool=fal # If we triggered a mismatch if mismatch_flag # Keep the bmu as the top activation despite creating a new category - if art.opts.sort - bmu = index[1] - else - bmu = top_bmu - end + bmu = top_bmu + # Get the correct label y_hat = supervised ? y : art.n_categories + 1 + # Create a new category create_category!(art, sample, y_hat) end @@ -370,7 +371,12 @@ function classify(art::DDVFA, x::RealVector ; preprocessed::Bool=false, get_bmu: end # Sort by highest activation - index = sortperm(art.T, rev=true) + if art.opts.sort + index = sortperm(art.T, rev=true) + top_bmu = index[1] + else + top_bmu = argmax(art.T) + end # Default to mismatch mismatch_flag = true @@ -379,7 +385,12 @@ function classify(art::DDVFA, x::RealVector ; preprocessed::Bool=false, get_bmu: accommodate_vector!(art.M, art.n_categories) for jx = 1:art.n_categories # Get the best-matching unit - bmu = index[jx] + if art.opts.sort + bmu = index[jx] + else + bmu = argmax(art.T) + end + # Get the match value of this activation art.M[bmu] = similarity(art.opts.similarity, art.F2[bmu], sample, false) # If the match satisfies the threshold criterion, then report that label @@ -390,14 +401,18 @@ function classify(art::DDVFA, x::RealVector ; preprocessed::Bool=false, get_bmu: y_hat = art.labels[bmu] mismatch_flag = false break + elseif !art.opts.sort + # Remove the top activation + art.T[bmu] = 0.0 end end # If we did not find a resonant category if mismatch_flag # Update the stored match and activation values of the best matching unit - bmu = index[1] + bmu = top_bmu log_art_stats!(art, bmu, true) + # Report either the best matching unit or the mismatch label -1 y_hat = get_bmu ? art.labels[bmu] : -1 end diff --git a/src/ART/single/modules/FuzzyART.jl b/src/ART/single/modules/FuzzyART.jl index 370758e..5777011 100644 --- a/src/ART/single/modules/FuzzyART.jl +++ b/src/ART/single/modules/FuzzyART.jl @@ -326,6 +326,7 @@ function train!(art::FuzzyART, x::RealVector ; y::Integer=0, preprocessed::Bool= # Sort activation function values in descending order if art.opts.sort index = sortperm(art.T, rev=true) + top_bmu = index[1] else top_bmu = argmax(art.T) end @@ -371,12 +372,7 @@ function train!(art::FuzzyART, x::RealVector ; y::Integer=0, preprocessed::Bool= # If there was no resonant category, make a new one if mismatch_flag # Keep the bmu as the top activation despite creating a new category - # bmu = index[1] - if art.opts.sort - bmu = index[1] - else - bmu = top_bmu - end + bmu = top_bmu # Get the correct label for the new category y_hat = supervised ? y : art.n_categories + 1 @@ -395,13 +391,18 @@ end # COMMON DOC: FuzzyART incremental classification method function classify(art::FuzzyART, x::RealVector ; preprocessed::Bool=false, get_bmu::Bool=false) # Preprocess the data - x = init_classify!(x, art, preprocessed) + sample = init_classify!(x, art, preprocessed) # Compute activation and match functions - activation_match!(art, x) + activation_match!(art, sample) # Sort activation function values in descending order - index = sortperm(art.T, rev=true) + if art.opts.sort + index = sortperm(art.T, rev=true) + top_bmu = index[1] + else + top_bmu = argmax(art.T) + end # Default is mismatch mismatch_flag = true @@ -409,8 +410,12 @@ function classify(art::FuzzyART, x::RealVector ; preprocessed::Bool=false, get_b # Iterate over all categories for jx in 1:art.n_categories - # Set the best matching unit - bmu = index[jx] + # Get the best-matching unit + if art.opts.sort + bmu = index[jx] + else + bmu = argmax(art.T) + end # Vigilance check - pass if art.M[bmu] >= art.threshold @@ -418,12 +423,16 @@ function classify(art::FuzzyART, x::RealVector ; preprocessed::Bool=false, get_b y_hat = art.labels[bmu] mismatch_flag = false break + elseif !art.opts.sort + # Remove the top activation + art.T[bmu] = 0.0 end end + # If we did not find a match if mismatch_flag # Report either the best matching unit or the mismatch label -1 - bmu = index[1] + bmu = top_bmu # Report either the best matching unit or the mismatch label -1 y_hat = get_bmu ? art.labels[bmu] : -1 diff --git a/src/ARTMAP/SFAM.jl b/src/ARTMAP/SFAM.jl index 103cecc..70e2658 100644 --- a/src/ARTMAP/SFAM.jl +++ b/src/ARTMAP/SFAM.jl @@ -70,6 +70,14 @@ $(_OPTS_DOCSTRING) Selected weight update function. """ update::Symbol = :basic_update + + """ + Flag to sort the F2 nodes by activation before the match phase + + When true, the F2 nodes are sorted by activation before match. + When false, an iterative argmax and inhibition procedure is used to find the best-matching unit. + """ + sort::Bool = false end """ @@ -249,34 +257,49 @@ function train!(art::SFAM, x::RealVector, y::Integer ; preprocessed::Bool=false) end # Sort activation function values in descending order - index = sortperm(art.T, rev=true) + if art.opts.sort + index = sortperm(art.T, rev=true) + top_bmu = index[1] + else + top_bmu = argmax(art.T) + end + mismatch_flag = true accommodate_vector!(art.M, art.n_categories) for jx in 1:art.n_categories # Set the best-matching-unit index - bmu = index[jx] + # bmu = index[jx] + if art.opts.sort + bmu = index[jx] + else + bmu = argmax(art.T) + end + # Compute match function art.M[bmu] = art_match(art, sample, bmu) # Current winner if art.M[bmu] >= rho_baseline if y == art.labels[bmu] # Update the weight and break - # art.W[:, index[jx]] = learn(art, sample, art.W[:, index[jx]]) learn!(art, sample, bmu) mismatch_flag = false break else # Match tracking - @debug "Match tracking" + # @debug "Match tracking" rho_baseline = art.M[bmu] + art.opts.epsilon end + elseif !art.opts.sort + # Remove the top activation + art.T[bmu] = 0.0 end end # If we triggered a mismatch if mismatch_flag # Keep the bmu as the top activation despite creating a new category - bmu = index[1] + bmu = top_bmu + # Create new weight vector create_category!(art, sample, y) end @@ -301,7 +324,12 @@ function classify(art::SFAM, x::RealVector ; preprocessed::Bool=false, get_bmu:: end # Sort activation function values in descending order - index = sortperm(art.T, rev=true) + if art.opts.sort + index = sortperm(art.T, rev=true) + top_bmu = index[1] + else + top_bmu = argmax(art.T) + end # Default to mismatch mismatch_flag = true @@ -310,7 +338,12 @@ function classify(art::SFAM, x::RealVector ; preprocessed::Bool=false, get_bmu:: accommodate_vector!(art.M, art.n_categories) for jx in 1:art.n_categories # Set the best-matching-unit index - bmu = index[jx] + if art.opts.sort + bmu = index[jx] + else + bmu = argmax(art.T) + end + # Compute match function art.M[bmu] = art_match(art, sample, bmu) # Current winner @@ -318,13 +351,17 @@ function classify(art::SFAM, x::RealVector ; preprocessed::Bool=false, get_bmu:: y_hat = art.labels[bmu] mismatch_flag = false break + elseif !art.opts.sort + # Remove the top activation + art.T[bmu] = 0.0 end end # If we did not find a resonant category if mismatch_flag # Keep the bmu as the top activation - bmu = index[1] + bmu = top_bmu + # Report either the best matching unit or the mismatch label -1 y_hat = get_bmu ? art.labels[bmu] : -1 end From 9c5cb78da3934e89f5714389d7cf283463629bbf Mon Sep 17 00:00:00 2001 From: Sasha Petrenko Date: Mon, 9 Sep 2024 12:22:45 -0500 Subject: [PATCH 4/8] Document new common sort option --- docs/src/man/guide.md | 2 ++ src/ART/distributed/modules/DDVFA.jl | 2 ++ src/ART/single/modules/FuzzyART.jl | 4 ++-- src/ARTMAP/SFAM.jl | 4 ++-- 4 files changed, 8 insertions(+), 4 deletions(-) diff --git a/docs/src/man/guide.md b/docs/src/man/guide.md index 84e1dcc..9f2dcd5 100644 --- a/docs/src/man/guide.md +++ b/docs/src/man/guide.md @@ -242,6 +242,8 @@ Though most parameters differ between each ART and ARTMAP module, they all share - `display::Bool`: a flag to display or suppress progress bars and logging messages during training and testing. - `max_epochs::Int`: the maximum number of epochs to train over the data, regardless if other stopping conditions have not been met yet. +- `sort::Bool`: if a sort procedure on the activations is done before the match rule. +This is false by default for all modules, using instead an `argmax` and node deactivation strategy for evaluating the vigilance criterion, which is faster in *most* cases. Otherwise, most ART and ARTMAP modules share the following nomenclature for algorithmic parameters: diff --git a/src/ART/distributed/modules/DDVFA.jl b/src/ART/distributed/modules/DDVFA.jl index f4d11ff..ff2416d 100644 --- a/src/ART/distributed/modules/DDVFA.jl +++ b/src/ART/distributed/modules/DDVFA.jl @@ -292,6 +292,7 @@ function train!(art::DDVFA, x::RealVector ; y::Integer=0, preprocessed::Bool=fal # Default to mismatch mismatch_flag = true + y_hat = -1 # Compute the activation for all categories accommodate_vector!(art.T, art.n_categories) @@ -380,6 +381,7 @@ function classify(art::DDVFA, x::RealVector ; preprocessed::Bool=false, get_bmu: # Default to mismatch mismatch_flag = true + y_hat = -1 # Iterate over the list of activations accommodate_vector!(art.M, art.n_categories) diff --git a/src/ART/single/modules/FuzzyART.jl b/src/ART/single/modules/FuzzyART.jl index 5777011..f5f0a14 100644 --- a/src/ART/single/modules/FuzzyART.jl +++ b/src/ART/single/modules/FuzzyART.jl @@ -330,20 +330,20 @@ function train!(art::FuzzyART, x::RealVector ; y::Integer=0, preprocessed::Bool= else top_bmu = argmax(art.T) end - # index = sortperm(art.T, rev=true) # Initialize mismatch as true mismatch_flag = true + y_hat = -1 # Loop over all categories for jx = 1:art.n_categories # Best matching unit - # bmu = index[jx] if art.opts.sort bmu = index[jx] else bmu = argmax(art.T) end + # Vigilance check - pass if art.M[bmu] >= art.threshold # If supervised and the label differed, force mismatch diff --git a/src/ARTMAP/SFAM.jl b/src/ARTMAP/SFAM.jl index 70e2658..1f34b9d 100644 --- a/src/ARTMAP/SFAM.jl +++ b/src/ARTMAP/SFAM.jl @@ -265,10 +265,10 @@ function train!(art::SFAM, x::RealVector, y::Integer ; preprocessed::Bool=false) end mismatch_flag = true + accommodate_vector!(art.M, art.n_categories) for jx in 1:art.n_categories # Set the best-matching-unit index - # bmu = index[jx] if art.opts.sort bmu = index[jx] else @@ -286,7 +286,6 @@ function train!(art::SFAM, x::RealVector, y::Integer ; preprocessed::Bool=false) break else # Match tracking - # @debug "Match tracking" rho_baseline = art.M[bmu] + art.opts.epsilon end elseif !art.opts.sort @@ -333,6 +332,7 @@ function classify(art::SFAM, x::RealVector ; preprocessed::Bool=false, get_bmu:: # Default to mismatch mismatch_flag = true + y_hat = -1 # Iterate over the list of activations accommodate_vector!(art.M, art.n_categories) From a5063d4e2f18015f3fb77b769f0c150cd28af732 Mon Sep 17 00:00:00 2001 From: Sasha Petrenko Date: Mon, 9 Sep 2024 12:31:27 -0500 Subject: [PATCH 5/8] Add sort opt to more modules --- src/ART/distributed/modules/MergeART.jl | 8 ++++++++ src/ARTMAP/FAM.jl | 8 ++++++++ 2 files changed, 16 insertions(+) diff --git a/src/ART/distributed/modules/MergeART.jl b/src/ART/distributed/modules/MergeART.jl index 97cf7dc..aef875c 100644 --- a/src/ART/distributed/modules/MergeART.jl +++ b/src/ART/distributed/modules/MergeART.jl @@ -91,6 +91,14 @@ $(_OPTS_DOCSTRING) Selected weight update function. """ update::Symbol = :basic_update + + """ + Flag to sort the F2 nodes by activation before the match phase + + When true, the F2 nodes are sorted by activation before match. + When false, an iterative argmax and inhibition procedure is used to find the best-matching unit. + """ + sort::Bool = false end """ diff --git a/src/ARTMAP/FAM.jl b/src/ARTMAP/FAM.jl index 4ecbdad..682b33f 100644 --- a/src/ARTMAP/FAM.jl +++ b/src/ARTMAP/FAM.jl @@ -52,6 +52,14 @@ $(_OPTS_DOCSTRING) Display flag for progress bars. """ display::Bool = false + + """ + Flag to sort the F2 nodes by activation before the match phase + + When true, the F2 nodes are sorted by activation before match. + When false, an iterative argmax and inhibition procedure is used to find the best-matching unit. + """ + sort::Bool = false end """ From b6e8dc0c4bce5c0ff6bb8729c6b11ef66da2ae3e Mon Sep 17 00:00:00 2001 From: Sasha Petrenko Date: Mon, 9 Sep 2024 12:52:14 -0500 Subject: [PATCH 6/8] Add tests, add sort opt for DVFA --- src/ART/distributed/modules/DDVFA.jl | 1 + src/ART/single/modules/DVFA.jl | 65 ++++++++++++++++++++++----- test/adaptiveresonance/common.jl | 10 +++-- test/adaptiveresonance/exceptions.jl | 2 +- test/adaptiveresonance/performance.jl | 3 +- 5 files changed, 65 insertions(+), 16 deletions(-) diff --git a/src/ART/distributed/modules/DDVFA.jl b/src/ART/distributed/modules/DDVFA.jl index ff2416d..1c717a1 100644 --- a/src/ART/distributed/modules/DDVFA.jl +++ b/src/ART/distributed/modules/DDVFA.jl @@ -326,6 +326,7 @@ function train!(art::DDVFA, x::RealVector ; y::Integer=0, preprocessed::Bool=fal if supervised && (art.labels[bmu] != y) break end + # Update the weights with the sample train!(art.F2[bmu], sample, preprocessed=true) # Save the output label for the sample diff --git a/src/ART/single/modules/DVFA.jl b/src/ART/single/modules/DVFA.jl index d25ff90..7d25da4 100644 --- a/src/ART/single/modules/DVFA.jl +++ b/src/ART/single/modules/DVFA.jl @@ -75,6 +75,14 @@ $(_OPTS_DOCSTRING) Selected weight update function. """ update::Symbol = :basic_update + + """ + Flag to sort the F2 nodes by activation before the match phase + + When true, the F2 nodes are sorted by activation before match. + When false, an iterative argmax and inhibition procedure is used to find the best-matching unit. + """ + sort::Bool = false end """ @@ -283,24 +291,34 @@ function train!(art::DVFA, x::RealVector ; y::Integer=0, preprocessed::Bool=fals # Compute the activation and match for all categories activation_match!(art, sample) # Sort activation function values in descending order - index = sortperm(art.T, rev=true) + if art.opts.sort + index = sortperm(art.T, rev=true) + top_bmu = index[1] + else + top_bmu = argmax(art.T) + end # Default to mismatch mismatch_flag = true # Loop over all categories - for j = 1:art.n_categories + for jx = 1:art.n_categories # Best matching unit - bmu = index[j] - # If supervised and the label differs, trigger mismatch - if supervised && (art.labels[bmu] != y) - break + if art.opts.sort + bmu = index[jx] + else + bmu = argmax(art.T) end + # Vigilance test upper bound if art.M[bmu] >= art.threshold_ub + # If supervised and the label differs, trigger mismatch + if supervised && (art.labels[bmu] != y) + break + end + # Learn the sample learn!(art, sample, bmu) # Update sample label for output - # y_hat = supervised ? y : art.labels[bmu] y_hat = art.labels[bmu] # No mismatch mismatch_flag = false @@ -314,15 +332,20 @@ function train!(art::DVFA, x::RealVector ; y::Integer=0, preprocessed::Bool=fals # No mismatch mismatch_flag = false break + elseif !art.opts.sort + # Remove the top activation + art.T[bmu] = 0.0 end end # If there was no resonant category, make a new one if mismatch_flag # Keep the bmu as the top activation despite creating a new category - bmu = index[1] + bmu = top_bmu + # Create a new category-to-cluster label y_hat = supervised ? y : art.n_clusters + 1 + # Create a new category create_category!(art, sample, y_hat) end @@ -341,23 +364,43 @@ function classify(art::DVFA, x::RealVector ; preprocessed::Bool=false, get_bmu:: # Compute activation and match functions activation_match!(art, sample) # Sort activation function values in descending order - index = sortperm(art.T, rev=true) + if art.opts.sort + index = sortperm(art.T, rev=true) + top_bmu = index[1] + else + top_bmu = argmax(art.T) + end + + # Default to mismatch mismatch_flag = true + y_hat = -1 + + # Iterate over the list of activations for jx in 1:art.n_categories - bmu = index[jx] + # Get the best-matching unit + if art.opts.sort + bmu = index[jx] + else + bmu = argmax(art.T) + end + # Vigilance check - pass if art.M[bmu] >= art.threshold_ub # Current winner y_hat = art.labels[bmu] mismatch_flag = false break + elseif !art.opts.sort + # Remove the top activation + art.T[bmu] = 0.0 end end # If we did not find a resonant category if mismatch_flag # Create new weight vector - bmu = index[1] + bmu = top_bmu + # Report either the best matching unit or the mismatch label -1 y_hat = get_bmu ? art.labels[bmu] : -1 end diff --git a/test/adaptiveresonance/common.jl b/test/adaptiveresonance/common.jl index 6b62b05..750543a 100644 --- a/test/adaptiveresonance/common.jl +++ b/test/adaptiveresonance/common.jl @@ -33,7 +33,7 @@ Tests of common code for the `AdaptiveResonance.jl` package. @test_logs (:warn,) AdaptiveResonance.data_setup!(dc3, three_by_two) bad_config = DataConfig(1, 0, 3) @test_throws ErrorException linear_normalization(three_by_two, config=bad_config) -end # @testset "common.jl" +end @testset "constants.jl" begin @info "------- Constants Tests -------" @@ -57,7 +57,11 @@ end # Iterate over all modules for art in ADAPTIVERESONANCE_MODULES - art_module = art(alpha=1e-3, display=false) + art_module = art( + alpha=1e-3, + display=false, + sort=true, + ) end end @@ -77,4 +81,4 @@ end for i = 1:n_samples train!(art, x) end -end \ No newline at end of file +end diff --git a/test/adaptiveresonance/exceptions.jl b/test/adaptiveresonance/exceptions.jl index c0c609b..08261ed 100644 --- a/test/adaptiveresonance/exceptions.jl +++ b/test/adaptiveresonance/exceptions.jl @@ -37,7 +37,7 @@ Tests the edge cases and exceptions of the entire `AdaptiveResonance.jl` package end end -@testset "init_tain!" begin +@testset "init_train!" begin # Create a new FuzzyART module art = FuzzyART() diff --git a/test/adaptiveresonance/performance.jl b/test/adaptiveresonance/performance.jl index 1b90360..8d41663 100644 --- a/test/adaptiveresonance/performance.jl +++ b/test/adaptiveresonance/performance.jl @@ -16,7 +16,8 @@ A test of the performance of every ART and ARTMAP module. # All common ART options art_opts = [ # (display = true,), - (display = false,), + (display = false, sort = true,), + (display = false, sort = false,), ] # Specific ART options From 49d354d04cb405327fbdc14e888e84f58996fb6f Mon Sep 17 00:00:00 2001 From: Sasha Petrenko Date: Tue, 24 Sep 2024 09:25:12 -0500 Subject: [PATCH 7/8] Bump to v0.8.5 --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index cbe8b9a..960e9fc 100644 --- a/Project.toml +++ b/Project.toml @@ -2,7 +2,7 @@ name = "AdaptiveResonance" uuid = "3d72adc0-63d3-4141-bf9b-84450dd0395b" authors = ["Sasha Petrenko"] description = "A Julia package for Adaptive Resonance Theory (ART) algorithms." -version = "0.8.4" +version = "0.8.5" [deps] Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" From 1db5c5a44107394f9f88a297ce532f154dc1cd6a Mon Sep 17 00:00:00 2001 From: Sasha Petrenko Date: Tue, 24 Sep 2024 09:27:45 -0500 Subject: [PATCH 8/8] Up upload artifact action for paper to v4 --- .github/workflows/draft-pdf.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/draft-pdf.yml b/.github/workflows/draft-pdf.yml index c1aa45a..7ef2eaa 100644 --- a/.github/workflows/draft-pdf.yml +++ b/.github/workflows/draft-pdf.yml @@ -21,7 +21,7 @@ jobs: # This should be the path to the paper within your repo. paper-path: paper/paper.md - name: Upload - uses: actions/upload-artifact@v1 + uses: actions/upload-artifact@v4 with: name: paper # This is the output path where Pandoc will write the compiled