Skip to content

Commit

Permalink
feat: semantic detection PoC
Browse files Browse the repository at this point in the history
  • Loading branch information
brokad committed Aug 23, 2021
1 parent 127519f commit 0e9e05e
Show file tree
Hide file tree
Showing 15 changed files with 1,393 additions and 12 deletions.
398 changes: 393 additions & 5 deletions Cargo.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
members = [
"gen",
"core",
"semdet",
"synth",
"dist/playground"
]
Expand Down
2 changes: 1 addition & 1 deletion core/src/graph/string/faker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ impl Default for Locale {
#[derive(Clone, Default, Deserialize, Debug, Serialize, PartialEq, Eq)]
pub struct FakerArgs {
#[serde(default)]
locales: Vec<Locale>,
pub locales: Vec<Locale>,
}

type FakerFunction = for<'r> fn(&'r mut dyn RngCore, &FakerArgs) -> String;
Expand Down
41 changes: 41 additions & 0 deletions semdet/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
[package]
name = "semantic-detection"
version = "0.1.0"
edition = "2018"
authors = [
"Damien Broka <[email protected]>",
]

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[lib]
name = "semantic_detection"
crate-type=["lib", "dylib"]

[features]
default = [ "dummy" ]
train = [ "pyo3" ]
dummy = [ ]
torch = [ "tch" ]

[dependencies.arrow]
version = "5.1.0"

[dependencies.fake]
version = "2.4.1"
features = ["http"]

[dependencies.pyo3]
version = "0.14.2"
optional = true
features = [ "extension-module" ]

[dependencies.tch]
version = "0.5.0"
optional = true

[dependencies.ndarray]
version = "0.15.3"

[dev-dependencies.rand]
version = "0.8.4"
22 changes: 22 additions & 0 deletions semdet/build.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
use std::env;
use std::fs;
use std::io::Result;
use std::path::{Path, PathBuf};

fn main() -> Result<()> {
let pretrained_path = env::var_os("PRETRAINED")
.map(PathBuf::from)
.unwrap_or_else(|| Path::new("train").join("dummy.tch"));
let target_path = PathBuf::from(env::var_os("OUT_DIR").unwrap()).join("pretrained.tch");
eprintln!(
"attempting to copy pretrained weights:\n\t<- {}\n\t-> {}",
pretrained_path.to_str().unwrap(),
target_path.to_str().unwrap()
);
fs::copy(&pretrained_path, &target_path)?;
println!(
"cargo:rustc-env=PRETRAINED={}",
target_path.to_str().unwrap()
);
Ok(())
}
102 changes: 102 additions & 0 deletions semdet/src/decode.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
use ndarray::{ArrayView, Ix1};

use std::convert::Infallible;

/// Trait for functions that produce a value from an input [`Array`](ndarray::Array) of prescribed
/// shape.
///
/// The type parameter `D` should probably be a [`Dimension`](ndarray::Dimension) for implementations
/// to be useful.
pub trait Decoder<D> {
type Err: std::error::Error + 'static;

/// The type of values returned.
type Value;

/// Compute and return a [`Self::Value`](Decoder::Value) from the input `tensor`.
///
/// Implementations are allowed to panic if `tensor.shape() != self.shape()`.
fn decode(&self, tensor: ArrayView<f32, D>) -> Result<Self::Value, Self::Err>;

/// The shape that is required of a valid input of this decoder.
fn shape(&self) -> D;
}

impl<'d, D, Dm> Decoder<Dm> for &'d D
where
D: Decoder<Dm>,
{
type Err = D::Err;
type Value = D::Value;

fn decode(&self, tensor: ArrayView<f32, Dm>) -> Result<Self::Value, Self::Err> {
<D as Decoder<Dm>>::decode(self, tensor)
}

fn shape(&self) -> Dm {
<D as Decoder<Dm>>::shape(self)
}
}

pub struct MaxIndexDecoder<S> {
index: Vec<S>,
}

impl<S> MaxIndexDecoder<S> {
/// # Panics
///
/// If `index` is empty.
pub fn from_vec(index: Vec<S>) -> Self {
assert!(
!index.is_empty(),
"passed `index` to `from_values` must not be empty"
);
Self { index }
}
}

impl<S> Decoder<Ix1> for MaxIndexDecoder<S>
where
S: Clone,
{
type Err = Infallible;
type Value = Option<S>;

fn decode(&self, tensor: ArrayView<f32, Ix1>) -> Result<Self::Value, Self::Err> {
let (idx, by) = tensor
.iter()
.enumerate()
.max_by(|(_, l), (_, r)| l.total_cmp(r))
.unwrap();
if *by > (1. / tensor.len() as f32) {
let value = self.index.get(idx).unwrap().clone();
Ok(Some(value))
} else {
Ok(None)
}
}

fn shape(&self) -> Ix1 {
Ix1(self.index.len())
}
}

#[cfg(test)]
pub mod tests {
use super::Decoder;
use super::MaxIndexDecoder;

use ndarray::{Array, Ix1};

#[test]
fn decoder_max_index() {
let decoder = MaxIndexDecoder::from_vec((0..10).collect());

for idx in 0..10 {
let mut input = Array::zeros(Ix1(10));
*input.get_mut(idx).unwrap() = 1.;
let output = decoder.decode(input.view()).unwrap();
assert_eq!(output, Some(idx));
}
}
}
Loading

0 comments on commit 0e9e05e

Please sign in to comment.