Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use KMeans++ For ColorScape Generation #3164

Open
wants to merge 4 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 114 additions & 0 deletions API.Benchmark/ImageServiceBenchmark.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
using System;
using System.IO;
using System.Linq;
using System.Numerics;
using System.Collections.Generic;
using System.Drawing;
using System.Collections.Concurrent;
using System.Text.RegularExpressions;
using BenchmarkDotNet.Attributes;
using BenchmarkDotNet.Order;
using NetVips;
using Image = NetVips.Image;



namespace API.Benchmark;

[MemoryDiagnoser]
[Orderer(SummaryOrderPolicy.FastestToSlowest)]
[RankColumn]
public class ImageBenchmarks
{
private readonly string _testDirectoryColorScapes = "C:/Users/User/Documents/GitHub/Kavita/API.Tests/Services/Test Data/ImageService/ColorScapes";

private List<List<Vector3>> allRgbPixels;

[GlobalSetup]
public void Setup()
{
allRgbPixels = new List<List<Vector3>>();

var imageFiles = Directory.GetFiles(_testDirectoryColorScapes, "*.*")
.Where(file => !file.EndsWith("html"))
.Where(file => !file.Contains("_output") && !file.Contains("_baseline"))
.ToList();

foreach (var imagePath in imageFiles)
{
using var image = Image.NewFromFile(imagePath);
// Resize the image to speed up processing
var resizedImage = image.Resize(0.1);
// Convert image to RGB array
var pixels = resizedImage.WriteToMemory().ToArray();
// Convert to list of Vector3 (RGB)
var rgbPixels = new List<Vector3>();

for (var i = 0; i < pixels.Length - 2; i += 3)
{
rgbPixels.Add(new Vector3(pixels[i], pixels[i + 1], pixels[i + 2]));
}

// Add the rgbPixels list to allRgbPixels
allRgbPixels.Add(rgbPixels);
}
}

[Benchmark]
public void CalculateColorScape_original()
{
foreach (var rgbPixels in allRgbPixels)
{
Original_KMeansClustering(rgbPixels, 4);
}
}

[Benchmark]
public void CalculateColorScape_optimized()
{
foreach (var rgbPixels in allRgbPixels)
{
Services.ImageService.KMeansClustering(rgbPixels, 4);
}
}

private static List<Vector3> Original_KMeansClustering(List<Vector3> points, int k, int maxIterations = 100)
{
var random = new Random();
var centroids = points.OrderBy(x => random.Next()).Take(k).ToList();

for (var i = 0; i < maxIterations; i++)
{
var clusters = new List<Vector3>[k];
for (var j = 0; j < k; j++)
{
clusters[j] = [];
}

foreach (var point in points)
{
var nearestCentroidIndex = centroids
.Select((centroid, index) => new { Index = index, Distance = Vector3.DistanceSquared(centroid, point) })
.OrderBy(x => x.Distance)
.First().Index;
clusters[nearestCentroidIndex].Add(point);
}

var newCentroids = clusters.Select(cluster =>
cluster.Count != 0 ? new Vector3(
cluster.Average(p => p.X),
cluster.Average(p => p.Y),
cluster.Average(p => p.Z)
) : Vector3.Zero
).ToList();

if (centroids.SequenceEqual(newCentroids))
break;

centroids = newCentroids;
}

return centroids;
}

}
25 changes: 24 additions & 1 deletion API.Benchmark/ParserBenchmarks.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using System;
using System.Collections.Generic;
using System.Collections.Concurrent;
using System.IO;
using System.Text.RegularExpressions;
using BenchmarkDotNet.Attributes;
Expand Down Expand Up @@ -36,7 +37,22 @@ private static string Normalize(string name)
var normalized = NormalizeRegex.Replace(name, string.Empty).ToLower();
return string.IsNullOrEmpty(normalized) ? name : normalized;
}
private static readonly ConcurrentDictionary<string, string> NormalizedCache =
new ConcurrentDictionary<string, string>();

private static string New_Normalize(string name)
{
// Check cache first
if (NormalizedCache.TryGetValue(name, out string cachedResult))
{
return cachedResult;
}
string normalized = NormalizeRegex.Replace(name, string.Empty).Trim().ToLowerInvariant();

// Add to cache
NormalizedCache.TryAdd(name, normalized);
return normalized;
}


[Benchmark]
Expand All @@ -47,7 +63,14 @@ public void TestNormalizeName()
Normalize(name);
}
}

[Benchmark]
public void TestNormalizeName_New()
{
foreach (var name in _names)
{
New_Normalize(name);
}
}

[Benchmark]
public void TestIsEpub()
Expand Down
178 changes: 139 additions & 39 deletions API/Services/ImageService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -579,45 +579,145 @@ private static bool IsColorCloseToWhiteOrBlack(Vector3 color)
return lightness is > WhiteThreshold or < BlackThreshold;
}

