Skip to content

Commit

Permalink
factorize the code
Browse files Browse the repository at this point in the history
  • Loading branch information
paulbrauner-da committed Nov 13, 2024
1 parent 5800a5e commit c95b04e
Showing 1 changed file with 31 additions and 65 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2394,6 +2394,7 @@ private[lf] object SBuiltin {
coid: V.ContractId,
)(f: (Option[Ref.PackageName], SValue) => Control[Question.Update]): Control[Question.Update] = {

// Checks that the metadata of [original] and [recomputed] are the same, fails with a [Control.Error] if not.
def checkContractUpgradable(
original: ContractInfo,
recomputed: ContractInfo,
Expand Down Expand Up @@ -2433,14 +2434,29 @@ private[lf] object SBuiltin {
)
}

def importLocalContract(contract: ContractInfo) = {
val coinstArg = contract.arg
val srcTmplId = contract.templateId
type ContractInfoChecker =
ContractInfo => (() => Control[Question.Update]) => Control[Question.Update]

def localContractInfoUpgradeChecker(
originalContractInfo: ContractInfo
): ContractInfoChecker =
(upgradedContractInfo: ContractInfo) =>
checkContractUpgradable(originalContractInfo, upgradedContractInfo)

def globalContractInfoUpgradeChecker(
coid: V.ContractId,
srcTemplateId: Ref.Identifier,
): ContractInfoChecker =
(upgradedContractInfo: ContractInfo) =>
validateContractInfo(machine, coid, srcTemplateId, upgradedContractInfo)

def importContract(
coinst: V.ContractInstance,
contractInfoUpgradeChecker: ContractInfoChecker,
) = {
val V.ContractInstance(_, srcTmplId, coinstArg) = coinst
val (upgradingIsEnabled, dstTmplId) = optTargetTemplateId match {
case Some(tycon)
if V
.ContractInstance(contract.packageName, contract.templateId, contract.arg)
.upgradable =>
case Some(tycon) if coinst.upgradable =>
(true, tycon)
case _ =>
(false, srcTmplId) // upgrading not enabled; import at source type
Expand Down Expand Up @@ -2478,8 +2494,7 @@ private[lf] object SBuiltin {
upgradingIsEnabled && (srcTmplId.packageId != dstTmplId.packageId)
}
if (needValidationCall) {

checkContractUpgradable(contract, upgradedContract) { () =>
contractInfoUpgradeChecker(upgradedContract) { () =>
f(upgradedContract.packageName, upgradedContract.any)
}
} else {
Expand All @@ -2491,60 +2506,6 @@ private[lf] object SBuiltin {
}
}

def importGlobalContract(coinst: V.ContractInstance) = {
val V.ContractInstance(_, srcTmplId, coinstArg) = coinst
val (upgradingIsEnabled, dstTmplId) = optTargetTemplateId match {
case Some(tycon) if coinst.upgradable =>
(true, tycon)
case _ =>
(false, srcTmplId) // upgrading not enabled; import at source type
}
if (srcTmplId.qualifiedName != dstTmplId.qualifiedName) {
Control.Error(
IE.WronglyTypedContract(coid, dstTmplId, srcTmplId)
)
} else
machine.ensurePackageIsLoaded(
dstTmplId.packageId,
language.Reference.Template(dstTmplId),
) { () =>
importValue(machine, dstTmplId, coinstArg) { templateArg =>
getContractInfo(
machine,
coid,
dstTmplId,
templateArg,
allowCatchingContractInfoErrors = false,
) { contract =>
ensureContractActive(machine, coid, contract.templateId) {

machine.checkContractVisibility(coid, contract)
machine.enforceLimitAddInputContract()
machine.enforceLimitSignatoriesAndObservers(coid, contract)

// In Validation mode, we always call validateContractInfo
// In Submission mode, we only call validateContractInfo when src != dest
val needValidationCall: Boolean =
if (machine.validating) {
upgradingIsEnabled
} else {
// we already check qualified names match
upgradingIsEnabled && (srcTmplId.packageId != dstTmplId.packageId)
}
if (needValidationCall) {

validateContractInfo(machine, coid, srcTmplId, contract) { () =>
f(contract.packageName, contract.any)
}
} else {
f(contract.packageName, contract.any)
}
}
}
}
}
}

machine.getLocalContract(coid) match {
case Some((templateId, templateArg)) =>
ensureContractActive(machine, coid, templateId) {
Expand All @@ -2560,12 +2521,17 @@ private[lf] object SBuiltin {
// import its value and validate its contract info again.
f(contract.packageName, SValue.SAnyContract(templateId, templateArg))
} else {
importLocalContract(contract)
importContract(
V.ContractInstance(contract.packageName, templateId, contract.arg),
localContractInfoUpgradeChecker(contract),
)
}
}
}
case None =>
machine.lookupGlobalContract(coid)(importGlobalContract)
machine.lookupGlobalContract(coid)(coinst =>
importContract(coinst, globalContractInfoUpgradeChecker(coid, coinst.template))
)
}
}

Expand Down

0 comments on commit c95b04e

Please sign in to comment.