Skip to content

Commit

Permalink
chore(website): add back on_link_find_callback
Browse files Browse the repository at this point in the history
  • Loading branch information
j-mendez committed Aug 29, 2023
1 parent 298f95b commit 9cefa7b
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 26 deletions.
17 changes: 11 additions & 6 deletions spider/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ This is a basic async example crawling a web page, add spider to your `Cargo.tom

```toml
[dependencies]
spider = "1.37.0"
spider = "1.36.5"
```

And then the code:
Expand Down Expand Up @@ -52,6 +52,7 @@ website.configuration.delay = 0; // Defaults to 0 ms due to concurrency handling
website.configuration.request_timeout = None; // Defaults to 15000 ms
website.configuration.http2_prior_knowledge = false; // Enable if you know the webserver supports http2
website.configuration.user_agent = Some("myapp/version".into()); // Defaults to using a random agent
website.on_link_find_callback = Some(|s, html| { println!("link target: {}", s); (s, html)}); // Callback to run on each link find
website.configuration.blacklist_url.get_or_insert(Default::default()).push("https://choosealicense.com/licenses/".into());
website.configuration.proxies.get_or_insert(Default::default()).push("socks5://10.1.1.1:12345".into()); // Defaults to none - proxy list.

Expand All @@ -71,6 +72,10 @@ website
.with_request_timeout(None)
.with_http2_prior_knowledge(false)
.with_user_agent(Some("myapp/version".into()))
.with_on_link_find_callback(Some(|link, html| {
println!("link target: {}", link.inner());
(link, html)
}))
.with_headers(None)
.with_blacklist_url(Some(Vec::from(["https://choosealicense.com/licenses/".into()])))
.with_proxies(None);
Expand All @@ -82,7 +87,7 @@ We have a couple optional feature flags. Regex blacklisting, jemaloc backend, gl

```toml
[dependencies]
spider = { version = "1.37.0", features = ["regex", "ua_generator"] }
spider = { version = "1.36.5", features = ["regex", "ua_generator"] }
```

1. `ua_generator`: Enables auto generating a random real User-Agent.
Expand All @@ -104,7 +109,7 @@ Move processing to a worker, drastically increases performance even if worker is

```toml
[dependencies]
spider = { version = "1.37.0", features = ["decentralized"] }
spider = { version = "1.36.5", features = ["decentralized"] }
```

```sh
Expand All @@ -125,7 +130,7 @@ Use the subscribe method to get a broadcast channel.

```toml
[dependencies]
spider = { version = "1.37.0", features = ["sync"] }
spider = { version = "1.36.5", features = ["sync"] }
```

```rust,no_run
Expand Down Expand Up @@ -155,7 +160,7 @@ Allow regex for blacklisting routes

```toml
[dependencies]
spider = { version = "1.37.0", features = ["regex"] }
spider = { version = "1.36.5", features = ["regex"] }
```

```rust,no_run
Expand All @@ -182,7 +187,7 @@ If you are performing large workloads you may need to control the crawler by ena

```toml
[dependencies]
spider = { version = "1.37.0", features = ["control"] }
spider = { version = "1.36.5", features = ["control"] }
```

