Skip to content

Commit

Permalink
Fixed multi-dimensional RNG number generation (#808).
Browse files Browse the repository at this point in the history
  • Loading branch information
m4rs-mt authored May 29, 2022
1 parent 966419e commit 2eae285
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 17 deletions.
29 changes: 13 additions & 16 deletions Src/ILGPU.Algorithms/Random/RNG.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// ---------------------------------------------------------------------------------------
// ILGPU Algorithms
// Copyright (c) 2021 ILGPU Project
// Copyright (c) 2021-2022 ILGPU Project
// www.ilgpu.net
//
// File: RNG.cs
Expand Down Expand Up @@ -158,17 +158,17 @@ public abstract void FillUniform(
/// <summary>
/// The maximum number of parallel groups.
/// </summary>
private readonly int groupSize;
private readonly int maxNumParallelWarps;

/// <summary>
/// Initializes the RNG view.
/// </summary>
/// <param name="providers">The random providers.</param>
/// <param name="numParallelGroups">The maximum number of parallel groups.</param>
internal RNGView(ArrayView<TRandomProvider> providers, int numParallelGroups)
/// <param name="numParallelWarps">The maximum number of parallel warps.</param>
internal RNGView(ArrayView<TRandomProvider> providers, int numParallelWarps)
{
randomProviders = providers;
groupSize = numParallelGroups;
maxNumParallelWarps = numParallelWarps;
}

#endregion
Expand All @@ -183,17 +183,17 @@ internal RNGView(ArrayView<TRandomProvider> providers, int numParallelGroups)
private readonly ref TRandomProvider GetRandomProvider()
{
// Compute the global warp index
int groupOffset = Stride3D.DenseXY.ComputeElementIndex(
Grid.Index,
Grid.Dimension) % groupSize;
int warpOffset = Group.LinearIndex;
int warpIdx = groupOffset * Warp.WarpSize + warpOffset / Warp.WarpSize;
int groupIndex = Group.LinearIndex;
int warpIndex = Warp.ComputeWarpIdx(groupIndex);
int groupStride = XMath.DivRoundUp(Group.Dimension.Size, Warp.WarpSize);
int groupOffset = Grid.LinearIndex * groupStride;
int providerIndex = groupOffset + warpIndex;

// Access the underlying provider
Trace.Assert(
warpIdx < randomProviders.Length,
providerIndex < randomProviders.Length,
"Current warp does not have a valid RNG provider");
return ref randomProviders[warpIdx];
return ref randomProviders[providerIndex];
}

/// <summary>
Expand Down Expand Up @@ -403,14 +403,11 @@ public RNGView<TRandomProvider> GetViewViaThreads(int numThreads) =>
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public RNGView<TRandomProvider> GetView(int numWarps)
{
// Ensure that the number of warps is a multiple of the warp size.
int numGroups = XMath.DivRoundUp(numWarps, Accelerator.WarpSize);
numWarps = numGroups * Accelerator.WarpSize;
Trace.Assert(
numWarps > 0 && numWarps <= randomProvidersPerWarp.Length,
"Invalid number of warps");
var subView = randomProvidersPerWarp.View.SubView(0, numWarps);
return new RNGView<TRandomProvider>(subView, numGroups);
return new RNGView<TRandomProvider>(subView, numWarps);
}

/// <summary>
Expand Down
16 changes: 15 additions & 1 deletion Src/ILGPU/Grid.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// ---------------------------------------------------------------------------------------
// ILGPU
// Copyright (c) 2017-2021 ILGPU Project
// Copyright (c) 2017-2022 ILGPU Project
// www.ilgpu.net
//
// File: Grid.cs
Expand Down Expand Up @@ -101,6 +101,13 @@ public static int DimZ
/// <returns>The grid dimension.</returns>
public static Index3D Dimension => new Index3D(DimX, DimY, DimZ);

/// <summary>
/// Returns the linear grid index of the current group within the current
/// thread grid.
/// </summary>
public static int LinearIndex =>
Stride3D.DenseXY.ComputeElementIndex(Index, Dimension);

/// <summary>
/// Returns the global index.
/// </summary>
Expand All @@ -115,6 +122,13 @@ public static int DimZ
Index,
Group.Index);

/// <summary>
/// Returns the linear thread index of the current thread within the current
/// thread grid.
/// </summary>
public static int GlobalLinearIndex =>
LinearIndex * Group.Dimension.Size + Group.LinearIndex;

#endregion

#region Methods
Expand Down

0 comments on commit 2eae285

Please sign in to comment.