Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[naga] Implement binding_array function arguments #6523

Open
wants to merge 3 commits into
base: trunk
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions naga/src/back/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,7 @@ impl crate::TypeInner {
crate::TypeInner::Image { .. }
| crate::TypeInner::Sampler { .. }
| crate::TypeInner::AccelerationStructure { .. } => true,
crate::TypeInner::BindingArray { .. } => true,
_ => false,
}
}
Expand Down
34 changes: 31 additions & 3 deletions naga/src/back/spv/block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -423,8 +423,19 @@ impl<'w> BlockContext<'w> {
};

let binding_type_id = self.get_type_id(LookupType::Handle(binding_type));

let load_id = self.gen_id();

// Map `binding_type_id` back to the original pointer if it is an opaque
// type.
match self.ir_module.types[binding_type].inner {
crate::TypeInner::Image { .. }
| crate::TypeInner::Sampler { .. }
| crate::TypeInner::AccelerationStructure => {
self.function_arg_ids.insert(load_id, result_id);
}
_ => {}
}

block.body.push(Instruction::load(
binding_type_id,
load_id,
Expand Down Expand Up @@ -514,8 +525,19 @@ impl<'w> BlockContext<'w> {
};

let binding_type_id = self.get_type_id(LookupType::Handle(binding_type));

let load_id = self.gen_id();

// Map `binding_type_id` back to the original pointer if it is an opaque
// type.
match self.ir_module.types[binding_type].inner {
crate::TypeInner::Image { .. }
| crate::TypeInner::Sampler { .. }
| crate::TypeInner::AccelerationStructure => {
self.function_arg_ids.insert(load_id, result_id);
}
_ => {}
}

block.body.push(Instruction::load(
binding_type_id,
load_id,
Expand Down Expand Up @@ -2668,7 +2690,13 @@ impl<'w> BlockContext<'w> {
let id = self.gen_id();
self.temp_list.clear();
for &argument in arguments {
self.temp_list.push(self.cached[argument]);
// Check if we should use the `argument_id` directly or the pointer to it.
let argument_id = self.cached[argument];
let argument_id = self
.function_arg_ids
.get(&argument_id)
.map_or(argument_id, |&pointer_id| pointer_id);
self.temp_list.push(argument_id);
}

let type_id = match result {
Expand Down
10 changes: 10 additions & 0 deletions naga/src/back/spv/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,7 @@ impl CachedExpressions {
self.ids.resize(length, 0);
}
}

impl ops::Index<Handle<crate::Expression>> for CachedExpressions {
type Output = Word;
fn index(&self, h: Handle<crate::Expression>) -> &Word {
Expand Down Expand Up @@ -689,6 +690,10 @@ struct BlockContext<'w> {
/// SPIR-V ids for expressions we've evaluated.
cached: CachedExpressions,

/// The pointers of the cached expressions' SPIR-V ids from [`BlockContext::cached`].
/// Only used when loaded opaque types need to be passed to a function call.
function_arg_ids: crate::FastIndexMap<Word, Word>,

/// The `Writer`'s temporary vector, for convenience.
temp_list: Vec<Word>,

Expand Down Expand Up @@ -762,6 +767,11 @@ pub struct Writer {
// retain the table here between functions to save heap allocations.
saved_cached: CachedExpressions,

// Maps the expression ids from `saved_cached` to the pointer id they were loaded from.
// Only used when opaque types need to be passed to a function call.
// This goes alongside `saved_cached`, so it too is only meaningful within a BlockContext.
function_arg_ids: crate::FastIndexMap<Word, Word>,

gl450_ext_inst_id: Word,

// Just a temporary list of SPIR-V ids
Expand Down
7 changes: 7 additions & 0 deletions naga/src/back/spv/recyclable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,13 @@ impl<K, S: Clone> Recyclable for indexmap::IndexSet<K, S> {
}
}

impl<K, V, S: Clone> Recyclable for indexmap::IndexMap<K, V, S> {
fn recycle(mut self) -> Self {
self.clear();
self
}
}

impl<K: Ord, V> Recyclable for std::collections::BTreeMap<K, V> {
fn recycle(mut self) -> Self {
self.clear();
Expand Down
3 changes: 3 additions & 0 deletions naga/src/back/spv/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ impl Writer {
global_variables: HandleVec::new(),
binding_map: options.binding_map.clone(),
saved_cached: CachedExpressions::default(),
function_arg_ids: crate::FastIndexMap::default(),
gl450_ext_inst_id,
temp_list: Vec::new(),
})
Expand Down Expand Up @@ -130,6 +131,7 @@ impl Writer {
cached_constants: take(&mut self.cached_constants).recycle(),
global_variables: take(&mut self.global_variables).recycle(),
saved_cached: take(&mut self.saved_cached).recycle(),
function_arg_ids: take(&mut self.function_arg_ids).recycle(),
temp_list: take(&mut self.temp_list).recycle(),
};

Expand Down Expand Up @@ -585,6 +587,7 @@ impl Writer {
function: &mut function,
// Re-use the cached expression table from prior functions.
cached: std::mem::take(&mut self.saved_cached),
function_arg_ids: std::mem::take(&mut self.function_arg_ids),

// Steal the Writer's temp list for a bit.
temp_list: std::mem::take(&mut self.temp_list),
Expand Down
9 changes: 9 additions & 0 deletions naga/src/front/spv/image.rs
Original file line number Diff line number Diff line change
Expand Up @@ -592,6 +592,12 @@ impl<I: Iterator<Item = u32>> super::Frontend<I> {
}
}

crate::Expression::FunctionArgument(i) => {
match ctx.type_arena[ctx.arguments[i as usize].ty].inner {
crate::TypeInner::BindingArray { base, .. } => base,
_ => return Err(Error::InvalidGlobalVar(ctx.expressions[base].clone())),
}
}
ref other => return Err(Error::InvalidGlobalVar(other.clone())),
},

Expand All @@ -611,6 +617,9 @@ impl<I: Iterator<Item = u32>> super::Frontend<I> {
crate::Expression::GlobalVariable(handle) => {
*self.handle_sampling.get_mut(&handle).unwrap() |= sampling_bit;
}
crate::Expression::FunctionArgument(i) => {
ctx.parameter_sampling[i as usize] |= sampling_bit;
}

ref other => return Err(Error::InvalidGlobalVar(other.clone())),
},
Expand Down
11 changes: 11 additions & 0 deletions naga/src/front/spv/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4406,6 +4406,17 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
crate::Expression::FunctionArgument(i) => {
fun_parameter_sampling[i as usize] |= flags;
}
crate::Expression::Access { base, .. } => match expressions[base] {
crate::Expression::GlobalVariable(handle) => {
if let Some(sampling) = self.handle_sampling.get_mut(&handle) {
*sampling |= flags
}
}
crate::Expression::FunctionArgument(i) => {
fun_parameter_sampling[i as usize] |= flags;
}
ref other => return Err(Error::InvalidGlobalVar(other.clone())),
},
ref other => return Err(Error::InvalidGlobalVar(other.clone())),
}
}
Expand Down
4 changes: 0 additions & 4 deletions naga/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -844,9 +844,6 @@ pub enum TypeInner {
/// a binding array of samplers yields a [`Sampler`], indexing a pointer to the
/// binding array of storage buffers produces a pointer to the storage struct.
///
/// Unlike textures and samplers, binding arrays are not [`ARGUMENT`], so
/// they cannot be passed as arguments to functions.
///
/// Naga's WGSL front end supports binding arrays with the type syntax
/// `binding_array<T, N>`.
///
Expand All @@ -858,7 +855,6 @@ pub enum TypeInner {
/// [`SamplerArray`]: https://docs.rs/wgpu/latest/wgpu/enum.BindingResource.html#variant.SamplerArray
/// [`BufferArray`]: https://docs.rs/wgpu/latest/wgpu/enum.BindingResource.html#variant.BufferArray
/// [`DATA`]: crate::valid::TypeFlags::DATA
/// [`ARGUMENT`]: crate::valid::TypeFlags::ARGUMENT
/// [naga#1864]: https://github.com/gfx-rs/naga/issues/1864
BindingArray { base: Handle<Type>, size: ArraySize },
}
Expand Down
42 changes: 23 additions & 19 deletions naga/src/proc/typifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -444,27 +444,31 @@ impl<'a> ResolveContext<'a> {
space: crate::AddressSpace::Function,
})
}
crate::Expression::Load { pointer } => match *past(pointer)?.inner_with(types) {
Ti::Pointer { base, space: _ } => {
if let Ti::Atomic(scalar) = types[base].inner {
TypeResolution::Value(Ti::Scalar(scalar))
} else {
TypeResolution::Handle(base)
crate::Expression::Load { pointer } => {
let past_pointer = past(pointer)?;
match *past_pointer.inner_with(types) {
Ti::Pointer { base, space: _ } => {
if let Ti::Atomic(scalar) = types[base].inner {
TypeResolution::Value(Ti::Scalar(scalar))
} else {
TypeResolution::Handle(base)
}
}
Ti::ValuePointer {
size,
scalar,
space: _,
} => TypeResolution::Value(match size {
Some(size) => Ti::Vector { size, scalar },
None => Ti::Scalar(scalar),
}),
Ti::BindingArray { .. } => past_pointer.clone(),
ref other => {
log::error!("Pointer type {:?}", other);
return Err(ResolveError::InvalidPointer(pointer));
}
}
Ti::ValuePointer {
size,
scalar,
space: _,
} => TypeResolution::Value(match size {
Some(size) => Ti::Vector { size, scalar },
None => Ti::Scalar(scalar),
}),
ref other => {
log::error!("Pointer type {:?}", other);
return Err(ResolveError::InvalidPointer(pointer));
}
},
}
crate::Expression::ImageSample {
image,
gather: Some(_),
Expand Down
1 change: 1 addition & 0 deletions naga/src/valid/analyzer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ impl GlobalOrArgument {
crate::Expression::Access { base, .. }
| crate::Expression::AccessIndex { base, .. } => match expression_arena[base] {
crate::Expression::GlobalVariable(var) => GlobalOrArgument::Global(var),
crate::Expression::FunctionArgument(i) => GlobalOrArgument::Argument(i),
_ => return Err(ExpressionError::ExpectedGlobalOrArgument),
},
_ => return Err(ExpressionError::ExpectedGlobalOrArgument),
Expand Down
9 changes: 9 additions & 0 deletions naga/src/valid/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,7 @@ impl super::Validator {
.flags
.contains(TypeFlags::SIZED | TypeFlags::DATA) => {}
Ti::ValuePointer { .. } => {}
Ti::BindingArray { .. } => {}
ref other => {
log::error!("Loading {:?}", other);
return Err(ExpressionError::InvalidPointerType(pointer));
Expand Down Expand Up @@ -1677,6 +1678,14 @@ impl super::Validator {
_ => Err(ExpressionError::ExpectedBindingArrayType(array_ty)),
}
}
Ex::FunctionArgument(i) => {
let array_ty = function.arguments[i as usize].ty;

match module.types[array_ty].inner {
crate::TypeInner::BindingArray { base, .. } => Ok(base),
_ => Err(ExpressionError::ExpectedBindingArrayType(array_ty)),
}
}
_ => Err(ExpressionError::ExpectedGlobalVariable),
}
}
Expand Down
6 changes: 4 additions & 2 deletions naga/src/valid/type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -665,10 +665,12 @@ impl super::Validator {
}
Ti::BindingArray { base, size } => {
let type_info_mask = match size {
crate::ArraySize::Constant(_) => TypeFlags::SIZED | TypeFlags::HOST_SHAREABLE,
crate::ArraySize::Constant(_) => {
TypeFlags::SIZED | TypeFlags::HOST_SHAREABLE | TypeFlags::ARGUMENT
}
crate::ArraySize::Dynamic => {
// Final type is non-sized
TypeFlags::HOST_SHAREABLE
TypeFlags::HOST_SHAREABLE | TypeFlags::ARGUMENT
}
};
let base_info = &self.types[base.index()];
Expand Down