Skip to content

Commit

Permalink
added exponential distribution and vector graph - also refactored num…
Browse files Browse the repository at this point in the history
…erical analysis
  • Loading branch information
Jack Dermody committed Jul 25, 2024
1 parent a152626 commit cd71ccb
Show file tree
Hide file tree
Showing 24 changed files with 915 additions and 130 deletions.
6 changes: 3 additions & 3 deletions BrightData.Cuda/CudaProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -662,7 +662,7 @@ internal float FindStdDev(IDeviceMemoryPtr a, uint size, float mean, uint ai = 1
if (ptr != a)
ptr.Release();

return Convert.ToSingle(System.Math.Sqrt(total.Sum() / inputSize));
return MathF.Sqrt(total.Sum() / inputSize);
}
return 0f;
}
Expand Down Expand Up @@ -701,7 +701,7 @@ internal float EuclideanDistance(IDeviceMemoryPtr a, IDeviceMemoryPtr b, uint si
{
var ret = Allocate(size, stream);
Invoke(_euclideanDistance, stream, size, a.DevicePointer, b.DevicePointer, ret.DevicePointer, size, ai, bi, ci);
return Convert.ToSingle(System.Math.Sqrt(SumValues(ret, size)));
return MathF.Sqrt(SumValues(ret, size));
}

internal float ManhattanDistance(IDeviceMemoryPtr a, IDeviceMemoryPtr b, uint size, uint ai = 1, uint bi = 1, uint ci = 1, CuStream* stream = null)
Expand Down Expand Up @@ -730,7 +730,7 @@ internal float CosineDistance(IDeviceMemoryPtr a, IDeviceMemoryPtr b, uint size,
else if (bb.Equals(0f))
return 0.0f;
else
return 1f - (ab / (float)System.Math.Sqrt(aa) / (float)System.Math.Sqrt(bb));
return 1f - (ab / MathF.Sqrt(aa) / MathF.Sqrt(bb));
}
finally {
buffer.Release();
Expand Down
8 changes: 4 additions & 4 deletions BrightData.UnitTests/AnalysisTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -39,25 +39,25 @@ public void DateAnalysisNoMostFrequent()
[Fact]
public void IntegerAnalysis()
{
var analysis = new[] { 1, 2, 3 }.Analyze();
var analysis = new[] { 1, 2, 3 }.AnalyzeAsDoubles();
analysis.Min.Should().Be(1);
analysis.Max.Should().Be(3);
analysis.Median.Should().Be(2);
analysis.NumDistinct.Should().Be(3);
analysis.Total.Should().Be(3);
analysis.Count.Should().Be(3);
analysis.SampleStdDev.Should().Be(1);
}

[Fact]
public void IntegerAnalysis2()
{
var analysis = new[] { 1, 2, 2, 3 }.Analyze();
var analysis = new[] { 1, 2, 2, 3 }.AnalyzeAsDoubles();
analysis.Min.Should().Be(1);
analysis.Max.Should().Be(3);
analysis.Median.Should().Be(2);
analysis.NumDistinct.Should().Be(3);
analysis.Mode.Should().Be(2);
analysis.Total.Should().Be(4);
analysis.Count.Should().Be(4);
}

[Fact]
Expand Down
12 changes: 12 additions & 0 deletions BrightData.UnitTests/SpanTests.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System;
using System.Linq;
using BrightData.UnitTests.Helper;
using FluentAssertions;
using Xunit;
Expand Down Expand Up @@ -54,5 +55,16 @@ public void SearchSpan()
}
resultCount.Should().Be(1);
}

[Fact]
public void GetRankedIndices()
{
Span<float> span = stackalloc float[32];
for (var i = 0; i < 32; i++)
span[i] = 16 - i;
var indices = span.GetRankedIndices();
indices.Length.Should().Be(32);
indices.Should().ContainInConsecutiveOrder(32.AsRange().Select(i => 31 - i));
}
}
}
37 changes: 36 additions & 1 deletion BrightData.UnitTests/VectorSetTests.cs
Original file line number Diff line number Diff line change
@@ -1,14 +1,24 @@
using BrightData.UnitTests.Helper;
using System;
using BrightData.UnitTests.Helper;
using System.Linq;
using BrightData.LinearAlgebra.VectorIndexing;
using BrightData.LinearAlgebra.VectorIndexing.Helper;
using BrightData.Types;
using FluentAssertions;
using Xunit;
using Xunit.Abstractions;

