Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify lq.math implementations #677

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open

Simplify lq.math implementations #677

wants to merge 2 commits into from

Conversation

lgeiger
Copy link
Member

@lgeiger lgeiger commented Jun 4, 2021

This PR simplifies the implementations of lq.math to make them more readable. Currently this would break Compute Engine, since we rely on the implementation details there. We could either fix this using @tf.function(experimental_implements=...) which would require us dropping support for older TensorFlow version, or we could update the patterns in Compute Engine to handle this implementation properly which shouldn't be hard.

I briefly evaluated the performance impact of this change on a T4 GPU in this notebook (although only on one input size). This shows performance improvements for some cases and performance regressions for others, although I am not sure if this would be noticeable in a real world model. I am happy to run benchmarks if you think this change might impact full model training performance.

This PR is still WIP for now to explore this further.

@lgeiger lgeiger added the breaking-change Changes that will break user code label Jun 4, 2021
@lgeiger lgeiger requested a review from AdamHillier June 4, 2021 12:51
@Tombana
Copy link
Contributor

Tombana commented Jun 7, 2021

We could either fix this using @tf.function(experimental_implements=...) which would require us dropping support for older TensorFlow version, or we could update the patterns in Compute Engine ..

Minor comment: @tf.function(experimental_implements=...) doesn't make it convert, it only adds a tag that the converter can read, you'd still need to add the pattern to compute engine, so you might as well make it a proper pattern.

@lgeiger
Copy link
Member Author

lgeiger commented Jun 7, 2021

Minor comment: @tf.function(experimental_implements=...) doesn't make it convert, it only adds a tag that the converter can read, you'd still need to add the pattern to compute engine, so you might as well make it a proper pattern.

Yes, you are right. The idea behind it was that we then would be implementation independent, but I agree a proper pattern is much easier and we can still keep it backwards compatible.

@AdamHillier
Copy link
Contributor

I agree that matching directly against the pattern in LCE would be good. If it doesn't cause any issues, it might be nice to still keep the experimental_implements so that if we want to change the converter implementation to use it at some point in the future then that 'tag' will already be there in existing versions of Larq.

How do you think we should manage the scenario where somebody uses a new version of Larq but an old version of the converter? I guess on init we could try and detect if LCE is installed, detect the version, and print some kind of warning message if it's <= 0.6?

@lgeiger
Copy link
Member Author

lgeiger commented Jun 17, 2021

How do you think we should manage the scenario where somebody uses a new version of Larq but an old version of the converter? I guess on init we could try and detect if LCE is installed, detect the version, and print some kind of warning message if it's <= 0.6?

Yes, I think that should work well. That's how TensorFlow Addons handles it as well.

@lgeiger
Copy link
Member Author

lgeiger commented Jul 13, 2021

I added a LCE version check to __init__ in 3f16474 now that LCE 0.6.1 has support for this.

I'd be happy to move this check to setup.py though in case warning during import is too verbose.

larq/math.py Show resolved Hide resolved
larq/utils.py Outdated Show resolved Hide resolved
Copy link
Contributor

@AdamHillier AdamHillier left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome :)

I guess we should run a quick sanity check model conversion with LCE 0.6.1 before merging, but this looks great.

@lgeiger
Copy link
Member Author

lgeiger commented Jul 13, 2021

I guess we should run a quick sanity check model conversion with LCE 0.6.1 before merging, but this looks great.

I will run a sanity check later today or tomorrow to be save!

@simonmaurer
Copy link

simonmaurer commented Jul 28, 2021

@lgeiger @AdamHillier @Tombana : just as a suggestion since I'm dealing with larq.quantizers.SteSign and larq.quantizers.SteHeaviside model conversions.
for larq.math.sign and larq.math.heaviside to use:

tf.where(x >= 0., 1., -1.)

instead of

tf.where(x >= 0, tf.ones_like(x), -tf.ones_like(x))

This would result in fewer ops.

@lgeiger
Copy link
Member Author

lgeiger commented Jul 28, 2021

@simonmaurer Did you do any performance profiling of you proposed solution?

One problem with using tf.where(x >= 0., 1., -1.) is that this might change the datatype (e.g. it will always return float32, even if x is float16 or int8). This means we would need to use something like tf.where(x >= 0, tf.constant(1, dtype=x.dtype), tf.constant(-1, dtype=x.dtype)) or do some casting which wouldn't reduce the number of ops.
I briefly profiled some of the options on a T4 GPU in this notebook, but please let me know if I am missing something here.

@simonmaurer
Copy link

simonmaurer commented Jul 28, 2021

@lgeiger ah I see, this is indeed good reasoning. I didn't do any performance tests yet.
I implicitly assumed you are working with tf.float32 tensors since this is the data type that LceDequantize outputs in the end.
but now thinking about it again, the LceDequantize also handles dequantization to tf.int8, otherwise there would be no need for the additional Dequantize op at the model output. so definitely a good point.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
breaking-change Changes that will break user code
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants