Skip to content

Commit

Permalink
fix: added system semantics for crossgl (#170)
Browse files Browse the repository at this point in the history
  • Loading branch information
samthakur587 authored Sep 20, 2024
1 parent 1e7b255 commit 6729060
Show file tree
Hide file tree
Showing 4 changed files with 162 additions and 86 deletions.
36 changes: 11 additions & 25 deletions crosstl/src/backend/slang/SlangAst.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,7 @@ def __init__(
functions,
global_vars,
cbuffers,
shader_type,
):
self.shader_type = shader_type
self.imports = imports
self.exports = exports
self.structs = structs
Expand All @@ -41,26 +39,6 @@ def __repr__(self):
return f"ShaderNode(imports={self.imports}, exports={self.exports}, structs={self.structs}, typedefs={self.typedefs}, functions={self.functions}), global_vars={self.global_vars}, cbuffers={self.cbuffers}"


class shaderTypeNode:
"""
Represents a shader type node in the AST.
Attributes:
vertex (bool): The vertex shader type
fragment (bool): The fragment shader type
compute (bool): The compute shader type
"""

def __init__(self, vertex=False, fragment=False, compute=False):
self.vertex = vertex
self.fragment = fragment
self.compute = compute

def __repr__(self):
return f"shaderTypeNode(vertex={self.vertex}, fragment={self.fragment}, compute={self.compute})"


class ImportNode(ASTNode):
"""
Expand Down Expand Up @@ -148,17 +126,25 @@ class FunctionNode(ASTNode):
"""

def __init__(
self, return_type, name, params, body, is_generic=False, type_function="custom"
self,
return_type,
name,
params,
body,
is_generic=False,
qualifier=None,
semantic=None,
):
self.return_type = return_type
self.name = name
self.params = params
self.body = body
self.is_generic = is_generic
self.type_function = type_function
self.qualifier = qualifier
self.semantic = semantic

def __repr__(self):
return f"FunctionNode(return_type='{self.return_type}', name='{self.name}', params={self.params}, body={self.body}, is_generic={self.is_generic}), type_function={self.type_function}"
return f"FunctionNode(return_type='{self.return_type}', name='{self.name}', params={self.params}, body={self.body}, is_generic={self.is_generic}, qualifier={self.qualifier}, semantic={self.semantic})"


class VariableNode(ASTNode):
Expand Down
147 changes: 128 additions & 19 deletions crosstl/src/backend/slang/SlangCrossGLCodeGen.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,102 @@ def __init__(self):
"TextureCube": "samplerCube",
}

self.semantic_map = {
# Vertex inputs position
"POSITION": "in_Position",
"POSITION0": "in_Position0",
"POSITION1": "in_Position1",
"POSITION2": "in_Position2",
"POSITION3": "in_Position3",
"POSITION4": "in_Position4",
"POSITION5": "in_Position5",
"POSITION6": "in_Position6",
"POSITION7": "in_Position7",
# Vertex inputs normal
"NORMAL": "in_Normal",
"NORMAL0": "in_Normal0",
"NORMAL1": "in_Normal1",
"NORMAL2": "in_Normal2",
"NORMAL3": "in_Normal3",
"NORMAL4": "in_Normal4",
"NORMAL5": "in_Normal5",
"NORMAL6": "in_Normal6",
"NORMAL7": "in_Normal7",
# Vertex inputs tangent
"TANGENT": "in_Tangent",
"TANGENT0": "in_Tangent0",
"TANGENT1": "in_Tangent1",
"TANGENT2": "in_Tangent2",
"TANGENT3": "in_Tangent3",
"TANGENT4": "in_Tangent4",
"TANGENT5": "in_Tangent5",
"TANGENT6": "in_Tangent6",
"TANGENT7": "in_Tangent7",
# Vertex inputs binormal
"BINORMAL": "in_Binormal",
"BINORMAL0": "in_Binormal0",
"BINORMAL1": "in_Binormal1",
"BINORMAL2": "in_Binormal2",
"BINORMAL3": "in_Binormal3",
"BINORMAL4": "in_Binormal4",
"BINORMAL5": "in_Binormal5",
"BINORMAL6": "in_Binormal6",
"BINORMAL7": "in_Binormal7",
# Vertex inputs color
"COLOR": "Color",
"COLOR0": "Color0",
"COLOR1": "Color1",
"COLOR2": "Color2",
"COLOR3": "Color3",
"COLOR4": "Color4",
"COLOR5": "Color5",
"COLOR6": "Color6",
"COLOR7": "Color7",
# Vertex inputs texcoord
"TEXCOORD": "TexCoord",
"TEXCOORD0": "TexCoord0",
"TEXCOORD1": "TexCoord1",
"TEXCOORD2": "TexCoord2",
"TEXCOORD3": "TexCoord3",
"TEXCOORD4": "TexCoord4",
"TEXCOORD5": "TexCoord5",
"TEXCOORD6": "TexCoord6",
# Vertex inputs instance
"FRONT_FACE": "gl_IsFrontFace",
"PRIMITIVE_ID": "gl_PrimitiveID",
"INSTANCE_ID": "gl_InstanceID",
"VERTEX_ID": "gl_VertexID",
# Vertex outputs
"SV_Position": "Out_Position",
"SV_Position0": "Out_Position0",
"SV_Position1": "Out_Position1",
"SV_Position2": "Out_Position2",
"SV_Position3": "Out_Position3",
"SV_Position4": "Out_Position4",
"SV_Position5": "Out_Position5",
"SV_Position6": "Out_Position6",
"SV_Position7": "Out_Position7",
# Fragment inputs
"SV_Target": "Out_Color",
"SV_Target0": "Out_Color0",
"SV_Target1": "Out_Color1",
"SV_Target2": "Out_Color2",
"SV_Target3": "Out_Color3",
"SV_Target4": "Out_Color4",
"SV_Target5": "Out_Color5",
"SV_Target6": "Out_Color6",
"SV_Target7": "Out_Color7",
"SV_Depth": "Out_Depth",
"SV_Depth0": "Out_Depth0",
"SV_Depth1": "Out_Depth1",
"SV_Depth2": "Out_Depth2",
"SV_Depth3": "Out_Depth3",
"SV_Depth4": "Out_Depth4",
"SV_Depth5": "Out_Depth5",
"SV_Depth6": "Out_Depth6",
"SV_Depth7": "Out_Depth7",
}

