From e9f927d0b9bfae10e3d80e1a69ab59f02f2902d6 Mon Sep 17 00:00:00 2001 From: BruceDai Date: Thu, 23 Mar 2023 17:03:18 +0800 Subject: [PATCH] Implement remaining reduction operations reduceL1 / reduceL2 reduceLogSum / reduceLogSumExp reduceSumSquare fixes #17 --- src/reduce.js | 61 +++- test/reduce_test.js | 724 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 784 insertions(+), 1 deletion(-) diff --git a/src/reduce.js b/src/reduce.js index 952dba1..390fa47 100644 --- a/src/reduce.js +++ b/src/reduce.js @@ -1,7 +1,9 @@ 'use strict'; +import {pow} from './binary.js'; import {squeeze} from './squeeze.js'; -import {sizeOfShape, Tensor} from './lib/tensor.js'; +import {abs, exp, log} from './unary.js'; +import {sizeOfShape, Scalar, Tensor} from './lib/tensor.js'; import {validateReduceParams} from './lib/validate-input.js'; /** @@ -120,3 +122,60 @@ export function reduceSum(input, options = {}) { return reduce(input, (previousValue, currentValue) => previousValue + currentValue, options); } + +/** + * Compute the sum of the square of all the input values along the axes. + * @param {Tensor} input + * @param {MLReduceOptions} options + * @return {Tensor} + */ +export function reduceSumSquare(input, options = {}) { + return reduceSum(pow(input, new Scalar(2)), options); +} + +/** + * Compute the L1 norm of all the input values along the axes. + * @param {Tensor} input + * @param {MLReduceOptions} options + * @return {Tensor} + */ +export function reduceL1(input, options = {}) { + return reduceSum(abs(input), options); +} + +/** + * Compute the L2 norm of all the input values along the axes. + * @param {Tensor} input + * @param {MLReduceOptions} options + * @return {Tensor} + */ +export function reduceL2(input, options = {}) { + const intermediateResult = reduceSumSquare(input, options); + if (intermediateResult.rank === 0) { + return new Tensor( + [], + [Math.pow(intermediateResult.getValueByIndex(0), 0.5)]); + } else { + return pow(intermediateResult, new Scalar(0.5)); + } +} + +/** + * Compute the log value of the sum of all the input values along the axes. + * @param {Tensor} input + * @param {MLReduceOptions} options + * @return {Tensor} + */ +export function reduceLogSum(input, options = {}) { + return log(reduceSum(input, options)); +} + +/** + * Compute the log value of the sum of the exponent of all the input values along the axes. + * @param {Tensor} input + * @param {MLReduceOptions} options + * @return {Tensor} + */ +export function reduceLogSumExp(input, options = {}) { + return log(reduceSum(exp(input), options)); +} diff --git a/test/reduce_test.js b/test/reduce_test.js index 3d794a2..27c8a87 100644 --- a/test/reduce_test.js +++ b/test/reduce_test.js @@ -408,4 +408,728 @@ describe('test reduce', function() { values: [1., 5., 9., 13., 17., 21.], }); }); + + it('reduceSumSquare default', function() { + testReduce( + 'SumSquare', {}, { + shape: [3, 2, 2], + values: [ + 0., 1., 2., 3., + 4., 5., 6., 7., + 8., 9., 10., 11., + ], + }, + {shape: [], values: [506]}); + }); + + it('reduceSumSquare default axes keep dims', function() { + testReduce( + 'SumSquare', {keepDimensions: true}, { + shape: [3, 2, 2], + values: [ + 0., 1., 2., 3., + 4., 5., 6., 7., + 8., 9., 10., 11., + ], + }, + {shape: [1, 1, 1], values: [506]}); + }); + + it('reduceSumSquare axes0 do not keep dims', function() { + testReduce( + 'SumSquare', {axes: [0], keepDimensions: false}, { + shape: [3, 2, 2], + values: [ + 0., 1., 2., 3., + 4., 5., 6., 7., + 8., 9., 10., 11., + ], + }, + { + shape: [2, 2], + values: [80, 107, 140, 179], + }); + }); + + it('reduceSumSquare axes1 do not keep dims', function() { + testReduce( + 'SumSquare', {axes: [1], keepDimensions: false}, { + shape: [3, 2, 2], + values: [ + 0., 1., 2., 3., + 4., 5., 6., 7., + 8., 9., 10., 11., + ], + }, + { + shape: [3, 2], + values: [4, 10, 52, 74, 164, 202], + }); + }); + + it('reduceSumSquare axes2 do not keep dims', function() { + testReduce( + 'SumSquare', {axes: [2], keepDimensions: false}, { + shape: [3, 2, 2], + values: [ + 0., 1., 2., 3., + 4., 5., 6., 7., + 8., 9., 10., 11., + ], + }, + { + shape: [3, 2], + values: [1, 13, 41, 85, 145, 221], + }); + }); + + it('reduceSumSquare axes0 keep dims', function() { + testReduce( + 'SumSquare', {axes: [0], keepDimensions: true}, { + shape: [3, 2, 2], + values: [ + 0., 1., 2., 3., + 4., 5., 6., 7., + 8., 9., 10., 11., + ], + }, + { + shape: [1, 2, 2], + values: [80, 107, 140, 179], + }); + }); + + it('reduceSumSquare axes1 keep dims', function() { + testReduce( + 'SumSquare', {axes: [1], keepDimensions: true}, { + shape: [3, 2, 2], + values: [ + 0., 1., 2., 3., + 4., 5., 6., 7., + 8., 9., 10., 11., + ], + }, + { + shape: [3, 1, 2], + values: [4, 10, 52, 74, 164, 202], + }); + }); + + it('reduceSumSquare axes2 keep dims', function() { + testReduce( + 'SumSquare', {axes: [2], keepDimensions: true}, { + shape: [3, 2, 2], + values: [ + 0., 1., 2., 3., + 4., 5., 6., 7., + 8., 9., 10., 11., + ], + }, + { + shape: [3, 2, 1], + values: [1, 13, 41, 85, 145, 221], + }); + }); + + it('reduceL1 default', function() { + testReduce( + 'L1', {}, { + shape: [3, 2, 2], + values: [ + 0., -1., 2., -3., + 4., -5., 6., -7., + 8., -9., 10., -11., + ], + }, + {shape: [], values: [66.]}); + }); + + it('reduceL1 default axes keep dims', function() { + testReduce( + 'L1', {keepDimensions: true}, { + shape: [3, 2, 2], + values: [ + 0., -1., 2., -3., + 4., -5., 6., -7., + 8., -9., 10., -11., + ], + }, + {shape: [1, 1, 1], values: [66.]}); + }); + + it('reduceL1 axes0 do not keep dims', function() { + testReduce( + 'L1', {axes: [0], keepDimensions: false}, { + shape: [3, 2, 2], + values: [ + 0., -1., 2., -3., + 4., -5., 6., -7., + 8., -9., 10., -11., + ], + }, + { + shape: [2, 2], + values: [12., 15., 18., 21.], + }); + }); + + it('reduceL1 axes1 do not keep dims', function() { + testReduce( + 'L1', {axes: [1], keepDimensions: false}, { + shape: [3, 2, 2], + values: [ + 0., -1., 2., -3., + 4., -5., 6., -7., + 8., -9., 10., -11., + ], + }, + { + shape: [3, 2], + values: [2., 4., 10., 12., 18., 20.], + }); + }); + + it('reduceL1 axes2 do not keep dims', function() { + testReduce( + 'L1', {axes: [2], keepDimensions: false}, { + shape: [3, 2, 2], + values: [ + 0., -1., 2., -3., + 4., -5., 6., -7., + 8., -9., 10., -11., + ], + }, + { + shape: [3, 2], + values: [1., 5., 9., 13., 17., 21.], + }); + }); + + it('reduceL1 axes0 keep dims', function() { + testReduce( + 'L1', {axes: [0], keepDimensions: true}, { + shape: [3, 2, 2], + values: [ + 0., -1., 2., -3., + 4., -5., 6., -7., + 8., -9., 10., -11., + ], + }, + { + shape: [1, 2, 2], + values: [12., 15., 18., 21.], + }); + }); + + it('reduceL1 axes1 keep dims', function() { + testReduce( + 'L1', {axes: [1], keepDimensions: true}, { + shape: [3, 2, 2], + values: [ + 0., -1., 2., -3., + 4., -5., 6., -7., + 8., -9., 10., -11., + ], + }, + { + shape: [3, 1, 2], + values: [2., 4., 10., 12., 18., 20.], + }); + }); + + it('reduceL1 axes2 keep dims', function() { + testReduce( + 'L1', {axes: [2], keepDimensions: true}, { + shape: [3, 2, 2], + values: [ + 0., -1., 2., -3., + 4., -5., 6., -7., + 8., -9., 10., -11., + ], + }, + { + shape: [3, 2, 1], + values: [1., 5., 9., 13., 17., 21.], + }); + }); + + it('reduceL2 default', function() { + testReduce( + 'L2', {}, { + shape: [3, 2, 2], + values: [ + 0., 1., 2., 3., + 4., 5., 6., 7., + 8., 9., 10., 11., + ], + }, + {shape: [], values: [22.494443758403985]}); + }); + + it('reduceL2 default axes keep dims', function() { + testReduce( + 'L2', {keepDimensions: true}, { + shape: [3, 2, 2], + values: [ + 0., 1., 2., 3., + 4., 5., 6., 7., + 8., 9., 10., 11., + ], + }, + {shape: [1, 1, 1], values: [22.494443758403985]}); + }); + + it('reduceL2 axes0 do not keep dims', function() { + testReduce( + 'L2', {axes: [0], keepDimensions: false}, { + shape: [3, 2, 2], + values: [ + 0., 1., 2., 3., + 4., 5., 6., 7., + 8., 9., 10., 11., + ], + }, + { + shape: [2, 2], + values: [ + 8.94427190999916, + 10.344080432788601, + 11.832159566199232, + 13.379088160259652, + ], + }); + }); + + it('reduceL2 axes1 do not keep dims', function() { + testReduce( + 'L2', {axes: [1], keepDimensions: false}, { + shape: [3, 2, 2], + values: [ + 0., 1., 2., 3., + 4., 5., 6., 7., + 8., 9., 10., 11., + ], + }, + { + shape: [3, 2], + values: [ + 2, + 3.1622776601683795, + 7.211102550927978, + 8.602325267042627, + 12.806248474865697, + 14.212670403551895, + ], + }); + }); + + it('reduceL2 axes2 do not keep dims', function() { + testReduce( + 'L2', {axes: [2], keepDimensions: false}, { + shape: [3, 2, 2], + values: [ + 0., 1., 2., 3., + 4., 5., 6., 7., + 8., 9., 10., 11., + ], + }, + { + shape: [3, 2], + values: [ + 1, + 3.605551275463989, + 6.4031242374328485, + 9.219544457292887, + 12.041594578792296, + 14.866068747318506, + ], + }); + }); + + it('reduceL2 axes0 keep dims', function() { + testReduce( + 'L2', {axes: [0], keepDimensions: true}, { + shape: [3, 2, 2], + values: [ + 0., 1., 2., 3., + 4., 5., 6., 7., + 8., 9., 10., 11., + ], + }, + { + shape: [1, 2, 2], + values: [ + 8.94427190999916, + 10.344080432788601, + 11.832159566199232, + 13.379088160259652, + ], + }); + }); + + it('reduceL2 axes1 keep dims', function() { + testReduce( + 'L2', {axes: [1], keepDimensions: true}, { + shape: [3, 2, 2], + values: [ + 0., 1., 2., 3., + 4., 5., 6., 7., + 8., 9., 10., 11., + ], + }, + { + shape: [3, 1, 2], + values: [ + 2, + 3.1622776601683795, + 7.211102550927978, + 8.602325267042627, + 12.806248474865697, + 14.212670403551895, + ], + }); + }); + + it('reduceL2 axes2 keep dims', function() { + testReduce( + 'L2', {axes: [2], keepDimensions: true}, { + shape: [3, 2, 2], + values: [ + 0., 1., 2., 3., + 4., 5., 6., 7., + 8., 9., 10., 11., + ], + }, + { + shape: [3, 2, 1], + values: [ + 1, + 3.605551275463989, + 6.4031242374328485, + 9.219544457292887, + 12.041594578792296, + 14.866068747318506, + ], + }); + }); + + it('reduceLogSum default', function() { + testReduce( + 'LogSum', {}, { + shape: [3, 2, 2], + values: [ + 0., 1., 2., 3., + 4., 5., 6., 7., + 8., 9., 10., 11., + ], + }, + {shape: [], values: [4.189654742026425]}); + }); + + it('reduceLogSum default axes keep dims', function() { + testReduce( + 'LogSum', {keepDimensions: true}, { + shape: [3, 2, 2], + values: [ + 0., 1., 2., 3., + 4., 5., 6., 7., + 8., 9., 10., 11., + ], + }, + {shape: [1, 1, 1], values: [4.189654742026425]}); + }); + + it('reduceLogSum axes0 do not keep dims', function() { + testReduce( + 'LogSum', {axes: [0], keepDimensions: false}, { + shape: [3, 2, 2], + values: [ + 0., 1., 2., 3., + 4., 5., 6., 7., + 8., 9., 10., 11., + ], + }, + { + shape: [2, 2], + values: [ + 2.4849066497880004, + 2.70805020110221, + 2.8903717578961645, + 3.044522437723423, + ], + }); + }); + + it('reduceLogSum axes1 do not keep dims', function() { + testReduce( + 'LogSum', {axes: [1], keepDimensions: false}, { + shape: [3, 2, 2], + values: [ + 0., 1., 2., 3., + 4., 5., 6., 7., + 8., 9., 10., 11., + ], + }, + { + shape: [3, 2], + values: [ + 0.6931471805599453, + 1.3862943611198906, + 2.302585092994046, + 2.4849066497880004, + 2.8903717578961645, + 2.995732273553991, + ], + }); + }); + + it('reduceLogSum axes2 do not keep dims', function() { + testReduce( + 'LogSum', {axes: [2], keepDimensions: false}, { + shape: [3, 2, 2], + values: [ + 0., 1., 2., 3., + 4., 5., 6., 7., + 8., 9., 10., 11., + ], + }, + { + shape: [3, 2], + values: [ + 0, + 1.6094379124341003, + 2.1972245773362196, + 2.5649493574615367, + 2.833213344056216, + 3.044522437723423, + ], + }); + }); + + it('reduceLogSum axes0 keep dims', function() { + testReduce( + 'LogSum', {axes: [0], keepDimensions: true}, { + shape: [3, 2, 2], + values: [ + 0., 1., 2., 3., + 4., 5., 6., 7., + 8., 9., 10., 11., + ], + }, + { + shape: [1, 2, 2], + values: [ + 2.4849066497880004, + 2.70805020110221, + 2.8903717578961645, + 3.044522437723423, + ], + }); + }); + + it('reduceLogSum axes1 keep dims', function() { + testReduce( + 'LogSum', {axes: [1], keepDimensions: true}, { + shape: [3, 2, 2], + values: [ + 0., 1., 2., 3., + 4., 5., 6., 7., + 8., 9., 10., 11., + ], + }, + { + shape: [3, 1, 2], + values: [ + 0.6931471805599453, + 1.3862943611198906, + 2.302585092994046, + 2.4849066497880004, + 2.8903717578961645, + 2.995732273553991, + ], + }); + }); + + it('reduceLogSum axes2 keep dims', function() { + testReduce( + 'LogSum', {axes: [2], keepDimensions: true}, { + shape: [3, 2, 2], + values: [ + 0., 1., 2., 3., + 4., 5., 6., 7., + 8., 9., 10., 11., + ], + }, + { + shape: [3, 2, 1], + values: [ + 0, + 1.6094379124341003, + 2.1972245773362196, + 2.5649493574615367, + 2.833213344056216, + 3.044522437723423, + ], + }); + }); + + it('reduceLogSumExp default', function() { + testReduce( + 'LogSumExp', {}, { + shape: [3, 2, 2], + values: [ + 0., 1., 2., 3., + 4., 5., 6., 7., + 8., 9., 10., 11., + ], + }, + {shape: [], values: [11.458669001155853]}); + }); + + it('reduceLogSumExp default axes keep dims', function() { + testReduce( + 'LogSumExp', {keepDimensions: true}, { + shape: [3, 2, 2], + values: [ + 0., 1., 2., 3., + 4., 5., 6., 7., + 8., 9., 10., 11., + ], + }, + {shape: [1, 1, 1], values: [11.458669001155853]}); + }); + + it('reduceLogSumExp axes0 do not keep dims', function() { + testReduce( + 'LogSumExp', {axes: [0], keepDimensions: false}, { + shape: [3, 2, 2], + values: [ + 0., 1., 2., 3., + 4., 5., 6., 7., + 8., 9., 10., 11., + ], + }, + { + shape: [2, 2], + values: [ + 8.018479302594658, + 9.018479302594658, + 10.018479302594658, + 11.018479302594658, + ], + }); + }); + + it('reduceLogSumExp axes1 do not keep dims', function() { + testReduce( + 'LogSumExp', {axes: [1], keepDimensions: false}, { + shape: [3, 2, 2], + values: [ + 0., 1., 2., 3., + 4., 5., 6., 7., + 8., 9., 10., 11., + ], + }, + { + shape: [3, 2], + values: [ + 2.1269280110429727, + 3.1269280110429722, + 6.126928011042972, + 7.126928011042972, + 10.126928011042972, + 11.126928011042972, + ], + }); + }); + + it('reduceLogSumExp axes2 do not keep dims', function() { + testReduce( + 'LogSumExp', {axes: [2], keepDimensions: false}, { + shape: [3, 2, 2], + values: [ + 0., 1., 2., 3., + 4., 5., 6., 7., + 8., 9., 10., 11., + ], + }, + { + shape: [3, 2], + values: [ + 1.3132616875182228, + 3.313261687518223, + 5.313261687518223, + 7.313261687518223, + 9.313261687518223, + 11.313261687518223, + ], + }); + }); + + it('reduceLogSumExp axes0 keep dims', function() { + testReduce( + 'LogSumExp', {axes: [0], keepDimensions: true}, { + shape: [3, 2, 2], + values: [ + 0., 1., 2., 3., + 4., 5., 6., 7., + 8., 9., 10., 11., + ], + }, + { + shape: [1, 2, 2], + values: [ + 8.018479302594658, + 9.018479302594658, + 10.018479302594658, + 11.018479302594658, + ], + }); + }); + + it('reduceLogSumExp axes1 keep dims', function() { + testReduce( + 'LogSumExp', {axes: [1], keepDimensions: true}, { + shape: [3, 2, 2], + values: [ + 0., 1., 2., 3., + 4., 5., 6., 7., + 8., 9., 10., 11., + ], + }, + { + shape: [3, 1, 2], + values: [ + 2.1269280110429727, + 3.1269280110429722, + 6.126928011042972, + 7.126928011042972, + 10.126928011042972, + 11.126928011042972, + ], + }); + }); + + it('reduceLogSumExp axes2 keep dims', function() { + testReduce( + 'LogSumExp', {axes: [2], keepDimensions: true}, { + shape: [3, 2, 2], + values: [ + 0., 1., 2., 3., + 4., 5., 6., 7., + 8., 9., 10., 11., + ], + }, + { + shape: [3, 2, 1], + values: [ + 1.3132616875182228, + 3.313261687518223, + 5.313261687518223, + 7.313261687518223, + 9.313261687518223, + 11.313261687518223, + ], + }); + }); });