diff --git a/compiler/plc_ast/src/ast.rs b/compiler/plc_ast/src/ast.rs index 83a218d5ec..7aa3a0a546 100644 --- a/compiler/plc_ast/src/ast.rs +++ b/compiler/plc_ast/src/ast.rs @@ -600,7 +600,6 @@ pub enum AstStatement { DefaultValue(DefaultValue), // Literals Literal(AstLiteral), - CastStatement(CastStatement), MultipliedStatement(MultipliedStatement), // Expressions ReferenceExpr(ReferenceExpr), @@ -735,9 +734,6 @@ impl Debug for AstNode { } AstStatement::ContinueStatement(..) => f.debug_struct("ContinueStatement").finish(), AstStatement::ExitStatement(..) => f.debug_struct("ExitStatement").finish(), - AstStatement::CastStatement(CastStatement { target, type_name }) => { - f.debug_struct("CastStatement").field("type_name", type_name).field("target", target).finish() - } AstStatement::ReferenceExpr(ReferenceExpr { access, base }) => { f.debug_struct("ReferenceExpr").field("kind", access).field("base", base).finish() } diff --git a/compiler/plc_ast/src/lib.rs b/compiler/plc_ast/src/lib.rs index 7e7d78a7de..33febadee0 100644 --- a/compiler/plc_ast/src/lib.rs +++ b/compiler/plc_ast/src/lib.rs @@ -7,3 +7,4 @@ pub mod control_statements; pub mod literals; mod pre_processor; pub mod provider; +pub mod visitor; diff --git a/compiler/plc_ast/src/visitor.rs b/compiler/plc_ast/src/visitor.rs new file mode 100644 index 0000000000..0aab9108ea --- /dev/null +++ b/compiler/plc_ast/src/visitor.rs @@ -0,0 +1,714 @@ +//! This module defines the `AstVisitor` trait and its associated macros. +//! The `AstVisitor` trait provides a set of methods for traversing and visiting ASTs + +use crate::ast::AstNode; +use crate::ast::*; +use crate::control_statements::{AstControlStatement, ConditionalBlock, ReturnStatement}; +use crate::literals::AstLiteral; + +/// Macro that calls the visitor's `visit` method for every AstNode in the passed iterator `iter`. +macro_rules! visit_all_nodes { + ($visitor:expr, $iter:expr) => { + for node in $iter { + $visitor.visit(node); + } + }; +} + +/// Macro that calls the visitor's `visit` method for every AstNode in the passed sequence of nodes. +macro_rules! visit_nodes { + ($visitor:expr, $($node:expr),*) => { + $( + $visitor.visit($node); + )* + }; +} + +/// The `Walker` implements the traversal of the AST nodes and Ast-related objects (e.g. CompilationUnit). +/// The `walk` method is called on the object to visit its children. +/// If the object passed to a `AstVisitor`'s `visit` method implements the `Walker` trait, +/// a call to the it's walk function continues the visiting process on its children. +/// +/// Spliting the traversal logic into a separate trait allows to call the default traversal logic +/// from the visitor while overriding the visitor's `visit` method for specific nodes. +/// +/// # Example +/// ``` +/// use plc_ast::ast::AstNode; +/// use plc_ast::visitor::Walker; +/// use plc_ast::visitor::AstVisitor; +/// +/// struct MyAssignment { +/// left: AstNode, +/// right: AstNode, +/// } +/// +/// impl Walker for MyAssignment { +/// fn walk(&self, visitor: &mut V) +/// where +/// V: AstVisitor, +/// { +/// visitor.visit(&self.right); +/// visitor.visit(&self.left); +/// } +/// } +/// ``` +/// +pub trait Walker { + fn walk(&self, visitor: &mut V) + where + V: AstVisitor; +} + +/// The `AstVisitor` trait provides a set of methods for visiting different types of AST nodes. +/// Implementors can individually override the methods they are interested in. When overriding a method, +/// make sure to call `walk` on the visited statement to visit its children. DO NOT call walk on +/// the node itself to avoid a recursion (last parameter). Implementors may also decide to not call +/// the statement's `walk` method to avoid visiting the children of the statement. +/// +/// The visitor offers strongly typed `visit_X` functions for every node type. The function's signature +/// is `fn visit_X(&mut self, stmt: &X, node: &AstNode)`. The `stmt` parameter is the unwrapped, typed +/// node and the `node` parameter is the `AstNode` wrapping the stmt. The `AstNode` node offers access to location +/// information and the AstId. Note that some nodes are not wrapped in an `AstNode` node (e.g. `CompilationUnit`) +/// and therefore only the strongly typed node is passed to the `visit_X` function. +/// +/// # Example +/// ``` +/// use plc_ast::{ +/// ast::{Assignment, AstNode}, +/// visitor::{AstVisitor, Walker}, +/// }; +/// +/// struct AssignmentCounter { +/// count: usize, +/// } +/// +/// impl AstVisitor for AssignmentCounter { +/// fn visit_assignment(&mut self, stmt: &Assignment, _node: &AstNode) { +/// self.count += 1; +/// // visit child nodes +/// stmt.walk(self); +/// } +/// +/// fn visit_output_assignment(&mut self, stmt: &Assignment, _node: &AstNode) { +/// self.count += 1; +/// // visit child nodes +/// stmt.walk(self); +/// } +/// } +/// ``` +pub trait AstVisitor: Sized { + /// Visits this `AstNode`. The default implementation calls the `walk` method on the node + /// and will eventually call the strongly typed `visit` method for the node (e.g. visit_assignment + /// if the node is an `AstStatement::Assignment`). + /// # Arguments + /// * `node` - The `AstNode` node to visit. + fn visit(&mut self, node: &AstNode) { + node.walk(self) + } + + /// Visits a `CompilationUnit` node. + /// Make sure to call `walk` on the `CompilationUnit` node to visit its children. + /// # Arguments + /// * `unit` - The unwraped, typed `CompilationUnit` node to visit. + /// * `node` - The wrapped `AstNode` node to visit. Offers access to location information and AstId + fn visit_compilation_unit(&mut self, unit: &CompilationUnit) { + unit.walk(self) + } + + /// Visits an `Implementation` node. + /// Make sure to call `walk` on the `Implementation` node to visit its children. + /// # Arguments + /// * `implementation` - The unwraped, typed `Implementation` node to visit. + fn visit_implementation(&mut self, implementation: &Implementation) { + implementation.walk(self); + } + + /// Visits a `DataTypeDeclaration` node. + /// Make sure to call `walk` on the `VariableBlock` node to visit its children. + /// # Arguments + /// * `block` - The unwraped, typed `VariableBlock` node to visit. + fn visit_variable_block(&mut self, block: &VariableBlock) { + block.walk(self) + } + + /// Visits a `Variable` node. + /// Make sure to call `walk` on the `Variable` node to visit its children. + /// # Arguments + /// * `variable` - The unwraped, typed `Variable` node to visit. + fn visit_variable(&mut self, variable: &Variable) { + variable.walk(self); + } + + /// Visits an enum element `AstNode` node. + /// Make sure to call `walk` on the `AstNode` node to visit its children. + /// # Arguments + /// * `element` - The unwraped, typed `AstNode` node to visit. + fn visit_enum_element(&mut self, element: &AstNode) { + element.walk(self); + } + + /// Visits a `DataTypeDeclaration` node. + /// Make sure to call `walk` on the `DataTypeDeclaration` node to visit its children. + /// # Arguments + /// * `data_type_declaration` - The unwraped, typed `DataTypeDeclaration` node to visit. + fn visit_data_type_declaration(&mut self, data_type_declaration: &DataTypeDeclaration) { + data_type_declaration.walk(self); + } + + /// Visits a `UserTypeDeclaration` node. + /// Make sure to call `walk` on the `UserTypeDeclaration` node to visit its children. + /// # Arguments + /// * `user_type` - The unwraped, typed `UserTypeDeclaration` node to visit. + fn visit_user_type_declaration(&mut self, user_type: &UserTypeDeclaration) { + user_type.walk(self); + } + + /// Visits a `UserTypeDeclaration` node. + /// Make sure to call `walk` on the `DataType` node to visit its children. + /// # Arguments + /// * `data_type` - The unwraped, typed `DataType` node to visit. + fn visit_data_type(&mut self, data_type: &DataType) { + data_type.walk(self); + } + + /// Visits a `Pou` node. + /// Make sure to call `walk` on the `Pou` node to visit its children. + /// # Arguments + /// * `pou` - The unwraped, typed `Pou` node to visit. + fn visit_pou(&mut self, pou: &Pou) { + pou.walk(self); + } + + /// Visits an `EmptyStatement` node. + /// # Arguments + /// * `stmt` - The unwraped, typed `EmptyStatement` node to visit. + /// * `node` - The wrapped `AstNode` node to visit. Offers access to location information and AstId + fn visit_empty_statement(&mut self, _stmt: &EmptyStatement, _node: &AstNode) {} + + /// Visits a `DefaultValue` node. + /// # Arguments + /// * `stmt` - The unwraped, typed `DefaultValue` node to visit. + /// * `node` - The wrapped `AstNode` node to visit. Offers access to location information and AstId + fn visit_default_value(&mut self, _stmt: &DefaultValue, _node: &AstNode) {} + + /// Visits an `AstLiteral` node. + /// Make sure to call `walk` on the `AstLiteral` node to visit its children. + /// # Arguments + /// * `stmt` - The unwraped, typed `AstLiteral` node to visit. + /// * `node` - The wrapped `AstNode` node to visit. Offers access to location information and AstId + fn visit_literal(&mut self, stmt: &AstLiteral, _node: &AstNode) { + stmt.walk(self) + } + + /// Visits a `MultipliedStatement` node. + /// Make sure to call `walk` on the `MultipliedStatement` node to visit its children. + /// # Arguments + /// * `stmt` - The unwraped, typed `MultipliedStatement` node to visit. + /// * `node` - The wrapped `AstNode` node to visit. Offers access to location information and AstId + fn visit_multiplied_statement(&mut self, stmt: &MultipliedStatement, _node: &AstNode) { + stmt.walk(self) + } + + /// Visits a `ReferenceExpr` node. + /// Make sure to call `walk` on the `ReferenceExpr` node to visit its children. + /// # Arguments + /// * `stmt` - The unwraped, typed `ReferenceExpr` node to visit. + /// * `node` - The wrapped `AstNode` node to visit. Offers access to location information and AstId + fn visit_reference_expr(&mut self, stmt: &ReferenceExpr, _node: &AstNode) { + stmt.walk(self) + } + + /// Visits an `Identifier` node. + /// Make sure to call `walk` on the `Identifier` node to visit its children. + /// # Arguments + /// * `stmt` - The unwraped, typed `Identifier` node to visit. + /// * `node` - The wrapped `AstNode` node to visit. Offers access to location information and AstId + fn visit_identifier(&mut self, _stmt: &str, _node: &AstNode) {} + + /// Visits a `DirectAccess` node. + /// Make sure to call `walk` on the `DirectAccess` node to visit its children. + /// # Arguments + /// * `stmt` - The unwraped, typed `DirectAccess` node to visit. + /// * `node` - The wrapped `AstNode` node to visit. Offers access to location information and AstId + fn visit_direct_access(&mut self, stmt: &DirectAccess, _node: &AstNode) { + stmt.walk(self) + } + + /// Visits a `HardwareAccess` node. + /// Make sure to call `walk` on the `HardwareAccess` node to visit its children. + /// # Arguments + /// * `stmt` - The unwraped, typed `HardwareAccess` node to visit. + /// * `node` - The wrapped `AstNode` node to visit. Offers access to location information and AstId + fn visit_hardware_access(&mut self, stmt: &HardwareAccess, _node: &AstNode) { + stmt.walk(self) + } + + /// Visits a `BinaryExpression` node. + /// Make sure to call `walk` on the `BinaryExpression` node to visit its children. + /// # Arguments + /// * `stmt` - The unwraped, typed `BinaryExpression` node to visit. + /// * `node` - The wrapped `AstNode` node to visit. Offers access to location information and AstId + fn visit_binary_expression(&mut self, stmt: &BinaryExpression, _node: &AstNode) { + stmt.walk(self) + } + + /// Visits a `UnaryExpression` node. + /// Make sure to call `walk` on the `UnaryExpression` node to visit its children. + /// # Arguments + /// * `stmt` - The unwraped, typed `UnaryExpression` node to visit. + /// * `node` - The wrapped `AstNode` node to visit. Offers access to location information and AstId + fn visit_unary_expression(&mut self, stmt: &UnaryExpression, _node: &AstNode) { + stmt.walk(self) + } + + /// Visits an `ExpressionList` node. + /// Make sure to call `walk` on the `Vec` node to visit its children. + /// # Arguments + /// * `stmt` - The unwraped, typed `ExpressionList` node to visit. + /// * `node` - The wrapped `AstNode` node to visit. Offers access to location information and AstId + fn visit_expression_list(&mut self, stmt: &Vec, _node: &AstNode) { + visit_all_nodes!(self, stmt); + } + + /// Visits a `ParenExpression` node. + /// Make sure to call `walk` on the inner `AstNode` node to visit its children. + /// # Arguments + /// * `inner` - The unwraped, typed inner `AstNode` node to visit. + /// * `node` - The wrapped `AstNode` node to visit. Offers access to location information and AstId + fn visit_paren_expression(&mut self, inner: &AstNode, _node: &AstNode) { + inner.walk(self) + } + + /// Visits a `RangeStatement` node. + /// Make sure to call `walk` on the `RangeStatement` node to visit its children. + /// # Arguments + /// * `stmt` - The unwraped, typed `RangeStatement` node to visit. + /// * `node` - The wrapped `AstNode` node to visit. Offers access to location information and AstId + fn visit_range_statement(&mut self, stmt: &RangeStatement, _node: &AstNode) { + stmt.walk(self) + } + + /// Visits a `VlaRangeStatement` node. + /// # Arguments + /// * `node` - The wrapped `AstNode` node to visit. Offers access to location information and AstId + fn visit_vla_range_statement(&mut self, _node: &AstNode) {} + + /// Visits an `Assignment` node. + /// Make sure to call `walk` on the `Assignment` node to visit its children. + /// # Arguments + /// * `stmt` - The unwraped, typed `Assignment` node to visit. + /// * `node` - The wrapped `AstNode` node to visit. Offers access to location information and AstId + fn visit_assignment(&mut self, stmt: &Assignment, _node: &AstNode) { + stmt.walk(self) + } + + /// Visits an `OutputAssignment` node. + /// Make sure to call `walk` on the `Assignment` node to visit its children. + /// # Arguments + /// * `stmt` - The unwraped, typed `Assignment` node to visit. + /// * `node` - The wrapped `AstNode` node to visit. Offers access to location information and AstId + fn visit_output_assignment(&mut self, stmt: &Assignment, _node: &AstNode) { + stmt.walk(self) + } + + /// Visits a `CallStatement` node. + /// Make sure to call `walk` on the `CallStatement` node to visit its children. + /// # Arguments + /// * `stmt` - The unwraped, typed `CallStatement` node to visit. + /// * `node` - The wrapped `AstNode` node to visit. Offers access to location information and AstId + fn visit_call_statement(&mut self, stmt: &CallStatement, _node: &AstNode) { + stmt.walk(self) + } + + /// Visits an `AstControlStatement` node. + /// Make sure to call `walk` on the `AstControlStatement` node to visit its children. + /// # Arguments + /// * `stmt` - The unwraped, typed `AstControlStatement` node to visit. + /// * `node` - The wrapped `AstNode` node to visit. Offers access to location information and AstId + fn visit_control_statement(&mut self, stmt: &AstControlStatement, _node: &AstNode) { + stmt.walk(self) + } + + /// Visits a `CaseCondition` node. + /// Make sure to call `walk` on the child-`AstNode` node to visit its children. + /// # Arguments + /// * `stmt` - The unwraped, typed `CaseCondition` node to visit. + /// * `node` - The wrapped `AstNode` node to visit. Offers access to location information and AstId + fn visit_case_condition(&mut self, child: &AstNode, _node: &AstNode) { + child.walk(self) + } + + /// Visits an `ExitStatement` node. + /// # Arguments + /// * `node` - The wrapped `AstNode` node to visit. Offers access to location information and AstId + fn visit_exit_statement(&mut self, _node: &AstNode) {} + + /// Visits a `ContinueStatement` node. + /// # Arguments + /// * `node` - The wrapped `AstNode` node to visit. Offers access to location information and AstId + fn visit_continue_statement(&mut self, _node: &AstNode) {} + + /// Visits a `ReturnStatement` node. + /// Make sure to call `walk` on the `ReturnStatement` node to visit its children. + /// # Arguments + /// * `stmt` - The unwraped, typed `ReturnStatement` node to visit. + /// * `node` - The wrapped `AstNode` node to visit. Offers access to location information and AstId + fn visit_return_statement(&mut self, stmt: &ReturnStatement, _node: &AstNode) { + stmt.walk(self) + } + + /// Visits a `JumpStatement` node. + /// Make sure to call `walk` on the `JumpStatement` node to visit its children. + /// # Arguments + /// * `stmt` - The unwraped, typed `JumpStatement` node to visit. + /// * `node` - The wrapped `AstNode` node to visit. Offers access to location information and AstId + fn visit_jump_statement(&mut self, stmt: &JumpStatement, _node: &AstNode) { + stmt.walk(self) + } + + /// Visits a `LabelStatement` node. + /// # Arguments + /// * `stmt` - The unwraped, typed `LabelStatement` node to visit. + /// * `node` - The wrapped `AstNode` node to visit. Offers access to location information and AstId + fn visit_label_statement(&mut self, _stmt: &LabelStatement, _node: &AstNode) {} +} + +/// Helper method that walks through a slice of `ConditionalBlock` and applies the visitor's `walk` method to each node. +fn walk_conditional_blocks(visitor: &mut V, blocks: &[ConditionalBlock]) +where + V: AstVisitor, +{ + for b in blocks { + visit_nodes!(visitor, &b.condition); + visit_all_nodes!(visitor, &b.body); + } +} + +impl Walker for AstLiteral { + fn walk(&self, _visitor: &mut V) + where + V: AstVisitor, + { + // do nothing + } +} + +impl Walker for MultipliedStatement { + fn walk(&self, visitor: &mut V) + where + V: AstVisitor, + { + visitor.visit(&self.element) + } +} + +impl Walker for ReferenceExpr { + fn walk(&self, visitor: &mut V) + where + V: AstVisitor, + { + if let Some(base) = &self.base { + visitor.visit(base); + } + + match &self.access { + ReferenceAccess::Member(t) | ReferenceAccess::Index(t) | ReferenceAccess::Cast(t) => { + visitor.visit(t) + } + _ => {} + } + } +} + +impl Walker for DirectAccess { + fn walk(&self, visitor: &mut V) + where + V: AstVisitor, + { + visit_nodes!(visitor, &self.index); + } +} + +impl Walker for HardwareAccess { + fn walk(&self, visitor: &mut V) + where + V: AstVisitor, + { + visit_all_nodes!(visitor, &self.address); + } +} + +impl Walker for BinaryExpression { + fn walk(&self, visitor: &mut V) + where + V: AstVisitor, + { + visit_nodes!(visitor, &self.left, &self.right); + } +} + +impl Walker for UnaryExpression { + fn walk(&self, visitor: &mut V) + where + V: AstVisitor, + { + visit_nodes!(visitor, &self.value); + } +} + +impl Walker for Assignment { + fn walk(&self, visitor: &mut V) + where + V: AstVisitor, + { + visit_nodes!(visitor, &self.left, &self.right); + } +} + +impl Walker for RangeStatement { + fn walk(&self, visitor: &mut V) + where + V: AstVisitor, + { + visit_nodes!(visitor, &self.start, &self.end); + } +} + +impl Walker for CallStatement { + fn walk(&self, visitor: &mut V) + where + V: AstVisitor, + { + visit_nodes!(visitor, &self.operator); + if let Some(params) = &self.parameters { + visit_nodes!(visitor, params); + } + } +} + +impl Walker for AstControlStatement { + fn walk(&self, visitor: &mut V) + where + V: AstVisitor, + { + match self { + AstControlStatement::If(stmt) => { + walk_conditional_blocks(visitor, &stmt.blocks); + visit_all_nodes!(visitor, &stmt.else_block); + } + AstControlStatement::WhileLoop(stmt) | AstControlStatement::RepeatLoop(stmt) => { + visit_nodes!(visitor, &stmt.condition); + visit_all_nodes!(visitor, &stmt.body); + } + AstControlStatement::ForLoop(stmt) => { + visit_nodes!(visitor, &stmt.counter, &stmt.start, &stmt.end); + visit_all_nodes!(visitor, &stmt.by_step); + visit_all_nodes!(visitor, &stmt.body); + } + AstControlStatement::Case(stmt) => { + visit_nodes!(visitor, &stmt.selector); + walk_conditional_blocks(visitor, &stmt.case_blocks); + visit_all_nodes!(visitor, &stmt.else_block); + } + } + } +} + +impl Walker for ReturnStatement { + fn walk(&self, visitor: &mut V) + where + V: AstVisitor, + { + visit_all_nodes!(visitor, &self.condition); + } +} + +impl Walker for JumpStatement { + fn walk(&self, visitor: &mut V) + where + V: AstVisitor, + { + visit_nodes!(visitor, &self.condition, &self.target); + } +} + +impl Walker for AstNode { + fn walk(&self, visitor: &mut V) + where + V: AstVisitor, + { + let node = self; + match &self.stmt { + AstStatement::EmptyStatement(stmt) => visitor.visit_empty_statement(stmt, node), + AstStatement::DefaultValue(stmt) => visitor.visit_default_value(stmt, node), + AstStatement::Literal(stmt) => visitor.visit_literal(stmt, node), + AstStatement::MultipliedStatement(stmt) => visitor.visit_multiplied_statement(stmt, node), + AstStatement::ReferenceExpr(stmt) => visitor.visit_reference_expr(stmt, node), + AstStatement::Identifier(stmt) => visitor.visit_identifier(stmt, node), + AstStatement::DirectAccess(stmt) => visitor.visit_direct_access(stmt, node), + AstStatement::HardwareAccess(stmt) => visitor.visit_hardware_access(stmt, node), + AstStatement::BinaryExpression(stmt) => visitor.visit_binary_expression(stmt, node), + AstStatement::UnaryExpression(stmt) => visitor.visit_unary_expression(stmt, node), + AstStatement::ExpressionList(stmt) => visitor.visit_expression_list(stmt, node), + AstStatement::ParenExpression(stmt) => visitor.visit_paren_expression(stmt, node), + AstStatement::RangeStatement(stmt) => visitor.visit_range_statement(stmt, node), + AstStatement::VlaRangeStatement => visitor.visit_vla_range_statement(node), + AstStatement::Assignment(stmt) => visitor.visit_assignment(stmt, node), + AstStatement::OutputAssignment(stmt) => visitor.visit_output_assignment(stmt, node), + AstStatement::CallStatement(stmt) => visitor.visit_call_statement(stmt, node), + AstStatement::ControlStatement(stmt) => visitor.visit_control_statement(stmt, node), + AstStatement::CaseCondition(stmt) => visitor.visit_case_condition(stmt, node), + AstStatement::ExitStatement(_stmt) => visitor.visit_exit_statement(node), + AstStatement::ContinueStatement(_stmt) => visitor.visit_continue_statement(node), + AstStatement::ReturnStatement(stmt) => visitor.visit_return_statement(stmt, node), + AstStatement::JumpStatement(stmt) => visitor.visit_jump_statement(stmt, node), + AstStatement::LabelStatement(stmt) => visitor.visit_label_statement(stmt, node), + } + } +} + +impl Walker for CompilationUnit { + fn walk(&self, visitor: &mut V) + where + V: AstVisitor, + { + for block in &self.global_vars { + visitor.visit_variable_block(block); + } + + for user_type in &self.user_types { + visitor.visit_user_type_declaration(user_type); + } + + for pou in &self.units { + visitor.visit_pou(pou); + } + + for i in &self.implementations { + visitor.visit_implementation(i); + } + } +} + +impl Walker for UserTypeDeclaration { + fn walk(&self, visitor: &mut V) + where + V: AstVisitor, + { + visitor.visit_data_type(&self.data_type); + visit_all_nodes!(visitor, &self.initializer); + } +} + +impl Walker for VariableBlock { + fn walk(&self, visitor: &mut V) + where + V: AstVisitor, + { + for v in self.variables.iter() { + visitor.visit_variable(v); + } + } +} + +impl Walker for Variable { + fn walk(&self, visitor: &mut V) + where + V: AstVisitor, + { + visit_all_nodes!(visitor, &self.address); + visitor.visit_data_type_declaration(&self.data_type_declaration); + visit_all_nodes!(visitor, &self.initializer); + } +} + +impl Walker for DataType { + fn walk(&self, visitor: &mut V) + where + V: AstVisitor, + { + match self { + DataType::StructType { variables, .. } => { + for v in variables.iter() { + visitor.visit_variable(v); + } + } + DataType::EnumType { elements, .. } => { + for ele in flatten_expression_list(elements) { + visitor.visit_enum_element(ele); + } + } + DataType::SubRangeType { bounds, .. } => { + visit_all_nodes!(visitor, bounds); + } + DataType::ArrayType { bounds, referenced_type, .. } => { + visitor.visit(bounds); + visitor.visit_data_type_declaration(referenced_type); + } + DataType::PointerType { referenced_type, .. } => { + visitor.visit_data_type_declaration(referenced_type); + } + DataType::StringType { size, .. } => { + visit_all_nodes!(visitor, size); + } + DataType::VarArgs { referenced_type, .. } => { + if let Some(data_type_declaration) = referenced_type { + visitor.visit_data_type_declaration(data_type_declaration); + } + } + DataType::GenericType { .. } => { + //no further visits + } + } + } +} + +impl Walker for DataTypeDeclaration { + fn walk(&self, visitor: &mut V) + where + V: AstVisitor, + { + if let DataTypeDeclaration::DataTypeDefinition { data_type, .. } = self { + visitor.visit_data_type(data_type); + } + } +} + +impl Walker for Option +where + T: Walker, +{ + fn walk(&self, visitor: &mut V) + where + V: AstVisitor, + { + if let Some(node) = self { + node.walk(visitor); + } + } +} + +impl Walker for Pou { + fn walk(&self, visitor: &mut V) + where + V: AstVisitor, + { + for block in &self.variable_blocks { + visitor.visit_variable_block(block); + } + + self.return_type.as_ref().inspect(|rt| visitor.visit_data_type_declaration(rt)); + } +} + +impl Walker for Implementation { + fn walk(&self, visitor: &mut V) + where + V: AstVisitor, + { + for n in &self.statements { + visitor.visit(n); + } + } +} diff --git a/src/builtins.rs b/src/builtins.rs index 4deca18bb1..de6de63e59 100644 --- a/src/builtins.rs +++ b/src/builtins.rs @@ -846,20 +846,6 @@ fn generate_variable_length_array_bound_function<'ink>( let offset = if is_lower { (value - 1) as u64 * 2 } else { (value - 1) as u64 * 2 + 1 }; llvm.i32_type().const_int(offset, false) } - AstStatement::CastStatement(data) => { - let ExpressionValue::RValue(value) = generator.generate_expression_value(&data.target)? else { - unreachable!() - }; - - if !value.is_int_value() { - return Err(Diagnostic::codegen_error( - format!("Expected INT value, found {}", value.get_type()), - location, - )); - }; - - value.into_int_value() - } // e.g. LOWER_BOUND(arr, idx + 3) _ => { let expression_value = generator.generate_expression(params[1])?; diff --git a/src/codegen/generators/expression_generator.rs b/src/codegen/generators/expression_generator.rs index 404526afd6..5db678157f 100644 --- a/src/codegen/generators/expression_generator.rs +++ b/src/codegen/generators/expression_generator.rs @@ -1751,7 +1751,6 @@ impl<'ink, 'b> ExpressionCodeGenerator<'ink, 'b> { } // if there is just one assignment, this may be an struct-initialization (TODO this is not very elegant :-/ ) AstStatement::Assignment { .. } => self.generate_literal_struct(literal_statement), - AstStatement::CastStatement(data) => self.generate_expression_value(&data.target), _ => Err(cannot_generate_literal()), } } diff --git a/src/parser/tests.rs b/src/parser/tests.rs index 8ee0e295b6..f9879edf50 100644 --- a/src/parser/tests.rs +++ b/src/parser/tests.rs @@ -5,6 +5,7 @@ use plc_ast::{ use plc_source::source_location::SourceLocation; // Copyright (c) 2020 Ghaith Hachem and Mathias Rieder +mod ast_visitor_tests; mod class_parser_tests; mod container_parser_tests; mod control_parser_tests; diff --git a/src/parser/tests/ast_visitor_tests.rs b/src/parser/tests/ast_visitor_tests.rs new file mode 100644 index 0000000000..c97533c520 --- /dev/null +++ b/src/parser/tests/ast_visitor_tests.rs @@ -0,0 +1,628 @@ +use plc_ast::{ + ast::LinkageType, + provider::IdProvider, + visitor::{AstVisitor, Walker}, +}; +use plc_source::source_location::SourceLocationFactory; + +use crate::{lexer, parser}; + +/// This is a simple visitor that collects all identifiers and literals in a given body +/// It is used by unit tests, to easily see if all identifiers (even the deeply nested ones) were visited +/// and therefore that all subtrees of the AST were visited. +/// +/// e.g. if we have a source code like this: +/// ``` +/// PROGRAM prg +/// foo (a := b, c => 4); +/// END_PROGRAM +/// ``` +/// The visitor should collect the following identifiers/literals: "foo", "a", "b", "c" and "4" +#[derive(Default)] +struct IdentifierCollector { + identifiers: Vec, +} + +impl AstVisitor for IdentifierCollector { + fn visit_identifier(&mut self, stmt: &str, _node: &plc_ast::ast::AstNode) { + self.identifiers.push(stmt.to_string()); + } + + fn visit_literal(&mut self, stmt: &plc_ast::literals::AstLiteral, _node: &plc_ast::ast::AstNode) { + self.identifiers.push(stmt.get_literal_value()); + } +} + +/// Helper function to create a vector of strings with all characters in the range from start to end +fn get_character_range(start: char, end: char) -> Vec { + (start as u8..=end as u8).map(|c| c as char).map(|c| c.to_string()).collect() +} + +/// Helper function to collect all identifiers in a given source code +/// using the IdentifierCollector visitor +fn collect_identifiers(src: &str) -> IdentifierCollector { + let mut visitor = IdentifierCollector::default(); + visit(src, &mut visitor); + visitor.identifiers.sort(); + visitor +} + +/// Helper function to visit a given source code with a given visitor +fn visit(src: &str, visitor: &mut impl AstVisitor) { + let id_provider = IdProvider::default(); + let (compilation_unit, _) = parser::parse( + lexer::lex_with_ids(src, id_provider.clone(), SourceLocationFactory::internal(src)), + LinkageType::Internal, + "test.st", + ); + + visitor.visit_compilation_unit(&compilation_unit) +} + +#[test] +fn test_visit_arithmetic_expressions() { + // GIVEN a source code with arithmetic expressions + // WHEN we visit all nodes in the AST + let visitor = collect_identifiers( + " + PROGRAM prg + a; + b; + c := NOT d; + e := f MOD g; + h := (i / j / k - (l + m)) * n; + END_PROGRAM", + ); + // THEN we expect to also visit subexpressions in binary and unary expressions + assert_eq!(get_character_range('a', 'n'), visitor.identifiers); +} + +#[test] +fn test_visit_expression_list() { + // GIVEN a source code with an expression list + // WHEN we visit all nodes in the AST + let visitor = collect_identifiers( + " + PROGRAM prg + a,b,c; + END_PROGRAM", + ); + // THEN we expect to visit all identifiers in the expression list + assert_eq!(get_character_range('a', 'c'), visitor.identifiers); +} + +#[test] +fn test_if_statement() { + // GIVEN a source code with an if statement + // WHEN we visit all nodes in the AST + let visitor = collect_identifiers( + " + PROGRAM prg + IF a THEN + b := c; + ELSIF d THEN + e := f; + ELSE + g := h; + END_IF; + END_PROGRAM", + ); + // THEN we expect to visit the condition, the body, the elseif condition and body and the else body + assert_eq!(get_character_range('a', 'h'), visitor.identifiers); +} + +#[test] +fn test_visit_for_loop_statement() { + // GIVEN a source code with a for loop statement + // WHEN we visit all nodes in the AST + let visitor = collect_identifiers( + " + PROGRAM prg + FOR a := b TO c BY d DO + e; + f; + END_FOR; + END_PROGRAM", + ); + // THEN we expect to visit the loop variable, the start, end and step expressions and the loop body + assert_eq!(get_character_range('a', 'f'), visitor.identifiers); +} + +#[test] +fn test_visit_while_loop_statement() { + // GIVEN a source code with a while loop statement + // WHEN we visit all nodes in the AST + let visitor = collect_identifiers( + " + PROGRAM prg + WHILE a < b DO + c; + d; + END_WHILE; + END_PROGRAM", + ); + // THEN we expect to visit the condition and the loop body + assert_eq!(get_character_range('a', 'd'), visitor.identifiers); +} +#[test] +fn test_visit_repeat_loop_statement() { + // GIVEN a source code with a repeat loop statement + // WHEN we visit all nodes in the AST + let visitor = collect_identifiers( + " + PROGRAM prg + REPEAT + a; + b; + UNTIL c > d; + END_PROGRAM", + ); + // THEN we expect to visit the loop body and the condition + assert_eq!(get_character_range('a', 'd'), visitor.identifiers); +} + +#[test] +fn test_visit_case_statement() { + // GIVEN a source code with a case statement + + let visitor = collect_identifiers( + " + PROGRAM prg + CASE a OF + b: + c; + d; + e, f: + g; + h; + ELSE + i; + j; + END_CASE; + END_PROGRAM", + ); + // THEN we expect to visit the case expression, the case labels, their bodies and the else body + assert_eq!(get_character_range('a', 'j'), visitor.identifiers); +} + +#[test] +fn test_visit_multiplied_statement() { + // GIVEN a source code with a multiplied statement + // WHEN we visit all nodes in the AST + let visitor = collect_identifiers( + " + PROGRAM prg + 3(a+b); + END_PROGRAM", + ); + // THEN we expect to visit the multiplied expression and its subexpressions + assert_eq!(get_character_range('a', 'b'), visitor.identifiers); +} + +#[test] +fn test_visist_array_expressions() { + // GIVEN a source code with array expressions + // WHEN we visit all nodes in the AST + let visitor = collect_identifiers( + " + PROGRAM prg + a[b]; + c[d,e+f]; + g[h+i][j+k]; + END_PROGRAM", + ); + // THEN we expect to visit the array expressions and the array-accessor expressions + assert_eq!(get_character_range('a', 'k'), visitor.identifiers); +} + +#[test] +fn test_visit_range_statement() { + // GIVEN a source code with range statements + // WHEN we visit all nodes in the AST + let visitor = collect_identifiers( + " + PROGRAM prg + a..b; + END_PROGRAM", + ); + // THEN we expect to visit the start and end expressions of the range + assert_eq!(get_character_range('a', 'b'), visitor.identifiers); +} + +#[test] +fn test_visit_assignment_expressions() { + // GIVEN a source code with assignment expressions + // WHEN we visit all nodes in the AST + let visitor = collect_identifiers( + " + PROGRAM prg + a := b; + c => d; + e =>; + END_PROGRAM", + ); + // THEN we expect to visit the left and right side of the assignment expressions + assert_eq!(get_character_range('a', 'e'), visitor.identifiers); +} + +#[test] +fn test_visit_direct_access_statement_expressions() { + // GIVEN a source code with direct access expressions + // WHEN we visit all nodes in the AST + let visitor = collect_identifiers( + " + PROGRAM prg + %IW1.2.3; + %MD4; + END_PROGRAM", + ); + // THEN we expect to visit all segments of the direct access + assert_eq!(get_character_range('1', '4'), visitor.identifiers); +} + +#[test] +fn test_visit_call_statements() { + // GIVEN a source code with call statements + // WHEN we visit all nodes in the AST + let visitor = collect_identifiers( + " + PROGRAM prg + a(); + b(c,d); + e(f:=(g), h=>i); + END_PROGRAM", + ); + // THEN we expect to visit the function name and all arguments + assert_eq!(get_character_range('a', 'i'), visitor.identifiers); +} + +#[test] +fn test_visit_return_statement() { + // GIVEN a source code with a return statement + // WHEN we visit all nodes in the AST + let visitor = collect_identifiers( + " + FUNCTION prg : INT + RETURN a + b; + END_PROGRAM", + ); + // THEN we expect to visit the return expression + assert_eq!(get_character_range('a', 'b'), visitor.identifiers); +} + +#[test] +fn test_visit_into_var_global() { + // GIVEN a source code with a var_global section + // WHEN we visit all nodes in the AST + let visitor = collect_identifiers( + " + VAR_GLOBAL + a : INT := c; + c : INT := d; + END_VAR", + ); + // THEN we expect to visit all initializers (variable names are no AstStatements!) + assert_eq!(get_character_range('c', 'd'), visitor.identifiers); +} + +#[test] +fn test_visit_data_type_declaration() { + // GIVEN a visitor that collects variables, enum elements and range expressions + struct FieldCollector { + fields: Vec, + } + + // This is a simple visitor that collects all field names in a datatype + impl AstVisitor for FieldCollector { + fn visit_variable(&mut self, variable: &plc_ast::ast::Variable) { + self.fields.push(variable.name.clone()); + variable.walk(self); + } + + fn visit_enum_element(&mut self, element: &plc_ast::ast::AstNode) { + if let Some(name) = element.get_flat_reference_name() { + self.fields.push(name.to_string()); + } + element.walk(self); + } + + fn visit_range_statement( + &mut self, + stmt: &plc_ast::ast::RangeStatement, + _node: &plc_ast::ast::AstNode, + ) { + if let Some((start, end)) = + stmt.start.get_flat_reference_name().zip(stmt.end.get_flat_reference_name()) + { + self.fields.push(start.to_string()); + self.fields.push(end.to_string()); + } + stmt.walk(self); + } + } + let mut visitor = FieldCollector { fields: vec![] }; + // WHEN we visit a source code with a complex datatype + visit( + " + TYPE myStruct: STRUCT + a, b, c: DINT; + s: STRING; + e: (enum1, enum2, enum3); + END_STRUCT; + END_TYPE + + TYPE MyEnum: (myEnum1, myEnum2, myEnum3); + END_TYPE + + TYPE MySubRange: INT(max..min); END_TYPE + + TYPE MyArray: ARRAY[start..end] OF INT; END_TYPE + ", + &mut visitor, + ); + + visitor.fields.sort(); + // THEN we expect to visit all fields, enum elements and range expressions + assert_eq!( + vec![ + "a", "b", "c", "e", "end", "enum1", "enum2", "enum3", "max", "min", "myEnum1", "myEnum2", + "myEnum3", "s", "start" + ], + visitor.fields + ); +} + +#[test] +fn test_count_assignments() { + // GIVEN a visitor that counts assignments + struct AssignmentCounter { + count: usize, + } + + impl AstVisitor for AssignmentCounter { + fn visit_assignment(&mut self, stmt: &plc_ast::ast::Assignment, _node: &plc_ast::ast::AstNode) { + self.count += 1; + stmt.walk(self) + } + + fn visit_output_assignment( + &mut self, + stmt: &plc_ast::ast::Assignment, + _node: &plc_ast::ast::AstNode, + ) { + self.count += 1; + stmt.walk(self) + } + } + + let id_provider = IdProvider::default(); + let (compilation_unit, _) = parser::parse( + lexer::lex_with_ids( + " + PROGRAM prg + a := b; + c => d; + e := f; + foo(a := baz(x := 2, z => 3)); + END_PROGRAM", + id_provider.clone(), + SourceLocationFactory::internal(""), + ), + LinkageType::Internal, + "test.st", + ); + + let mut visitor = AssignmentCounter { count: 0 }; + // WHEN we visit a source code with assignments + for st in &compilation_unit.implementations[0].statements { + visitor.visit(st); + } + // THEN we expect to visit all assignments + assert_eq!(6, visitor.count); +} + +#[test] +fn test_visit_datatype_initializers_statement() { + // GIVEN a source code with datatype initializers + // WHEN we visit all nodes in the AST + let visitor = collect_identifiers( + " + TYPE MyStruct: STRUCT + field1: DINT := a; + field2: DINT := (b + c); + field3: ARRAY[1..3] OF DINT := 4(d); + field4: ARRAY[4..7] OF DINT := (e, f, g, h); + field5: (i := j, k := l) := m; + + END_STRUCT + END_TYPE", + ); + // THEN we expect to visit all initializers and enum elements + let mut expected = ["1", "3", "4", "7"].iter().map(|c| c.to_string()).collect::>(); + + expected.extend(get_character_range('a', 'm')); + assert_eq!(expected, visitor.identifiers); +} + +#[test] +fn test_visit_array_declaration_statement() { + // GIVEN a source code with array declarations + // WHEN we visit all nodes in the AST + let visitor = collect_identifiers( + " + TYPE MyArray: ARRAY[(a+b)..(c+d)] OF INT; END_TYPE", + ); + // THEN we expect to visit the start and end expressions of the array + assert_eq!(get_character_range('a', 'd'), visitor.identifiers); +} + +#[test] +fn test_visit_qualified_expressions() { + // GIVEN a source code with qualified expressions + // WHEN we visit all nodes in the AST + let visitor = collect_identifiers( + " + PROGRAM prg + a.b; + c.d^.e; + f.g[h].i; + END_PROGRAM", + ); + // THEN we expect to visit all segments in the qualified expressions + assert_eq!(get_character_range('a', 'i'), visitor.identifiers); +} + +#[test] +fn test_visit_variable_block() { + // GIVEN a source code with a variable block + // WHEN we visit all nodes in the AST + let visitor = collect_identifiers( + " + PROGRAM prg + VAR_INPUT + a : INT := X; + END_VAR + VAR_OUTPUT + b : INT := Y; + END_VAR + VAR CONSTANT + c : INT + END_VAR + END_PROGRAM", + ); + // THEN we expect to visit all variables and their initializers + assert_eq!(get_character_range('X', 'Y'), visitor.identifiers); +} + +#[test] +fn test_visit_continue_exit() { + // THIS test is mainly here to cover the default visit implementation of Continue, Exit and EmptyStatement + let visitor = collect_identifiers( + " + PROGRAM prg + CONTINUE; + EXIT; + ; + END_PROGRAM", + ); + assert_eq!(0, visitor.identifiers.len()); +} + +#[test] +fn test_visit_default_value() { + // GIVEN a Visitor that visits default values + struct DefaultValueCollector { + visited: bool, + } + + // This is a simple visitor that collects all field names in a datatype + impl AstVisitor for DefaultValueCollector { + fn visit_default_value(&mut self, _stmt: &plc_ast::ast::DefaultValue, _node: &plc_ast::ast::AstNode) { + self.visited = true; + } + } + + let mut visitor = DefaultValueCollector { visited: false }; + // WHEN we visit a source code with a default value + visit( + " + VAR_GLOBAL CONSTANT + a : INT; + END_VAR + ", + &mut visitor, + ); + // THEN we expect to visit the default value + assert!(visitor.visited); +} + +#[test] +fn test_visit_direct_access() { + // GIVEN a Visitor that visits direct accesses + struct Visited { + visited: bool, + } + + impl AstVisitor for Visited { + fn visit_direct_access(&mut self, _stmt: &plc_ast::ast::DirectAccess, _node: &plc_ast::ast::AstNode) { + self.visited = true; + } + } + + let mut visitor = Visited { visited: false }; + // WHEN we visit a source code with a direct access + visit( + " + PROGRAM prg + x.1; + ", + &mut visitor, + ); + // THEN we expect to visit the direct access + assert!(visitor.visited); + + let v = collect_identifiers( + " + PROGRAM prg + x.1; + ", + ); + assert_eq!(vec!["1", "x"], v.identifiers); +} + +#[test] +fn test_invalid_case_condition() { + // this tests ensures that we visit "invalid" statements. (see parser's behavior in parse_statement) + struct Visited { + visited: bool, + } + + impl AstVisitor for Visited { + fn visit_case_condition(&mut self, _child: &plc_ast::ast::AstNode, _node: &plc_ast::ast::AstNode) { + self.visited = true; + } + } + + let mut visitor = Visited { visited: false }; + + visit( + " + PROGRAM prg + x: + ", + &mut visitor, + ); + assert!(visitor.visited); +} + +#[test] +fn test_visit_string_declaration() { + // GIVEN a source code with a string declaration + // WHEN we visit all nodes in the AST + let visitor = collect_identifiers( + " + PROGRAM prg + VAR + str: STRING(X); + END_VAR + ", + ); + // THEN we expect to visit the string length + assert_eq!(vec!["X"], visitor.identifiers); +} + +#[test] +fn test_visit_pointer_declaration() { + // GIVEN a source code with a pointer declaration + // WHEN we visit all nodes in the AST + let visitor = collect_identifiers( + " + PROGRAM prg + VAR + str: POINTER TO ARRAY[a..b] OF INT := c; + END_VAR + ", + ); + // THEN we expect to visit the pointer type and the initializer + assert_eq!(get_character_range('a', 'c'), visitor.identifiers); +} diff --git a/src/resolver.rs b/src/resolver.rs index 3e6a29dcf8..6bf25261a2 100644 --- a/src/resolver.rs +++ b/src/resolver.rs @@ -11,9 +11,8 @@ use std::hash::Hash; use plc_ast::{ ast::{ self, flatten_expression_list, Assignment, AstFactory, AstId, AstNode, AstStatement, - BinaryExpression, CastStatement, CompilationUnit, DataType, DataTypeDeclaration, DirectAccessType, - JumpStatement, Operator, Pou, ReferenceAccess, ReferenceExpr, TypeNature, UserTypeDeclaration, - Variable, + BinaryExpression, CompilationUnit, DataType, DataTypeDeclaration, DirectAccessType, JumpStatement, + Operator, Pou, ReferenceAccess, ReferenceExpr, TypeNature, UserTypeDeclaration, Variable, }, control_statements::{AstControlStatement, ReturnStatement}, literals::{Array, AstLiteral, StringValue}, @@ -1442,57 +1441,6 @@ impl<'i> TypeAnnotator<'i> { AstStatement::CallStatement(..) => { self.visit_call_statement(statement, ctx); } - AstStatement::CastStatement(CastStatement { target, type_name }, ..) => { - //see if this type really exists - let data_type = self.index.find_effective_type_info(type_name); - let statement_to_annotation = if let Some(DataTypeInformation::Enum { name, .. }) = data_type - { - //enum cast - self.visit_statement(&ctx.with_qualifier(name.to_string()), target); - //use the type of the target - let type_name = self.annotation_map.get_type_or_void(target, self.index).get_name(); - vec![(statement, type_name.to_string())] - } else if let Some(t) = data_type { - // special handling for unlucky casted-strings where caste-type does not match the literal encoding - // ´STRING#"abc"´ or ´WSTRING#'abc'´ - match (t, target.as_ref().get_stmt()) { - ( - DataTypeInformation::String { encoding: StringEncoding::Utf8, .. }, - AstStatement::Literal(AstLiteral::String(StringValue { - value, - is_wide: is_wide @ true, - })), - ) - | ( - DataTypeInformation::String { encoding: StringEncoding::Utf16, .. }, - AstStatement::Literal(AstLiteral::String(StringValue { - value, - is_wide: is_wide @ false, - })), - ) => { - // visit the target-statement as if the programmer used the correct quotes to prevent - // a utf16 literal-global-variable that needs to be casted back to utf8 or vice versa - self.visit_statement( - ctx, - &AstNode::new_literal( - AstLiteral::new_string(value.clone(), !is_wide), - target.get_id(), - target.get_location(), - ), - ); - } - _ => {} - } - vec![(statement, t.get_name().to_string()), (target, t.get_name().to_string())] - } else { - //unknown type? what should we do here? - self.visit_statement(ctx, target); - vec![] - }; - for (stmt, annotation) in statement_to_annotation { - self.annotate(stmt, StatementAnnotation::value(annotation)); - } - } AstStatement::ReferenceExpr(data, ..) => { self.visit_reference_expr(&data.access, data.base.as_deref(), statement, ctx); } diff --git a/src/validation/statement.rs b/src/validation/statement.rs index 4e6d8af18c..d995ed440f 100644 --- a/src/validation/statement.rs +++ b/src/validation/statement.rs @@ -58,18 +58,6 @@ pub fn visit_statement( AstStatement::Literal(AstLiteral::Array(Array { elements: Some(elements) })) => { visit_statement(validator, elements.as_ref(), context); } - AstStatement::CastStatement(data) => { - if let AstStatement::Literal(literal) = data.target.get_stmt() { - validate_cast_literal( - validator, - literal, - statement, - &data.type_name, - &statement.get_location(), - context, - ); - } - } AstStatement::MultipliedStatement(data) => { visit_statement(validator, &data.element, context); }