Skip to content

Commit

Permalink
feat(parser): Add support for ASSIGN_XOR in DirectX backend (#201)
Browse files Browse the repository at this point in the history
* feat(parser): Add support for ASSIGN_XOR in DirectX backend

Signed-off-by: Maharshi Basu <[email protected]>

* test: Add test for assignment operations in DirectX backend

Signed-off-by: Maharshi Basu <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix: updated variable dec in parser to handle type a.b op c

---------

Signed-off-by: Maharshi Basu <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: samthakur587 <[email protected]>
  • Loading branch information
3 people authored Oct 12, 2024
1 parent 2e6ede3 commit d422113
Show file tree
Hide file tree
Showing 5 changed files with 165 additions and 7 deletions.
1 change: 1 addition & 0 deletions crosstl/src/backend/DirectX/DirectxLexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
("MINUS_EQUALS", r"-="),
("MULTIPLY_EQUALS", r"\*="),
("DIVIDE_EQUALS", r"/="),
("ASSIGN_XOR", r"\^="),
("AND", r"&&"),
("OR", r"\|\|"),
("DOT", r"\."),
Expand Down
44 changes: 43 additions & 1 deletion crosstl/src/backend/DirectX/DirectxParser.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,12 +207,19 @@ def parse_variable_declaration_or_assignment(self):
"BOOL",
"IDENTIFIER",
]:
# Handle variable declaration (e.g., int a = b;)
first_token = self.current_token
self.eat(self.current_token[0])
self.eat(self.current_token[0]) # Eat type or identifier
name = None

# Check for member access (e.g., a.b)
if self.current_token[0] == "IDENTIFIER":
name = self.current_token[1]
self.eat("IDENTIFIER")

if self.current_token[0] == "DOT":
name = self.parse_member_access(name)

if self.current_token[0] == "SEMICOLON":
self.eat("SEMICOLON")
return VariableNode(first_token[1], name)
Expand All @@ -222,20 +229,40 @@ def parse_variable_declaration_or_assignment(self):
"MINUS_EQUALS",
"MULTIPLY_EQUALS",
"DIVIDE_EQUALS",
"ASSIGN_XOR",
]:
# Handle assignment operators (e.g., =, +=, -=, ^=, etc.)
op = self.current_token[1]
self.eat(self.current_token[0])
value = self.parse_expression()
self.eat("SEMICOLON")
return AssignmentNode(VariableNode(first_token[1], name), value, op)

elif self.current_token[0] in [
"EQUALS",
"PLUS_EQUALS",
"MINUS_EQUALS",
"MULTIPLY_EQUALS",
"DIVIDE_EQUALS",
"ASSIGN_XOR",
]:
# Handle assignment operators (e.g., =, +=, -=, ^=, etc.)
op = self.current_token[1]
self.eat(self.current_token[0])
value = self.parse_expression()
self.eat("SEMICOLON")
return AssignmentNode(first_token[1], value, op)

elif self.current_token[0] == "DOT":
# Handle int a.b = c; case directly
left = self.parse_member_access(first_token[1])
if self.current_token[0] in [
"EQUALS",
"PLUS_EQUALS",
"MINUS_EQUALS",
"MULTIPLY_EQUALS",
"DIVIDE_EQUALS",
"ASSIGN_XOR",
]:
op = self.current_token[1]
self.eat(self.current_token[0])
Expand All @@ -245,6 +272,8 @@ def parse_variable_declaration_or_assignment(self):
else:
self.eat("SEMICOLON")
return left

# If it's not a type/identifier, it must be an expression
expr = self.parse_expression()
self.eat("SEMICOLON")
return expr
Expand Down Expand Up @@ -330,6 +359,7 @@ def parse_assignment(self):
"MINUS_EQUALS",
"MULTIPLY_EQUALS",
"DIVIDE_EQUALS",
"ASSIGN_XOR",
]:
op = self.current_token[1]
self.eat(self.current_token[0])
Expand Down Expand Up @@ -406,6 +436,14 @@ def parse_unary(self):

