Skip to content

Commit

Permalink
tensor + forward pass working
Browse files Browse the repository at this point in the history
  • Loading branch information
Shubhamai committed Apr 23, 2024
1 parent 7ed4a79 commit a3142af
Show file tree
Hide file tree
Showing 9 changed files with 213 additions and 30 deletions.
14 changes: 11 additions & 3 deletions src/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ pub enum BinaryOp {

#[derive(Debug, Clone, Copy, PartialEq)]
pub enum UnaryOp {
Negate,
Negate,
Not, // ! - logical not
}

Expand Down Expand Up @@ -368,8 +368,7 @@ fn expr_bp(lexer: &mut Lexer, min_bp: u8) -> ASTNode {

fn prefix_binding_power(op: Ops) -> ((), u8) {
match op {
| Ops::UnaryOp(UnaryOp::Not)
| Ops::UnaryOp(UnaryOp::Negate) => ((), 15),
Ops::UnaryOp(UnaryOp::Not) | Ops::UnaryOp(UnaryOp::Negate) => ((), 15),
_ => panic!("bad op: {:?}", op),
}
}
Expand Down Expand Up @@ -474,6 +473,9 @@ mod tests {
"(/ (. (. x (relu (. a (b (+ 0 2))) (- 2 1))) (max 0)) 2)"
);

let s = expr("x.relu(a.sigmoid(0+2))");
assert_eq!(s, "(. x (relu (. a (sigmoid (+ 0 2)))))");

let s = expr("a == b");
assert_eq!(s, "(== a b)");

Expand Down Expand Up @@ -512,5 +514,11 @@ mod tests {

let s = parse("a += (c == 4);");
assert_eq!(s, "a = (+ a (== c 4))");

let s = parse("a -= 4;");
assert_eq!(s, "a = (- a 4)");

let s = parse("a *= (5 != 4);");
assert_eq!(s, "a = (* a (!= 5 4))");
}
}
3 changes: 3 additions & 0 deletions src/chunk.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ pub enum OpCode {
OpGetGlobal,
OpSetGlobal,
OpPower,

OpCall,
}

impl std::fmt::Display for OpCode {
Expand All @@ -47,6 +49,7 @@ impl std::fmt::Display for OpCode {
OpCode::OpGetGlobal => write!(f, "OP_GET_GLOBAL"),
OpCode::OpSetGlobal => write!(f, "OP_SET_GLOBAL"),
OpCode::OpPower => write!(f, "OP_POWER"),
OpCode::OpCall => write!(f, "OP_CALL"),
}
}
}
Expand Down
26 changes: 24 additions & 2 deletions src/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use crate::{
ast::{ASTNode, BinaryOp, Ops, PostfixOp, UnaryOp},
chunk::{Chunk, OpCode, VectorType},
interner::Interner,
tensor::Tensor,
value::ValueType,
};

Expand Down Expand Up @@ -33,7 +34,7 @@ impl Compiler {
ASTNode::Number(n) => {
self.chunk.write(VectorType::Code(OpCode::OpConstant));

let constant = self.chunk.add_constant(ValueType::Number(n));
let constant = self.chunk.add_constant(ValueType::Tensor(Tensor::new(n)));
self.chunk.write(VectorType::Constant(constant));
}
ASTNode::Boolean(b) => {
Expand Down Expand Up @@ -105,7 +106,14 @@ impl Compiler {
Ops::PostfixOp(PostfixOp::STAR_STAR) => {
self.chunk.write(VectorType::Code(OpCode::OpPower))
}
_ => println!("Invalid operator"), // TODO: handle this error
Ops::PostfixOp(PostfixOp::Call) => {
println!("Call");
self.chunk.write(VectorType::Code(OpCode::OpCall));
self.chunk
.write(VectorType::Constant(self.chunk.constants.len() - 1));
// TODO: need for testing for this - a.relu(c.relu()), a.relu().relu()
}
x => println!("Invalid operator {:?}", x),
}
}
ASTNode::Print(expr) => {
Expand Down Expand Up @@ -139,6 +147,20 @@ impl Compiler {
self.chunk.write(VectorType::Code(OpCode::OpSetGlobal));
self.chunk.write(VectorType::Constant(global));
}
ASTNode::Callee(iden, args) => {
println!("Callee");
let global = self
.chunk
.add_constant(ValueType::Identifier(self.interner.intern_string(iden)));
self.chunk.write(VectorType::Constant(global));

// for arg in args {
// self.visit(arg.clone());
// }

// self.chunk.write(VectorType::Code(OpCode::OpCall));
// self.chunk.write(VectorType::Constant(args.len()));
}
_ => println!("Invalid ASTNode"), // TODO: handle this error
}
}
Expand Down
5 changes: 3 additions & 2 deletions src/debug.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ impl Debug {
chunk::OpCode::OpConstant
| chunk::OpCode::OpDefineGlobal
| chunk::OpCode::OpGetGlobal
| chunk::OpCode::OpSetGlobal,
| chunk::OpCode::OpSetGlobal
| chunk::OpCode::OpCall,
) => {
let constant = self.chunk.code[offset + 1];
match constant {
Expand All @@ -69,7 +70,7 @@ impl Debug {
return offset + 2;
}
chunk::VectorType::Constant(_) => {
return offset +1;
return offset + 1;
}
}
}
Expand Down
22 changes: 17 additions & 5 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ mod compiler;
mod debug;
mod interner;
mod scanner;
mod tensor;
mod value;
mod vm;