```rust
Expand Down
91 changes: 71 additions & 20 deletions spider/src/website.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,10 @@ pub struct Website {
domain: Box<CaseInsensitiveString>,
/// the domain url parsed
domain_parsed: Option<Box<Url>>,
/// callback when a link is found.
pub on_link_find_callback: Option<
fn(CaseInsensitiveString, Option<String>) -> (CaseInsensitiveString, Option<String>),
>,
/// subscribe and broadcast changes
channel: Option<Arc<(broadcast::Sender<Page>, broadcast::Receiver<Page>)>>,
}
Expand All @@ -126,6 +130,7 @@ impl Website {
Ok(u) => Some(Box::new(crate::page::convert_abs_path(&u, "/"))),
_ => None,
},
on_link_find_callback: None,
channel: None,
}
}
Expand Down Expand Up @@ -510,7 +515,14 @@ impl Website {
{
let page = Page::new(&self.domain.inner(), &client).await;

self.links_visited.insert(*self.domain.clone());
self.links_visited.insert(match self.on_link_find_callback {
Some(cb) => {
let c = cb(*self.domain.clone(), None);

c.0
}
_ => *self.domain.clone(),
});

let links = HashSet::from(page.links(&base).await);

Expand Down Expand Up @@ -555,7 +567,14 @@ impl Website {
)
.await;

self.links_visited.insert(*self.domain.clone());
self.links_visited.insert(match self.on_link_find_callback {
Some(cb) => {
let c = cb(*self.domain.to_owned(), None);

c.0
}
_ => *self.domain.to_owned(),
});

match &self.channel {
Some(c) => {
Expand Down Expand Up @@ -614,8 +633,12 @@ impl Website {
let u = page.get_url();
let u = if u.is_empty() { link } else { u.into() };

self.links_visited.insert(u);
let link_result = match self.on_link_find_callback {
Some(cb) => cb(u, None),
_ => (u, None),
};

self.links_visited.insert(link_result.0);
match &self.channel {
Some(c) => {
match c.0.send(page.clone()) {
Expand Down Expand Up @@ -661,8 +684,12 @@ impl Website {
if self.is_allowed_default(&link.inner(), &blacklist_url) {
let page = Page::new(&link.inner(), &client).await;
let u = page.get_url().into();
let link_result = match self.on_link_find_callback {
Some(cb) => cb(u),
_ => u,
};

self.links_visited.insert(u);
self.links_visited.insert(link_result);
let page_links = HashSet::from(page.links(&base).await);

links.extend(page_links);
Expand Down Expand Up @@ -717,6 +744,7 @@ impl Website {
if selectors.is_some() {
let mut interval = Box::pin(tokio::time::interval(Duration::from_millis(10)));
let throttle = Box::pin(self.get_delay());
let on_link_find_callback = self.on_link_find_callback;
let shared = Arc::new((
client,
unsafe { selectors.unwrap_unchecked() },
Expand Down Expand Up @@ -766,7 +794,12 @@ impl Website {

set.spawn_on(
async move {
let page = Page::new(&link.as_ref(), &shared.0).await;
let link_result = match on_link_find_callback {
Some(cb) => cb(link, None),
_ => (link, None),
};
let page =
Page::new(&link_result.0.as_ref(), &shared.0).await;
let page_links = page.links(&shared.1).await;

match &shared.2 {
Expand Down Expand Up @@ -815,7 +848,7 @@ impl Website {
let domain = self.domain.inner().as_str();
let mut interval = Box::pin(tokio::time::interval(Duration::from_millis(10)));
let throttle = Box::pin(self.get_delay());

let on_link_find_callback = self.on_link_find_callback;
// http worker verify
let http_worker = std::env::var("SPIDER_WORKER")
.unwrap_or_else(|_| "http:".to_string())
Expand Down Expand Up @@ -865,7 +898,11 @@ impl Website {

set.spawn_on(
async move {
let link_results = link.as_ref();
let link_results = match on_link_find_callback {
Some(cb) => cb(link, None),
_ => (link, None),
};
let link_results = link_results.0.as_ref();
let page = Page::new(
&if http_worker && link_results.starts_with("https") {
link_results
Expand Down Expand Up @@ -920,6 +957,7 @@ impl Website {
let selectors = unsafe { selectors.unwrap_unchecked() };
let delay = Box::from(self.configuration.delay);
let delay_enabled = self.configuration.delay > 0;
let on_link_find_callback = self.on_link_find_callback;
let mut interval = tokio::time::interval(Duration::from_millis(10));

let mut new_links: HashSet<CaseInsensitiveString> = HashSet::new();
Expand Down Expand Up @@ -952,7 +990,11 @@ impl Website {
tokio::time::sleep(Duration::from_millis(*delay)).await;
}
let link = link.clone();
let page = Page::new(&link.as_ref(), &client).await;
let link_result = match on_link_find_callback {
Some(cb) => cb(link, None),
_ => (link, None),
};
let page = Page::new(&link_result.0.as_ref(), &client).await;
let page_links = page.links(&selectors).await;
task::yield_now().await;
new_links.extend(page_links);
Expand Down Expand Up @@ -993,12 +1035,13 @@ impl Website {
if selectors.is_some() {
self.pages = Some(Box::new(Vec::new()));
let delay = self.configuration.delay;
let on_link_find_callback = self.on_link_find_callback;
let mut interval = tokio::time::interval(Duration::from_millis(10));
let selectors = Arc::new(unsafe { selectors.unwrap_unchecked() });
let throttle = Duration::from_millis(delay);

let mut links: HashSet<CaseInsensitiveString> = HashSet::from([*self.domain.clone()]);
let mut set: JoinSet<(CaseInsensitiveString, Option<Bytes>)> = JoinSet::new();
let mut set: JoinSet<(CaseInsensitiveString, Page)> = JoinSet::new();

// crawl while links exists
loop {
Expand Down Expand Up @@ -1032,10 +1075,20 @@ impl Website {
set.spawn(async move {
drop(permit);
let page = crate::utils::fetch_page_html(&link.as_ref(), &client).await;
let page = build(&link.as_ref(), page);

let (link, _) = match on_link_find_callback {
Some(cb) => {
let c = cb(link, Some(page.get_html()));

c
}
_ => (link, Some(page.get_html())),
};

match &channel {
Some(c) => {
match c.0.send(build(&link.as_ref(), page.clone())) {
match c.0.send(page.clone()) {
_ => (),
};
}
Expand All @@ -1055,16 +1108,14 @@ impl Website {
while let Some(res) = set.join_next().await {
match res {
Ok(msg) => {
if msg.1.is_some() {
let page = build(&msg.0.as_ref(), msg.1);
let page_links = page.links(&*selectors).await;
links.extend(&page_links - &self.links_visited);
task::yield_now().await;
match self.pages.as_mut() {
Some(p) => p.push(page),
_ => (),
};
}
let page = msg.1;
let page_links = page.links(&*selectors).await;
links.extend(&page_links - &self.links_visited);
task::yield_now().await;
match self.pages.as_mut() {
Some(p) => p.push(page.clone()),
_ => (),
};
}
_ => (),
};
Expand Down

0 comments on commit 9cefa7b

Please sign in to comment.