def parse_primary(self):
if self.current_token[0] in ["IDENTIFIER", "FLOAT", "FVECTOR"]:
if self.current_token[0] == "IDENTIFIER":
name = self.current_token[1]
self.eat("IDENTIFIER")
if self.current_token[0] == "LPAREN":
return self.parse_function_call(name)
elif self.current_token[0] == "DOT":
return self.parse_member_access(name)
return VariableNode("", name)
if self.current_token[0] in ["FLOAT", "FVECTOR"]:
type_name = self.current_token[1]
self.eat(self.current_token[0])
Expand All @@ -415,6 +453,10 @@ def parse_primary(self):
elif self.current_token[0] == "NUMBER":
value = self.current_token[1]
self.eat("NUMBER")
if self.current_token[0] == "IDENTIFIER":
name = self.current_token[1]
self.eat("IDENTIFIER")
return VariableNode(value, name)
return value
elif self.current_token[0] == "LPAREN":
self.eat("LPAREN")
Expand Down
52 changes: 46 additions & 6 deletions tests/test_backend/test_directx/test_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def test_struct_codegen():
};
struct PSOutput {
float4 out_color : SV_Target0;
float4 out_color : SV_TARGET0;
};
PSOutput PSMain(PSInput input) {
Expand Down Expand Up @@ -93,7 +93,7 @@ def test_if_codegen():
};
struct PSOutput {
float4 out_color : SV_Target0;
float4 out_color : SV_TARGET0;
};
PSOutput PSMain(PSInput input) {
Expand Down Expand Up @@ -139,7 +139,7 @@ def test_for_codegen():
};
struct PSOutput {
float4 out_color : SV_Target0;
float4 out_color : SV_TARGET0;
};
PSOutput PSMain(PSInput input) {
Expand Down Expand Up @@ -188,7 +188,7 @@ def test_else_codegen():
};
struct PSOutput {
float4 out_color : SV_Target0;
float4 out_color : SV_TARGET0;
};
PSOutput PSMain(PSInput input) {
Expand Down Expand Up @@ -238,7 +238,7 @@ def test_function_call_codegen():
};
struct PSOutput {
float4 out_color : SV_Target0;
float4 out_color : SV_TARGET0;
};
PSOutput PSMain(PSInput input) {
Expand Down Expand Up @@ -285,7 +285,7 @@ def test_else_if_codegen():
};
struct PSOutput {
float4 out_color : SV_Target0;
float4 out_color : SV_TARGET0;
};
PSOutput PSMain(PSInput input) {
Expand All @@ -309,5 +309,45 @@ def test_else_if_codegen():
pytest.fail("Else_if statement parsing or code generation not implemented.")


def test_assignment_ops_parsing():
code = """
PSOutput PSMain(PSInput input) {
PSOutput output;
output.out_color = float4(0.0, 0.0, 0.0, 1.0);
if (input.in_position.r > 0.5) {
output.out_color += input.in_position;
}
if (input.in_position.r < 0.5) {
output.out_color -= float4(0.1, 0.1, 0.1, 0.1);
}
if (input.in_position.g > 0.5) {
output.out_color *= 2.0;
}
if (input.in_position.b > 0.5) {
out_color /= 2.0;
}
if (input.in_position.r == 0.5) {
uint redValue = asuint(output.out_color.r);
output.redValue ^= 0x1;
output.out_color.r = asfloat(redValue);
}
return output;
}
"""
try:
tokens = tokenize_code(code)
ast = parse_code(tokens)
generated_code = generate_code(ast)
print(generated_code)
except SyntaxError:
pytest.fail("assignment ops parsing or code generation not implemented.")


if __name__ == "__main__":
pytest.main()
37 changes: 37 additions & 0 deletions tests/test_backend/test_directx/test_lexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,5 +111,42 @@ def test_else_if_tokenization():
pytest.fail("else_if tokenization not implemented.")


def test_assignment_ops_tokenization():
code = """
PSOutput PSMain(PSInput input) {
PSOutput output;
output.out_color = float4(0.0, 0.0, 0.0, 1.0);
if (input.in_position.r > 0.5) {
output.out_color += input.in_position;
}
if (input.in_position.r < 0.5) {
output.out_color -= float4(0.1, 0.1, 0.1, 0.1);
}
if (input.in_position.g > 0.5) {
output.out_color *= 2.0;
}
if (input.in_position.b > 0.5) {
output.out_color /= 2.0;
}
if (input.in_position.r == 0.5) {
uint redValue = asuint(output.out_color.r);
redValue ^= 0x1;
output.out_color.r = asfloat(redValue);
}
return output;
}
"""
try:
tokenize_code(code)
except SyntaxError:
pytest.fail("assign_op tokenization is not implemented.")


if __name__ == "__main__":
pytest.main()
38 changes: 38 additions & 0 deletions tests/test_backend/test_directx/test_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,5 +130,43 @@ def test_else_if_parsing():
pytest.fail("else_if parsing not implemented.")


def test_assignment_ops_parsing():
code = """
PSOutput PSMain(PSInput input) {
PSOutput output;
output.out_color = float4(0.0, 0.0, 0.0, 1.0);
if (input.in_position.r > 0.5) {
output.out_color += input.in_position;
}
if (input.in_position.r < 0.5) {
output.out_color -= float4(0.1, 0.1, 0.1, 0.1);
}
if (input.in_position.g > 0.5) {
output.out_color *= 2.0;
}
if (input.in_position.b > 0.5) {
out_color /= 2.0;
}
if (input.in_position.r == 0.5) {
uint redValue = asuint(output.out_color.r);
output.redValue ^= 0x1;
output.out_color.r = asfloat(redValue);
}
return output;
}
"""
try:
tokens = tokenize_code(code)
parse_code(tokens)
except SyntaxError:
pytest.fail("assign_op parsing not implemented.")


if __name__ == "__main__":
pytest.main()

0 comments on commit d422113

Please sign in to comment.