Skip to content

Commit

Permalink
Downstream only
Browse files Browse the repository at this point in the history
  • Loading branch information
alexroan committed Aug 14, 2024
1 parent 500daa2 commit b22a6ef
Show file tree
Hide file tree
Showing 12 changed files with 320 additions and 557 deletions.
563 changes: 280 additions & 283 deletions aderyn_core/src/context/investigator/callgraph_tests.rs

Large diffs are not rendered by default.

186 changes: 23 additions & 163 deletions aderyn_core/src/context/investigator/standard.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,6 @@ use crate::{

use super::StandardInvestigatorVisitor;

#[derive(PartialEq)]
pub enum StandardInvestigationStyle {
/// Picks the regular call graph (forward)
Downstream,

/// Picks the reverse call graph
Upstream,

/// Picks both the call graphs (choose this if upstream side effects also need to be tracked)
BothWays,
}

pub struct StandardInvestigator {
/// Ad-hoc Nodes that we would like to explore downstream from.
pub entry_points: Vec<NodeID>,
Expand All @@ -38,31 +26,16 @@ pub struct StandardInvestigator {
/// and only consists of [`crate::ast::FunctionDefinition`] and [`crate::ast::ModifierDefinition`]
/// These are nodes that are the *actual* starting points for traversal in the graph
pub forward_surface_points: Vec<NodeID>,

/// Same as the forward one, but acts on reverse graph.
pub backward_surface_points: Vec<NodeID>,

/// Decides what graph type to chose from [`WorkspaceContext`]
pub investigation_style: StandardInvestigationStyle,
}

#[derive(PartialEq, Clone, Copy)]
enum CurrentDFSVector {
Forward, // Going downstream
Backward, // Going upstream
UpstreamSideEffect, // Going downstream from upstream nodes
}

impl StandardInvestigator {
/// Creates a [`StandardInvestigator`] by exploring paths from given nodes. This is the starting point.
pub fn for_specific_nodes(
context: &WorkspaceContext,
nodes: &[&ASTNode],
investigation_style: StandardInvestigationStyle,
) -> super::Result<StandardInvestigator> {
let mut entry_points = vec![];
let mut forward_surface_points = vec![];
let mut backward_surface_points = vec![];

// Construct entry points
for &node in nodes {
Expand All @@ -89,42 +62,17 @@ impl StandardInvestigator {
}
}

// Construct backward surface points
for &node in nodes {
if node.node_type() == NodeType::FunctionDefinition
|| node.node_type() == NodeType::ModifierDefinition
{
if let Some(id) = node.id() {
backward_surface_points.push(id);
}
} else {
let parent_surface_point = node
.closest_ancestor_of_type(context, NodeType::FunctionDefinition)
.or_else(|| {
node.closest_ancestor_of_type(context, NodeType::ModifierDefinition)
});
if let Some(parent_surface_point) = parent_surface_point {
if let Some(parent_surface_point_id) = parent_surface_point.id() {
backward_surface_points.push(parent_surface_point_id);
}
}
}
}

Ok(StandardInvestigator {
entry_points,
forward_surface_points,
backward_surface_points,
investigation_style,
})
}

pub fn new(
context: &WorkspaceContext,
nodes: &[&ASTNode],
investigation_style: StandardInvestigationStyle,
) -> super::Result<StandardInvestigator> {
Self::for_specific_nodes(context, nodes, investigation_style)
Self::for_specific_nodes(context, nodes)
}

/// Visit the entry points and all the plausible function definitions and modifier definitions that
Expand Down Expand Up @@ -168,68 +116,25 @@ impl StandardInvestigator {

// Keep track of visited node IDs during DFS from surface nodes
let mut visited_downstream = HashSet::new();
let mut visited_upstream = HashSet::new();
let mut visited_upstream_side_effects = HashSet::new();

// Now decide, which points to visit upstream or downstream
if self.investigation_style == StandardInvestigationStyle::BothWays
|| self.investigation_style == StandardInvestigationStyle::Downstream
{
// Visit the subgraph starting from surface points
for surface_point_id in &self.forward_surface_points {
self.dfs_and_visit_subgraph(
*surface_point_id,
&mut visited_downstream,
context,
forward_callgraph,
visitor,
CurrentDFSVector::Forward,
None,
)?;
}
}

if self.investigation_style == StandardInvestigationStyle::BothWays
|| self.investigation_style == StandardInvestigationStyle::Upstream
{
// Visit the subgraph starting from surface points
for surface_point_id in &self.backward_surface_points {
self.dfs_and_visit_subgraph(
*surface_point_id,
&mut visited_upstream,
context,
reverse_callgraph,
visitor,
CurrentDFSVector::Backward,
None,
)?;
}
// Visit the subgraph starting from surface points
for surface_point_id in &self.forward_surface_points {
self.dfs_and_visit_subgraph(
*surface_point_id,
&mut visited_downstream,
context,
forward_callgraph,
visitor,
None,
)?;
}

// Collect already visited nodes so that we don't repeat visit calls on them
// while traversing through side effect nodes.
let mut blacklisted = HashSet::new();
let mut blacklisted: HashSet<i64> = HashSet::new();
blacklisted.extend(visited_downstream.iter());
blacklisted.extend(visited_upstream.iter());
blacklisted.extend(self.entry_points.iter());

if self.investigation_style == StandardInvestigationStyle::BothWays {
// Visit the subgraph from the upstream points (go downstream in forward graph)
// but do not re-visit the upstream nodes or the downstream nodes again

for surface_point_id in &visited_upstream {
self.dfs_and_visit_subgraph(
*surface_point_id,
&mut visited_upstream_side_effects,
context,
forward_callgraph,
visitor,
CurrentDFSVector::UpstreamSideEffect,
Some(&blacklisted),
)?;
}
}

Ok(())
}

Expand All @@ -241,7 +146,6 @@ impl StandardInvestigator {
context: &WorkspaceContext,
callgraph: &WorkspaceCallGraph,
visitor: &mut T,
current_investigation_direction: CurrentDFSVector,
blacklist: Option<&HashSet<NodeID>>,
) -> super::Result<()>
where
Expand All @@ -255,20 +159,10 @@ impl StandardInvestigator {

if let Some(blacklist) = blacklist {
if !blacklist.contains(&node_id) {
self.make_relevant_visit_call(
context,
node_id,
visitor,
current_investigation_direction,
)?;
self.make_relevant_visit_call(context, node_id, visitor)?;
}
} else {
self.make_relevant_visit_call(
context,
node_id,
visitor,
current_investigation_direction,
)?;
self.make_relevant_visit_call(context, node_id, visitor)?;
}

if let Some(pointing_to) = callgraph.graph.get(&node_id) {
Expand All @@ -279,7 +173,6 @@ impl StandardInvestigator {
context,
callgraph,
visitor,
current_investigation_direction,
blacklist,
)?;
}
Expand All @@ -292,7 +185,6 @@ impl StandardInvestigator {
context: &WorkspaceContext,
node_id: NodeID,
visitor: &mut T,
current_investigation_direction: CurrentDFSVector,
) -> super::Result<()>
where
T: StandardInvestigatorVisitor,
Expand All @@ -304,47 +196,15 @@ impl StandardInvestigator {
return Ok(());
}

match current_investigation_direction {
CurrentDFSVector::Forward => {
if let ASTNode::FunctionDefinition(function) = node {
visitor
.visit_downstream_function_definition(function)
.map_err(|_| super::Error::DownstreamFunctionDefinitionVisitError)?;
}
if let ASTNode::ModifierDefinition(modifier) = node {
visitor
.visit_downstream_modifier_definition(modifier)
.map_err(|_| super::Error::DownstreamModifierDefinitionVisitError)?;
}
}
CurrentDFSVector::Backward => {
if let ASTNode::FunctionDefinition(function) = node {
visitor
.visit_upstream_function_definition(function)
.map_err(|_| super::Error::UpstreamFunctionDefinitionVisitError)?;
}
if let ASTNode::ModifierDefinition(modifier) = node {
visitor
.visit_upstream_modifier_definition(modifier)
.map_err(|_| super::Error::UpstreamModifierDefinitionVisitError)?;
}
}
CurrentDFSVector::UpstreamSideEffect => {
if let ASTNode::FunctionDefinition(function) = node {
visitor
.visit_upstream_side_effect_function_definition(function)
.map_err(|_| {
super::Error::UpstreamSideEffectFunctionDefinitionVisitError
})?;
}
if let ASTNode::ModifierDefinition(modifier) = node {
visitor
.visit_upstream_side_effect_modifier_definition(modifier)
.map_err(|_| {
super::Error::UpstreamSideEffectModifierDefinitionVisitError
})?;
}
}
if let ASTNode::FunctionDefinition(function) = node {
visitor
.visit_downstream_function_definition(function)
.map_err(|_| super::Error::DownstreamFunctionDefinitionVisitError)?;
}
if let ASTNode::ModifierDefinition(modifier) = node {
visitor
.visit_downstream_modifier_definition(modifier)
.map_err(|_| super::Error::DownstreamModifierDefinitionVisitError)?;
}
}

Expand Down
1 change: 0 additions & 1 deletion aderyn_core/src/detect/high/contract_locks_ether.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,6 @@ mod contract_eth_helper {
let investigator = StandardInvestigator::new(
context,
funcs.iter().collect::<Vec<_>>().as_slice(),
StandardInvestigationStyle::Downstream,
)
.ok()?;

Expand Down
10 changes: 2 additions & 8 deletions aderyn_core/src/detect/high/delegate_call_no_address_check.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@ use std::error::Error;
use crate::ast::NodeID;

use crate::capture;
use crate::context::investigator::{
StandardInvestigationStyle, StandardInvestigator, StandardInvestigatorVisitor,
};
use crate::context::investigator::{StandardInvestigator, StandardInvestigatorVisitor};
use crate::detect::detector::IssueDetectorNamePool;
use crate::detect::helpers;
use crate::{
Expand All @@ -30,11 +28,7 @@ impl IssueDetector for DelegateCallOnUncheckedAddressDetector {
has_delegate_call_on_non_state_variable_address: false,
context,
};
let investigator = StandardInvestigator::new(
context,
&[&(func.into())],
StandardInvestigationStyle::Downstream,
)?;
let investigator = StandardInvestigator::new(context, &[&(func.into())])?;
investigator.investigate(context, &mut tracker)?;

if tracker.has_delegate_call_on_non_state_variable_address
Expand Down
8 changes: 2 additions & 6 deletions aderyn_core/src/detect/high/msg_value_in_loops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@ use crate::ast::{ASTNode, Expression, NodeID};

use crate::capture;
use crate::context::browser::ExtractMemberAccesses;
use crate::context::investigator::{
StandardInvestigationStyle, StandardInvestigator, StandardInvestigatorVisitor,
};
use crate::context::investigator::{StandardInvestigator, StandardInvestigatorVisitor};
use crate::detect::detector::IssueDetectorNamePool;
use crate::{
context::workspace_context::WorkspaceContext,
Expand Down Expand Up @@ -72,9 +70,7 @@ impl IssueDetector for MsgValueUsedInLoopDetector {

fn uses_msg_value(context: &WorkspaceContext, ast_node: &ASTNode) -> Option<bool> {
let mut tracker = MsgValueTracker::default();
let investigator =
StandardInvestigator::new(context, &[ast_node], StandardInvestigationStyle::Downstream)
.ok()?;
let investigator = StandardInvestigator::new(context, &[ast_node]).ok()?;

investigator.investigate(context, &mut tracker).ok()?;
Some(tracker.has_msg_value)
Expand Down
10 changes: 2 additions & 8 deletions aderyn_core/src/detect/high/out_of_order_retryable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@ use crate::ast::{Expression, MemberAccess, NodeID};

use crate::capture;
use crate::context::browser::ExtractFunctionCalls;
use crate::context::investigator::{
StandardInvestigationStyle, StandardInvestigator, StandardInvestigatorVisitor,
};
use crate::context::investigator::{StandardInvestigator, StandardInvestigatorVisitor};
use crate::detect::detector::IssueDetectorNamePool;
use crate::detect::helpers;
use crate::{
Expand All @@ -29,11 +27,7 @@ impl IssueDetector for OutOfOrderRetryableDetector {
let mut tracker = OutOfOrderRetryableTracker {
number_of_retry_calls: 0,
};
let investigator = StandardInvestigator::new(
context,
&[&(func.into())],
StandardInvestigationStyle::Downstream,
)?;
let investigator = StandardInvestigator::new(context, &[&(func.into())])?;
investigator.investigate(context, &mut tracker)?;
if tracker.number_of_retry_calls >= 2 {
capture!(self, context, func);
Expand Down
10 changes: 2 additions & 8 deletions aderyn_core/src/detect/high/send_ether_no_checks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@ use std::error::Error;
use crate::ast::NodeID;

use crate::capture;
use crate::context::investigator::{
StandardInvestigationStyle, StandardInvestigator, StandardInvestigatorVisitor,
};
use crate::context::investigator::{StandardInvestigator, StandardInvestigatorVisitor};
use crate::context::workspace_context::ASTNode;
use crate::detect::detector::IssueDetectorNamePool;
use crate::detect::helpers;
Expand All @@ -27,11 +25,7 @@ impl IssueDetector for SendEtherNoChecksDetector {
fn detect(&mut self, context: &WorkspaceContext) -> Result<bool, Box<dyn Error>> {
for func in helpers::get_implemented_external_and_public_functions(context) {
let mut tracker = MsgSenderAndCallWithValueTracker::default();
let investigator = StandardInvestigator::new(
context,
&[&(func.into())],
StandardInvestigationStyle::Downstream,
)?;
let investigator = StandardInvestigator::new(context, &[&(func.into())])?;
investigator.investigate(context, &mut tracker)?;

if tracker.sends_native_eth && !tracker.has_msg_sender_checks {
Expand Down
Loading

0 comments on commit b22a6ef

Please sign in to comment.