private static List<Vector3> KMeansClustering(List<Vector3> points, int k, int maxIterations = 100)
{
var random = new Random();
var centroids = points.OrderBy(x => random.Next()).Take(k).ToList();

for (var i = 0; i < maxIterations; i++)
{
var clusters = new List<Vector3>[k];
for (var j = 0; j < k; j++)
{
clusters[j] = [];
}

foreach (var point in points)
{
var nearestCentroidIndex = centroids
.Select((centroid, index) => new { Index = index, Distance = Vector3.DistanceSquared(centroid, point) })
.OrderBy(x => x.Distance)
.First().Index;
clusters[nearestCentroidIndex].Add(point);
}

var newCentroids = clusters.Select(cluster =>
cluster.Count != 0 ? new Vector3(
cluster.Average(p => p.X),
cluster.Average(p => p.Y),
cluster.Average(p => p.Z)
) : Vector3.Zero
).ToList();

if (centroids.SequenceEqual(newCentroids))
break;

centroids = newCentroids;
}

return centroids;
}

public static List<Vector3> KMeansClustering(List<Vector3> points, int k, int maxIterations = 100)
{
// Initialize centroids using k-means++ for better starting positions
var centroids = InitializeCentroidsKMeansPlusPlus(points, k);

var assignments = new int[points.Count];
var clusters = new List<int>[k];
for (int i = 0; i < k; i++)
{
clusters[i] = new List<int>();
}

for (var iteration = 0; iteration < maxIterations; iteration++)
{
bool centroidsChanged = false;

foreach (var cluster in clusters)
{
cluster.Clear();
}

// Assign points to the nearest centroid
Parallel.For(0, points.Count, i =>
{
var point = points[i];
int nearestCentroidIndex = 0;
float minDistanceSquared = float.MaxValue;

for (int c = 0; c < k; c++)
{
var centroid = centroids[c];
float dx = point.X - centroid.X;
float dy = point.Y - centroid.Y;
float dz = point.Z - centroid.Z;
float distanceSquared = dx * dx + dy * dy + dz * dz;

if (distanceSquared < minDistanceSquared)
{
minDistanceSquared = distanceSquared;
nearestCentroidIndex = c;
}
}

assignments[i] = nearestCentroidIndex;
});

// Build clusters
for (int i = 0; i < points.Count; i++)
{
clusters[assignments[i]].Add(i);
}

// Update centroids
for (int c = 0; c < k; c++)
{
var cluster = clusters[c];
if (cluster.Count == 0)
continue;

float sumX = 0, sumY = 0, sumZ = 0;
foreach (var index in cluster)
{
var point = points[index];
sumX += point.X;
sumY += point.Y;
sumZ += point.Z;
}

var count = cluster.Count;
var newCentroid = new Vector3(sumX / count, sumY / count, sumZ / count);

// Check if centroids have changed significantly
if (!IsCentroidConverged(centroids[c], newCentroid))
{
centroidsChanged = true;
centroids[c] = newCentroid;
}
}

if (!centroidsChanged)
break;
}

return centroids;
}
// K-means++ initialization for better starting centroids
private static List<Vector3> InitializeCentroidsKMeansPlusPlus(List<Vector3> points, int k)
{
var random = new Random();
var centroids = new List<Vector3> { points[random.Next(points.Count)] };
var distances = new float[points.Count];

for (int i = 1; i < k; i++)
{
float totalDistance = 0;
for (int p = 0; p < points.Count; p++)
{
var point = points[p];
var minDistance = float.MaxValue;

foreach (var centroid in centroids)
{
var dx = point.X - centroid.X;
var dy = point.Y - centroid.Y;
var dz = point.Z - centroid.Z;
var distanceSquared = dx * dx + dy * dy + dz * dz;

if (distanceSquared < minDistance)
{
minDistance = distanceSquared;
}
}
distances[p] = minDistance;
totalDistance += minDistance;
}

var targetDistance = random.NextDouble() * totalDistance;
totalDistance = 0;

for (int p = 0; p < points.Count; p++)
{
totalDistance += distances[p];
if (totalDistance >= targetDistance)
{
centroids.Add(points[p]);
break;
}
}
}

return centroids;
}

// Helper method to check centroid convergence with a tolerance
private static bool IsCentroidConverged(Vector3 oldCentroid, Vector3 newCentroid, float tolerance = 0.0001f)
{
return Vector3.DistanceSquared(oldCentroid, newCentroid) <= tolerance * tolerance;
}

public static List<Vector3> SortByBrightness(List<Vector3> colors)
{
return colors.OrderBy(c => 0.299 * c.X + 0.587 * c.Y + 0.114 * c.Z).ToList();
Expand Down
15 changes: 14 additions & 1 deletion API/Services/Tasks/Scanner/Parser/Parser.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using System;
using System.Collections.Immutable;
using System.Collections.Concurrent;
using System.Globalization;
using System.IO;
using System.Linq;
Expand Down Expand Up @@ -1106,9 +1107,21 @@ public static float MaxNumberFromRange(string range)
}
}

private static readonly ConcurrentDictionary<string, string> NormalizedCache =
new ConcurrentDictionary<string, string>();

public static string Normalize(string name)
{
return NormalizeRegex.Replace(name, string.Empty).Trim().ToLower();
// Check cache first
if (NormalizedCache.TryGetValue(name, out string cachedResult))
F0x1 marked this conversation as resolved.
Show resolved Hide resolved
{
return cachedResult;
}
string normalized = NormalizeRegex.Replace(name, string.Empty).Trim().ToLowerInvariant();

// Add to cache
NormalizedCache.TryAdd(name, normalized);
return normalized;
}

/// <summary>
Expand Down