def generate(self, ast):
code = "shader main {\n"
if ast.imports:
Expand All @@ -58,7 +154,7 @@ def generate(self, ast):
if isinstance(node, StructNode):
code += f" struct {node.name} {{\n"
for member in node.members:
code += f" {self.map_type(member.vtype)} {member.name};\n"
code += f" {self.map_type(member.vtype)} {member.name} {self.map_semantic(member.semantic)};\n"
code += " }\n"
# Generate global variables
for node in ast.global_vars:
Expand All @@ -70,22 +166,24 @@ def generate(self, ast):

# Generate custom functions
for func in ast.functions:
function_type_node = func.type_function
if function_type_node == "custom":
if func.qualifier == "vertex":
code += " // Vertex Shader\n"
code += " vertex {\n"
code += self.generate_function(func)
code += " }\n\n"
elif func.qualifier == "fragment":
code += " // Fragment Shader\n"
code += " fragment {\n"
code += self.generate_function(func)
code += " }\n\n"

elif func.qualifier == "compute":
code += " // Compute Shader\n"
code += " compute {\n"
code += self.generate_function(func)
code += " }\n\n"
else:
if function_type_node.vertex:
code += f"vertex {{\n"
code += self.generate_function(func)
code += f"}}\n"
elif function_type_node.fragment:
code += f"fragment {{\n"
code += self.generate_function(func)
code += f"}}\n"
elif function_type_node.compute:
code += f"compute {{\n"
code += self.generate_function(func)
code += f"}}\n"
code += self.generate_function(func)

code += "}\n"
return code
Expand All @@ -100,10 +198,15 @@ def generate_cbuffers(self, ast):
code += " }\n"
return code

def generate_function(self, func):
params = ", ".join(f"{self.map_type(p.vtype)} {p.name}" for p in func.params)
code = f" {self.map_type(func.return_type)} {func.name}({params}) {{\n"
code += self.generate_function_body(func.body, indent=2)
def generate_function(self, func, indent=1):
code = " "
code += " " * indent
params = ", ".join(
f"{self.map_type(p.vtype)} {p.name} {self.map_semantic(p.semantic)}"
for p in func.params
)
code += f" {self.map_type(func.return_type)} {func.name}({params}) {self.map_semantic(func.semantic)} {{\n"
code += self.generate_function_body(func.body, indent=indent + 1)
code += " }\n\n"
return code

Expand Down Expand Up @@ -199,3 +302,9 @@ def map_type(self, slang_type):
if slang_type:
return self.type_map.get(slang_type, slang_type)
return slang_type

