Skip to content

Commit

Permalink
Merge pull request #887 from dfinity/ulan/image-classification-simd
Browse files Browse the repository at this point in the history
Use Wasm SIMD in the image classification example
  • Loading branch information
ulan authored Jun 13, 2024
2 parents 4f44852 + b6ac198 commit 1fd4b2c
Show file tree
Hide file tree
Showing 12 changed files with 130 additions and 40 deletions.
83 changes: 69 additions & 14 deletions rust/image-classification/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

22 changes: 11 additions & 11 deletions rust/image-classification/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,12 @@
This is an ICP smart contract that accepts an image from the user and runs image classification inference.
The smart contract consists of two canisters:

- the backend canister embeds the [the Tract ONNX inference engine](https://github.com/sonos/tract) with [the MobileNet v2-7 model](https://github.com/onnx/models/tree/main/validated/vision/classification/mobilenet). It provides a `classify()` endpoint for the frontend code to call.
- the backend canister embeds the [the Tract ONNX inference engine](https://github.com/sonos/tract) with [the MobileNet v2-7 model](https://github.com/onnx/models/tree/main/validated/vision/classification/mobilenet).
It provides `classify()` and `classify_query()` endpoints for the frontend code to call.
The former endpoint is used for replicated execution (running on all nodes) whereas the latter runs only on a single node.
- the frontend canister contains the Web assets such as HTML, JS, CSS that are served to the browser.

Note that currently Wasm execution is not optimized for this workload.
A single call executes about 24B instructions (~10s).

This is expected to improve in the future with:

- faster deterministic floating-point operations.
- Wasm SIMD (Single-Instruction Multiple Data).

The ICP mainnet subnets and `dfx` running a replica version older than [463296](https://dashboard.internetcomputer.org/release/463296c0bc82ad5999b70245e5f125c14ba7d090) may fail with an instruction-limit-exceeded error.
This example uses Wasm SIMD instructions that are available in `dfx` version `0.20.2-beta.0` or newer.

# Dependencies

Expand Down Expand Up @@ -45,12 +39,18 @@ Install NodeJS dependencies for the frontend:
npm install
```

Install `wasm-opt`:

```
cargo install wasm-opt
```

# Build

```
dfx start --background
dfx deploy
```

If the deployment is successfull, the it will show the `frontend` URL.
If the deployment is successful, the it will show the `frontend` URL.
Open that URL in browser to interact with the smart contract.
6 changes: 1 addition & 5 deletions rust/image-classification/dfx.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,7 @@
"package": "backend",
"type": "custom",
"wasm": "target/wasm32-wasi/release/backend-ic.wasm",
"build": [
"cargo build --release --target=wasm32-wasi",
"wasi2ic ./target/wasm32-wasi/release/backend.wasm ./target/wasm32-wasi/release/backend-ic.wasm"
]

"build": [ "bash build.sh" ]
},
"frontend": {
"dependencies": [
Expand Down
2 changes: 1 addition & 1 deletion rust/image-classification/src/backend/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,6 @@ prost = "0.11.0"
prost-types = "0.11.0"
image = { version = "0.24", features = ["png"], default-features = false }
serde = { version = "1.0", features = ["derive"] }
tract-onnx = { git = "https://github.com/sonos/tract", version = "=0.21.2-pre" }
tract-onnx = { git = "https://github.com/sonos/tract", rev = "2a2914ac29390cc08963301c9f3d437b52dd321a" }
ic-stable-structures = "0.6"
ic-wasi-polyfill = { git = "https://github.com/wasm-forge/ic-wasi-polyfill", version = "0.3.17" }
1 change: 1 addition & 0 deletions rust/image-classification/src/backend/backend.did
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,5 @@ type ClassificationResult = variant {

service : {
"classify": (image: blob) -> (ClassificationResult);
"classify_query": (image: blob) -> (ClassificationResult) query;
}
20 changes: 17 additions & 3 deletions rust/image-classification/src/backend/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use std::cell::RefCell;
use candid::{CandidType, Deserialize};
use ic_stable_structures::{memory_manager::{MemoryId, MemoryManager}, DefaultMemoryImpl};
use ic_stable_structures::{
memory_manager::{MemoryId, MemoryManager},
DefaultMemoryImpl,
};
use std::cell::RefCell;

mod onnx;

Expand Down Expand Up @@ -31,7 +34,7 @@ enum ClassificationResult {
Err(ClassificationError),
}

#[ic_cdk::update]
#[ic_cdk::query]
fn classify(image: Vec<u8>) -> ClassificationResult {
let result = match onnx::classify(image) {
Ok(result) => ClassificationResult::Ok(result),
Expand All @@ -42,6 +45,17 @@ fn classify(image: Vec<u8>) -> ClassificationResult {
result
}

#[ic_cdk::query]
fn classify_query(image: Vec<u8>) -> ClassificationResult {
let result = match onnx::classify(image) {
Ok(result) => ClassificationResult::Ok(result),
Err(err) => ClassificationResult::Err(ClassificationError {
message: err.to_string(),
}),
};
result
}

#[ic_cdk::init]
fn init() {
let wasi_memory = MEMORY_MANAGER.with(|m| m.borrow().get(WASI_MEMORY_ID));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,5 @@ type ClassificationResult = variant {

service : {
"classify": (image: blob) -> (ClassificationResult);
"classify_query": (image: blob) -> (ClassificationResult) query;
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ export type ClassificationResult = { 'Ok' : Array<Classification> } |
{ 'Err' : ClassificationError };
export interface _SERVICE {
'classify' : ActorMethod<[Uint8Array | number[]], ClassificationResult>,
'classify_query' : ActorMethod<[Uint8Array | number[]], ClassificationResult>,
}
export declare const idlFactory: IDL.InterfaceFactory;
export declare const init: (args: { IDL: typeof IDL }) => IDL.Type[];
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@ export const idlFactory = ({ IDL }) => {
});
return IDL.Service({
'classify' : IDL.Func([IDL.Vec(IDL.Nat8)], [ClassificationResult], []),
'classify_query' : IDL.Func(
[IDL.Vec(IDL.Nat8)],
[ClassificationResult],
['query'],
),
});
};
export const init = ({ IDL }) => { return []; };
9 changes: 4 additions & 5 deletions rust/image-classification/src/frontend/assets/main.css
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ textarea {
flex-flow: row;
justify-content: left;
align-items: center;
margin-bottom: 20px;
}

.toggle-switch {
Expand All @@ -71,7 +72,6 @@ textarea {
bottom: 0;
background-color: #ccc;
border-radius: 18px;
transition: 0.4s;
}

.slider:before {
Expand All @@ -83,7 +83,6 @@ textarea {
bottom: 1px;
background-color: white;
border-radius: 50%;
transition: 0.4s;
}

input:checked+.slider {
Expand Down Expand Up @@ -129,15 +128,15 @@ li {

@keyframes astrodance {
0% {
transform: translate(-50%,-50%) rotate(-20deg)
transform: translate(-50%, -50%) rotate(-20deg)
}

50% {
transform: translate(-50%,-50%) rotate(10deg)
transform: translate(-50%, -50%) rotate(10deg)
}

to {
transform: translate(-50%,-50%) rotate(-20deg)
transform: translate(-50%, -50%) rotate(-20deg)
}
}

Expand Down
9 changes: 9 additions & 0 deletions rust/image-classification/src/frontend/src/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,15 @@ <h1>ICP image classification</h1>
<input id="file" class="file" name="file" type="file" accept="image/png, image/jpeg" />
<div id="container">
<div id="message"></div>
<div class="option invisible" id="replicated_option">
<label>
<div class="toggle-switch">
<input type="checkbox" id="replicated">
<span class="slider"></span>
</div>
&nbsp; replicated execution
</label>
</div>
<img id="loader" src="loader.svg" class="loader invisible" />
<button id="classify" class="clean-button invisible" disabled>Go!</button>
</div>
Expand Down
Loading

0 comments on commit 1fd4b2c

Please sign in to comment.