Skip to content

Commit

Permalink
Update pytorch.js (#1061)
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Nov 5, 2024
1 parent d27bc43 commit 8fc8400
Showing 1 changed file with 50 additions and 41 deletions.
91 changes: 50 additions & 41 deletions source/pytorch.js
Original file line number Diff line number Diff line change
Expand Up @@ -1899,6 +1899,7 @@ pytorch.Execution = class extends python.Execution {
if (expression.target.type === 'id' && expression.target.value === 'uninitialized') {
const type = this.type(expression.args[0]);
const node = this._graph.create('prim::Uninitialized');
node.setSourceRange(expression.location);
this.graph.insertNode(node);
const value = node.addOutput();
value.setType(type);
Expand Down Expand Up @@ -2159,9 +2160,11 @@ pytorch.Execution = class extends python.Execution {
const input = node.inputs()[0].node();
if (input.kind() === 'prim::TupleConstruct') {
const value = input.inputs()[index];
const node = value.node();
if (node.kind() === 'prim::Constant') {
return pytorch.Utility.constant(node, 'value');
const constant = value.node();
if (constant.kind() === 'prim::Constant') {
state.push(node);
state.push(constant);
return pytorch.Utility.constant(constant, 'value');
}
}
}
Expand Down Expand Up @@ -2325,8 +2328,8 @@ pytorch.Execution = class extends python.Execution {
block(statements, context) {
const torch = this.torch;
statements = Array.prototype.slice.call(statements);
while (statements.length > 0) {
if (statements.length > 1) {
for (let i = 0; i < statements.length;) {
if (i < statements.length - 1) {
const containsVariableReference = (statements, value) => {
if (statements) {
for (const statement of statements) {
Expand All @@ -2340,30 +2343,45 @@ pytorch.Execution = class extends python.Execution {
}
return false;
};
const [assign, condition] = statements;
const assign = statements[i];
const condition = statements[i + 1];
// _x = <expr>
// if _x:
// ...
if (assign.type === '=' && condition.type === 'if' &&
assign.target.type === 'id' && condition.test.type === 'id' &&
assign.target.value === condition.test.value &&
!containsVariableReference(statements.slice(2), condition.test.value) &&
(!statements[1].body || !containsVariableReference(statements[1].body.statements), condition.test.value) &&
(!statements[1].orelse || !containsVariableReference(statements[1].orelse.statements, condition.test.value))) {
statements.shift();
statements[0] = {
!containsVariableReference(statements.slice(i + 2), condition.test.value) &&
(!condition.body || !containsVariableReference(condition.body.statements), condition.test.value) &&
(!condition.orelse || !containsVariableReference(condition.orelse.statements, condition.test.value))) {
statements.splice(i, 2, {
location: condition.location,
type: 'if',
test: assign.expression,
body: condition.body,
orelse: condition.orelse,
location: condition.location,
};
});
}
}
const [condition] = statements;
const condition = statements[i];
if (condition.type === 'if') {
const state = [];
let test = this.static(condition.test, context, state);
if (test === null) {
test = false;
} else if (typeof test === 'boolean') {
test = test === true;
} else if (Number.isInteger(test)) {
test = test !== 0;
} else if (typeof test === 'string') {
test = test && test.length > 0;
}
if (test === true) {
statements.splice(i, 1, ...condition.body.statements);
} else if (test === false) {
statements.splice(i, 1, ...condition.orelse.statements);
}

const count = new Map();
for (const node of state) {
if (count.has(node)) {
Expand All @@ -2377,33 +2395,20 @@ pytorch.Execution = class extends python.Execution {
node.destroy();
}
}
if (test === null) {
test = false;
} else if (typeof test === 'boolean') {
test = test === true;
} else if (Number.isInteger(test)) {
test = test !== 0;
} else if (typeof test === 'string') {
test = test && test.length > 0;
}
if (test === true) {
statements.shift();
statements = condition.body.statements.concat(statements);
continue;
}
if (test === false) {
statements.shift();
statements = condition.orelse.statements.concat(statements);

if (test === true || test === false) {
continue;
}
}
if (statements.length > 0) {
const statement = statements.shift();
if (i < statements.length) {
const statement = statements[i];
if (statement.type === 'if') {
const test = this.expression(statement.test, context);
const condition = statement;
const test = this.expression(condition.test, context);
if (test instanceof torch.Value && test.type() instanceof torch.BoolType) {
const refs = new Set();
for (const statement of statements) {
for (let j = i + 1; j < statements.length; j++) {
const statement = statements[j];
if (!statement.refs) {
this.variables(statement, statement);
}
Expand Down Expand Up @@ -2446,23 +2451,23 @@ pytorch.Execution = class extends python.Execution {
}
return value.type();
};
this.variables(statement, statement);
this.variables(condition, condition);
const node = this._graph.create('prim::If');
node.setSourceRange(statement.location);
this.graph.insertNode(node);
node.addInput(test);
const prev = this._graph.insertPoint();
const true_block = node.addBlock();
this._graph.setInsertPoint(true_block);
let vars = __variables(statement.body.statements.concat(statement.orelse.statements));
let vars = __variables(condition.body.statements.concat(statement.orelse.statements));
vars = new Map(Array.from(vars).map((name) => [name, {}]));
this.block(statement.body.statements, context);
this.block(condition.body.statements, context);
for (const [name, entry] of vars) {
entry.body = context.get(name);
}
const false_block = node.addBlock();
this._graph.setInsertPoint(false_block);
this.block(statement.orelse.statements, context);
this.block(condition.orelse.statements, context);
for (const [name, entry] of vars) {
entry.orelse = context.get(name);
}
Expand Down Expand Up @@ -2502,6 +2507,7 @@ pytorch.Execution = class extends python.Execution {
}
value.setType(type);
}
i++;
continue;
}
throw new pytorch.Error("Unsupported condition.");
Expand All @@ -2510,6 +2516,7 @@ pytorch.Execution = class extends python.Execution {
if (value !== undefined) {
return value;
}
i++;
}
}
return undefined;
Expand Down Expand Up @@ -2601,6 +2608,7 @@ pytorch.Execution = class extends python.Execution {
return super.call(target, name, args, context);
}
const torch = this.torch;
const builtins = this.builtins;
if (name === '__new__') {
const identifier = pytorch.Utility.target(target);
if (identifier) {
Expand Down Expand Up @@ -2773,7 +2781,7 @@ pytorch.Execution = class extends python.Execution {
} else {
const value = this.variable(v);
value.value = v;
if (!value.type() && v instanceof this.builtins.dict) {
if (!value.type() && v instanceof builtins.dict) {
value.setType(type);
}
input = value;
Expand Down Expand Up @@ -2964,6 +2972,7 @@ pytorch.Execution = class extends python.Execution {

isType(obj, type, N) {
const torch = this.torch;
const builtins = this.builtins;
switch (type.str()) {
case 'Tensor':
return !Array.isArray(obj) && (pytorch.Utility.isTensor(obj) || obj === null ||
Expand Down Expand Up @@ -3105,7 +3114,7 @@ pytorch.Execution = class extends python.Execution {
return true;
}
}
if (obj instanceof this.builtins.dict) {
if (obj instanceof builtins.dict) {
return true;
}
return false;
Expand Down

0 comments on commit 8fc8400

Please sign in to comment.