Expand Down Expand Up @@ -63,7 +64,9 @@ fn run_repl() {
let mut lexer = Lexer::new(input.to_string());

let out = Parser::new(&mut lexer).parse();
println!("{:?}", out);
for stmt in out.iter() {
println!("{};", stmt);
}

let mut compiler = compiler::Compiler::new();
let (bytecode, interner) = compiler.compile(out);
Expand All @@ -85,7 +88,10 @@ pub fn run_source(src: &str) -> InterpretResult {
let mut lexer = Lexer::new(src.to_string());

let out = Parser::new(&mut lexer).parse();
println!("{:?}", out);
for stmt in out.iter() {
println!("{:?}", stmt);
}
println!("-------------");

let mut compiler = compiler::Compiler::new();
let (bytecode, interner) = compiler.compile(out);
Expand All @@ -104,7 +110,7 @@ pub fn run_source(src: &str) -> InterpretResult {

#[cfg(test)]
mod tests {
use crate::{run_source, value::ValueType, vm::InterpretResult};
use crate::{run_source, tensor::Tensor, value::ValueType, vm::InterpretResult};

#[test]
fn test_micrograd_example() {
Expand All @@ -115,14 +121,20 @@ mod tests {
let d = a * b + b**3;
c += c + 1;
c += 1 + c + (-a);
print(c == -1);
d += d * 2 + (b + a).relu();
d += 3 * d + (b - a).relu();
let e = c - d;
let f = e**2;
let g = f / 2.0;
g += 10.0 / f;
print(g) // prints 24.7041, the outcome of this forward pass
"#;

let out = run_source(&src);

assert_eq!(
out,
InterpretResult::InterpretOk(vec![ValueType::Boolean(true)])
InterpretResult::InterpretOk(vec![ValueType::Tensor(Tensor::new(24.70408163265306))])
);
}
}
91 changes: 91 additions & 0 deletions src/tensor.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
#[derive(Debug, Clone, Copy)]
pub struct Tensor {
data: f64,
pub grad: f64,
pub shape: [usize; 1],
}

// display tensor
impl std::fmt::Display for Tensor {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "{}, grad: {}", self.data, self.grad)
}
}

impl Tensor {
pub fn new(data: f64) -> Self {
Tensor {
data,
grad: 0.0,
shape: [0],
}
}

pub fn powf(&self, exp: Tensor) -> Self {
Tensor::new(self.data.powf(exp.data))
}

pub fn relu(&self) -> Self {
Tensor::new(self.data.max(0.0))
}

pub fn backward(&mut self) {
Tensor::new(self.grad);
}

pub fn grad(&mut self) {
self.grad = 1.;
}
}

impl std::ops::Add for Tensor {
type Output = Self;

fn add(self, other: Self) -> Self {
Tensor::new(self.data + other.data)
}
}

impl std::ops::Sub for Tensor {
type Output = Self;

fn sub(self, other: Self) -> Self {
Tensor::new(self.data - other.data)
}
}

impl std::ops::Mul for Tensor {
type Output = Self;

fn mul(self, other: Self) -> Self {
Tensor::new(self.data * other.data)
}
}

impl std::ops::Div for Tensor {
type Output = Self;

fn div(self, other: Self) -> Self {
Tensor::new(self.data / other.data)
}
}

impl std::ops::Neg for Tensor {
type Output = Self;

fn neg(self) -> Self {
Tensor::new(-self.data)
}
}

impl PartialEq for Tensor {
fn eq(&self, other: &Self) -> bool {
self.data == other.data
}
}

impl PartialOrd for Tensor {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
self.data.partial_cmp(&other.data)
}
}
21 changes: 10 additions & 11 deletions src/value.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,19 @@
use crate::interner::StringObjIdx;
use crate::{interner::StringObjIdx, tensor::Tensor};

#[derive(Debug, Clone, Copy)]
pub enum ValueType {
Number(f64), // TODO: Ideally, it should be seperate types for int and float (maybe?)
Tensor(Tensor), // TODO: Ideally, it should be seperate types for int and float (maybe?)
String(StringObjIdx),
Identifier(StringObjIdx),
Boolean(bool),
Nil,
// Lists, Dicts, Tensors, etc.
}


impl std::fmt::Display for ValueType {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
ValueType::Number(n) => write!(f, "num->{}", n),
ValueType::Tensor(n) => write!(f, "num->{}", n),
ValueType::String(s) => write!(f, "str->{}", s),
ValueType::Identifier(s) => write!(f, "iden->{}", s),
ValueType::Boolean(b) => write!(f, "bool->{}", b),
Expand All @@ -29,7 +28,7 @@ impl std::ops::Add for ValueType {

fn add(self, other: Self) -> Self {
match (self, other) {
(ValueType::Number(a), ValueType::Number(b)) => ValueType::Number(a + b),
(ValueType::Tensor(a), ValueType::Tensor(b)) => ValueType::Tensor(a + b),
(ValueType::String(a), ValueType::String(b)) => ValueType::String(a + b),
_ => panic!("Operands must be numbers."),
}
Expand All @@ -41,7 +40,7 @@ impl std::ops::Sub for ValueType {

fn sub(self, other: Self) -> Self {
match (self, other) {
(ValueType::Number(a), ValueType::Number(b)) => ValueType::Number(a - b),
(ValueType::Tensor(a), ValueType::Tensor(b)) => ValueType::Tensor(a - b),
_ => panic!("Operands must be numbers."),
}
}
Expand All @@ -52,7 +51,7 @@ impl std::ops::Mul for ValueType {

fn mul(self, other: Self) -> Self {
match (self, other) {
(ValueType::Number(a), ValueType::Number(b)) => ValueType::Number(a * b),
(ValueType::Tensor(a), ValueType::Tensor(b)) => ValueType::Tensor(a * b),
_ => panic!("Operands must be numbers."),
}
}
Expand All @@ -63,7 +62,7 @@ impl std::ops::Div for ValueType {

fn div(self, other: Self) -> Self {
match (self, other) {
(ValueType::Number(a), ValueType::Number(b)) => ValueType::Number(a / b),
(ValueType::Tensor(a), ValueType::Tensor(b)) => ValueType::Tensor(a / b),
_ => panic!("Operands must be numbers."),
}
}
Expand All @@ -74,7 +73,7 @@ impl std::ops::Neg for ValueType {

fn neg(self) -> Self {
match self {
ValueType::Number(n) => ValueType::Number(-n),
ValueType::Tensor(n) => ValueType::Tensor(-n),
_ => panic!("Operand must be a number."),
}
}
Expand All @@ -95,7 +94,7 @@ impl std::ops::Not for ValueType {
impl std::cmp::PartialEq for ValueType {
fn eq(&self, other: &Self) -> bool {
match (self, other) {
(ValueType::Number(a), ValueType::Number(b)) => a == b,
(ValueType::Tensor(a), ValueType::Tensor(b)) => a == b,
(ValueType::Boolean(a), ValueType::Boolean(b)) => a == b,
(ValueType::Nil, ValueType::Nil) => true,
_ => false,
Expand All @@ -106,7 +105,7 @@ impl std::cmp::PartialEq for ValueType {
impl std::cmp::PartialOrd for ValueType {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
match (self, other) {
(ValueType::Number(a), ValueType::Number(b)) => a.partial_cmp(b),
(ValueType::Tensor(a), ValueType::Tensor(b)) => a.partial_cmp(b),
_ => None,
}
}
Expand Down
Loading

0 comments on commit a3142af

Please sign in to comment.