Skip to content

Commit

Permalink
More TDOT conversion
Browse files Browse the repository at this point in the history
  • Loading branch information
dgasmith committed Jul 30, 2016
1 parent 2b55765 commit 05b2fd8
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 13 deletions.
28 changes: 18 additions & 10 deletions test/test_blas.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,24 @@
((['jli', 'kjl'], 'ki', set('jl')), 'GEMM'), # GEMM T T Tensor

# Tensor Dot (requires copy), lets not deal with this for now
((['ilj', 'jlk'], 'ik', set('jl')), False), # FT GEMM N N Tensor
((['ijl', 'ljk'], 'ik', set('jl')), False), # ST GEMM N N Tensor
((['ilj', 'kjl'], 'ik', set('jl')), False), # FT GEMM N T Tensor
((['ijl', 'klj'], 'ik', set('jl')), False), # ST GEMM N T Tensor
((['lji', 'jlk'], 'ik', set('jl')), False), # FT GEMM T N Tensor
((['jli', 'ljk'], 'ik', set('jl')), False), # ST GEMM T N Tensor
((['lji', 'jlk'], 'ik', set('jl')), False), # FT GEMM T N Tensor
((['jli', 'ljk'], 'ik', set('jl')), False), # ST GEMM T N Tensor
((['ilj', 'jlk'], 'ik', set('jl')), 'TDOT'), # FT GEMM N N Tensor
((['ijl', 'ljk'], 'ik', set('jl')), 'TDOT'), # ST GEMM N N Tensor
((['ilj', 'kjl'], 'ik', set('jl')), 'TDOT'), # FT GEMM N T Tensor
((['ijl', 'klj'], 'ik', set('jl')), 'TDOT'), # ST GEMM N T Tensor
((['lji', 'jlk'], 'ik', set('jl')), 'TDOT'), # FT GEMM T N Tensor
((['jli', 'ljk'], 'ik', set('jl')), 'TDOT'), # ST GEMM T N Tensor
((['lji', 'jlk'], 'ik', set('jl')), 'TDOT'), # FT GEMM T N Tensor
((['jli', 'ljk'], 'ik', set('jl')), 'TDOT'), # ST GEMM T N Tensor

# Tensor Dot (requires copy), lets not deal with this for now with transpose
((['ilj', 'jlk'], 'ik', set('lj')), 'TDOT'), # FT GEMM N N Tensor
((['ijl', 'ljk'], 'ik', set('lj')), 'TDOT'), # ST GEMM N N Tensor
((['ilj', 'kjl'], 'ik', set('lj')), 'TDOT'), # FT GEMM N T Tensor
((['ijl', 'klj'], 'ik', set('lj')), 'TDOT'), # ST GEMM N T Tensor
((['lji', 'jlk'], 'ik', set('lj')), 'TDOT'), # FT GEMM T N Tensor
((['jli', 'ljk'], 'ik', set('lj')), 'TDOT'), # ST GEMM T N Tensor
((['lji', 'jlk'], 'ik', set('lj')), 'TDOT'), # FT GEMM T N Tensor
((['jli', 'ljk'], 'ik', set('lj')), 'TDOT'), # ST GEMM T N Tensor

# Other
((['ijk', 'ikj'], '', set('ijk')), False), # Transpose DOT
Expand Down Expand Up @@ -79,6 +89,4 @@ def test_tensor_blas(inp, benchmark):
view_right, tensor_strs[1],
output, reduced_idx)

print(einsum_result)
print(blas_result)
assert np.allclose(einsum_result, blas_result)
28 changes: 25 additions & 3 deletions test/test_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,31 @@ def test_type_errors():
with pytest.raises(TypeError):
contract(*(None,)*63)

# Cannot have two ->
with pytest.raises(ValueError):
contract("->,->", 0, 5)

# Undefined symbol lhs
with pytest.raises(ValueError):
contract("&,a->", 0, 5)

# Undefined symbol rhs
with pytest.raises(ValueError):
contract("a,a->&", 0, 5)

with pytest.raises(ValueError):
contract("a,a->&", 0, 5)

# Catch ellipsis errors
string = '...a->...a'
views = build_views(string)

with pytest.raises(TypeError):
contract(views[0], [Ellipsis, 'a'], [Ellipsis, 0])

with pytest.raises(TypeError):
contract(views[0], [Ellipsis, 0], [Ellipsis, 'a'])


def test_value_errors():
with pytest.raises(ValueError):
Expand Down Expand Up @@ -176,9 +201,6 @@ def test_ellipse_input4():
views = build_views(string)

ein = contract(string, *views, optimize=False)
print(ein)
print('----------')
opt = contract(views[0], [Ellipsis, 1], views[1], [Ellipsis, 0], [Ellipsis])
print(opt)
assert np.allclose(ein, opt)

0 comments on commit 05b2fd8

Please sign in to comment.