Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify lq.math implementations #677

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions larq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
utils,
)

utils._check_lce_version()

try:
from importlib import metadata # type: ignore
except ImportError:
Expand Down
5 changes: 3 additions & 2 deletions larq/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ def sign(x):
# Returns
A Tensor with same type as `x`.
"""
return tf.sign(tf.sign(x) + 0.1)
ones = tf.ones_like(x)
return tf.where(x >= 0, ones, -ones)
AdamHillier marked this conversation as resolved.
Show resolved Hide resolved


def heaviside(x):
Expand All @@ -41,4 +42,4 @@ def heaviside(x):
# Returns
A Tensor with same type as `x`.
"""
return tf.sign(tf.nn.relu(x))
return tf.where(x > 0, tf.ones_like(x), tf.zeros_like(x))
20 changes: 20 additions & 0 deletions larq/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
import warnings
from contextlib import contextmanager

from packaging import version
from tensorflow.keras.utils import get_custom_objects

try:
from importlib import metadata # type: ignore
except ImportError:
# Running on pre-3.8 Python; use importlib-metadata package
import importlib_metadata as metadata # type: ignore


def memory_as_readable_str(num_bits: int) -> str:
"""Generate a human-readable string for the memory size.
Expand Down Expand Up @@ -68,3 +76,15 @@ def patch_object(object, name, value):
setattr(object, name, value)
yield
setattr(object, name, old_value)


def _check_lce_version():
try:
lce_version = metadata.version("larq-compute-engine")
if version.parse(lce_version) < version.parse("0.6.1"):
warnings.warn(
f"larq-compute-engine={lce_version} is not supported by this version of larq.\n"
"Please upgrade larq-compute-engine to v0.6.1 or newer."
)
except metadata.PackageNotFoundError:
pass
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def readme():
license="Apache 2.0",
install_requires=[
"numpy >= 1.15.4, < 2.0",
"packaging>=19.2,<22.0",
"terminaltables>=3.1.0",
"dataclasses ; python_version<'3.7'",
"importlib-metadata >= 2.0, < 4.0 ; python_version<'3.8'",
Expand All @@ -31,7 +32,6 @@ def readme():
"black==21.6b0",
"flake8>=3.7.9,<3.10.0",
"isort==5.9.2",
"packaging>=19.2,<22.0",
"pytest>=5.2.4,<6.3.0",
"pytest-cov>=2.8.1,<2.13.0",
"pytest-xdist>=1.30,<2.4",
Expand Down