Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved generic instantiation behavior and added type precedence to … #1428

Merged
merged 8 commits into from
Nov 5, 2024
2 changes: 1 addition & 1 deletion Src/grammar/cql.g4
Original file line number Diff line number Diff line change
Expand Up @@ -320,14 +320,14 @@ expression
| expression 'properly'? 'between' expressionTerm 'and' expressionTerm #betweenExpression
| ('duration' 'in')? pluralDateTimePrecision 'between' expressionTerm 'and' expressionTerm #durationBetweenExpression
| 'difference' 'in' pluralDateTimePrecision 'between' expressionTerm 'and' expressionTerm #differenceBetweenExpression
| expression ('|' | 'union' | 'intersect' | 'except') expression #inFixSetExpression
| expression ('<=' | '<' | '>' | '>=') expression #inequalityExpression
| expression intervalOperatorPhrase expression #timingExpression
| expression ('=' | '!=' | '~' | '!~') expression #equalityExpression
| expression ('in' | 'contains') dateTimePrecisionSpecifier? expression #membershipExpression
| expression 'and' expression #andExpression
| expression ('or' | 'xor') expression #orExpression
| expression 'implies' expression #impliesExpression
| expression ('|' | 'union' | 'intersect' | 'except') expression #inFixSetExpression
;

dateTimePrecision
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -728,20 +728,32 @@ public Object visitListSelector(cqlParser.ListSelectorContext ctx) {

DataType elementType = elementTypeSpecifier != null ? elementTypeSpecifier.getResultType() : null;
DataType inferredElementType = null;
DataType initialInferredElementType = null;

List<Expression> elements = new ArrayList<>();
for (cqlParser.ExpressionContext elementContext : ctx.expression()) {
Expression element = parseExpression(elementContext);

if (element == null) {
throw new RuntimeException("Element failed to parse");
}

if (elementType != null) {
libraryBuilder.verifyType(element.getResultType(), elementType);
} else {
if (inferredElementType == null) {
inferredElementType = element.getResultType();
if (initialInferredElementType == null) {
initialInferredElementType = element.getResultType();
inferredElementType = initialInferredElementType;
} else {
// Once a list type is inferred as Any, keep it that way
// The only potential exception to this is if the element responsible for the inferred type of Any
// is a null
DataType compatibleType =
libraryBuilder.findCompatibleType(inferredElementType, element.getResultType());
if (compatibleType != null) {
if (compatibleType != null
&& (!inferredElementType.equals(libraryBuilder.resolveTypeName("System", "Any"))
|| initialInferredElementType.equals(
libraryBuilder.resolveTypeName("System", "Any")))) {
inferredElementType = compatibleType;
} else {
inferredElementType = libraryBuilder.resolveTypeName("System", "Any");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1098,12 +1098,49 @@ public Invocation resolveProperContainsInvocation(
return resolveBinaryInvocation("System", "ProperContains", properContains);
}

private int getTypeScore(OperatorResolution resolution) {
int typeScore = ConversionMap.ConversionScore.ExactMatch.score();
for (DataType operand : resolution.getOperator().getSignature().getOperandTypes()) {
typeScore += ConversionMap.getTypePrecedenceScore(operand);
}

return typeScore;
}

private Expression lowestScoringInvocation(Invocation primary, Invocation secondary) {
if (primary != null) {
if (secondary != null) {
if (secondary.getResolution().getScore()
< primary.getResolution().getScore()) {
return secondary.getExpression();
} else if (primary.getResolution().getScore()
< secondary.getResolution().getScore()) {
return primary.getExpression();
}
if (primary.getResolution().getScore()
== secondary.getResolution().getScore()) {
int primaryTypeScore = getTypeScore(primary.getResolution());
int secondaryTypeScore = getTypeScore(secondary.getResolution());

if (secondaryTypeScore < primaryTypeScore) {
return secondary.getExpression();
} else if (primaryTypeScore < secondaryTypeScore) {
return primary.getExpression();
} else {
// ERROR:
StringBuilder message = new StringBuilder("Call to operator ")
.append(primary.getResolution().getOperator().getName())
.append("/")
.append(secondary.getResolution().getOperator().getName())
.append(" is ambiguous with: ")
.append("\n - ")
.append(primary.getResolution().getOperator().getName())
.append(primary.getResolution().getOperator().getSignature())
.append("\n - ")
.append(secondary.getResolution().getOperator().getName())
.append(secondary.getResolution().getOperator().getSignature());
throw new IllegalArgumentException(message.toString());
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,26 @@
import org.hl7.cql.model.*;

public class ConversionMap {
public enum TypePrecedenceScore {
Simple(1),
Tuple(2),
Class(3),
Interval(4),
List(5),
Choice(6),
Other(7);

private final int score;

public int score() {
return score;
}

TypePrecedenceScore(int score) {
this.score = score;
}
}

public enum ConversionScore {
ExactMatch(0),
SubType(1),
Expand All @@ -30,6 +50,25 @@ public int score() {
}
}

public static int getTypePrecedenceScore(DataType operand) {
switch (operand.getClass().getSimpleName()) {
case "SimpleType":
return ConversionMap.TypePrecedenceScore.Simple.score();
case "TupleType":
return ConversionMap.TypePrecedenceScore.Tuple.score();
case "ClassType":
return ConversionMap.TypePrecedenceScore.Class.score();
case "IntervalType":
return ConversionMap.TypePrecedenceScore.Interval.score();
case "ListType":
return ConversionMap.TypePrecedenceScore.List.score();
case "ChoiceType":
return ConversionMap.TypePrecedenceScore.Choice.score();
default:
return ConversionMap.TypePrecedenceScore.Other.score();
}
}

public static int getConversionScore(DataType callOperand, DataType operand, Conversion conversion) {
if (operand.equals(callOperand)) {
return ConversionMap.ConversionScore.ExactMatch.score();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,44 +41,136 @@ public Signature getSignature() {
return operator.getSignature();
}

/*
The invocation signature is the call signature with arguments of type Any set to the operand types
*/
private Signature getInvocationSignature(Signature callSignature, Signature operatorSignature) {
if (callSignature.getSize() == operatorSignature.getSize()) {
DataType[] invocationTypes = new DataType[callSignature.getSize()];
Iterator<DataType> callTypes = callSignature.getOperandTypes().iterator();
Iterator<DataType> operatorTypes =
operatorSignature.getOperandTypes().iterator();
boolean isResolved = false;
for (int i = 0; i < invocationTypes.length; i++) {
DataType callType = callTypes.next();
DataType operatorType = operatorTypes.next();
if (callType.equals(DataType.ANY) && !operatorType.equals(DataType.ANY)) {
isResolved = true;
invocationTypes[i] = operatorType;
} else {
invocationTypes[i] = callType;
}
}
if (isResolved) {
return new Signature(invocationTypes);
}
}
return callSignature;
}

private OperatorResolution getOperatorResolution(
Operator operator,
Signature callSignature,
Signature invocationSignature,
ConversionMap conversionMap,
OperatorMap operatorMap,
boolean allowPromotionAndDemotion,
boolean requireConversions) {
Conversion[] conversions = getConversions(
callSignature, operator.getSignature(), conversionMap, operatorMap, allowPromotionAndDemotion);
OperatorResolution result = new OperatorResolution(operator, conversions);
if (requireConversions && conversions == null) {
return null;
}
return result;
}

public List<OperatorResolution> resolve(
CallContext callContext, ConversionMap conversionMap, OperatorMap operatorMap) {
List<OperatorResolution> results = null;
if (operator.getSignature().equals(callContext.getSignature())) {
results = new ArrayList<>();
results.add(new OperatorResolution(operator));
return results;
Signature invocationSignature = getInvocationSignature(callContext.getSignature(), operator.getSignature());

// Attempt exact match against this signature
if (operator.getSignature().equals(invocationSignature)) {
OperatorResolution result = getOperatorResolution(
operator,
callContext.getSignature(),
invocationSignature,
conversionMap,
operatorMap,
callContext.getAllowPromotionAndDemotion(),
false);
if (result != null) {
results = new ArrayList<>();
results.add(result);
return results;
}
}

// Attempt to resolve against sub signatures
results = subSignatures.resolve(callContext, conversionMap, operatorMap);
if (results == null && operator.getSignature().isSuperTypeOf(callContext.getSignature())) {
results = new ArrayList<>();
results.add(new OperatorResolution(operator));

// If no subsignatures match, attempt subType match against this signature
if (results == null && operator.getSignature().isSuperTypeOf(invocationSignature)) {
OperatorResolution result = getOperatorResolution(
operator,
callContext.getSignature(),
invocationSignature,
conversionMap,
operatorMap,
callContext.getAllowPromotionAndDemotion(),
false);
if (result != null) {
results = new ArrayList<>();
results.add(result);
return results;
}
}

if (results == null && conversionMap != null) {
// Attempt to find a conversion path from the call signature to the target signature
Conversion[] conversions =
new Conversion[operator.getSignature().getSize()];
boolean isConvertible = callContext
.getSignature()
.isConvertibleTo(
operator.getSignature(),
conversionMap,
operatorMap,
callContext.getAllowPromotionAndDemotion(),
conversions);
if (isConvertible) {
OperatorResolution resolution = new OperatorResolution(operator);
resolution.setConversions(conversions);
results = new ArrayList<>();
results.add(resolution);
OperatorResolution result = getOperatorResolution(
operator,
callContext.getSignature(),
invocationSignature,
conversionMap,
operatorMap,
callContext.getAllowPromotionAndDemotion(),
true);
if (result != null) {
if (results == null) {
results = new ArrayList<>();
}
results.add(result);
}
}

return results;
}

private Conversion[] getConversions(
Signature callSignature,
Signature operatorSignature,
ConversionMap conversionMap,
OperatorMap operatorMap,
boolean allowPromotionAndDemotion) {
if (callSignature == null
|| operatorSignature == null
|| callSignature.getSize() != operatorSignature.getSize()) {
return null;
}

Conversion[] conversions = new Conversion[callSignature.getSize()];
boolean isConvertible = callSignature.isConvertibleTo(
operatorSignature, conversionMap, operatorMap, allowPromotionAndDemotion, conversions);

if (isConvertible) {
return conversions;
}

return null;
}

private SignatureNodes subSignatures = new SignatureNodes();

public boolean hasSubSignatures() {
Expand Down Expand Up @@ -277,60 +369,42 @@ public List<OperatorResolution> resolve(
throw new IllegalArgumentException("callContext is null");
}

List<OperatorResolution> results = signatures.resolve(callContext, conversionMap, operatorMap);

// If there is no resolution, or all resolutions require conversion, attempt to instantiate a generic signature
if (results == null || allResultsUseConversion(results)) {
// If the callContext signature contains choices, attempt instantiation with all possible combinations of
// the call signature (ouch, this could really hurt...)
boolean signaturesInstantiated = false;
List<Signature> callSignatures = expandChoices(callContext.getSignature());
for (Signature callSignature : callSignatures) {
Operator result = instantiate(
callSignature, operatorMap, conversionMap, callContext.getAllowPromotionAndDemotion());
if (result != null && !signatures.contains(result)) {
// If the generic signature was instantiated, store it as an actual signature.
signatures.add(new SignatureNode(result));
signaturesInstantiated = true;
// Attempt to instantiate any generic signatures
// If the callContext signature contains choices, attempt instantiation with all possible combinations of
// the call signature (ouch, this could really hurt...)
boolean signaturesInstantiated = false;
List<Signature> callSignatures = expandChoices(callContext.getSignature());
for (Signature callSignature : callSignatures) {
List<Operator> instantiations =
instantiate(callSignature, operatorMap, conversionMap, callContext.getAllowPromotionAndDemotion());
for (Operator instantiation : instantiations) {
// If the generic signature was instantiated, store it as an actual signature.
if (!signatures.contains(instantiation)) {
signatures.add(new SignatureNode(instantiation));
}
}

// re-attempt the resolution with the instantiated signature registered
if (signaturesInstantiated) {
results = signatures.resolve(callContext, conversionMap, operatorMap);
}
}

List<OperatorResolution> results = signatures.resolve(callContext, conversionMap, operatorMap);

return results;
}

private Operator instantiate(
private List<Operator> instantiate(
Signature signature,
OperatorMap operatorMap,
ConversionMap conversionMap,
boolean allowPromotionAndDemotion) {
List<Operator> instantiations = new ArrayList<Operator>();
int lowestConversionScore = Integer.MAX_VALUE;
Operator instantiation = null;

for (GenericOperator genericOperator : genericOperators.values()) {
InstantiationResult instantiationResult =
genericOperator.instantiate(signature, operatorMap, conversionMap, allowPromotionAndDemotion);
if (instantiationResult.getOperator() != null) {
if (instantiationResult.getConversionScore() <= lowestConversionScore) {
if (instantiation == null || instantiationResult.getConversionScore() < lowestConversionScore) {
instantiation = instantiationResult.getOperator();
lowestConversionScore = instantiationResult.getConversionScore();
} else {
throw new IllegalArgumentException(String.format(
"Ambiguous generic instantiation of operator %s between signature %s and %s.",
this.name,
instantiation.getSignature().toString(),
instantiationResult.getOperator().getSignature().toString()));
}
}
instantiations.add(instantiationResult.getOperator());
}
}

return instantiation;
return instantiations;
}
}
Loading
Loading