Skip to content

Commit

Permalink
Adding a control dependency on the gradients to the gradient optimizers.
Browse files Browse the repository at this point in the history
This improves determinism and makes the gradients be computed correctly for unclear reasons. (#520)

Co-authored-by: Nicolas Feybesse ([email protected])
  • Loading branch information
Craigacp authored and karllessard committed Feb 23, 2024
1 parent af6ee7c commit e204dbe
Show file tree
Hide file tree
Showing 12 changed files with 90 additions and 82 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.tensorflow.Operand;
import org.tensorflow.Output;
import org.tensorflow.op.Op;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.Variable;
import org.tensorflow.op.train.ApplyAdadelta;
import org.tensorflow.types.family.TType;
Expand Down Expand Up @@ -150,16 +151,16 @@ private <T extends TType> void createAdaDeltaSlot(Output<T> v) {

/** {@inheritDoc} */
@Override
protected <T extends TType> Op applyDense(Output<T> gradient, Output<T> variable) {
protected <T extends TType> Op applyDense(Ops deps, Output<T> gradient, Output<T> variable) {
Variable<T> accumSlot = getSlot(variable, ACCUMULATOR).get();
Variable<T> accumUpdateSlot = getSlot(variable, ACCUMULATOR_UPDATE).get();
return tf.train.applyAdadelta(
return deps.train.applyAdadelta(
variable,
accumSlot,
accumUpdateSlot,
tf.dtypes.cast(tf.constant(learningRate), gradient.type()),
tf.dtypes.cast(tf.constant(rho), gradient.type()),
tf.dtypes.cast(tf.constant(epsilon), gradient.type()),
deps.dtypes.cast(deps.constant(learningRate), gradient.type()),
deps.dtypes.cast(deps.constant(rho), gradient.type()),
deps.dtypes.cast(deps.constant(epsilon), gradient.type()),
gradient,
ApplyAdadelta.useLocking(true));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.tensorflow.Operand;
import org.tensorflow.Output;
import org.tensorflow.op.Op;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.Variable;
import org.tensorflow.op.train.ApplyAdagrad;
import org.tensorflow.types.family.TType;
Expand Down Expand Up @@ -140,10 +141,10 @@ private <T extends TType> void createAdaGradSlot(Output<T> v) {

/** {@inheritDoc} */
@Override
protected <T extends TType> Op applyDense(Output<T> gradient, Output<T> variable) {
protected <T extends TType> Op applyDense(Ops deps, Output<T> gradient, Output<T> variable) {
Variable<T> slot = getSlot(variable, ACCUMULATOR).get();
return tf.train.applyAdagrad(
variable, slot, tf.dtypes.cast(tf.constant(learningRate), gradient.type()), gradient, opts);
return deps.train.applyAdagrad(
variable, slot, deps.dtypes.cast(deps.constant(learningRate), gradient.type()), gradient, opts);
}

/** {@inheritDoc} */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.tensorflow.Output;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.op.Op;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.Variable;
import org.tensorflow.op.train.ApplyAdagradDa;
import org.tensorflow.types.TInt64;
Expand Down Expand Up @@ -209,17 +210,17 @@ private <T extends TType> void createAdaGradDASlot(Output<T> v) {

/** {@inheritDoc} */
@Override
protected <T extends TType> Op applyDense(Output<T> gradient, Output<T> variable) {
protected <T extends TType> Op applyDense(Ops deps, Output<T> gradient, Output<T> variable) {
Variable<T> gradSlot = getSlot(variable, ACCUMULATOR).get();
Variable<T> gradSquaredSlot = getSlot(variable, SQUARED_ACCUMULATOR).get();
return tf.train.applyAdagradDa(
return deps.train.applyAdagradDa(
variable,
gradSlot,
gradSquaredSlot,
gradient,
tf.dtypes.cast(tf.constant(learningRate), gradient.type()),
tf.dtypes.cast(tf.constant(l1Strength), gradient.type()),
tf.dtypes.cast(tf.constant(l2Strength), gradient.type()),
deps.dtypes.cast(deps.constant(learningRate), gradient.type()),
deps.dtypes.cast(deps.constant(l1Strength), gradient.type()),
deps.dtypes.cast(deps.constant(l2Strength), gradient.type()),
globalStep,
ApplyAdagradDa.useLocking(true));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.tensorflow.Output;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.op.Op;
import org.tensorflow.op.Ops;
import org.tensorflow.op.Scope;
import org.tensorflow.op.annotation.Endpoint;
import org.tensorflow.op.annotation.Operator;
Expand Down Expand Up @@ -223,19 +224,19 @@ private <T extends TType> void createAdamSlot(Output<T> v) {

/** {@inheritDoc} */
@Override
protected <T extends TType> Op applyDense(Output<T> gradient, Output<T> variable) {
protected <T extends TType> Op applyDense(Ops deps, Output<T> gradient, Output<T> variable) {
Variable<T> firstMomentSlot = getSlot(variable, FIRST_MOMENT).get();
Variable<T> secondMomentSlot = getSlot(variable, SECOND_MOMENT).get();
return tf.train.applyAdam(
return deps.train.applyAdam(
variable,
firstMomentSlot,
secondMomentSlot,
tf.dtypes.cast(betaOnePower, gradient.type()),
tf.dtypes.cast(betaTwoPower, gradient.type()),
tf.dtypes.cast(learningRateConst, gradient.type()),
tf.dtypes.cast(betaOneConst, gradient.type()),
tf.dtypes.cast(betaTwoConst, gradient.type()),
tf.dtypes.cast(epsilonConst, gradient.type()),
deps.dtypes.cast(betaOnePower, gradient.type()),
deps.dtypes.cast(betaTwoPower, gradient.type()),
deps.dtypes.cast(learningRateConst, gradient.type()),
deps.dtypes.cast(betaOneConst, gradient.type()),
deps.dtypes.cast(betaTwoConst, gradient.type()),
deps.dtypes.cast(epsilonConst, gradient.type()),
gradient,
ApplyAdam.useLocking(true));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import org.tensorflow.Output;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.op.Op;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.Constant;
import org.tensorflow.op.core.Variable;
import org.tensorflow.op.train.ApplyAdaMax;
Expand Down Expand Up @@ -155,19 +156,19 @@ private <T extends TType> void createAdamaxSlot(Output<T> v) {

/** {@inheritDoc} */
@Override
protected <T extends TType> Op applyDense(Output<T> gradient, Output<T> variable) {
protected <T extends TType> Op applyDense(Ops deps, Output<T> gradient, Output<T> variable) {
Variable<T> firstMomentSlot = getSlot(variable, FIRST_MOMENT).get();
Variable<T> secondMomentSlot = getSlot(variable, SECOND_MOMENT).get();
return ApplyAdaMax.create(
this.tf.scope(),
deps.scope(),
variable,
firstMomentSlot,
secondMomentSlot,
tf.dtypes.cast(betaOnePower, gradient.type()),
tf.dtypes.cast(learningRateConst, gradient.type()),
tf.dtypes.cast(betaOneConst, gradient.type()),
tf.dtypes.cast(betaTwoConst, gradient.type()),
tf.dtypes.cast(epsilonConst, gradient.type()),
deps.dtypes.cast(betaOnePower, gradient.type()),
deps.dtypes.cast(learningRateConst, gradient.type()),
deps.dtypes.cast(betaOneConst, gradient.type()),
deps.dtypes.cast(betaTwoConst, gradient.type()),
deps.dtypes.cast(epsilonConst, gradient.type()),
gradient,
ApplyAdaMax.useLocking(true));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import org.tensorflow.Operand;
import org.tensorflow.Output;
import org.tensorflow.op.Op;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.Variable;
import org.tensorflow.op.train.ApplyFtrl;
import org.tensorflow.types.family.TType;
Expand Down Expand Up @@ -238,21 +239,21 @@ private <T extends TType> void createFtrlSlot(Output<T> v) {

/** {@inheritDoc} */
@Override
protected <T extends TType> Op applyDense(Output<T> gradient, Output<T> variable) {
protected <T extends TType> Op applyDense(Ops deps, Output<T> gradient, Output<T> variable) {
Variable<T> accumSlot = getSlot(variable, ACCUMULATOR).get();
Variable<T> linearSlot = getSlot(variable, LINEAR_ACCUMULATOR).get();
ApplyFtrl.Options options = ApplyFtrl.useLocking(true);
return this.tf.train.applyFtrl(
return deps.train.applyFtrl(
variable,
accumSlot, // accum
linearSlot, // linear
gradient, // gradient
tf.dtypes.cast(tf.constant(learningRate), gradient.type()), // lr
tf.dtypes.cast(tf.constant(l1RegularizationStrength), gradient.type()), // l1
tf.dtypes.cast(tf.constant(l2RegularizationStrength), gradient.type()), // l2
tf.dtypes.cast(
tf.constant(l2ShrinkageRegularizationStrength), gradient.type()), // l2Shrinkage
tf.dtypes.cast(tf.constant(learningRatePower), gradient.type()), // lrPower
deps.dtypes.cast(deps.constant(learningRate), gradient.type()), // lr
deps.dtypes.cast(deps.constant(l1RegularizationStrength), gradient.type()), // l1
deps.dtypes.cast(deps.constant(l2RegularizationStrength), gradient.type()), // l2
deps.dtypes.cast(
deps.constant(l2ShrinkageRegularizationStrength), gradient.type()), // l2Shrinkage
deps.dtypes.cast(deps.constant(learningRatePower), gradient.type()), // lrPower
options);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import org.tensorflow.Graph;
import org.tensorflow.Output;
import org.tensorflow.op.Op;
import org.tensorflow.op.Ops;
import org.tensorflow.op.train.ApplyGradientDescent;
import org.tensorflow.types.family.TType;

Expand Down Expand Up @@ -65,10 +66,10 @@ public GradientDescent(Graph graph, String name, float learningRate) {

/** {@inheritDoc} */
@Override
protected <T extends TType> Op applyDense(Output<T> gradient, Output<T> variable) {
return tf.train.applyGradientDescent(
protected <T extends TType> Op applyDense(Ops deps, Output<T> gradient, Output<T> variable) {
return deps.train.applyGradientDescent(
variable,
tf.dtypes.cast(tf.constant(learningRate), gradient.type()),
deps.dtypes.cast(deps.constant(learningRate), gradient.type()),
gradient,
ApplyGradientDescent.useLocking(true));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.tensorflow.Operand;
import org.tensorflow.Output;
import org.tensorflow.op.Op;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.Variable;
import org.tensorflow.op.train.ApplyMomentum;
import org.tensorflow.types.family.TType;
Expand Down Expand Up @@ -130,14 +131,14 @@ private <T extends TType> void createMomentumSlot(Output<T> v) {

/** {@inheritDoc} */
@Override
protected <T extends TType> Op applyDense(Output<T> gradient, Output<T> variable) {
protected <T extends TType> Op applyDense(Ops deps, Output<T> gradient, Output<T> variable) {
Variable<T> slot = getSlot(variable, MOMENTUM).get();
return tf.train.applyMomentum(
return deps.train.applyMomentum(
variable,
slot,
tf.dtypes.cast(tf.constant(learningRate), gradient.type()),
deps.dtypes.cast(deps.constant(learningRate), gradient.type()),
gradient,
tf.dtypes.cast(tf.constant(momentum), gradient.type()),
deps.dtypes.cast(deps.constant(momentum), gradient.type()),
ApplyMomentum.useNesterov(useNesterov),
ApplyMomentum.useLocking(true));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import org.tensorflow.Output;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.op.Op;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.Assign;
import org.tensorflow.op.core.Constant;
import org.tensorflow.op.core.Variable;
Expand Down Expand Up @@ -224,53 +225,53 @@ protected Optional<Op> prepare(String scopeName) {

/** {@inheritDoc} */
@Override
protected <T extends TType> Op applyDense(Output<T> gradient, Output<T> variable) {
protected <T extends TType> Op applyDense(Ops deps, Output<T> gradient, Output<T> variable) {
Class<T> type = gradient.type();
Variable<T> m = getSlot(variable, FIRST_MOMENT).get(); // first Moment
Variable<T> v = getSlot(variable, SECOND_MOMENT).get(); // Second Moment

// gPrime = grad / coefficients['oneMinusMScheduleNew']
Operand<T> gPrime = tf.math.div(gradient, tf.dtypes.cast(oneMinusMScheduleNew, type));
Operand<T> gPrime = deps.math.div(gradient, deps.dtypes.cast(oneMinusMScheduleNew, type));
// mT = (coefficients['beta_1_t'] * m + coefficients['one_minus_beta_1_t'] * grad)
Operand<T> mT =
tf.math.add(
tf.math.mul(tf.dtypes.cast(betaOneConst, type), m),
tf.math.mul(tf.dtypes.cast(oneMinusBeta1, type), gradient));
deps.math.add(
deps.math.mul(deps.dtypes.cast(betaOneConst, type), m),
deps.math.mul(deps.dtypes.cast(oneMinusBeta1, type), gradient));
// mT = state_ops.assign(m, mT, use_locking=self._use_locking)
// update m
mT = tf.assign(m, mT, Assign.useLocking(true));
mT = deps.assign(m, mT, Assign.useLocking(true));

// mTPrime = mT / coefficients['oneMinusMScheduleNext']
Operand<T> mTPrime = tf.math.div(mT, tf.dtypes.cast(oneMinusMScheduleNext, type));
Operand<T> mTPrime = deps.math.div(mT, deps.dtypes.cast(oneMinusMScheduleNext, type));

// vT = (coefficients['beta_2_t'] * v + coefficients['one_minus_beta_2_t'] *
// math_ops.square(grad))
Operand<T> vT =
tf.math.add(
tf.math.mul(tf.dtypes.cast(betaTwoConst, type), v),
tf.math.mul(tf.dtypes.cast(oneMinusBeta2, type), tf.math.square(gradient)));
deps.math.add(
deps.math.mul(deps.dtypes.cast(betaTwoConst, type), v),
deps.math.mul(deps.dtypes.cast(oneMinusBeta2, type), deps.math.square(gradient)));
// vT = state_ops.assign(v, vT, use_locking=self._use_locking)
// update v
vT = tf.assign(v, vT, Assign.useLocking(true));
vT = deps.assign(v, vT, Assign.useLocking(true));

// vTPrime = vT / coefficients['vTPrimeDenominator']
Operand<T> vTPrime = tf.math.div(vT, tf.dtypes.cast(vTPrimeDenominator, type));
Operand<T> vTPrime = deps.math.div(vT, deps.dtypes.cast(vTPrimeDenominator, type));

// m_t_bar = (coefficients['oneMinusMT'] * gPrime + coefficients['mT1'] * mTPrime)
Operand<T> m_t_bar =
tf.math.add(
tf.math.mul(tf.dtypes.cast(oneMinusMT, type), gPrime),
tf.math.mul(tf.dtypes.cast(mT1, type), mTPrime));
deps.math.add(
deps.math.mul(deps.dtypes.cast(oneMinusMT, type), gPrime),
deps.math.mul(deps.dtypes.cast(mT1, type), mTPrime));
// varT = var - coefficients['lr_t'] * m_t_bar / (math_ops.sqrt(vTPrime) +
// coefficients['epsilon'])
Operand<T> varT =
tf.math.sub(
deps.math.sub(
variable,
tf.math.div(
tf.math.mul(tf.dtypes.cast(learningRateConst, type), m_t_bar),
tf.math.add(tf.math.sqrt(vTPrime), tf.dtypes.cast(epsilonConst, type))));
deps.math.div(
deps.math.mul(deps.dtypes.cast(learningRateConst, type), m_t_bar),
deps.math.add(deps.math.sqrt(vTPrime), deps.dtypes.cast(epsilonConst, type))));

return tf.assign(variable, varT, Assign.useLocking(true));
return deps.assign(variable, varT, Assign.useLocking(true));
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,14 +168,16 @@ public Op applyGradients(List<GradAndVar<? extends TType>> gradsAndVars, String
gradsAndVars.stream().map(GradAndVar::getVariable).collect(Collectors.toList());

createSlots(variables);
List<Op> gradients = gradsAndVars.stream().map(GradAndVar::getGradient).filter(g -> !g.isClosed()).collect(Collectors.toList());
Ops tfOpsGrads = tf.withControlDependencies(gradients);

Optional<Op> prepOp = prepare(name + "/prepare");

List<Op> updateOps = new ArrayList<>();
prepOp.ifPresent(updateOps::add);
for (GradAndVar<? extends TType> pair : gradsAndVars) {
if (!pair.gradient.isClosed()) {
updateOps.add(applyDense(pair));
updateOps.add(applyDense(tfOpsGrads, pair));
}
}

Expand Down Expand Up @@ -261,8 +263,8 @@ protected void createSlots(List<Output<? extends TType>> variables) {}
* @param <T> the datatype of the gradients and variables.
* @return An operand which applies the desired optimizer update to the variable.
*/
private <T extends TType> Op applyDense(GradAndVar<T> gradVarPair) {
return applyDense(gradVarPair.getGradient(), gradVarPair.getVariable());
private <T extends TType> Op applyDense(Ops opDependencies, GradAndVar<T> gradVarPair) {
return applyDense(opDependencies, gradVarPair.getGradient(), gradVarPair.getVariable());
}

/**
Expand All @@ -273,7 +275,7 @@ private <T extends TType> Op applyDense(GradAndVar<T> gradVarPair) {
* @param <T> The type of the variable.
* @return An operand which applies the desired optimizer update to the variable.
*/
protected abstract <T extends TType> Op applyDense(Output<T> gradient, Output<T> variable);
protected abstract <T extends TType> Op applyDense(Ops opDependencies, Output<T> gradient, Output<T> variable);

/**
* Gathers up the update operations into a single op that can be used as a run target.
Expand Down
Loading

0 comments on commit e204dbe

Please sign in to comment.