namespace BrightData.UnitTests
{
public class VectorSetTests : UnitTestBase
{
readonly ITestOutputHelper _testOutputHelper;

public VectorSetTests(ITestOutputHelper testOutputHelper)
{
_testOutputHelper = testOutputHelper;
}

[Fact]
public void Average()
{
Expand Down Expand Up @@ -53,5 +63,30 @@ public void Closest()
score[0].Should().Be(2);
score[1].Should().Be(1);
}

[Fact]
public void TestVectorGraphNode()
{
var node = new IndexedFixedSizeGraphNode<float>(1);
node.Index.Should().Be(1);
node.NeighbourIndices.Length.Should().Be(0);

node.TryAddNeighbour(2, 0.9f);
node.NeighbourIndices[0].Should().Be(2);
node.NeighbourWeights[0].Should().Be(0.9f);

node.TryAddNeighbour(3, 0.8f);
node.NeighbourIndices[0].Should().Be(3);
node.NeighbourWeights[0].Should().Be(0.8f);

for(var i = 4U; i <= 10; i++)
node.TryAddNeighbour(i, 1f - 0.1f * i);
node.NeighbourIndices.Length.Should().Be(8);
node.NeighbourIndices[0].Should().Be(10);
node.NeighbourIndices[1].Should().Be(9);

node.TryAddNeighbour(20, 0.5f).Should().BeTrue();
node.TryAddNeighbour(20, 0.5f).Should().BeFalse();
}
}
}
22 changes: 19 additions & 3 deletions BrightData/Analysis/CastToDoubleNumericAnalysis.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System;
using System.Numerics;
using BrightData.Converter;
using BrightData.Types;

Expand All @@ -8,12 +9,13 @@ namespace BrightData.Analysis
/// Used to cast other numeric types to doubles for numeric analysis
/// </summary>
/// <typeparam name="T"></typeparam>
internal class CastToDoubleNumericAnalysis<T>(uint writeCount = Consts.MaxWriteCount) : IDataAnalyser<T>
where T : struct
internal class CastToDoubleNumericAnalysis<T>(uint writeCount = Consts.MaxWriteCount) : IDataAnalyser<T>, INumericAnalysis<T>
where T : unmanaged, INumber<T>
{
readonly ConvertToDouble<T> _converter = new();
ulong _count;

Check warning on line 16 in BrightData/Analysis/CastToDoubleNumericAnalysis.cs

View workflow job for this annotation

GitHub Actions / build

The field 'CastToDoubleNumericAnalysis<T>._count' is never used

public NumericAnalyser Analysis { get; } = new(writeCount);
public NumericAnalyser<double> Analysis { get; } = new(writeCount);

public void Add(T val)
{
Expand All @@ -35,5 +37,19 @@ public void WriteTo(MetaData metadata)
{
Analysis.WriteTo(metadata);
}

public T L1Norm => T.CreateSaturating(Analysis.L1Norm);
public T L2Norm => T.CreateSaturating(Analysis.L2Norm);
public T Min => T.CreateSaturating(Analysis.Min);
public T Max => T.CreateSaturating(Analysis.Max);
public T Mean => T.CreateSaturating(Analysis.Mean);
public T? SampleVariance => Analysis.SampleVariance.HasValue ? T.CreateSaturating(Analysis.SampleVariance.Value) : null;
public T? PopulationVariance => Analysis.PopulationVariance.HasValue ? T.CreateSaturating(Analysis.PopulationVariance.Value) : null;
public uint NumDistinct => Analysis.NumDistinct;
public T? SampleStdDev => Analysis.SampleStdDev.HasValue ? T.CreateSaturating(Analysis.SampleStdDev.Value) : null;
public T? PopulationStdDev => Analysis.PopulationStdDev.HasValue ? T.CreateSaturating(Analysis.PopulationStdDev.Value) : null;
public ulong Count => Analysis.Count;
public T? Median => Analysis.Median.HasValue ? T.CreateSaturating(Analysis.Median.Value) : null;
public T? Mode => Analysis.Mode.HasValue ? T.CreateSaturating(Analysis.Mode.Value) : null;
}
}
21 changes: 12 additions & 9 deletions BrightData/Analysis/LinearBinnedFrequencyAnalysis.cs
Original file line number Diff line number Diff line change
@@ -1,20 +1,22 @@
using System;
using System.Collections.Generic;
using System.Numerics;

namespace BrightData.Analysis
{
/// <summary>
/// Binned frequency analysis
/// </summary>
internal class LinearBinnedFrequencyAnalysis(double min, double max, uint numBins)
internal class LinearBinnedFrequencyAnalysis<T>(T min, T max, uint numBins)
where T : unmanaged, INumber<T>, IBinaryFloatingPointIeee754<T>
{
readonly double _step = (max - min) / numBins;
readonly T _step = (max - min) / T.CreateTruncating(numBins);
readonly ulong[] _bins = new ulong[numBins];
ulong _belowRange = 0, _aboveRange = 0;

public void Add(double val)
public void Add(T val)
{
if (double.IsNaN(val))
if (T.IsNaN(val))
return;

if (val < min)
Expand All @@ -29,23 +31,24 @@ public void Add(double val)
}
}

public IEnumerable<(double Start, double End, ulong Count)> ContinuousFrequency
public IEnumerable<(T Start, T End, ulong Count)> ContinuousFrequency
{
get
{
if (_belowRange > 0)
yield return (double.NegativeInfinity, min, _belowRange);
yield return (T.NegativeInfinity, min, _belowRange);
var index = 0;
foreach (var c in _bins) {
var val = T.CreateTruncating(index);
yield return (
min + (index * _step),
min + (index + 1) * _step,
min + (val * _step),
min + (val + T.One) * _step,
c
);
++index;
}
if(_aboveRange > 0)
yield return (max, double.PositiveInfinity, _aboveRange);
yield return (max, T.PositiveInfinity, _aboveRange);
}
}
}
Expand Down
Loading

0 comments on commit cd71ccb

Please sign in to comment.