def map_semantic(self, semantic):
if semantic is not None:
return f"@ {self.semantic_map.get(semantic, semantic)}"
else:
return ""
5 changes: 1 addition & 4 deletions crosstl/src/backend/slang/SlangLexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
("COMMENT_MULTI", r"/\*[\s\S]*?\*/"),
("STRUCT", r"\bstruct\b"),
("CBUFFER", r"\bcbuffer\b"),
("TYPE_SHADER", r'\[shader\("(vertex|fragment|compute)"\)\]'),
("SHADER", r"\bshader\b"),
("STRING", r'"(?:\\.|[^"\\])*"'),
("TEXTURE2D", r"\bTexture2D\b"),
Expand All @@ -30,10 +31,6 @@
("BREAK", r"\bbreak\b"),
("CONTINUE", r"\bcontinue\b"),
("REGISTER", r"\bregister\b"),
(
"SEMANTIC",
r":\s*[A-Za-z_][A-Za-z0-9_]*",
), # Correctly capturing the entire semantic token
("STRING", r'"[^"]*"'),
("IDENTIFIER", r"[a-zA-Z_][a-zA-Z0-9_]*"),
("NUMBER", r"\d+(\.\d+)?"),
Expand Down
60 changes: 22 additions & 38 deletions crosstl/src/backend/slang/SlangParser.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ def parse_shader(self):
typedefs = []
cbuffers = []
global_variables = []
shader_type = None
while self.current_token[0] != "EOF":
if self.current_token[0] == "IMPORT":
imports.append(self.parse_import())
Expand All @@ -48,12 +47,10 @@ def parse_shader(self):
cbuffers.append(self.parse_cbuffer())
elif self.current_token[0] == "TYPEDEF":
typedefs.append(self.parse_typedef())
elif (
self.current_token[0] == "LBRACKET"
and self.tokens[self.pos + 1][0] == "SHADER"
):
shader_type = self.parse_shader_type()
functions.append(self.parse_function(shader_type))
elif self.current_token[0] == "TYPE_SHADER":
type_shader = self.current_token[1].split('"')[1]
self.eat("TYPE_SHADER")
functions.append(self.parse_function(type_shader))
elif self.current_token[0] in [
"VOID",
"FLOAT",
Expand All @@ -78,30 +75,8 @@ def parse_shader(self):
functions,
global_variables,
cbuffers,
shader_type,
)

def parse_shader_type(self):
self.eat("LBRACKET")
self.eat("SHADER")
self.eat("LPAREN")
vertex = False
fragment = False
compute = False
type_sh = self.current_token[1].split('"')[1]
if type_sh == "vertex":
vertex = True
self.eat(self.current_token[0])
elif type_sh == "fragment":
fragment = True
self.eat(self.current_token[0])
elif type_sh == "compute":
compute = True
self.eat("COMPUTE")
self.eat("RPAREN")
self.eat("RBRACKET")
return shaderTypeNode(vertex, fragment, compute)

def is_function(self):
# Look ahead to check if there's a left parenthesis after the identifier
current_pos = self.pos
Expand Down Expand Up @@ -171,9 +146,10 @@ def parse_struct(self):
var_name = self.current_token[1]
self.eat("IDENTIFIER")
semantic = None
if self.current_token[0] == "SEMANTIC":
if self.current_token[0] == "COLON":
self.eat("COLON")
semantic = self.current_token[1]
self.eat("SEMANTIC")
self.eat(self.current_token[0])
self.eat("SEMICOLON")
members.append(VariableNode(vtype, var_name, semantic))
self.eat("RBRACE")
Expand All @@ -188,7 +164,7 @@ def parse_typedef(self):
self.eat("SEMICOLON")
return TypedefNode(original_type, new_type)

def parse_function(self, shader_type="custom"):
def parse_function(self, shader_type=None):
is_generic = False
if self.current_token[0] == "GENERIC":
is_generic = True
Expand All @@ -200,10 +176,15 @@ def parse_function(self, shader_type="custom"):
self.eat("LPAREN")
params = self.parse_parameters()
self.eat("RPAREN")
if self.current_token[0] == "SEMANTIC":
self.eat("SEMANTIC")
semantic = None
if self.current_token[0] == "COLON":
self.eat("COLON")
semantic = self.current_token[1]
self.eat(self.current_token[0])
body = self.parse_block()
return FunctionNode(return_type, name, params, body, is_generic, shader_type)
return FunctionNode(
return_type, name, params, body, is_generic, shader_type, semantic
)

def parse_parameters(self):
params = []
Expand All @@ -213,12 +194,15 @@ def parse_parameters(self):
self.eat(self.current_token[0])
name = self.current_token[1]
self.eat("IDENTIFIER")
semantic = None
if self.current_token[0] == "IDENTIFIER":
struct_def = self.current_token[1]
self.eat("IDENTIFIER")
if self.current_token[0] == "SEMANTIC":
self.eat("SEMANTIC")
params.append(VariableNode(vtype + struct_def, name))
if self.current_token[0] == "COLON":
self.eat("COLON")
semantic = self.current_token[1]
self.eat(self.current_token[0])
params.append(VariableNode(vtype + struct_def, name, semantic))
if self.current_token[0] == "COMMA":
self.eat("COMMA")
return params
Expand Down

0 comments on commit 6729060

Please sign in to comment.