From 05b2fd885bc502a7908a4596ac164998b18dbace Mon Sep 17 00:00:00 2001 From: Daniel Smith Date: Sat, 30 Jul 2016 17:29:31 -0400 Subject: [PATCH] More TDOT conversion --- test/test_blas.py | 28 ++++++++++++++++++---------- test/test_input.py | 28 +++++++++++++++++++++++++--- 2 files changed, 43 insertions(+), 13 deletions(-) diff --git a/test/test_blas.py b/test/test_blas.py index c8e6589..9d7fd7d 100644 --- a/test/test_blas.py +++ b/test/test_blas.py @@ -33,14 +33,24 @@ ((['jli', 'kjl'], 'ki', set('jl')), 'GEMM'), # GEMM T T Tensor # Tensor Dot (requires copy), lets not deal with this for now - ((['ilj', 'jlk'], 'ik', set('jl')), False), # FT GEMM N N Tensor - ((['ijl', 'ljk'], 'ik', set('jl')), False), # ST GEMM N N Tensor - ((['ilj', 'kjl'], 'ik', set('jl')), False), # FT GEMM N T Tensor - ((['ijl', 'klj'], 'ik', set('jl')), False), # ST GEMM N T Tensor - ((['lji', 'jlk'], 'ik', set('jl')), False), # FT GEMM T N Tensor - ((['jli', 'ljk'], 'ik', set('jl')), False), # ST GEMM T N Tensor - ((['lji', 'jlk'], 'ik', set('jl')), False), # FT GEMM T N Tensor - ((['jli', 'ljk'], 'ik', set('jl')), False), # ST GEMM T N Tensor + ((['ilj', 'jlk'], 'ik', set('jl')), 'TDOT'), # FT GEMM N N Tensor + ((['ijl', 'ljk'], 'ik', set('jl')), 'TDOT'), # ST GEMM N N Tensor + ((['ilj', 'kjl'], 'ik', set('jl')), 'TDOT'), # FT GEMM N T Tensor + ((['ijl', 'klj'], 'ik', set('jl')), 'TDOT'), # ST GEMM N T Tensor + ((['lji', 'jlk'], 'ik', set('jl')), 'TDOT'), # FT GEMM T N Tensor + ((['jli', 'ljk'], 'ik', set('jl')), 'TDOT'), # ST GEMM T N Tensor + ((['lji', 'jlk'], 'ik', set('jl')), 'TDOT'), # FT GEMM T N Tensor + ((['jli', 'ljk'], 'ik', set('jl')), 'TDOT'), # ST GEMM T N Tensor + + # Tensor Dot (requires copy), lets not deal with this for now with transpose + ((['ilj', 'jlk'], 'ik', set('lj')), 'TDOT'), # FT GEMM N N Tensor + ((['ijl', 'ljk'], 'ik', set('lj')), 'TDOT'), # ST GEMM N N Tensor + ((['ilj', 'kjl'], 'ik', set('lj')), 'TDOT'), # FT GEMM N T Tensor + ((['ijl', 'klj'], 'ik', set('lj')), 'TDOT'), # ST GEMM N T Tensor + ((['lji', 'jlk'], 'ik', set('lj')), 'TDOT'), # FT GEMM T N Tensor + ((['jli', 'ljk'], 'ik', set('lj')), 'TDOT'), # ST GEMM T N Tensor + ((['lji', 'jlk'], 'ik', set('lj')), 'TDOT'), # FT GEMM T N Tensor + ((['jli', 'ljk'], 'ik', set('lj')), 'TDOT'), # ST GEMM T N Tensor # Other ((['ijk', 'ikj'], '', set('ijk')), False), # Transpose DOT @@ -79,6 +89,4 @@ def test_tensor_blas(inp, benchmark): view_right, tensor_strs[1], output, reduced_idx) - print(einsum_result) - print(blas_result) assert np.allclose(einsum_result, blas_result) diff --git a/test/test_input.py b/test/test_input.py index f658d41..2ea9389 100644 --- a/test/test_input.py +++ b/test/test_input.py @@ -48,6 +48,31 @@ def test_type_errors(): with pytest.raises(TypeError): contract(*(None,)*63) + # Cannot have two -> + with pytest.raises(ValueError): + contract("->,->", 0, 5) + + # Undefined symbol lhs + with pytest.raises(ValueError): + contract("&,a->", 0, 5) + + # Undefined symbol rhs + with pytest.raises(ValueError): + contract("a,a->&", 0, 5) + + with pytest.raises(ValueError): + contract("a,a->&", 0, 5) + + # Catch ellipsis errors + string = '...a->...a' + views = build_views(string) + + with pytest.raises(TypeError): + contract(views[0], [Ellipsis, 'a'], [Ellipsis, 0]) + + with pytest.raises(TypeError): + contract(views[0], [Ellipsis, 0], [Ellipsis, 'a']) + def test_value_errors(): with pytest.raises(ValueError): @@ -176,9 +201,6 @@ def test_ellipse_input4(): views = build_views(string) ein = contract(string, *views, optimize=False) - print(ein) - print('----------') opt = contract(views[0], [Ellipsis, 1], views[1], [Ellipsis, 0], [Ellipsis]) - print(opt) assert np.allclose(ein, opt)