diff --git a/src/batch_normalization.js b/src/batch_normalization.js index 2e2e206..06f50d8 100644 --- a/src/batch_normalization.js +++ b/src/batch_normalization.js @@ -20,7 +20,7 @@ export function batchNormalization(input, mean, variance, {axis=1, scale, bias, // The output tensor has the same shape as the input tensor. let output = new Tensor(input.shape); const shape = new Array(input.rank).fill(1); - shape[axis] = -1; + shape[axis] = null; output = sub(input, reshape(mean, shape)); output = div(output, pow(add(reshape(variance, shape), new Scalar(epsilon)), new Scalar(0.5))); diff --git a/src/gru.js b/src/gru.js index 137b407..40cd627 100644 --- a/src/gru.js +++ b/src/gru.js @@ -73,7 +73,7 @@ export function gru(input, weight, recurrentWeight, steps, hiddenSize, cellInput, cellWeight[slot], cellRecurrentWeight[slot], cellHidden[slot], hiddenSize, {bias: cellBias[slot], recurrentBias: cellRecurrentBias[slot], resetAfter, layout, activations}), - [1, -1, hiddenSize]); + [1, null, hiddenSize]); cellOutput = (cellOutput ? concat([cellOutput, result], 0) : result); } @@ -81,7 +81,7 @@ export function gru(input, weight, recurrentWeight, steps, hiddenSize, hiddenState = cellOutput; if (returnSequence) { - cellOutput = reshape(cellOutput, [1, numDirections, -1, hiddenSize]); + cellOutput = reshape(cellOutput, [1, numDirections, null, hiddenSize]); sequence = (sequence ? concat([sequence, cellOutput], 0) : cellOutput); } diff --git a/src/reduce.js b/src/reduce.js index 5f8aa27..390fa47 100644 --- a/src/reduce.js +++ b/src/reduce.js @@ -1,7 +1,9 @@ 'use strict'; +import {pow} from './binary.js'; import {squeeze} from './squeeze.js'; -import {sizeOfShape, Tensor} from './lib/tensor.js'; +import {abs, exp, log} from './unary.js'; +import {sizeOfShape, Scalar, Tensor} from './lib/tensor.js'; import {validateReduceParams} from './lib/validate-input.js'; /** @@ -16,9 +18,6 @@ function reduce(input, reduceFunc, {keepDimensions = false, axes} = {}) { const outputShape = input.shape.slice(); for (let i = 0; i < inpAxes.length; ++i) { - if (inpAxes[i] === -1) { - inpAxes[i] = input.rank - 1; - } outputShape[inpAxes[i]] = 1; } @@ -123,3 +122,60 @@ export function reduceSum(input, options = {}) { return reduce(input, (previousValue, currentValue) => previousValue + currentValue, options); } + +/** + * Compute the sum of the square of all the input values along the axes. + * @param {Tensor} input + * @param {MLReduceOptions} options + * @return {Tensor} + */ +export function reduceSumSquare(input, options = {}) { + return reduceSum(pow(input, new Scalar(2)), options); +} + +/** + * Compute the L1 norm of all the input values along the axes. + * @param {Tensor} input + * @param {MLReduceOptions} options + * @return {Tensor} + */ +export function reduceL1(input, options = {}) { + return reduceSum(abs(input), options); +} + +/** + * Compute the L2 norm of all the input values along the axes. + * @param {Tensor} input + * @param {MLReduceOptions} options + * @return {Tensor} + */ +export function reduceL2(input, options = {}) { + const intermediateResult = reduceSumSquare(input, options); + if (intermediateResult.rank === 0) { + return new Tensor( + [], + [Math.pow(intermediateResult.getValueByIndex(0), 0.5)]); + } else { + return pow(intermediateResult, new Scalar(0.5)); + } +} + +/** + * Compute the log value of the sum of all the input values along the axes. + * @param {Tensor} input + * @param {MLReduceOptions} options + * @return {Tensor} + */ +export function reduceLogSum(input, options = {}) { + return log(reduceSum(input, options)); +} + +/** + * Compute the log value of the sum of the exponent of all the input values along the axes. + * @param {Tensor} input + * @param {MLReduceOptions} options + * @return {Tensor} + */ +export function reduceLogSumExp(input, options = {}) { + return log(reduceSum(exp(input), options)); +} diff --git a/src/reshape.js b/src/reshape.js index e1a428f..4cb0ffc 100644 --- a/src/reshape.js +++ b/src/reshape.js @@ -12,7 +12,7 @@ export function reshape(input, newShape) { let minusOneAxis; let elements = 1; for (let i = 0; i < newShape.length; ++i) { - if (newShape[i] === -1) { + if (newShape[i] === null) { minusOneAxis = i; } else if (newShape[i] > 0) { elements *= newShape[i]; diff --git a/src/slice.js b/src/slice.js index 90c3910..ecd0c26 100644 --- a/src/slice.js +++ b/src/slice.js @@ -20,7 +20,7 @@ export function slice(input, starts, sizes, {axes} = {}) { const axesLen = axes.length; const outputShape = input.shape.slice(); for (let i = 0; i < axesLen; ++i) { - const axis = axes[i] >= 0 ? axes[i] : axes[i] + rank; + const axis = axes[i]; const size = input.shape[axis]; const start = starts[i]; startsForAllAxes[axis] = start >= 0 ? start : start + size; diff --git a/src/split.js b/src/split.js index f2c9685..581965e 100644 --- a/src/split.js +++ b/src/split.js @@ -14,16 +14,14 @@ export function split(input, splits, {axis = 0} = {}) { validateSplitParams(...arguments); const outputs = []; let sliceSizes = []; - const rank = input.rank; - const inpAxis = axis >=0 ? axis : rank + axis; if (typeof splits === 'number') { - sliceSizes = new Array(splits).fill(input.shape[inpAxis] / splits); + sliceSizes = new Array(splits).fill(input.shape[axis] / splits); } else if (splits instanceof Array) { sliceSizes = splits.slice(); } let start = 0; for (const size of sliceSizes) { - outputs.push(slice(input, [start], [size], {axes: [inpAxis]})); + outputs.push(slice(input, [start], [size], {axes: [axis]})); start += size; } return outputs; diff --git a/test/reduce_test.js b/test/reduce_test.js index 1e9a734..27c8a87 100644 --- a/test/reduce_test.js +++ b/test/reduce_test.js @@ -58,15 +58,6 @@ describe('test reduce', function() { {shape: [3, 2], values: [100., 200., 300., 400., 500., 600.]}); }); - it('reduceMax negative axes do not keep dims', function() { - testReduce( - 'Max', {axes: [-1], keepDimensions: false}, { - shape: [3, 2, 2], - values: [1., 100., 200., 2., 300., 3., 4., 400., 500., 5., 600., 6.], - }, - {shape: [3, 2], values: [100., 200., 300., 400., 500., 600.]}); - }); - it('reduceMax axes0 keep dims', function() { testReduce( 'Max', {axes: [0], keepDimensions: true}, { @@ -94,15 +85,6 @@ describe('test reduce', function() { {shape: [3, 2, 1], values: [100., 200., 300., 400., 500., 600.]}); }); - it('reduceMax negative axes keep dims', function() { - testReduce( - 'Max', {axes: [-1], keepDimensions: true}, { - shape: [3, 2, 2], - values: [1., 100., 200., 2., 300., 3., 4., 400., 500., 5., 600., 6.], - }, - {shape: [3, 2, 1], values: [100., 200., 300., 400., 500., 600.]}); - }); - it('reduceMean default', function() { testReduce( 'Mean', {}, { @@ -148,15 +130,6 @@ describe('test reduce', function() { {shape: [3, 2], values: [3., 11., 15.5, 21., 28., 31.]}); }); - it('reduceMean negative axes do not keep dims', function() { - testReduce( - 'Mean', {axes: [-1], keepDimensions: false}, { - shape: [3, 2, 2], - values: [5., 1., 20., 2., 30., 1., 40., 2., 55., 1., 60., 2.], - }, - {shape: [3, 2], values: [3., 11., 15.5, 21., 28., 31.]}); - }); - it('reduceMean axes0 keep dims', function() { testReduce( 'Mean', {axes: [0], keepDimensions: true}, { @@ -184,15 +157,6 @@ describe('test reduce', function() { {shape: [3, 2, 1], values: [3., 11., 15.5, 21., 28., 31.]}); }); - it('reduceMean negative axes keep dims', function() { - testReduce( - 'Mean', {axes: [-1], keepDimensions: true}, { - shape: [3, 2, 2], - values: [5., 1., 20., 2., 30., 1., 40., 2., 55., 1., 60., 2.], - }, - {shape: [3, 2, 1], values: [3., 11., 15.5, 21., 28., 31.]}); - }); - it('reduceMin default', function() { testReduce( 'Min', {}, { @@ -238,15 +202,6 @@ describe('test reduce', function() { {shape: [3, 2], values: [1., 2., 3., 4., 5., 6.]}); }); - it('reduceMin negative axes do not keep dims', function() { - testReduce( - 'Min', {axes: [-1], keepDimensions: false}, { - shape: [3, 2, 2], - values: [1., 100., 200., 2., 300., 3., 4., 400., 500., 5., 600., 6.], - }, - {shape: [3, 2], values: [1., 2., 3., 4., 5., 6.]}); - }); - it('reduceMin axes0 keep dims', function() { testReduce( 'Min', {axes: [0], keepDimensions: true}, { @@ -274,15 +229,6 @@ describe('test reduce', function() { {shape: [3, 2, 1], values: [1., 2., 3., 4., 5., 6.]}); }); - it('reduceMin negative axes keep dims', function() { - testReduce( - 'Min', {axes: [-1], keepDimensions: true}, { - shape: [3, 2, 2], - values: [1., 100., 200., 2., 300., 3., 4., 400., 500., 5., 600., 6.], - }, - {shape: [3, 2, 1], values: [1., 2., 3., 4., 5., 6.]}); - }); - it('reduceProduct default', function() { testReduce( 'Product', {}, { @@ -337,18 +283,6 @@ describe('test reduce', function() { }); }); - it('reduceProduct negative axes do not keep dims', function() { - testReduce( - 'Product', {axes: [-1], keepDimensions: false}, { - shape: [3, 2, 2], - values: [0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11.], - }, - { - shape: [3, 2], - values: [0., 6., 20., 42., 72., 110.], - }); - }); - it('reduceProduct axes0 keep dims', function() { testReduce( 'Product', {axes: [0], keepDimensions: true}, { @@ -385,18 +319,6 @@ describe('test reduce', function() { }); }); - it('reduceProduct negative axes keep dims', function() { - testReduce( - 'Product', {axes: [-1], keepDimensions: true}, { - shape: [3, 2, 2], - values: [0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11.], - }, - { - shape: [3, 2, 1], - values: [0., 6., 20., 42., 72., 110.], - }); - }); - it('reduceSum default', function() { testReduce( 'Sum', {}, { @@ -451,18 +373,6 @@ describe('test reduce', function() { }); }); - it('reduceSum negative axes do not keep dims', function() { - testReduce( - 'Sum', {axes: [-1], keepDimensions: false}, { - shape: [3, 2, 2], - values: [0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11.], - }, - { - shape: [3, 2], - values: [1., 5., 9., 13., 17., 21.], - }); - }); - it('reduceSum axes0 keep dims', function() { testReduce( 'Sum', {axes: [0], keepDimensions: true}, { @@ -499,15 +409,727 @@ describe('test reduce', function() { }); }); - it('reduceSum negative axes keep dims', function() { + it('reduceSumSquare default', function() { testReduce( - 'Sum', {axes: [-1], keepDimensions: true}, { + 'SumSquare', {}, { shape: [3, 2, 2], - values: [0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11.], + values: [ + 0., 1., 2., 3., + 4., 5., 6., 7., + 8., 9., 10., 11., + ], + }, + {shape: [], values: [506]}); + }); + + it('reduceSumSquare default axes keep dims', function() { + testReduce( + 'SumSquare', {keepDimensions: true}, { + shape: [3, 2, 2], + values: [ + 0., 1., 2., 3., + 4., 5., 6., 7., + 8., 9., 10., 11., + ], + }, + {shape: [1, 1, 1], values: [506]}); + }); + + it('reduceSumSquare axes0 do not keep dims', function() { + testReduce( + 'SumSquare', {axes: [0], keepDimensions: false}, { + shape: [3, 2, 2], + values: [ + 0., 1., 2., 3., + 4., 5., 6., 7., + 8., 9., 10., 11., + ], + }, + { + shape: [2, 2], + values: [80, 107, 140, 179], + }); + }); + + it('reduceSumSquare axes1 do not keep dims', function() { + testReduce( + 'SumSquare', {axes: [1], keepDimensions: false}, { + shape: [3, 2, 2], + values: [ + 0., 1., 2., 3., + 4., 5., 6., 7., + 8., 9., 10., 11., + ], + }, + { + shape: [3, 2], + values: [4, 10, 52, 74, 164, 202], + }); + }); + + it('reduceSumSquare axes2 do not keep dims', function() { + testReduce( + 'SumSquare', {axes: [2], keepDimensions: false}, { + shape: [3, 2, 2], + values: [ + 0., 1., 2., 3., + 4., 5., 6., 7., + 8., 9., 10., 11., + ], + }, + { + shape: [3, 2], + values: [1, 13, 41, 85, 145, 221], + }); + }); + + it('reduceSumSquare axes0 keep dims', function() { + testReduce( + 'SumSquare', {axes: [0], keepDimensions: true}, { + shape: [3, 2, 2], + values: [ + 0., 1., 2., 3., + 4., 5., 6., 7., + 8., 9., 10., 11., + ], + }, + { + shape: [1, 2, 2], + values: [80, 107, 140, 179], + }); + }); + + it('reduceSumSquare axes1 keep dims', function() { + testReduce( + 'SumSquare', {axes: [1], keepDimensions: true}, { + shape: [3, 2, 2], + values: [ + 0., 1., 2., 3., + 4., 5., 6., 7., + 8., 9., 10., 11., + ], + }, + { + shape: [3, 1, 2], + values: [4, 10, 52, 74, 164, 202], + }); + }); + + it('reduceSumSquare axes2 keep dims', function() { + testReduce( + 'SumSquare', {axes: [2], keepDimensions: true}, { + shape: [3, 2, 2], + values: [ + 0., 1., 2., 3., + 4., 5., 6., 7., + 8., 9., 10., 11., + ], + }, + { + shape: [3, 2, 1], + values: [1, 13, 41, 85, 145, 221], + }); + }); + + it('reduceL1 default', function() { + testReduce( + 'L1', {}, { + shape: [3, 2, 2], + values: [ + 0., -1., 2., -3., + 4., -5., 6., -7., + 8., -9., 10., -11., + ], + }, + {shape: [], values: [66.]}); + }); + + it('reduceL1 default axes keep dims', function() { + testReduce( + 'L1', {keepDimensions: true}, { + shape: [3, 2, 2], + values: [ + 0., -1., 2., -3., + 4., -5., 6., -7., + 8., -9., 10., -11., + ], + }, + {shape: [1, 1, 1], values: [66.]}); + }); + + it('reduceL1 axes0 do not keep dims', function() { + testReduce( + 'L1', {axes: [0], keepDimensions: false}, { + shape: [3, 2, 2], + values: [ + 0., -1., 2., -3., + 4., -5., 6., -7., + 8., -9., 10., -11., + ], + }, + { + shape: [2, 2], + values: [12., 15., 18., 21.], + }); + }); + + it('reduceL1 axes1 do not keep dims', function() { + testReduce( + 'L1', {axes: [1], keepDimensions: false}, { + shape: [3, 2, 2], + values: [ + 0., -1., 2., -3., + 4., -5., 6., -7., + 8., -9., 10., -11., + ], + }, + { + shape: [3, 2], + values: [2., 4., 10., 12., 18., 20.], + }); + }); + + it('reduceL1 axes2 do not keep dims', function() { + testReduce( + 'L1', {axes: [2], keepDimensions: false}, { + shape: [3, 2, 2], + values: [ + 0., -1., 2., -3., + 4., -5., 6., -7., + 8., -9., 10., -11., + ], + }, + { + shape: [3, 2], + values: [1., 5., 9., 13., 17., 21.], + }); + }); + + it('reduceL1 axes0 keep dims', function() { + testReduce( + 'L1', {axes: [0], keepDimensions: true}, { + shape: [3, 2, 2], + values: [ + 0., -1., 2., -3., + 4., -5., 6., -7., + 8., -9., 10., -11., + ], + }, + { + shape: [1, 2, 2], + values: [12., 15., 18., 21.], + }); + }); + + it('reduceL1 axes1 keep dims', function() { + testReduce( + 'L1', {axes: [1], keepDimensions: true}, { + shape: [3, 2, 2], + values: [ + 0., -1., 2., -3., + 4., -5., 6., -7., + 8., -9., 10., -11., + ], + }, + { + shape: [3, 1, 2], + values: [2., 4., 10., 12., 18., 20.], + }); + }); + + it('reduceL1 axes2 keep dims', function() { + testReduce( + 'L1', {axes: [2], keepDimensions: true}, { + shape: [3, 2, 2], + values: [ + 0., -1., 2., -3., + 4., -5., 6., -7., + 8., -9., 10., -11., + ], }, { shape: [3, 2, 1], values: [1., 5., 9., 13., 17., 21.], }); }); + + it('reduceL2 default', function() { + testReduce( + 'L2', {}, { + shape: [3, 2, 2], + values: [ + 0., 1., 2., 3., + 4., 5., 6., 7., + 8., 9., 10., 11., + ], + }, + {shape: [], values: [22.494443758403985]}); + }); + + it('reduceL2 default axes keep dims', function() { + testReduce( + 'L2', {keepDimensions: true}, { + shape: [3, 2, 2], + values: [ + 0., 1., 2., 3., + 4., 5., 6., 7., + 8., 9., 10., 11., + ], + }, + {shape: [1, 1, 1], values: [22.494443758403985]}); + }); + + it('reduceL2 axes0 do not keep dims', function() { + testReduce( + 'L2', {axes: [0], keepDimensions: false}, { + shape: [3, 2, 2], + values: [ + 0., 1., 2., 3., + 4., 5., 6., 7., + 8., 9., 10., 11., + ], + }, + { + shape: [2, 2], + values: [ + 8.94427190999916, + 10.344080432788601, + 11.832159566199232, + 13.379088160259652, + ], + }); + }); + + it('reduceL2 axes1 do not keep dims', function() { + testReduce( + 'L2', {axes: [1], keepDimensions: false}, { + shape: [3, 2, 2], + values: [ + 0., 1., 2., 3., + 4., 5., 6., 7., + 8., 9., 10., 11., + ], + }, + { + shape: [3, 2], + values: [ + 2, + 3.1622776601683795, + 7.211102550927978, + 8.602325267042627, + 12.806248474865697, + 14.212670403551895, + ], + }); + }); + + it('reduceL2 axes2 do not keep dims', function() { + testReduce( + 'L2', {axes: [2], keepDimensions: false}, { + shape: [3, 2, 2], + values: [ + 0., 1., 2., 3., + 4., 5., 6., 7., + 8., 9., 10., 11., + ], + }, + { + shape: [3, 2], + values: [ + 1, + 3.605551275463989, + 6.4031242374328485, + 9.219544457292887, + 12.041594578792296, + 14.866068747318506, + ], + }); + }); + + it('reduceL2 axes0 keep dims', function() { + testReduce( + 'L2', {axes: [0], keepDimensions: true}, { + shape: [3, 2, 2], + values: [ + 0., 1., 2., 3., + 4., 5., 6., 7., + 8., 9., 10., 11., + ], + }, + { + shape: [1, 2, 2], + values: [ + 8.94427190999916, + 10.344080432788601, + 11.832159566199232, + 13.379088160259652, + ], + }); + }); + + it('reduceL2 axes1 keep dims', function() { + testReduce( + 'L2', {axes: [1], keepDimensions: true}, { + shape: [3, 2, 2], + values: [ + 0., 1., 2., 3., + 4., 5., 6., 7., + 8., 9., 10., 11., + ], + }, + { + shape: [3, 1, 2], + values: [ + 2, + 3.1622776601683795, + 7.211102550927978, + 8.602325267042627, + 12.806248474865697, + 14.212670403551895, + ], + }); + }); + + it('reduceL2 axes2 keep dims', function() { + testReduce( + 'L2', {axes: [2], keepDimensions: true}, { + shape: [3, 2, 2], + values: [ + 0., 1., 2., 3., + 4., 5., 6., 7., + 8., 9., 10., 11., + ], + }, + { + shape: [3, 2, 1], + values: [ + 1, + 3.605551275463989, + 6.4031242374328485, + 9.219544457292887, + 12.041594578792296, + 14.866068747318506, + ], + }); + }); + + it('reduceLogSum default', function() { + testReduce( + 'LogSum', {}, { + shape: [3, 2, 2], + values: [ + 0., 1., 2., 3., + 4., 5., 6., 7., + 8., 9., 10., 11., + ], + }, + {shape: [], values: [4.189654742026425]}); + }); + + it('reduceLogSum default axes keep dims', function() { + testReduce( + 'LogSum', {keepDimensions: true}, { + shape: [3, 2, 2], + values: [ + 0., 1., 2., 3., + 4., 5., 6., 7., + 8., 9., 10., 11., + ], + }, + {shape: [1, 1, 1], values: [4.189654742026425]}); + }); + + it('reduceLogSum axes0 do not keep dims', function() { + testReduce( + 'LogSum', {axes: [0], keepDimensions: false}, { + shape: [3, 2, 2], + values: [ + 0., 1., 2., 3., + 4., 5., 6., 7., + 8., 9., 10., 11., + ], + }, + { + shape: [2, 2], + values: [ + 2.4849066497880004, + 2.70805020110221, + 2.8903717578961645, + 3.044522437723423, + ], + }); + }); + + it('reduceLogSum axes1 do not keep dims', function() { + testReduce( + 'LogSum', {axes: [1], keepDimensions: false}, { + shape: [3, 2, 2], + values: [ + 0., 1., 2., 3., + 4., 5., 6., 7., + 8., 9., 10., 11., + ], + }, + { + shape: [3, 2], + values: [ + 0.6931471805599453, + 1.3862943611198906, + 2.302585092994046, + 2.4849066497880004, + 2.8903717578961645, + 2.995732273553991, + ], + }); + }); + + it('reduceLogSum axes2 do not keep dims', function() { + testReduce( + 'LogSum', {axes: [2], keepDimensions: false}, { + shape: [3, 2, 2], + values: [ + 0., 1., 2., 3., + 4., 5., 6., 7., + 8., 9., 10., 11., + ], + }, + { + shape: [3, 2], + values: [ + 0, + 1.6094379124341003, + 2.1972245773362196, + 2.5649493574615367, + 2.833213344056216, + 3.044522437723423, + ], + }); + }); + + it('reduceLogSum axes0 keep dims', function() { + testReduce( + 'LogSum', {axes: [0], keepDimensions: true}, { + shape: [3, 2, 2], + values: [ + 0., 1., 2., 3., + 4., 5., 6., 7., + 8., 9., 10., 11., + ], + }, + { + shape: [1, 2, 2], + values: [ + 2.4849066497880004, + 2.70805020110221, + 2.8903717578961645, + 3.044522437723423, + ], + }); + }); + + it('reduceLogSum axes1 keep dims', function() { + testReduce( + 'LogSum', {axes: [1], keepDimensions: true}, { + shape: [3, 2, 2], + values: [ + 0., 1., 2., 3., + 4., 5., 6., 7., + 8., 9., 10., 11., + ], + }, + { + shape: [3, 1, 2], + values: [ + 0.6931471805599453, + 1.3862943611198906, + 2.302585092994046, + 2.4849066497880004, + 2.8903717578961645, + 2.995732273553991, + ], + }); + }); + + it('reduceLogSum axes2 keep dims', function() { + testReduce( + 'LogSum', {axes: [2], keepDimensions: true}, { + shape: [3, 2, 2], + values: [ + 0., 1., 2., 3., + 4., 5., 6., 7., + 8., 9., 10., 11., + ], + }, + { + shape: [3, 2, 1], + values: [ + 0, + 1.6094379124341003, + 2.1972245773362196, + 2.5649493574615367, + 2.833213344056216, + 3.044522437723423, + ], + }); + }); + + it('reduceLogSumExp default', function() { + testReduce( + 'LogSumExp', {}, { + shape: [3, 2, 2], + values: [ + 0., 1., 2., 3., + 4., 5., 6., 7., + 8., 9., 10., 11., + ], + }, + {shape: [], values: [11.458669001155853]}); + }); + + it('reduceLogSumExp default axes keep dims', function() { + testReduce( + 'LogSumExp', {keepDimensions: true}, { + shape: [3, 2, 2], + values: [ + 0., 1., 2., 3., + 4., 5., 6., 7., + 8., 9., 10., 11., + ], + }, + {shape: [1, 1, 1], values: [11.458669001155853]}); + }); + + it('reduceLogSumExp axes0 do not keep dims', function() { + testReduce( + 'LogSumExp', {axes: [0], keepDimensions: false}, { + shape: [3, 2, 2], + values: [ + 0., 1., 2., 3., + 4., 5., 6., 7., + 8., 9., 10., 11., + ], + }, + { + shape: [2, 2], + values: [ + 8.018479302594658, + 9.018479302594658, + 10.018479302594658, + 11.018479302594658, + ], + }); + }); + + it('reduceLogSumExp axes1 do not keep dims', function() { + testReduce( + 'LogSumExp', {axes: [1], keepDimensions: false}, { + shape: [3, 2, 2], + values: [ + 0., 1., 2., 3., + 4., 5., 6., 7., + 8., 9., 10., 11., + ], + }, + { + shape: [3, 2], + values: [ + 2.1269280110429727, + 3.1269280110429722, + 6.126928011042972, + 7.126928011042972, + 10.126928011042972, + 11.126928011042972, + ], + }); + }); + + it('reduceLogSumExp axes2 do not keep dims', function() { + testReduce( + 'LogSumExp', {axes: [2], keepDimensions: false}, { + shape: [3, 2, 2], + values: [ + 0., 1., 2., 3., + 4., 5., 6., 7., + 8., 9., 10., 11., + ], + }, + { + shape: [3, 2], + values: [ + 1.3132616875182228, + 3.313261687518223, + 5.313261687518223, + 7.313261687518223, + 9.313261687518223, + 11.313261687518223, + ], + }); + }); + + it('reduceLogSumExp axes0 keep dims', function() { + testReduce( + 'LogSumExp', {axes: [0], keepDimensions: true}, { + shape: [3, 2, 2], + values: [ + 0., 1., 2., 3., + 4., 5., 6., 7., + 8., 9., 10., 11., + ], + }, + { + shape: [1, 2, 2], + values: [ + 8.018479302594658, + 9.018479302594658, + 10.018479302594658, + 11.018479302594658, + ], + }); + }); + + it('reduceLogSumExp axes1 keep dims', function() { + testReduce( + 'LogSumExp', {axes: [1], keepDimensions: true}, { + shape: [3, 2, 2], + values: [ + 0., 1., 2., 3., + 4., 5., 6., 7., + 8., 9., 10., 11., + ], + }, + { + shape: [3, 1, 2], + values: [ + 2.1269280110429727, + 3.1269280110429722, + 6.126928011042972, + 7.126928011042972, + 10.126928011042972, + 11.126928011042972, + ], + }); + }); + + it('reduceLogSumExp axes2 keep dims', function() { + testReduce( + 'LogSumExp', {axes: [2], keepDimensions: true}, { + shape: [3, 2, 2], + values: [ + 0., 1., 2., 3., + 4., 5., 6., 7., + 8., 9., 10., 11., + ], + }, + { + shape: [3, 2, 1], + values: [ + 1.3132616875182228, + 3.313261687518223, + 5.313261687518223, + 7.313261687518223, + 9.313261687518223, + 11.313261687518223, + ], + }); + }); }); diff --git a/test/reshape_test.js b/test/reshape_test.js index 640aea2..26a76ab 100644 --- a/test/reshape_test.js +++ b/test/reshape_test.js @@ -37,11 +37,11 @@ describe('test reshape', function() { testReshape([2, 3, 4], [24]); }); - it('reshape [2, 3, 4] to negative_dim [2, -1, 2]', function() { - testReshape([2, 3, 4], [2, -1, 2], [2, 6, 2]); + it('reshape [2, 3, 4] to [2, null, 2]', function() { + testReshape([2, 3, 4], [2, null, 2], [2, 6, 2]); }); - it('reshape [2, 3, 4] to negative_dim [-1, 2, 3, 4]', function() { - testReshape([2, 3, 4], [-1, 2, 3, 4], [1, 2, 3, 4]); + it('reshape [2, 3, 4] to [null, 2, 3, 4]', function() { + testReshape([2, 3, 4], [null, 2, 3, 4], [1, 2, 3, 4]); }); }); diff --git a/test/slice_test.js b/test/slice_test.js index 2d7b03f..619418a 100644 --- a/test/slice_test.js +++ b/test/slice_test.js @@ -123,43 +123,6 @@ describe('test slice', function() { inputShape, inputData, starts, sizes, axes, expectedShape, expected); }); - it('slice with negative axes', function() { - const inputShape = [3, 4, 5]; - const inputData = [ - 1.3165863e+00, 4.1239005e-02, 4.6697399e-01, -6.6145003e-02, - -3.7128052e-01, -1.0660021e+00, 7.5784922e-01, 3.5759725e-02, - 1.9211160e+00, -8.1603736e-01, 1.1800343e-01, -1.8293047e+00, - -2.1316205e-01, -3.6369815e-01, 6.4205879e-01, 7.1544610e-02, - 6.8498695e-01, 1.0001093e+00, -5.6261641e-01, -7.3343945e-01, - 1.6827687e+00, 1.2653192e+00, 5.8872145e-01, 3.1535852e-01, - 3.5038650e-01, 3.5865438e-01, -3.6469769e-01, -8.7751287e-01, - 2.7995768e-01, -1.6042528e+00, 8.6336482e-01, -1.7991974e+00, - -6.8652731e-01, 1.3729302e-03, -7.7775210e-01, 1.0199220e-01, - 4.2299256e-01, 1.1432177e-01, -5.0116669e-02, 1.5525131e+00, - -8.7060851e-01, 4.5739245e-01, 1.3543987e-01, -1.5927458e-02, - 9.1792661e-01, -4.5001405e-01, 1.9954188e-01, -5.1338053e-01, - -4.1026011e-01, -1.2718531e+00, 4.2538303e-01, -1.5449624e-01, - -3.4380481e-01, 7.8374326e-01, 1.7837452e+00, 9.6105379e-01, - -4.8783422e-01, -9.4987392e-01, -8.8750905e-01, -9.8019439e-01, - ]; - const starts = [0, 1]; - const sizes = [2, 4]; - const axes = [-3, -1]; - const expectedShape = [2, 4, 4]; - const expected = [ - 4.1239005e-02, 4.6697399e-01, -6.6145003e-02, -3.7128052e-01, - 7.5784922e-01, 3.5759725e-02, 1.9211160e+00, -8.1603736e-01, - -1.8293047e+00, -2.1316205e-01, -3.6369815e-01, 6.4205879e-01, - 6.8498695e-01, 1.0001093e+00, -5.6261641e-01, -7.3343945e-01, - 1.2653192e+00, 5.8872145e-01, 3.1535852e-01, 3.5038650e-01, - -3.6469769e-01, -8.7751287e-01, 2.7995768e-01, -1.6042528e+00, - -1.7991974e+00, -6.8652731e-01, 1.3729302e-03, -7.7775210e-01, - 4.2299256e-01, 1.1432177e-01, -5.0116669e-02, 1.5525131e+00, - ]; - testSlice( - inputShape, inputData, starts, sizes, axes, expectedShape, expected); - }); - it('slice with -1 sizes', function() { const inputShape = [3, 4, 5]; const inputData = [ diff --git a/test/split_test.js b/test/split_test.js index 738271e..6a6aaa8 100644 --- a/test/split_test.js +++ b/test/split_test.js @@ -35,7 +35,7 @@ describe('test split', function() { {shape: [2], value: [3, 4]}, {shape: [2], value: [5, 6]}, ], - 3, -1); + 3, 0); testSplit( [6], [1, 2, 3, 4, 5, 6], [{shape: [2], value: [1, 2]}, {shape: [4], value: [3, 4, 5, 6]}], @@ -43,7 +43,7 @@ describe('test split', function() { testSplit( [6], [1, 2, 3, 4, 5, 6], [{shape: [2], value: [1, 2]}, {shape: [4], value: [3, 4, 5, 6]}], - [2, 4], -1); + [2, 4], 0); }); it('split 2d', function() { @@ -54,13 +54,6 @@ describe('test split', function() { {shape: [2, 3], value: [4, 5, 6, 10, 11, 12]}, ], 2, 1); - testSplit( - [2, 6], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], - [ - {shape: [2, 3], value: [1, 2, 3, 7, 8, 9]}, - {shape: [2, 3], value: [4, 5, 6, 10, 11, 12]}, - ], - 2, -1); testSplit( [2, 6], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [ @@ -68,12 +61,5 @@ describe('test split', function() { {shape: [2, 4], value: [3, 4, 5, 6, 9, 10, 11, 12]}, ], [2, 4], 1); - testSplit( - [2, 6], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], - [ - {shape: [2, 2], value: [1, 2, 7, 8]}, - {shape: [2, 4], value: [3, 4, 5, 6, 9, 10, 11, 12]}, - ], - [2, 4], -1); }); });