diff --git a/.github/workflows/draft-pdf.yml b/.github/workflows/draft-pdf.yml index c1aa45a3..7ef2eaab 100644 --- a/.github/workflows/draft-pdf.yml +++ b/.github/workflows/draft-pdf.yml @@ -21,7 +21,7 @@ jobs: # This should be the path to the paper within your repo. paper-path: paper/paper.md - name: Upload - uses: actions/upload-artifact@v1 + uses: actions/upload-artifact@v4 with: name: paper # This is the output path where Pandoc will write the compiled diff --git a/Project.toml b/Project.toml index cbe8b9af..960e9fc9 100644 --- a/Project.toml +++ b/Project.toml @@ -2,7 +2,7 @@ name = "AdaptiveResonance" uuid = "3d72adc0-63d3-4141-bf9b-84450dd0395b" authors = ["Sasha Petrenko"] description = "A Julia package for Adaptive Resonance Theory (ART) algorithms." -version = "0.8.4" +version = "0.8.5" [deps] Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" diff --git a/docs/src/man/guide.md b/docs/src/man/guide.md index 84e1dccd..9f2dcd5a 100644 --- a/docs/src/man/guide.md +++ b/docs/src/man/guide.md @@ -242,6 +242,8 @@ Though most parameters differ between each ART and ARTMAP module, they all share - `display::Bool`: a flag to display or suppress progress bars and logging messages during training and testing. - `max_epochs::Int`: the maximum number of epochs to train over the data, regardless if other stopping conditions have not been met yet. +- `sort::Bool`: if a sort procedure on the activations is done before the match rule. +This is false by default for all modules, using instead an `argmax` and node deactivation strategy for evaluating the vigilance criterion, which is faster in *most* cases. Otherwise, most ART and ARTMAP modules share the following nomenclature for algorithmic parameters: diff --git a/src/ART/distributed/modules/DDVFA.jl b/src/ART/distributed/modules/DDVFA.jl index ac6f3d72..1c717a1a 100644 --- a/src/ART/distributed/modules/DDVFA.jl +++ b/src/ART/distributed/modules/DDVFA.jl @@ -91,6 +91,14 @@ $(_OPTS_DOCSTRING) Selected weight update function. """ update::Symbol = :basic_update + + """ + Flag to sort the F2 nodes by activation before the match phase + + When true, the F2 nodes are sorted by activation before match. + When false, an iterative argmax and inhibition procedure is used to find the best-matching unit. + """ + sort::Bool = false end # ----------------------------------------------------------------------------- @@ -228,6 +236,7 @@ function DDVFA(opts::opts_DDVFA) activation=opts.activation, match=opts.match, update=opts.update, + sort=opts.sort, ) # Construct the DDVFA module @@ -283,6 +292,7 @@ function train!(art::DDVFA, x::RealVector ; y::Integer=0, preprocessed::Bool=fal # Default to mismatch mismatch_flag = true + y_hat = -1 # Compute the activation for all categories accommodate_vector!(art.T, art.n_categories) @@ -292,11 +302,22 @@ function train!(art::DDVFA, x::RealVector ; y::Integer=0, preprocessed::Bool=fal end # Compute the match for each category in the order of greatest activation - index = sortperm(art.T, rev=true) + if art.opts.sort + index = sortperm(art.T, rev=true) + top_bmu = index[1] + else + top_bmu = argmax(art.T) + end + accommodate_vector!(art.M, art.n_categories) for jx = 1:art.n_categories # Best matching unit - bmu = index[jx] + if art.opts.sort + bmu = index[jx] + else + bmu = argmax(art.T) + end + # Compute the match with the similarity linkage method art.M[bmu] = similarity(art.opts.similarity, art.F2[bmu], sample, false) # If we got a match, then learn (update the category) @@ -305,6 +326,7 @@ function train!(art::DDVFA, x::RealVector ; y::Integer=0, preprocessed::Bool=fal if supervised && (art.labels[bmu] != y) break end + # Update the weights with the sample train!(art.F2[bmu], sample, preprocessed=true) # Save the output label for the sample @@ -312,15 +334,20 @@ function train!(art::DDVFA, x::RealVector ; y::Integer=0, preprocessed::Bool=fal # No mismatch mismatch_flag = false break + elseif !art.opts.sort + # Remove the top activation + art.T[bmu] = 0.0 end end # If we triggered a mismatch if mismatch_flag # Keep the bmu as the top activation despite creating a new category - bmu = index[1] + bmu = top_bmu + # Get the correct label y_hat = supervised ? y : art.n_categories + 1 + # Create a new category create_category!(art, sample, y_hat) end @@ -346,16 +373,27 @@ function classify(art::DDVFA, x::RealVector ; preprocessed::Bool=false, get_bmu: end # Sort by highest activation - index = sortperm(art.T, rev=true) + if art.opts.sort + index = sortperm(art.T, rev=true) + top_bmu = index[1] + else + top_bmu = argmax(art.T) + end # Default to mismatch mismatch_flag = true + y_hat = -1 # Iterate over the list of activations accommodate_vector!(art.M, art.n_categories) for jx = 1:art.n_categories # Get the best-matching unit - bmu = index[jx] + if art.opts.sort + bmu = index[jx] + else + bmu = argmax(art.T) + end + # Get the match value of this activation art.M[bmu] = similarity(art.opts.similarity, art.F2[bmu], sample, false) # If the match satisfies the threshold criterion, then report that label @@ -366,14 +404,18 @@ function classify(art::DDVFA, x::RealVector ; preprocessed::Bool=false, get_bmu: y_hat = art.labels[bmu] mismatch_flag = false break + elseif !art.opts.sort + # Remove the top activation + art.T[bmu] = 0.0 end end # If we did not find a resonant category if mismatch_flag # Update the stored match and activation values of the best matching unit - bmu = index[1] + bmu = top_bmu log_art_stats!(art, bmu, true) + # Report either the best matching unit or the mismatch label -1 y_hat = get_bmu ? art.labels[bmu] : -1 end diff --git a/src/ART/distributed/modules/MergeART.jl b/src/ART/distributed/modules/MergeART.jl index 97cf7dcc..aef875c6 100644 --- a/src/ART/distributed/modules/MergeART.jl +++ b/src/ART/distributed/modules/MergeART.jl @@ -91,6 +91,14 @@ $(_OPTS_DOCSTRING) Selected weight update function. """ update::Symbol = :basic_update + + """ + Flag to sort the F2 nodes by activation before the match phase + + When true, the F2 nodes are sorted by activation before match. + When false, an iterative argmax and inhibition procedure is used to find the best-matching unit. + """ + sort::Bool = false end """ diff --git a/src/ART/single/modules/DVFA.jl b/src/ART/single/modules/DVFA.jl index d25ff90f..7d25da4e 100644 --- a/src/ART/single/modules/DVFA.jl +++ b/src/ART/single/modules/DVFA.jl @@ -75,6 +75,14 @@ $(_OPTS_DOCSTRING) Selected weight update function. """ update::Symbol = :basic_update + + """ + Flag to sort the F2 nodes by activation before the match phase + + When true, the F2 nodes are sorted by activation before match. + When false, an iterative argmax and inhibition procedure is used to find the best-matching unit. + """ + sort::Bool = false end """ @@ -283,24 +291,34 @@ function train!(art::DVFA, x::RealVector ; y::Integer=0, preprocessed::Bool=fals # Compute the activation and match for all categories activation_match!(art, sample) # Sort activation function values in descending order - index = sortperm(art.T, rev=true) + if art.opts.sort + index = sortperm(art.T, rev=true) + top_bmu = index[1] + else + top_bmu = argmax(art.T) + end # Default to mismatch mismatch_flag = true # Loop over all categories - for j = 1:art.n_categories + for jx = 1:art.n_categories # Best matching unit - bmu = index[j] - # If supervised and the label differs, trigger mismatch - if supervised && (art.labels[bmu] != y) - break + if art.opts.sort + bmu = index[jx] + else + bmu = argmax(art.T) end + # Vigilance test upper bound if art.M[bmu] >= art.threshold_ub + # If supervised and the label differs, trigger mismatch + if supervised && (art.labels[bmu] != y) + break + end + # Learn the sample learn!(art, sample, bmu) # Update sample label for output - # y_hat = supervised ? y : art.labels[bmu] y_hat = art.labels[bmu] # No mismatch mismatch_flag = false @@ -314,15 +332,20 @@ function train!(art::DVFA, x::RealVector ; y::Integer=0, preprocessed::Bool=fals # No mismatch mismatch_flag = false break + elseif !art.opts.sort + # Remove the top activation + art.T[bmu] = 0.0 end end # If there was no resonant category, make a new one if mismatch_flag # Keep the bmu as the top activation despite creating a new category - bmu = index[1] + bmu = top_bmu + # Create a new category-to-cluster label y_hat = supervised ? y : art.n_clusters + 1 + # Create a new category create_category!(art, sample, y_hat) end @@ -341,23 +364,43 @@ function classify(art::DVFA, x::RealVector ; preprocessed::Bool=false, get_bmu:: # Compute activation and match functions activation_match!(art, sample) # Sort activation function values in descending order - index = sortperm(art.T, rev=true) + if art.opts.sort + index = sortperm(art.T, rev=true) + top_bmu = index[1] + else + top_bmu = argmax(art.T) + end + + # Default to mismatch mismatch_flag = true + y_hat = -1 + + # Iterate over the list of activations for jx in 1:art.n_categories - bmu = index[jx] + # Get the best-matching unit + if art.opts.sort + bmu = index[jx] + else + bmu = argmax(art.T) + end + # Vigilance check - pass if art.M[bmu] >= art.threshold_ub # Current winner y_hat = art.labels[bmu] mismatch_flag = false break + elseif !art.opts.sort + # Remove the top activation + art.T[bmu] = 0.0 end end # If we did not find a resonant category if mismatch_flag # Create new weight vector - bmu = index[1] + bmu = top_bmu + # Report either the best matching unit or the mismatch label -1 y_hat = get_bmu ? art.labels[bmu] : -1 end diff --git a/src/ART/single/modules/FuzzyART.jl b/src/ART/single/modules/FuzzyART.jl index 1526c200..f5f0a148 100644 --- a/src/ART/single/modules/FuzzyART.jl +++ b/src/ART/single/modules/FuzzyART.jl @@ -86,6 +86,14 @@ $(_OPTS_DOCSTRING) Selected weight update function. """ update::Symbol = :basic_update + + """ + Flag to sort the F2 nodes by activation before the match phase + + When true, the F2 nodes are sorted by activation before match. + When false, an iterative argmax and inhibition procedure is used to find the best-matching unit. + """ + sort::Bool = false end """ @@ -316,15 +324,26 @@ function train!(art::FuzzyART, x::RealVector ; y::Integer=0, preprocessed::Bool= activation_match!(art, sample) # Sort activation function values in descending order - index = sortperm(art.T, rev=true) + if art.opts.sort + index = sortperm(art.T, rev=true) + top_bmu = index[1] + else + top_bmu = argmax(art.T) + end # Initialize mismatch as true mismatch_flag = true + y_hat = -1 # Loop over all categories - for j = 1:art.n_categories + for jx = 1:art.n_categories # Best matching unit - bmu = index[j] + if art.opts.sort + bmu = index[jx] + else + bmu = argmax(art.T) + end + # Vigilance check - pass if art.M[bmu] >= art.threshold # If supervised and the label differed, force mismatch @@ -344,13 +363,16 @@ function train!(art::FuzzyART, x::RealVector ; y::Integer=0, preprocessed::Bool= # No mismatch mismatch_flag = false break + elseif !art.opts.sort + # Remove the top activation + art.T[bmu] = 0.0 end end # If there was no resonant category, make a new one if mismatch_flag # Keep the bmu as the top activation despite creating a new category - bmu = index[1] + bmu = top_bmu # Get the correct label for the new category y_hat = supervised ? y : art.n_categories + 1 @@ -369,13 +391,18 @@ end # COMMON DOC: FuzzyART incremental classification method function classify(art::FuzzyART, x::RealVector ; preprocessed::Bool=false, get_bmu::Bool=false) # Preprocess the data - x = init_classify!(x, art, preprocessed) + sample = init_classify!(x, art, preprocessed) # Compute activation and match functions - activation_match!(art, x) + activation_match!(art, sample) # Sort activation function values in descending order - index = sortperm(art.T, rev=true) + if art.opts.sort + index = sortperm(art.T, rev=true) + top_bmu = index[1] + else + top_bmu = argmax(art.T) + end # Default is mismatch mismatch_flag = true @@ -383,8 +410,12 @@ function classify(art::FuzzyART, x::RealVector ; preprocessed::Bool=false, get_b # Iterate over all categories for jx in 1:art.n_categories - # Set the best matching unit - bmu = index[jx] + # Get the best-matching unit + if art.opts.sort + bmu = index[jx] + else + bmu = argmax(art.T) + end # Vigilance check - pass if art.M[bmu] >= art.threshold @@ -392,12 +423,16 @@ function classify(art::FuzzyART, x::RealVector ; preprocessed::Bool=false, get_b y_hat = art.labels[bmu] mismatch_flag = false break + elseif !art.opts.sort + # Remove the top activation + art.T[bmu] = 0.0 end end + # If we did not find a match if mismatch_flag # Report either the best matching unit or the mismatch label -1 - bmu = index[1] + bmu = top_bmu # Report either the best matching unit or the mismatch label -1 y_hat = get_bmu ? art.labels[bmu] : -1 diff --git a/src/ARTMAP/FAM.jl b/src/ARTMAP/FAM.jl index 4ecbdada..682b33f9 100644 --- a/src/ARTMAP/FAM.jl +++ b/src/ARTMAP/FAM.jl @@ -52,6 +52,14 @@ $(_OPTS_DOCSTRING) Display flag for progress bars. """ display::Bool = false + + """ + Flag to sort the F2 nodes by activation before the match phase + + When true, the F2 nodes are sorted by activation before match. + When false, an iterative argmax and inhibition procedure is used to find the best-matching unit. + """ + sort::Bool = false end """ diff --git a/src/ARTMAP/SFAM.jl b/src/ARTMAP/SFAM.jl index 103cecc3..1f34b9d0 100644 --- a/src/ARTMAP/SFAM.jl +++ b/src/ARTMAP/SFAM.jl @@ -70,6 +70,14 @@ $(_OPTS_DOCSTRING) Selected weight update function. """ update::Symbol = :basic_update + + """ + Flag to sort the F2 nodes by activation before the match phase + + When true, the F2 nodes are sorted by activation before match. + When false, an iterative argmax and inhibition procedure is used to find the best-matching unit. + """ + sort::Bool = false end """ @@ -249,34 +257,48 @@ function train!(art::SFAM, x::RealVector, y::Integer ; preprocessed::Bool=false) end # Sort activation function values in descending order - index = sortperm(art.T, rev=true) + if art.opts.sort + index = sortperm(art.T, rev=true) + top_bmu = index[1] + else + top_bmu = argmax(art.T) + end + mismatch_flag = true + accommodate_vector!(art.M, art.n_categories) for jx in 1:art.n_categories # Set the best-matching-unit index - bmu = index[jx] + if art.opts.sort + bmu = index[jx] + else + bmu = argmax(art.T) + end + # Compute match function art.M[bmu] = art_match(art, sample, bmu) # Current winner if art.M[bmu] >= rho_baseline if y == art.labels[bmu] # Update the weight and break - # art.W[:, index[jx]] = learn(art, sample, art.W[:, index[jx]]) learn!(art, sample, bmu) mismatch_flag = false break else # Match tracking - @debug "Match tracking" rho_baseline = art.M[bmu] + art.opts.epsilon end + elseif !art.opts.sort + # Remove the top activation + art.T[bmu] = 0.0 end end # If we triggered a mismatch if mismatch_flag # Keep the bmu as the top activation despite creating a new category - bmu = index[1] + bmu = top_bmu + # Create new weight vector create_category!(art, sample, y) end @@ -301,16 +323,27 @@ function classify(art::SFAM, x::RealVector ; preprocessed::Bool=false, get_bmu:: end # Sort activation function values in descending order - index = sortperm(art.T, rev=true) + if art.opts.sort + index = sortperm(art.T, rev=true) + top_bmu = index[1] + else + top_bmu = argmax(art.T) + end # Default to mismatch mismatch_flag = true + y_hat = -1 # Iterate over the list of activations accommodate_vector!(art.M, art.n_categories) for jx in 1:art.n_categories # Set the best-matching-unit index - bmu = index[jx] + if art.opts.sort + bmu = index[jx] + else + bmu = argmax(art.T) + end + # Compute match function art.M[bmu] = art_match(art, sample, bmu) # Current winner @@ -318,13 +351,17 @@ function classify(art::SFAM, x::RealVector ; preprocessed::Bool=false, get_bmu:: y_hat = art.labels[bmu] mismatch_flag = false break + elseif !art.opts.sort + # Remove the top activation + art.T[bmu] = 0.0 end end # If we did not find a resonant category if mismatch_flag # Keep the bmu as the top activation - bmu = index[1] + bmu = top_bmu + # Report either the best matching unit or the mismatch label -1 y_hat = get_bmu ? art.labels[bmu] : -1 end diff --git a/test/adaptiveresonance/common.jl b/test/adaptiveresonance/common.jl index 6b62b055..750543a2 100644 --- a/test/adaptiveresonance/common.jl +++ b/test/adaptiveresonance/common.jl @@ -33,7 +33,7 @@ Tests of common code for the `AdaptiveResonance.jl` package. @test_logs (:warn,) AdaptiveResonance.data_setup!(dc3, three_by_two) bad_config = DataConfig(1, 0, 3) @test_throws ErrorException linear_normalization(three_by_two, config=bad_config) -end # @testset "common.jl" +end @testset "constants.jl" begin @info "------- Constants Tests -------" @@ -57,7 +57,11 @@ end # Iterate over all modules for art in ADAPTIVERESONANCE_MODULES - art_module = art(alpha=1e-3, display=false) + art_module = art( + alpha=1e-3, + display=false, + sort=true, + ) end end @@ -77,4 +81,4 @@ end for i = 1:n_samples train!(art, x) end -end \ No newline at end of file +end diff --git a/test/adaptiveresonance/exceptions.jl b/test/adaptiveresonance/exceptions.jl index c0c609b0..08261eda 100644 --- a/test/adaptiveresonance/exceptions.jl +++ b/test/adaptiveresonance/exceptions.jl @@ -37,7 +37,7 @@ Tests the edge cases and exceptions of the entire `AdaptiveResonance.jl` package end end -@testset "init_tain!" begin +@testset "init_train!" begin # Create a new FuzzyART module art = FuzzyART() diff --git a/test/adaptiveresonance/performance.jl b/test/adaptiveresonance/performance.jl index 1b90360f..8d416638 100644 --- a/test/adaptiveresonance/performance.jl +++ b/test/adaptiveresonance/performance.jl @@ -16,7 +16,8 @@ A test of the performance of every ART and ARTMAP module. # All common ART options art_opts = [ # (display = true,), - (display = false,), + (display = false, sort = true,), + (display = false, sort = false,), ] # Specific ART options