diff --git a/.github/workflows/ccpp.yml b/.github/workflows/ccpp.yml index f7bb5372..667925c8 100644 --- a/.github/workflows/ccpp.yml +++ b/.github/workflows/ccpp.yml @@ -25,7 +25,7 @@ jobs: run: | brew install pkg-config glfw3 python3 cmark fmt coreutils make brew link --overwrite python3 - python3 -m pip install --upgrade Pillow + brew install pillow - name: make run: | gmake -j$(nproc) BUILDTYPE=osx diff --git a/Makefile b/Makefile index 42c2f5ed..dd2f4a66 100644 --- a/Makefile +++ b/Makefile @@ -20,7 +20,7 @@ ifeq ($(BUILDTYPE),windows) EXT = .exe else export CC = gcc - export CXX = g++ + export CXX = g++ -std=c++20 export CFLAGS export CXXFLAGS export LDFLAGS @@ -77,4 +77,4 @@ bin/wordgrinder-minimal-dependencies-for-debian.tar.xz: wordgrinder.man \ xwordgrinder.man -include build/ab.mk \ No newline at end of file +include build/ab.mk diff --git a/src/c/filesystem.cc b/src/c/filesystem.cc index 0b210696..f2e0b76f 100644 --- a/src/c/filesystem.cc +++ b/src/c/filesystem.cc @@ -280,7 +280,7 @@ static int readfile_cb(lua_State* L) if (i < 0) goto error; - luaL_addlstring(&buffer, b, i, -1); + luaL_addlstring(&buffer, b, i); } fclose(fp); diff --git a/third_party/luau/Analysis/include/Luau/AnyTypeSummary.h b/third_party/luau/Analysis/include/Luau/AnyTypeSummary.h new file mode 100644 index 00000000..73d6f851 --- /dev/null +++ b/third_party/luau/Analysis/include/Luau/AnyTypeSummary.h @@ -0,0 +1,147 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/AstQuery.h" +#include "Luau/Config.h" +#include "Luau/ModuleResolver.h" +#include "Luau/Scope.h" +#include "Luau/Variant.h" +#include "Luau/Normalize.h" +#include "Luau/TypePack.h" +#include "Luau/TypeArena.h" + +#include +#include +#include +#include + +namespace Luau +{ + +class AstStat; +class ParseError; +struct TypeError; +struct LintWarning; +struct GlobalTypes; +struct ModuleResolver; +struct ParseResult; +struct DcrLogger; + +struct TelemetryTypePair +{ + std::string annotatedType; + std::string inferredType; +}; + +struct AnyTypeSummary +{ + TypeArena arena; + + AstStatBlock* rootSrc = nullptr; + DenseHashSet seenTypeFamilyInstances{nullptr}; + + int recursionCount = 0; + + std::string root; + int strictCount = 0; + + DenseHashMap seen{nullptr}; + + AnyTypeSummary(); + + void traverse(const Module* module, AstStat* src, NotNull builtinTypes); + + std::pair checkForAnyCast(const Scope* scope, AstExprTypeAssertion* expr); + + bool containsAny(TypePackId typ); + bool containsAny(TypeId typ); + + bool isAnyCast(const Scope* scope, AstExpr* expr, const Module* module, NotNull builtinTypes); + bool isAnyCall(const Scope* scope, AstExpr* expr, const Module* module, NotNull builtinTypes); + + bool hasVariadicAnys(const Scope* scope, AstExprFunction* expr, const Module* module, NotNull builtinTypes); + bool hasArgAnys(const Scope* scope, AstExprFunction* expr, const Module* module, NotNull builtinTypes); + bool hasAnyReturns(const Scope* scope, AstExprFunction* expr, const Module* module, NotNull builtinTypes); + + TypeId checkForFamilyInhabitance(const TypeId instance, Location location); + TypeId lookupType(const AstExpr* expr, const Module* module, NotNull builtinTypes); + TypePackId reconstructTypePack(const AstArray exprs, const Module* module, NotNull builtinTypes); + + DenseHashSet seenTypeFunctionInstances{nullptr}; + TypeId lookupAnnotation(AstType* annotation, const Module* module, NotNull builtintypes); + std::optional lookupPackAnnotation(AstTypePack* annotation, const Module* module); + TypeId checkForTypeFunctionInhabitance(const TypeId instance, const Location location); + + enum Pattern: uint64_t + { + Casts, + FuncArg, + FuncRet, + FuncApp, + VarAnnot, + VarAny, + TableProp, + Alias, + Assign + }; + + struct TypeInfo + { + Pattern code; + std::string node; + TelemetryTypePair type; + + explicit TypeInfo(Pattern code, std::string node, TelemetryTypePair type); + }; + + struct FindReturnAncestry final : public AstVisitor + { + AstNode* currNode{nullptr}; + AstNode* stat{nullptr}; + Position rootEnd; + bool found = false; + + explicit FindReturnAncestry(AstNode* stat, Position rootEnd); + + bool visit(AstType* node) override; + bool visit(AstNode* node) override; + bool visit(AstStatFunction* node) override; + bool visit(AstStatLocalFunction* node) override; + }; + + std::vector typeInfo; + + /** + * Fabricates a scope that is a child of another scope. + * @param node the lexical node that the scope belongs to. + * @param parent the parent scope of the new scope. Must not be null. + */ + const Scope* childScope(const AstNode* node, const Scope* parent); + + std::optional matchRequire(const AstExprCall& call); + AstNode* getNode(AstStatBlock* root, AstNode* node); + const Scope* findInnerMostScope(const Location location, const Module* module); + const AstNode* findAstAncestryAtLocation(const AstStatBlock* root, AstNode* node); + + void visit(const Scope* scope, AstStat* stat, const Module* module, NotNull builtinTypes); + void visit(const Scope* scope, AstStatBlock* block, const Module* module, NotNull builtinTypes); + void visit(const Scope* scope, AstStatIf* ifStatement, const Module* module, NotNull builtinTypes); + void visit(const Scope* scope, AstStatWhile* while_, const Module* module, NotNull builtinTypes); + void visit(const Scope* scope, AstStatRepeat* repeat, const Module* module, NotNull builtinTypes); + void visit(const Scope* scope, AstStatReturn* ret, const Module* module, NotNull builtinTypes); + void visit(const Scope* scope, AstStatLocal* local, const Module* module, NotNull builtinTypes); + void visit(const Scope* scope, AstStatFor* for_, const Module* module, NotNull builtinTypes); + void visit(const Scope* scope, AstStatForIn* forIn, const Module* module, NotNull builtinTypes); + void visit(const Scope* scope, AstStatAssign* assign, const Module* module, NotNull builtinTypes); + void visit(const Scope* scope, AstStatCompoundAssign* assign, const Module* module, NotNull builtinTypes); + void visit(const Scope* scope, AstStatFunction* function, const Module* module, NotNull builtinTypes); + void visit(const Scope* scope, AstStatLocalFunction* function, const Module* module, NotNull builtinTypes); + void visit(const Scope* scope, AstStatTypeAlias* alias, const Module* module, NotNull builtinTypes); + void visit(const Scope* scope, AstStatExpr* expr, const Module* module, NotNull builtinTypes); + void visit(const Scope* scope, AstStatDeclareGlobal* declareGlobal, const Module* module, NotNull builtinTypes); + void visit(const Scope* scope, AstStatDeclareClass* declareClass, const Module* module, NotNull builtinTypes); + void visit(const Scope* scope, AstStatDeclareFunction* declareFunction, const Module* module, NotNull builtinTypes); + void visit(const Scope* scope, AstStatError* error, const Module* module, NotNull builtinTypes); +}; + +} // namespace Luau \ No newline at end of file diff --git a/third_party/luau/Analysis/include/Luau/Anyification.h b/third_party/luau/Analysis/include/Luau/Anyification.h index 7b6f7171..4b9c8ee9 100644 --- a/third_party/luau/Analysis/include/Luau/Anyification.h +++ b/third_party/luau/Analysis/include/Luau/Anyification.h @@ -4,7 +4,7 @@ #include "Luau/NotNull.h" #include "Luau/Substitution.h" -#include "Luau/Type.h" +#include "Luau/TypeFwd.h" #include @@ -19,10 +19,22 @@ using ScopePtr = std::shared_ptr; // A substitution which replaces free types by any struct Anyification : Substitution { - Anyification(TypeArena* arena, NotNull scope, NotNull builtinTypes, InternalErrorReporter* iceHandler, TypeId anyType, - TypePackId anyTypePack); - Anyification(TypeArena* arena, const ScopePtr& scope, NotNull builtinTypes, InternalErrorReporter* iceHandler, TypeId anyType, - TypePackId anyTypePack); + Anyification( + TypeArena* arena, + NotNull scope, + NotNull builtinTypes, + InternalErrorReporter* iceHandler, + TypeId anyType, + TypePackId anyTypePack + ); + Anyification( + TypeArena* arena, + const ScopePtr& scope, + NotNull builtinTypes, + InternalErrorReporter* iceHandler, + TypeId anyType, + TypePackId anyTypePack + ); NotNull scope; NotNull builtinTypes; InternalErrorReporter* iceHandler; @@ -39,4 +51,4 @@ struct Anyification : Substitution bool ignoreChildren(TypePackId ty) override; }; -} // namespace Luau \ No newline at end of file +} // namespace Luau diff --git a/third_party/luau/Analysis/include/Luau/ApplyTypeFunction.h b/third_party/luau/Analysis/include/Luau/ApplyTypeFunction.h index 3f5f47fd..71430b28 100644 --- a/third_party/luau/Analysis/include/Luau/ApplyTypeFunction.h +++ b/third_party/luau/Analysis/include/Luau/ApplyTypeFunction.h @@ -3,7 +3,7 @@ #include "Luau/Substitution.h" #include "Luau/TxnLog.h" -#include "Luau/Type.h" +#include "Luau/TypeFwd.h" namespace Luau { diff --git a/third_party/luau/Analysis/include/Luau/AstQuery.h b/third_party/luau/Analysis/include/Luau/AstQuery.h index e7a018c0..633d6faf 100644 --- a/third_party/luau/Analysis/include/Luau/AstQuery.h +++ b/third_party/luau/Analysis/include/Luau/AstQuery.h @@ -3,6 +3,7 @@ #include "Luau/Ast.h" #include "Luau/Documentation.h" +#include "Luau/TypeFwd.h" #include @@ -13,9 +14,6 @@ struct Binding; struct SourceModule; struct Module; -struct Type; -using TypeId = const Type*; - using ScopePtr = std::shared_ptr; struct ExprOrLocal @@ -63,6 +61,22 @@ struct ExprOrLocal AstLocal* local = nullptr; }; +struct FindFullAncestry final : public AstVisitor +{ + std::vector nodes; + Position pos; + Position documentEnd; + bool includeTypes = false; + + explicit FindFullAncestry(Position pos, Position documentEnd, bool includeTypes = false); + + bool visit(AstType* type) override; + + bool visit(AstStatFunction* node) override; + + bool visit(AstNode* node) override; +}; + std::vector findAncestryAtPositionForAutocomplete(const SourceModule& source, Position pos); std::vector findAncestryAtPositionForAutocomplete(AstStatBlock* root, Position pos); std::vector findAstAncestryOfPosition(const SourceModule& source, Position pos, bool includeTypes = false); diff --git a/third_party/luau/Analysis/include/Luau/Autocomplete.h b/third_party/luau/Analysis/include/Luau/Autocomplete.h index 61832577..bc709c7f 100644 --- a/third_party/luau/Analysis/include/Luau/Autocomplete.h +++ b/third_party/luau/Analysis/include/Luau/Autocomplete.h @@ -38,6 +38,7 @@ enum class AutocompleteEntryKind String, Type, Module, + GeneratedFunction, }; enum class ParenthesesRecommendation @@ -70,6 +71,10 @@ struct AutocompleteEntry std::optional documentationSymbol = std::nullopt; Tags tags; ParenthesesRecommendation parens = ParenthesesRecommendation::None; + std::optional insertText; + + // Only meaningful if kind is Property. + bool indexedWithSelf = false; }; using AutocompleteEntryMap = std::unordered_map; @@ -94,4 +99,6 @@ using StringCompletionCallback = AutocompleteResult autocomplete(Frontend& frontend, const ModuleName& moduleName, Position position, StringCompletionCallback callback); +constexpr char kGeneratedAnonymousFunctionEntryName[] = "function (anonymous autofilled)"; + } // namespace Luau diff --git a/third_party/luau/Analysis/include/Luau/Breadcrumb.h b/third_party/luau/Analysis/include/Luau/Breadcrumb.h deleted file mode 100644 index 59b293a0..00000000 --- a/third_party/luau/Analysis/include/Luau/Breadcrumb.h +++ /dev/null @@ -1,75 +0,0 @@ -// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -#pragma once - -#include "Luau/Def.h" -#include "Luau/NotNull.h" -#include "Luau/Variant.h" - -#include -#include - -namespace Luau -{ - -using NullableBreadcrumbId = const struct Breadcrumb*; -using BreadcrumbId = NotNull; - -struct FieldMetadata -{ - std::string prop; -}; - -struct SubscriptMetadata -{ - BreadcrumbId key; -}; - -using Metadata = Variant; - -struct Breadcrumb -{ - NullableBreadcrumbId previous; - DefId def; - std::optional metadata; - std::vector children; -}; - -inline Breadcrumb* asMutable(NullableBreadcrumbId breadcrumb) -{ - LUAU_ASSERT(breadcrumb); - return const_cast(breadcrumb); -} - -template -const T* getMetadata(NullableBreadcrumbId breadcrumb) -{ - if (!breadcrumb || !breadcrumb->metadata) - return nullptr; - - return get_if(&*breadcrumb->metadata); -} - -struct BreadcrumbArena -{ - TypedAllocator allocator; - - template - BreadcrumbId add(NullableBreadcrumbId previous, DefId def, Args&&... args) - { - Breadcrumb* bc = allocator.allocate(Breadcrumb{previous, def, std::forward(args)...}); - if (previous) - asMutable(previous)->children.push_back(NotNull{bc}); - return NotNull{bc}; - } - - template - BreadcrumbId emplace(NullableBreadcrumbId previous, DefId def, Args&&... args) - { - Breadcrumb* bc = allocator.allocate(Breadcrumb{previous, def, Metadata{T{std::forward(args)...}}}); - if (previous) - asMutable(previous)->children.push_back(NotNull{bc}); - return NotNull{bc}; - } -}; - -} // namespace Luau diff --git a/third_party/luau/Analysis/include/Luau/BuiltinDefinitions.h b/third_party/luau/Analysis/include/Luau/BuiltinDefinitions.h index d4457638..94b0d87f 100644 --- a/third_party/luau/Analysis/include/Luau/BuiltinDefinitions.h +++ b/third_party/luau/Analysis/include/Luau/BuiltinDefinitions.h @@ -14,8 +14,6 @@ struct GlobalTypes; struct TypeChecker; struct TypeArena; -void registerBuiltinTypes(GlobalTypes& globals); - void registerBuiltinGlobals(Frontend& frontend, GlobalTypes& globals, bool typeCheckForAutocomplete = false); TypeId makeUnion(TypeArena& arena, std::vector&& types); TypeId makeIntersection(TypeArena& arena, std::vector&& types); @@ -27,19 +25,42 @@ TypeId makeOption(NotNull builtinTypes, TypeArena& arena, TypeId t /** Small utility function for building up type definitions from C++. */ TypeId makeFunction( // Monomorphic - TypeArena& arena, std::optional selfType, std::initializer_list paramTypes, std::initializer_list retTypes); + TypeArena& arena, + std::optional selfType, + std::initializer_list paramTypes, + std::initializer_list retTypes, + bool checked = false +); TypeId makeFunction( // Polymorphic - TypeArena& arena, std::optional selfType, std::initializer_list generics, std::initializer_list genericPacks, - std::initializer_list paramTypes, std::initializer_list retTypes); + TypeArena& arena, + std::optional selfType, + std::initializer_list generics, + std::initializer_list genericPacks, + std::initializer_list paramTypes, + std::initializer_list retTypes, + bool checked = false +); TypeId makeFunction( // Monomorphic - TypeArena& arena, std::optional selfType, std::initializer_list paramTypes, std::initializer_list paramNames, - std::initializer_list retTypes); + TypeArena& arena, + std::optional selfType, + std::initializer_list paramTypes, + std::initializer_list paramNames, + std::initializer_list retTypes, + bool checked = false +); TypeId makeFunction( // Polymorphic - TypeArena& arena, std::optional selfType, std::initializer_list generics, std::initializer_list genericPacks, - std::initializer_list paramTypes, std::initializer_list paramNames, std::initializer_list retTypes); + TypeArena& arena, + std::optional selfType, + std::initializer_list generics, + std::initializer_list genericPacks, + std::initializer_list paramTypes, + std::initializer_list paramNames, + std::initializer_list retTypes, + bool checked = false +); void attachMagicFunction(TypeId ty, MagicFunction fn); void attachDcrMagicFunction(TypeId ty, DcrMagicFunction fn); diff --git a/third_party/luau/Analysis/include/Luau/Cancellation.h b/third_party/luau/Analysis/include/Luau/Cancellation.h new file mode 100644 index 00000000..44131863 --- /dev/null +++ b/third_party/luau/Analysis/include/Luau/Cancellation.h @@ -0,0 +1,24 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include + +namespace Luau +{ + +struct FrontendCancellationToken +{ + void cancel() + { + cancelled.store(true); + } + + bool requested() + { + return cancelled.load(); + } + + std::atomic cancelled; +}; + +} // namespace Luau diff --git a/third_party/luau/Analysis/include/Luau/Clone.h b/third_party/luau/Analysis/include/Luau/Clone.h index b3cbe467..103b5bbd 100644 --- a/third_party/luau/Analysis/include/Luau/Clone.h +++ b/third_party/luau/Analysis/include/Luau/Clone.h @@ -16,10 +16,10 @@ using SeenTypePacks = std::unordered_map; struct CloneState { + NotNull builtinTypes; + SeenTypes seenTypes; SeenTypePacks seenTypePacks; - - int recursionCount = 0; }; TypePackId clone(TypePackId tp, TypeArena& dest, CloneState& cloneState); diff --git a/third_party/luau/Analysis/include/Luau/Constraint.h b/third_party/luau/Analysis/include/Luau/Constraint.h index 2223c29e..61253732 100644 --- a/third_party/luau/Analysis/include/Luau/Constraint.h +++ b/third_party/luau/Analysis/include/Luau/Constraint.h @@ -4,8 +4,8 @@ #include "Luau/Ast.h" // Used for some of the enumerations #include "Luau/DenseHash.h" #include "Luau/NotNull.h" -#include "Luau/Type.h" #include "Luau/Variant.h" +#include "Luau/TypeFwd.h" #include #include @@ -14,13 +14,15 @@ namespace Luau { +enum class ValueContext; struct Scope; -struct Type; -using TypeId = const Type*; - -struct TypePackVar; -using TypePackId = const TypePackVar*; +// if resultType is a freeType, assignmentType <: freeType <: resultType bounds +struct EqualityConstraint +{ + TypeId resultType; + TypeId assignmentType; +}; // subType <: superType struct SubtypeConstraint @@ -34,6 +36,11 @@ struct PackSubtypeConstraint { TypePackId subPack; TypePackId superPack; + + // HACK!! TODO clip. + // We need to know which of `PackSubtypeConstraint` are emitted from `AstStatReturn` vs any others. + // Then we force these specific `PackSubtypeConstraint` to only dispatch in the order of the `return`s. + bool returns = false; }; // generalizedType ~ gen sourceType @@ -41,46 +48,19 @@ struct GeneralizationConstraint { TypeId generalizedType; TypeId sourceType; -}; - -// subType ~ inst superType -struct InstantiationConstraint -{ - TypeId subType; - TypeId superType; -}; -struct UnaryConstraint -{ - AstExprUnary::Op op; - TypeId operandType; - TypeId resultType; + std::vector interiorTypes; }; -// let L : leftType -// let R : rightType -// in -// L op R : resultType -struct BinaryConstraint -{ - AstExprBinary::Op op; - TypeId leftType; - TypeId rightType; - TypeId resultType; - - // When we dispatch this constraint, we update the key at this map to record - // the overload that we selected. - const AstNode* astFragment; - DenseHashMap* astOriginalCallTypes; - DenseHashMap* astOverloadResolvedTypes; -}; - -// iteratee is iterable -// iterators is the iteration types. +// variables ~ iterate iterator +// Unpack the iterator, figure out what types it iterates over, and bind those types to variables. struct IterableConstraint { TypePackId iterator; - TypePackId variables; + std::vector variables; + + const AstNode* nextAstFragment; + DenseHashMap* astForInNextTypes; }; // name(namedType) = name @@ -105,22 +85,52 @@ struct FunctionCallConstraint TypeId fn; TypePackId argsPack; TypePackId result; - class AstExprCall* callSite; + class AstExprCall* callSite = nullptr; std::vector> discriminantTypes; + + // When we dispatch this constraint, we update the key at this map to record + // the overload that we selected. + DenseHashMap* astOverloadResolvedTypes = nullptr; +}; + +// function_check fn argsPack +// +// If fn is a function type and argsPack is a partially solved +// pack of arguments to be supplied to the function, propagate the argument +// types of fn into the types of argsPack. This is used to implement +// bidirectional inference of lambda arguments. +struct FunctionCheckConstraint +{ + TypeId fn; + TypePackId argsPack; + + class AstExprCall* callSite = nullptr; + NotNull> astTypes; + NotNull> astExpectedTypes; }; -// result ~ prim ExpectedType SomeSingletonType MultitonType +// prim FreeType ExpectedType PrimitiveType // -// If ExpectedType is potentially a singleton (an actual singleton or a union -// that contains a singleton), then result ~ SomeSingletonType +// FreeType is bounded below by the singleton type and above by PrimitiveType +// initially. When this constraint is resolved, it will check that the bounds +// of the free type are well-formed by subtyping. // -// else result ~ MultitonType +// If they are not well-formed, then FreeType is replaced by its lower bound +// +// If they are well-formed and ExpectedType is potentially a singleton (an +// actual singleton or a union that contains a singleton), +// then FreeType is replaced by its lower bound +// +// else FreeType is replaced by PrimitiveType struct PrimitiveTypeConstraint { - TypeId resultType; - TypeId expectedType; - TypeId singletonType; - TypeId multitonType; + TypeId freeType; + + // potentially gets used to force the lower bound? + std::optional expectedType; + + // the primitive type to check against + TypeId primitiveType; }; // result ~ hasProp type "prop_name" @@ -139,63 +149,131 @@ struct HasPropConstraint TypeId resultType; TypeId subjectType; std::string prop; + ValueContext context; + + // We want to track if this `HasPropConstraint` comes from a conditional. + // If it does, we're going to change the behavior of property look-up a bit. + // In particular, we're going to return `unknownType` for property lookups + // on `table` or inexact table types where the property is not present. + // + // This allows us to refine table types to have additional properties + // without reporting errors in typechecking on the property tests. + bool inConditional = false; + + // HACK: We presently need types like true|false or string|"hello" when + // deciding whether a particular literal expression should have a singleton + // type. This boolean is set to true when extracting the property type of a + // value that may be a union of tables. + // + // For example, in the following code fragment, we want the lookup of the + // success property to yield true|false when extracting an expectedType in + // this expression: + // + // type Result = {success:true, result: T} | {success:false, error: E} + // + // local r: Result = {success=true, result=9} + // + // If we naively simplify the expectedType to boolean, we will erroneously + // compute the type boolean for the success property of the table literal. + // This causes type checking to fail. + bool suppressSimplification = false; }; -// result ~ setProp subjectType ["prop", "prop2", ...] propType -// -// If the subject is a table or table-like thing that already has the named -// property chain, we unify propType with that existing property type. +// resultType ~ hasIndexer subjectType indexType // -// If the subject is a free table, we augment it in place. +// If the subject type is a table or table-like thing that supports indexing, +// populate the type result with the result type of such an index operation. // -// If the subject is an unsealed table, result is an augmented table that -// includes that new prop. -struct SetPropConstraint +// If the subject is not indexable, resultType is bound to errorType. +struct HasIndexerConstraint { TypeId resultType; TypeId subjectType; - std::vector path; - TypeId propType; + TypeId indexType; }; -// result ~ setIndexer subjectType indexType propType +// assignProp lhsType propName rhsType // -// If the subject is a table or table-like thing that already has an indexer, -// unify its indexType and propType with those from this constraint. -// -// If the table is a free or unsealed table, we augment it with a new indexer. -struct SetIndexerConstraint +// Assign a value of type rhsType into the named property of lhsType. + +struct AssignPropConstraint { - TypeId resultType; - TypeId subjectType; - TypeId indexType; + TypeId lhsType; + std::string propName; + TypeId rhsType; + + /// If a new property is to be inserted into a table type, it will be + /// ascribed this location. + std::optional propLocation; + + /// The canonical write type of the property. It is _solely_ used to + /// populate astTypes during constraint resolution. Nothing should ever + /// block on it. TypeId propType; + + // When we generate constraints, we increment the remaining prop count on + // the table if we are able. This flag informs the solver as to whether or + // not it should in turn decrement the prop count when this constraint is + // dispatched. + bool decrementPropCount = false; }; -// if negation: -// result ~ if isSingleton D then ~D else unknown where D = discriminantType -// if not negation: -// result ~ if isSingleton D then D else unknown where D = discriminantType -struct SingletonOrTopTypeConstraint +struct AssignIndexConstraint { - TypeId resultType; - TypeId discriminantType; - bool negated; + TypeId lhsType; + TypeId indexType; + TypeId rhsType; + + /// The canonical write type of the property. It is _solely_ used to + /// populate astTypes during constraint resolution. Nothing should ever + /// block on it. + TypeId propType; }; -// resultType ~ unpack sourceTypePack +// resultTypes ~ unpack sourceTypePack // // Similar to PackSubtypeConstraint, but with one important difference: If the // sourcePack is blocked, this constraint blocks. struct UnpackConstraint { - TypePackId resultPack; + std::vector resultPack; TypePackId sourcePack; }; -using ConstraintV = Variant; +// ty ~ reduce ty +// +// Try to reduce ty, if it is a TypeFunctionInstanceType. Otherwise, do nothing. +struct ReduceConstraint +{ + TypeId ty; +}; + +// tp ~ reduce tp +// +// Analogous to ReduceConstraint, but for type packs. +struct ReducePackConstraint +{ + TypePackId tp; +}; + +using ConstraintV = Variant< + SubtypeConstraint, + PackSubtypeConstraint, + GeneralizationConstraint, + IterableConstraint, + NameConstraint, + TypeAliasExpansionConstraint, + FunctionCallConstraint, + FunctionCheckConstraint, + PrimitiveTypeConstraint, + HasPropConstraint, + HasIndexerConstraint, + AssignPropConstraint, + AssignIndexConstraint, + UnpackConstraint, + ReduceConstraint, + ReducePackConstraint, + EqualityConstraint>; struct Constraint { @@ -209,10 +287,14 @@ struct Constraint ConstraintV c; std::vector> dependencies; + + DenseHashSet getMaybeMutatedFreeTypes() const; }; using ConstraintPtr = std::unique_ptr; +bool isReferenceCountedType(const TypeId typ); + inline Constraint& asMutable(const Constraint& c) { return const_cast(c); diff --git a/third_party/luau/Analysis/include/Luau/ConstraintGraphBuilder.h b/third_party/luau/Analysis/include/Luau/ConstraintGenerator.h similarity index 66% rename from third_party/luau/Analysis/include/Luau/ConstraintGraphBuilder.h rename to third_party/luau/Analysis/include/Luau/ConstraintGenerator.h index 5800d146..eb6e18eb 100644 --- a/third_party/luau/Analysis/include/Luau/ConstraintGraphBuilder.h +++ b/third_party/luau/Analysis/include/Luau/ConstraintGenerator.h @@ -5,14 +5,17 @@ #include "Luau/Constraint.h" #include "Luau/ControlFlow.h" #include "Luau/DataFlowGraph.h" +#include "Luau/InsertionOrderedMap.h" #include "Luau/Module.h" #include "Luau/ModuleResolver.h" +#include "Luau/Normalize.h" #include "Luau/NotNull.h" #include "Luau/Refinement.h" #include "Luau/Symbol.h" -#include "Luau/Type.h" +#include "Luau/TypeFwd.h" #include "Luau/TypeUtils.h" #include "Luau/Variant.h" +#include "Luau/Normalize.h" #include #include @@ -54,7 +57,7 @@ struct InferencePack } }; -struct ConstraintGraphBuilder +struct ConstraintGenerator { // A list of all the scopes in the module. This vector holds ownership of the // scope pointers; the scopes themselves borrow pointers to other scopes to @@ -65,9 +68,26 @@ struct ConstraintGraphBuilder NotNull builtinTypes; const NotNull arena; // The root scope of the module we're generating constraints for. - // This is null when the CGB is initially constructed. + // This is null when the CG is initially constructed. Scope* rootScope; + TypeContext typeContext = TypeContext::Default; + + struct InferredBinding + { + Scope* scope; + Location location; + TypeIds types; + }; + + // Some locals have multiple type states. We wish for Scope::bindings to + // map each local name onto the union of every type that the local can have + // over its lifetime, so we use this map to accumulate the set of types it + // might have. + // + // See the functions recordInferredBinding and fillInInferredBindings. + DenseHashMap inferredBindings{{}}; + // Constraints that go straight to the solver. std::vector constraints; @@ -86,6 +106,8 @@ struct ConstraintGraphBuilder // It is pretty uncommon for constraint generation to itself produce errors, but it can happen. std::vector errors; + // Needed to be able to enable error-suppression preservation for immediate refinements. + NotNull normalizer; // Needed to resolve modules to make 'require' import types properly. NotNull moduleResolver; // Occasionally constraint generation needs to produce an ICE. @@ -94,12 +116,34 @@ struct ConstraintGraphBuilder ScopePtr globalScope; std::function prepareModuleScope; + std::vector requireCycles; + + DenseHashMap localTypes{nullptr}; DcrLogger* logger; - ConstraintGraphBuilder(ModulePtr module, TypeArena* arena, NotNull moduleResolver, NotNull builtinTypes, - NotNull ice, const ScopePtr& globalScope, std::function prepareModuleScope, - DcrLogger* logger, NotNull dfg); + ConstraintGenerator( + ModulePtr module, + NotNull normalizer, + NotNull moduleResolver, + NotNull builtinTypes, + NotNull ice, + const ScopePtr& globalScope, + std::function prepareModuleScope, + DcrLogger* logger, + NotNull dfg, + std::vector requireCycles + ); + + /** + * The entry point to the ConstraintGenerator. This will construct a set + * of scopes, constraints, and free types that can be solved later. + * @param block the root block to generate constraints for. + */ + void visitModuleRoot(AstStatBlock* block); + +private: + std::vector> interiorTypes; /** * Fabricates a new free type belonging to a given scope. @@ -113,6 +157,18 @@ struct ConstraintGraphBuilder */ TypePackId freshTypePack(const ScopePtr& scope); + /** + * Allocate a new TypePack with the given head and tail. + * + * Avoids allocating 0-length type packs: + * + * If the head is non-empty, allocate and return a type pack with the given + * head and tail. + * If the head is empty and tail is non-empty, return *tail. + * If both the head and tail are empty, return an empty type pack. + */ + TypePackId addTypePack(std::vector head, std::optional tail); + /** * Fabricates a scope that is a child of another scope. * @param node the lexical node that the scope belongs to. @@ -120,6 +176,8 @@ struct ConstraintGraphBuilder */ ScopePtr childScope(AstNode* node, const ScopePtr& parent); + std::optional lookup(const ScopePtr& scope, Location location, DefId def, bool prototype = true); + /** * Adds a new constraint with no dependencies to a given scope. * @param scope the scope to add the constraint to. @@ -136,14 +194,34 @@ struct ConstraintGraphBuilder */ NotNull addConstraint(const ScopePtr& scope, std::unique_ptr c); - void applyRefinements(const ScopePtr& scope, Location location, RefinementId refinement); + struct RefinementPartition + { + // Types that we want to intersect against the type of the expression. + std::vector discriminantTypes; - /** - * The entry point to the ConstraintGraphBuilder. This will construct a set - * of scopes, constraints, and free types that can be solved later. - * @param block the root block to generate constraints for. - */ - void visit(AstStatBlock* block); + // Sometimes the type we're discriminating against is implicitly nil. + bool shouldAppendNilType = false; + }; + + using RefinementContext = InsertionOrderedMap; + void unionRefinements( + const ScopePtr& scope, + Location location, + const RefinementContext& lhs, + const RefinementContext& rhs, + RefinementContext& dest, + std::vector* constraints + ); + void computeRefinement( + const ScopePtr& scope, + Location location, + RefinementId refinement, + RefinementContext* refis, + bool sense, + bool eq, + std::vector* constraints + ); + void applyRefinements(const ScopePtr& scope, Location location, RefinementId refinement); ControlFlow visitBlockWithoutChildScope(const ScopePtr& scope, AstStatBlock* block); @@ -161,13 +239,19 @@ struct ConstraintGraphBuilder ControlFlow visit(const ScopePtr& scope, AstStatCompoundAssign* assign); ControlFlow visit(const ScopePtr& scope, AstStatIf* ifStatement); ControlFlow visit(const ScopePtr& scope, AstStatTypeAlias* alias); + ControlFlow visit(const ScopePtr& scope, AstStatTypeFunction* function); ControlFlow visit(const ScopePtr& scope, AstStatDeclareGlobal* declareGlobal); ControlFlow visit(const ScopePtr& scope, AstStatDeclareClass* declareClass); ControlFlow visit(const ScopePtr& scope, AstStatDeclareFunction* declareFunction); ControlFlow visit(const ScopePtr& scope, AstStatError* error); InferencePack checkPack(const ScopePtr& scope, AstArray exprs, const std::vector>& expectedTypes = {}); - InferencePack checkPack(const ScopePtr& scope, AstExpr* expr, const std::vector>& expectedTypes = {}); + InferencePack checkPack( + const ScopePtr& scope, + AstExpr* expr, + const std::vector>& expectedTypes = {}, + bool generalize = true + ); InferencePack checkPack(const ScopePtr& scope, AstExprCall* call); @@ -177,17 +261,25 @@ struct ConstraintGraphBuilder * @param expr the expression to check. * @param expectedType the type of the expression that is expected from its * surrounding context. Used to implement bidirectional type checking. + * @param generalize If true, generalize any lambdas that are encountered. * @return the type of the expression. */ - Inference check(const ScopePtr& scope, AstExpr* expr, ValueContext context = ValueContext::RValue, std::optional expectedType = {}, - bool forceSingleton = false); + Inference check( + const ScopePtr& scope, + AstExpr* expr, + std::optional expectedType = {}, + bool forceSingleton = false, + bool generalize = true + ); Inference check(const ScopePtr& scope, AstExprConstantString* string, std::optional expectedType, bool forceSingleton); Inference check(const ScopePtr& scope, AstExprConstantBool* bool_, std::optional expectedType, bool forceSingleton); - Inference check(const ScopePtr& scope, AstExprLocal* local, ValueContext context); + Inference check(const ScopePtr& scope, AstExprLocal* local); Inference check(const ScopePtr& scope, AstExprGlobal* global); + Inference checkIndexName(const ScopePtr& scope, const RefinementKey* key, AstExpr* indexee, const std::string& index, Location indexLocation); Inference check(const ScopePtr& scope, AstExprIndexName* indexName); Inference check(const ScopePtr& scope, AstExprIndexExpr* indexExpr); + Inference check(const ScopePtr& scope, AstExprFunction* func, std::optional expectedType, bool generalize); Inference check(const ScopePtr& scope, AstExprUnary* unary); Inference check(const ScopePtr& scope, AstExprBinary* binary, std::optional expectedType); Inference check(const ScopePtr& scope, AstExprIfElse* ifElse, std::optional expectedType); @@ -196,9 +288,11 @@ struct ConstraintGraphBuilder Inference check(const ScopePtr& scope, AstExprTable* expr, std::optional expectedType); std::tuple checkBinary(const ScopePtr& scope, AstExprBinary* binary, std::optional expectedType); - std::vector checkLValues(const ScopePtr& scope, AstArray exprs); - - TypeId checkLValue(const ScopePtr& scope, AstExpr* expr); + void visitLValue(const ScopePtr& scope, AstExpr* expr, TypeId rhsType); + void visitLValue(const ScopePtr& scope, AstExprLocal* local, TypeId rhsType); + void visitLValue(const ScopePtr& scope, AstExprGlobal* global, TypeId rhsType); + void visitLValue(const ScopePtr& scope, AstExprIndexName* indexName, TypeId rhsType); + void visitLValue(const ScopePtr& scope, AstExprIndexExpr* indexExpr, TypeId rhsType); struct FunctionSignature { @@ -213,7 +307,12 @@ struct ConstraintGraphBuilder ScopePtr bodyScope; }; - FunctionSignature checkFunctionSignature(const ScopePtr& parent, AstExprFunction* fn, std::optional expectedType = {}); + FunctionSignature checkFunctionSignature( + const ScopePtr& parent, + AstExprFunction* fn, + std::optional expectedType = {}, + std::optional originalName = {} + ); /** * Checks the body of a function expression. @@ -260,7 +359,11 @@ struct ConstraintGraphBuilder * privateTypeBindings map. **/ std::vector> createGenerics( - const ScopePtr& scope, AstArray generics, bool useCache = false, bool addTypes = true); + const ScopePtr& scope, + AstArray generics, + bool useCache = false, + bool addTypes = true + ); /** * Creates generic type packs given a list of AST definitions, resolving @@ -273,26 +376,51 @@ struct ConstraintGraphBuilder * privateTypePackBindings map. **/ std::vector> createGenericPacks( - const ScopePtr& scope, AstArray packs, bool useCache = false, bool addTypes = true); + const ScopePtr& scope, + AstArray packs, + bool useCache = false, + bool addTypes = true + ); Inference flattenPack(const ScopePtr& scope, Location location, InferencePack pack); void reportError(Location location, TypeErrorData err); void reportCodeTooComplex(Location location); + // make a union type function of these two types + TypeId makeUnion(const ScopePtr& scope, Location location, TypeId lhs, TypeId rhs); + // make an intersect type function of these two types + TypeId makeIntersect(const ScopePtr& scope, Location location, TypeId lhs, TypeId rhs); + /** Scan the program for global definitions. * - * ConstraintGraphBuilder needs to differentiate between globals and accesses to undefined symbols. Doing this "for + * ConstraintGenerator needs to differentiate between globals and accesses to undefined symbols. Doing this "for * real" in a general way is going to be pretty hard, so we are choosing not to tackle that yet. For now, we do an * initial scan of the AST and note what globals are defined. */ void prepopulateGlobalScope(const ScopePtr& globalScope, AstStatBlock* program); + bool recordPropertyAssignment(TypeId ty); + + // Record the fact that a particular local has a particular type in at least + // one of its states. + void recordInferredBinding(AstLocal* local, TypeId ty); + + void fillInInferredBindings(const ScopePtr& globalScope, AstStatBlock* block); + /** Given a function type annotation, return a vector describing the expected types of the calls to the function * For example, calling a function with annotation ((number) -> string & ((string) -> number)) * yields a vector of size 1, with value: [number | string] */ std::vector> getExpectedCallTypesForFunctionOverloads(const TypeId fnType); + + TypeId createTypeFunctionInstance( + const TypeFunction& function, + std::vector typeArguments, + std::vector packArguments, + const ScopePtr& scope, + Location location + ); }; /** Borrow a vector of pointers from a vector of owning pointers to constraints. diff --git a/third_party/luau/Analysis/include/Luau/ConstraintSolver.h b/third_party/luau/Analysis/include/Luau/ConstraintSolver.h index 6888e99c..c6b4a828 100644 --- a/third_party/luau/Analysis/include/Luau/ConstraintSolver.h +++ b/third_party/luau/Analysis/include/Luau/ConstraintSolver.h @@ -3,21 +3,30 @@ #pragma once #include "Luau/Constraint.h" +#include "Luau/DenseHash.h" #include "Luau/Error.h" +#include "Luau/Location.h" #include "Luau/Module.h" #include "Luau/Normalize.h" +#include "Luau/Substitution.h" #include "Luau/ToString.h" #include "Luau/Type.h" -#include "Luau/TypeReduction.h" +#include "Luau/TypeCheckLimits.h" +#include "Luau/TypeFwd.h" #include "Luau/Variant.h" +#include #include namespace Luau { +enum class ValueContext; + struct DcrLogger; +class AstExpr; + // TypeId, TypePackId, or Constraint*. It is impossible to know which, but we // never dereference this pointer. using BlockedConstraintId = Variant; @@ -71,9 +80,23 @@ struct ConstraintSolver // anything. std::unordered_map, size_t> blockedConstraints; // A mapping of type/pack pointers to the constraints they block. - std::unordered_map>, HashBlockedConstraintId> blocked; + std::unordered_map, HashBlockedConstraintId> blocked; // Memoized instantiations of type aliases. DenseHashMap instantiatedAliases{{}}; + // Breadcrumbs for where a free type's upper bound was expanded. We use + // these to provide more helpful error messages when a free type is solved + // as never unexpectedly. + DenseHashMap>> upperBoundContributors{nullptr}; + + // A mapping from free types to the number of unresolved constraints that mention them. + DenseHashMap unresolvedConstraints{{}}; + + // Irreducible/uninhabited type functions or type pack functions. + DenseHashSet uninhabitedTypeFunctions{{}}; + + // The set of types that will definitely be unchanged by generalization. + DenseHashSet generalizedTypes_{nullptr}; + const NotNull> generalizedTypes{&generalizedTypes_}; // Recorded errors that take place within the solver. ErrorVec errors; @@ -82,9 +105,20 @@ struct ConstraintSolver std::vector requireCycles; DcrLogger* logger; + TypeCheckLimits limits; + + DenseHashMap typeFunctionsToFinalize{nullptr}; - explicit ConstraintSolver(NotNull normalizer, NotNull rootScope, std::vector> constraints, - ModuleName moduleName, NotNull moduleResolver, std::vector requireCycles, DcrLogger* logger); + explicit ConstraintSolver( + NotNull normalizer, + NotNull rootScope, + std::vector> constraints, + ModuleName moduleName, + NotNull moduleResolver, + std::vector requireCycles, + DcrLogger* logger, + TypeCheckLimits limits + ); // Randomize the order in which to dispatch constraints void randomize(unsigned seed); @@ -95,10 +129,35 @@ struct ConstraintSolver **/ void run(); + + /** + * Attempts to perform one final reduction on type functions after every constraint has been completed + * + **/ + void finalizeTypeFunctions(); + bool isDone(); - void finalizeModule(); +private: + /** + * Bind a type variable to another type. + * + * A constraint is required and will validate that blockedTy is owned by this + * constraint. This prevents one constraint from interfering with another's + * blocked types. + * + * Bind will also unblock the type variable for you. + */ + void bind(NotNull constraint, TypeId ty, TypeId boundTo); + void bind(NotNull constraint, TypePackId tp, TypePackId boundTo); + + template + void emplace(NotNull constraint, TypeId ty, Args&&... args); + template + void emplace(NotNull constraint, TypePackId tp, Args&&... args); + +public: /** Attempt to dispatch a constraint. Returns true if it was successful. If * tryDispatch() returns false, the constraint remains in the unsolved set * and will be retried later. @@ -108,30 +167,71 @@ struct ConstraintSolver bool tryDispatch(const SubtypeConstraint& c, NotNull constraint, bool force); bool tryDispatch(const PackSubtypeConstraint& c, NotNull constraint, bool force); bool tryDispatch(const GeneralizationConstraint& c, NotNull constraint, bool force); - bool tryDispatch(const InstantiationConstraint& c, NotNull constraint, bool force); - bool tryDispatch(const UnaryConstraint& c, NotNull constraint, bool force); - bool tryDispatch(const BinaryConstraint& c, NotNull constraint, bool force); bool tryDispatch(const IterableConstraint& c, NotNull constraint, bool force); bool tryDispatch(const NameConstraint& c, NotNull constraint); bool tryDispatch(const TypeAliasExpansionConstraint& c, NotNull constraint); bool tryDispatch(const FunctionCallConstraint& c, NotNull constraint); + bool tryDispatch(const FunctionCheckConstraint& c, NotNull constraint); bool tryDispatch(const PrimitiveTypeConstraint& c, NotNull constraint); bool tryDispatch(const HasPropConstraint& c, NotNull constraint); - bool tryDispatch(const SetPropConstraint& c, NotNull constraint, bool force); - bool tryDispatch(const SetIndexerConstraint& c, NotNull constraint, bool force); - bool tryDispatch(const SingletonOrTopTypeConstraint& c, NotNull constraint); + + + bool tryDispatchHasIndexer( + int& recursionDepth, + NotNull constraint, + TypeId subjectType, + TypeId indexType, + TypeId resultType, + Set& seen + ); + bool tryDispatch(const HasIndexerConstraint& c, NotNull constraint); + + bool tryDispatch(const AssignPropConstraint& c, NotNull constraint); + bool tryDispatch(const AssignIndexConstraint& c, NotNull constraint); bool tryDispatch(const UnpackConstraint& c, NotNull constraint); + bool tryDispatch(const ReduceConstraint& c, NotNull constraint, bool force); + bool tryDispatch(const ReducePackConstraint& c, NotNull constraint, bool force); + bool tryDispatch(const EqualityConstraint& c, NotNull constraint, bool force); // for a, ... in some_table do // also handles __iter metamethod bool tryDispatchIterableTable(TypeId iteratorTy, const IterableConstraint& c, NotNull constraint, bool force); // for a, ... in next_function, t, ... do - bool tryDispatchIterableFunction( - TypeId nextTy, TypeId tableTy, TypeId firstIndexTy, const IterableConstraint& c, NotNull constraint, bool force); + bool tryDispatchIterableFunction(TypeId nextTy, TypeId tableTy, const IterableConstraint& c, NotNull constraint, bool force); + + std::pair, std::optional> lookupTableProp( + NotNull constraint, + TypeId subjectType, + const std::string& propName, + ValueContext context, + bool inConditional = false, + bool suppressSimplification = false + ); + std::pair, std::optional> lookupTableProp( + NotNull constraint, + TypeId subjectType, + const std::string& propName, + ValueContext context, + bool inConditional, + bool suppressSimplification, + DenseHashSet& seen + ); - std::pair, std::optional> lookupTableProp(TypeId subjectType, const std::string& propName); - std::pair, std::optional> lookupTableProp(TypeId subjectType, const std::string& propName, std::unordered_set& seen); + /** + * Generate constraints to unpack the types of srcTypes and assign each + * value to the corresponding BlockedType in destTypes. + * + * This function also overwrites the owners of each BlockedType. This is + * okay because this function is only used to decompose IterableConstraint + * into an UnpackConstraint. + * + * @param destTypes A vector of types comprised of BlockedTypes. + * @param srcTypes A TypePack that represents rvalues to be assigned. + * @returns The underlying UnpackConstraint. There's a bit of code in + * iteration that needs to pass blocks on to this constraint. + */ + NotNull unpackAndAssign(const std::vector destTypes, TypePackId srcTypes, NotNull constraint); void block(NotNull target, NotNull constraint); /** @@ -141,6 +241,16 @@ struct ConstraintSolver bool block(TypeId target, NotNull constraint); bool block(TypePackId target, NotNull constraint); + // Block on every target + template + bool block(const T& targets, NotNull constraint) + { + for (TypeId target : targets) + block(target, constraint); + + return false; + } + /** * For all constraints that are blocked on one constraint, make them block * on a new constraint. @@ -149,21 +259,21 @@ struct ConstraintSolver */ void inheritBlocks(NotNull source, NotNull addition); - // Traverse the type. If any blocked or pending types are found, block - // the constraint on them. + // Traverse the type. If any pending types are found, block the constraint + // on them. // // Returns false if a type blocks the constraint. // // FIXME: This use of a boolean for the return result is an appalling // interface. - bool recursiveBlock(TypeId target, NotNull constraint); - bool recursiveBlock(TypePackId target, NotNull constraint); + bool blockOnPendingTypes(TypeId target, NotNull constraint); + bool blockOnPendingTypes(TypePackId target, NotNull constraint); void unblock(NotNull progressed); - void unblock(TypeId progressed); - void unblock(TypePackId progressed); - void unblock(const std::vector& types); - void unblock(const std::vector& packs); + void unblock(TypeId progressed, Location location); + void unblock(TypePackId progressed, Location location); + void unblock(const std::vector& types, Location location); + void unblock(const std::vector& packs, Location location); /** * @returns true if the TypeId is in a blocked state. @@ -181,22 +291,6 @@ struct ConstraintSolver */ bool isBlocked(NotNull constraint); - /** - * Creates a new Unifier and performs a single unification operation. Commits - * the result. - * @param subType the sub-type to unify. - * @param superType the super-type to unify. - */ - void unify(TypeId subType, TypeId superType, NotNull scope); - - /** - * Creates a new Unifier and performs a single unification operation. Commits - * the result. - * @param subPack the sub-type pack to unify. - * @param superPack the super-type pack to unify. - */ - void unify(TypePackId subPack, TypePackId superPack, NotNull scope); - /** Pushes a new solver constraint to the solver. * @param cv the body of the constraint. **/ @@ -216,20 +310,47 @@ struct ConstraintSolver void reportError(TypeErrorData&& data, const Location& location); void reportError(TypeError e); -private: + /** + * Shifts the count of references from `source` to `target`. This should be paired + * with any instance of binding a free type in order to maintain accurate refcounts. + * If `target` is not a free type, this is a noop. + * @param source the free type which is being bound + * @param target the type which the free type is being bound to + */ + void shiftReferences(TypeId source, TypeId target); - /** Helper used by tryDispatch(SubtypeConstraint) and - * tryDispatch(PackSubtypeConstraint) - * - * Attempts to unify subTy with superTy. If doing so would require unifying + /** + * Generalizes the given free type if the reference counting allows it. + * @param the scope to generalize in + * @param type the free type we want to generalize + * @returns a non-free type that generalizes the argument, or `std::nullopt` if one + * does not exist + */ + std::optional generalizeFreeType(NotNull scope, TypeId type, bool avoidSealingTables = false); + + /** + * Checks the existing set of constraints to see if there exist any that contain + * the provided free type, indicating that it is not yet ready to be replaced by + * one of its bounds. + * @param ty the free type that to check for related constraints + * @returns whether or not it is unsafe to replace the free type by one of its bounds + */ + bool hasUnresolvedConstraints(TypeId ty); + + /** Attempts to unify subTy with superTy. If doing so would require unifying * BlockedTypes, fail and block the constraint on those BlockedTypes. * + * Note: TID can only be TypeId or TypePackId. + * * If unification fails, replace all free types with errorType. * * If unification succeeds, unblock every type changed by the unification. + * + * @returns true if the unification succeeded. False if the unification was + * too complex. */ - template - bool tryUnify(NotNull constraint, TID subTy, TID superTy); + template + bool unify(NotNull constraint, TID subTy, TID superTy); /** * Marks a constraint as being blocked on a type or type pack. The constraint @@ -238,7 +359,7 @@ struct ConstraintSolver * @param target the type or type pack pointer that the constraint is blocked on. * @param constraint the constraint to block. **/ - void block_(BlockedConstraintId target, NotNull constraint); + bool block_(BlockedConstraintId target, NotNull constraint); /** * Informs the solver that progress has been made on a type or type pack. The @@ -248,10 +369,20 @@ struct ConstraintSolver **/ void unblock_(BlockedConstraintId progressed); + /** + * Reproduces any constraints necessary for new types that are copied when applying a substitution. + * At the time of writing, this pertains only to type functions. + * @param subst the substitution that was applied + **/ + void reproduceConstraints(NotNull scope, const Location& location, const Substitution& subst); + TypeId errorRecoveryType() const; TypePackId errorRecoveryTypePack() const; - TypeId unionOfTypes(TypeId a, TypeId b, NotNull scope, bool unifyFreeTypes); + TypePackId anyifyModuleReturnTypePackGenerics(TypePackId tp); + + void throwTimeLimitError(); + void throwUserCancelError(); ToStringOptions opts; }; diff --git a/third_party/luau/Analysis/include/Luau/ControlFlow.h b/third_party/luau/Analysis/include/Luau/ControlFlow.h index 566d77bd..82c0403c 100644 --- a/third_party/luau/Analysis/include/Luau/ControlFlow.h +++ b/third_party/luau/Analysis/include/Luau/ControlFlow.h @@ -14,8 +14,8 @@ enum class ControlFlow None = 0b00001, Returns = 0b00010, Throws = 0b00100, - Break = 0b01000, // Currently unused. - Continue = 0b10000, // Currently unused. + Breaks = 0b01000, + Continues = 0b10000, }; inline ControlFlow operator&(ControlFlow a, ControlFlow b) diff --git a/third_party/luau/Analysis/include/Luau/DataFlowGraph.h b/third_party/luau/Analysis/include/Luau/DataFlowGraph.h index ce4ecb04..2a894bc9 100644 --- a/third_party/luau/Analysis/include/Luau/DataFlowGraph.h +++ b/third_party/luau/Analysis/include/Luau/DataFlowGraph.h @@ -3,29 +3,47 @@ // Do not include LValue. It should never be used here. #include "Luau/Ast.h" -#include "Luau/Breadcrumb.h" +#include "Luau/ControlFlow.h" #include "Luau/DenseHash.h" #include "Luau/Def.h" #include "Luau/Symbol.h" +#include "Luau/TypedAllocator.h" #include namespace Luau { +struct RefinementKey +{ + const RefinementKey* parent = nullptr; + DefId def; + std::optional propName; +}; + +struct RefinementKeyArena +{ + TypedAllocator allocator; + + const RefinementKey* leaf(DefId def); + const RefinementKey* node(const RefinementKey* parent, DefId def, const std::string& propName); +}; + struct DataFlowGraph { DataFlowGraph(DataFlowGraph&&) = default; DataFlowGraph& operator=(DataFlowGraph&&) = default; - NullableBreadcrumbId getBreadcrumb(const AstExpr* expr) const; + DefId getDef(const AstExpr* expr) const; + // Look up for the rvalue def for a compound assignment. + std::optional getRValueDefForCompoundAssign(const AstExpr* expr) const; - BreadcrumbId getBreadcrumb(const AstLocal* local) const; - BreadcrumbId getBreadcrumb(const AstExprLocal* local) const; - BreadcrumbId getBreadcrumb(const AstExprGlobal* global) const; + DefId getDef(const AstLocal* local) const; - BreadcrumbId getBreadcrumb(const AstStatDeclareGlobal* global) const; - BreadcrumbId getBreadcrumb(const AstStatDeclareFunction* func) const; + DefId getDef(const AstStatDeclareGlobal* global) const; + DefId getDef(const AstStatDeclareFunction* func) const; + + const RefinementKey* getRefinementKey(const AstExpr* expr) const; private: DataFlowGraph() = default; @@ -33,33 +51,60 @@ struct DataFlowGraph DataFlowGraph(const DataFlowGraph&) = delete; DataFlowGraph& operator=(const DataFlowGraph&) = delete; - DefArena defs; - BreadcrumbArena breadcrumbs; + DefArena defArena; + RefinementKeyArena keyArena; - DenseHashMap astBreadcrumbs{nullptr}; + DenseHashMap astDefs{nullptr}; // Sometimes we don't have the AstExprLocal* but we have AstLocal*, and sometimes we need to extract that DefId. - DenseHashMap localBreadcrumbs{nullptr}; + DenseHashMap localDefs{nullptr}; // There's no AstStatDeclaration, and it feels useless to introduce it just to enforce an invariant in one place. // All keys in this maps are really only statements that ambiently declares a symbol. - DenseHashMap declaredBreadcrumbs{nullptr}; + DenseHashMap declaredDefs{nullptr}; + + // Compound assignments are in a weird situation where the local being assigned to is also being used at its + // previous type implicitly in an rvalue position. This map provides the previous binding. + DenseHashMap compoundAssignDefs{nullptr}; + + DenseHashMap astRefinementKeys{nullptr}; friend struct DataFlowGraphBuilder; }; struct DfgScope { + enum ScopeType + { + Linear, + Loop, + Function, + }; + DfgScope* parent; - DenseHashMap bindings{Symbol{}}; - DenseHashMap> props{nullptr}; + ScopeType scopeType; + + using Bindings = DenseHashMap; + using Props = DenseHashMap>; + + Bindings bindings{Symbol{}}; + Props props{nullptr}; + + std::optional lookup(Symbol symbol) const; + std::optional lookup(DefId def, const std::string& key) const; - NullableBreadcrumbId lookup(Symbol symbol) const; - NullableBreadcrumbId lookup(DefId def, const std::string& key) const; + void inherit(const DfgScope* childScope); + + bool canUpdateDefinition(Symbol symbol) const; + bool canUpdateDefinition(DefId def, const std::string& key) const; +}; + +struct DataFlowResult +{ + DefId def; + const RefinementKey* parent = nullptr; }; -// Currently unsound. We do not presently track the control flow of the program. -// Additionally, we do not presently track assignments. struct DataFlowGraphBuilder { static DataFlowGraph build(AstStatBlock* root, NotNull handle); @@ -71,61 +116,80 @@ struct DataFlowGraphBuilder DataFlowGraphBuilder& operator=(const DataFlowGraphBuilder&) = delete; DataFlowGraph graph; - NotNull defs{&graph.defs}; - NotNull breadcrumbs{&graph.breadcrumbs}; + NotNull defArena{&graph.defArena}; + NotNull keyArena{&graph.keyArena}; struct InternalErrorReporter* handle = nullptr; DfgScope* moduleScope = nullptr; std::vector> scopes; - DfgScope* childScope(DfgScope* scope); - - void visit(DfgScope* scope, AstStatBlock* b); - void visitBlockWithoutChildScope(DfgScope* scope, AstStatBlock* b); - - void visit(DfgScope* scope, AstStat* s); - void visit(DfgScope* scope, AstStatIf* i); - void visit(DfgScope* scope, AstStatWhile* w); - void visit(DfgScope* scope, AstStatRepeat* r); - void visit(DfgScope* scope, AstStatBreak* b); - void visit(DfgScope* scope, AstStatContinue* c); - void visit(DfgScope* scope, AstStatReturn* r); - void visit(DfgScope* scope, AstStatExpr* e); - void visit(DfgScope* scope, AstStatLocal* l); - void visit(DfgScope* scope, AstStatFor* f); - void visit(DfgScope* scope, AstStatForIn* f); - void visit(DfgScope* scope, AstStatAssign* a); - void visit(DfgScope* scope, AstStatCompoundAssign* c); - void visit(DfgScope* scope, AstStatFunction* f); - void visit(DfgScope* scope, AstStatLocalFunction* l); - void visit(DfgScope* scope, AstStatTypeAlias* t); - void visit(DfgScope* scope, AstStatDeclareGlobal* d); - void visit(DfgScope* scope, AstStatDeclareFunction* d); - void visit(DfgScope* scope, AstStatDeclareClass* d); - void visit(DfgScope* scope, AstStatError* error); - - BreadcrumbId visitExpr(DfgScope* scope, AstExpr* e); - BreadcrumbId visitExpr(DfgScope* scope, AstExprLocal* l); - BreadcrumbId visitExpr(DfgScope* scope, AstExprGlobal* g); - BreadcrumbId visitExpr(DfgScope* scope, AstExprCall* c); - BreadcrumbId visitExpr(DfgScope* scope, AstExprIndexName* i); - BreadcrumbId visitExpr(DfgScope* scope, AstExprIndexExpr* i); - BreadcrumbId visitExpr(DfgScope* scope, AstExprFunction* f); - BreadcrumbId visitExpr(DfgScope* scope, AstExprTable* t); - BreadcrumbId visitExpr(DfgScope* scope, AstExprUnary* u); - BreadcrumbId visitExpr(DfgScope* scope, AstExprBinary* b); - BreadcrumbId visitExpr(DfgScope* scope, AstExprTypeAssertion* t); - BreadcrumbId visitExpr(DfgScope* scope, AstExprIfElse* i); - BreadcrumbId visitExpr(DfgScope* scope, AstExprInterpString* i); - BreadcrumbId visitExpr(DfgScope* scope, AstExprError* error); - - void visitLValue(DfgScope* scope, AstExpr* e); - void visitLValue(DfgScope* scope, AstExprLocal* l); - void visitLValue(DfgScope* scope, AstExprGlobal* g); - void visitLValue(DfgScope* scope, AstExprIndexName* i); - void visitLValue(DfgScope* scope, AstExprIndexExpr* i); - void visitLValue(DfgScope* scope, AstExprError* e); + struct FunctionCapture + { + std::vector captureDefs; + std::vector allVersions; + size_t versionOffset = 0; + }; + + DenseHashMap captures{Symbol{}}; + void resolveCaptures(); + + DfgScope* childScope(DfgScope* scope, DfgScope::ScopeType scopeType = DfgScope::Linear); + + void join(DfgScope* p, DfgScope* a, DfgScope* b); + void joinBindings(DfgScope* p, const DfgScope& a, const DfgScope& b); + void joinProps(DfgScope* p, const DfgScope& a, const DfgScope& b); + + DefId lookup(DfgScope* scope, Symbol symbol); + DefId lookup(DfgScope* scope, DefId def, const std::string& key); + + ControlFlow visit(DfgScope* scope, AstStatBlock* b); + ControlFlow visitBlockWithoutChildScope(DfgScope* scope, AstStatBlock* b); + + ControlFlow visit(DfgScope* scope, AstStat* s); + ControlFlow visit(DfgScope* scope, AstStatIf* i); + ControlFlow visit(DfgScope* scope, AstStatWhile* w); + ControlFlow visit(DfgScope* scope, AstStatRepeat* r); + ControlFlow visit(DfgScope* scope, AstStatBreak* b); + ControlFlow visit(DfgScope* scope, AstStatContinue* c); + ControlFlow visit(DfgScope* scope, AstStatReturn* r); + ControlFlow visit(DfgScope* scope, AstStatExpr* e); + ControlFlow visit(DfgScope* scope, AstStatLocal* l); + ControlFlow visit(DfgScope* scope, AstStatFor* f); + ControlFlow visit(DfgScope* scope, AstStatForIn* f); + ControlFlow visit(DfgScope* scope, AstStatAssign* a); + ControlFlow visit(DfgScope* scope, AstStatCompoundAssign* c); + ControlFlow visit(DfgScope* scope, AstStatFunction* f); + ControlFlow visit(DfgScope* scope, AstStatLocalFunction* l); + ControlFlow visit(DfgScope* scope, AstStatTypeAlias* t); + ControlFlow visit(DfgScope* scope, AstStatTypeFunction* f); + ControlFlow visit(DfgScope* scope, AstStatDeclareGlobal* d); + ControlFlow visit(DfgScope* scope, AstStatDeclareFunction* d); + ControlFlow visit(DfgScope* scope, AstStatDeclareClass* d); + ControlFlow visit(DfgScope* scope, AstStatError* error); + + DataFlowResult visitExpr(DfgScope* scope, AstExpr* e); + DataFlowResult visitExpr(DfgScope* scope, AstExprGroup* group); + DataFlowResult visitExpr(DfgScope* scope, AstExprLocal* l); + DataFlowResult visitExpr(DfgScope* scope, AstExprGlobal* g); + DataFlowResult visitExpr(DfgScope* scope, AstExprCall* c); + DataFlowResult visitExpr(DfgScope* scope, AstExprIndexName* i); + DataFlowResult visitExpr(DfgScope* scope, AstExprIndexExpr* i); + DataFlowResult visitExpr(DfgScope* scope, AstExprFunction* f); + DataFlowResult visitExpr(DfgScope* scope, AstExprTable* t); + DataFlowResult visitExpr(DfgScope* scope, AstExprUnary* u); + DataFlowResult visitExpr(DfgScope* scope, AstExprBinary* b); + DataFlowResult visitExpr(DfgScope* scope, AstExprTypeAssertion* t); + DataFlowResult visitExpr(DfgScope* scope, AstExprIfElse* i); + DataFlowResult visitExpr(DfgScope* scope, AstExprInterpString* i); + DataFlowResult visitExpr(DfgScope* scope, AstExprError* error); + + void visitLValue(DfgScope* scope, AstExpr* e, DefId incomingDef, bool isCompoundAssignment = false); + DefId visitLValue(DfgScope* scope, AstExprLocal* l, DefId incomingDef, bool isCompoundAssignment); + DefId visitLValue(DfgScope* scope, AstExprGlobal* g, DefId incomingDef, bool isCompoundAssignment); + DefId visitLValue(DfgScope* scope, AstExprIndexName* i, DefId incomingDef); + DefId visitLValue(DfgScope* scope, AstExprIndexExpr* i, DefId incomingDef); + DefId visitLValue(DfgScope* scope, AstExprError* e, DefId incomingDef); void visitType(DfgScope* scope, AstType* t); void visitType(DfgScope* scope, AstTypeReference* r); diff --git a/third_party/luau/Analysis/include/Luau/DcrLogger.h b/third_party/luau/Analysis/include/Luau/DcrLogger.h index 1e170d5b..d650d9e0 100644 --- a/third_party/luau/Analysis/include/Luau/DcrLogger.h +++ b/third_party/luau/Analysis/include/Luau/DcrLogger.h @@ -126,7 +126,11 @@ struct DcrLogger void captureInitialSolverState(const Scope* rootScope, const std::vector>& unsolvedConstraints); StepSnapshot prepareStepSnapshot( - const Scope* rootScope, NotNull current, bool force, const std::vector>& unsolvedConstraints); + const Scope* rootScope, + NotNull current, + bool force, + const std::vector>& unsolvedConstraints + ); void commitStepSnapshot(StepSnapshot snapshot); void captureFinalSolverState(const Scope* rootScope, const std::vector>& unsolvedConstraints); diff --git a/third_party/luau/Analysis/include/Luau/Def.h b/third_party/luau/Analysis/include/Luau/Def.h index 10d81367..9627f998 100644 --- a/third_party/luau/Analysis/include/Luau/Def.h +++ b/third_party/luau/Analysis/include/Luau/Def.h @@ -23,6 +23,7 @@ using DefId = NotNull; */ struct Cell { + bool subscripted = false; }; /** @@ -71,13 +72,16 @@ const T* get(DefId def) return get_if(&def->v); } +bool containsSubscriptedDefinition(DefId def); +void collectOperands(DefId def, std::vector* operands); + struct DefArena { TypedAllocator allocator; - DefId freshCell(); - // TODO: implement once we have cases where we need to merge in definitions - // DefId phi(const std::vector& defs); + DefId freshCell(bool subscripted = false); + DefId phi(DefId a, DefId b); + DefId phi(const std::vector& defs); }; } // namespace Luau diff --git a/third_party/luau/Analysis/include/Luau/Differ.h b/third_party/luau/Analysis/include/Luau/Differ.h new file mode 100644 index 00000000..d9b78939 --- /dev/null +++ b/third_party/luau/Analysis/include/Luau/Differ.h @@ -0,0 +1,208 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/DenseHash.h" +#include "Luau/TypeFwd.h" +#include "Luau/UnifierSharedState.h" + +#include +#include +#include + +namespace Luau +{ +struct DiffPathNode +{ + // TODO: consider using Variants to simplify toString implementation + enum Kind + { + TableProperty, + FunctionArgument, + FunctionReturn, + Union, + Intersection, + Negation, + }; + Kind kind; + // non-null when TableProperty + std::optional tableProperty; + // non-null when FunctionArgument (unless variadic arg), FunctionReturn (unless variadic arg), Union, or Intersection (i.e. anonymous fields) + std::optional index; + + /** + * Do not use for leaf nodes + */ + DiffPathNode(Kind kind) + : kind(kind) + { + } + + DiffPathNode(Kind kind, std::optional tableProperty, std::optional index) + : kind(kind) + , tableProperty(tableProperty) + , index(index) + { + } + + std::string toString() const; + + static DiffPathNode constructWithTableProperty(Name tableProperty); + + static DiffPathNode constructWithKindAndIndex(Kind kind, size_t index); + + static DiffPathNode constructWithKind(Kind kind); +}; + +struct DiffPathNodeLeaf +{ + std::optional ty; + std::optional tableProperty; + std::optional minLength; + bool isVariadic; + // TODO: Rename to anonymousIndex, for both union and Intersection + std::optional unionIndex; + DiffPathNodeLeaf( + std::optional ty, + std::optional tableProperty, + std::optional minLength, + bool isVariadic, + std::optional unionIndex + ) + : ty(ty) + , tableProperty(tableProperty) + , minLength(minLength) + , isVariadic(isVariadic) + , unionIndex(unionIndex) + { + } + + static DiffPathNodeLeaf detailsNormal(TypeId ty); + + static DiffPathNodeLeaf detailsTableProperty(TypeId ty, Name tableProperty); + + static DiffPathNodeLeaf detailsUnionIndex(TypeId ty, size_t index); + + static DiffPathNodeLeaf detailsLength(int minLength, bool isVariadic); + + static DiffPathNodeLeaf nullopts(); +}; + +struct DiffPath +{ + std::vector path; + + std::string toString(bool prependDot) const; +}; +struct DiffError +{ + enum Kind + { + Normal, + MissingTableProperty, + MissingUnionMember, + MissingIntersectionMember, + IncompatibleGeneric, + LengthMismatchInFnArgs, + LengthMismatchInFnRets, + }; + Kind kind; + + DiffPath diffPath; + DiffPathNodeLeaf left; + DiffPathNodeLeaf right; + + std::string leftRootName; + std::string rightRootName; + + DiffError(Kind kind, DiffPathNodeLeaf left, DiffPathNodeLeaf right, std::string leftRootName, std::string rightRootName) + : kind(kind) + , left(left) + , right(right) + , leftRootName(leftRootName) + , rightRootName(rightRootName) + { + checkValidInitialization(left, right); + } + DiffError(Kind kind, DiffPath diffPath, DiffPathNodeLeaf left, DiffPathNodeLeaf right, std::string leftRootName, std::string rightRootName) + : kind(kind) + , diffPath(diffPath) + , left(left) + , right(right) + , leftRootName(leftRootName) + , rightRootName(rightRootName) + { + checkValidInitialization(left, right); + } + + std::string toString(bool multiLine = false) const; + +private: + std::string toStringALeaf(std::string rootName, const DiffPathNodeLeaf& leaf, const DiffPathNodeLeaf& otherLeaf, bool multiLine) const; + void checkValidInitialization(const DiffPathNodeLeaf& left, const DiffPathNodeLeaf& right); + void checkNonMissingPropertyLeavesHaveNulloptTableProperty() const; +}; + +struct DifferResult +{ + std::optional diffError; + + DifferResult() {} + DifferResult(DiffError diffError) + : diffError(diffError) + { + } + + void wrapDiffPath(DiffPathNode node); +}; +struct DifferEnvironment +{ + TypeId rootLeft; + TypeId rootRight; + std::optional externalSymbolLeft; + std::optional externalSymbolRight; + DenseHashMap genericMatchedPairs; + DenseHashMap genericTpMatchedPairs; + + DifferEnvironment( + TypeId rootLeft, + TypeId rootRight, + std::optional externalSymbolLeft, + std::optional externalSymbolRight + ) + : rootLeft(rootLeft) + , rootRight(rootRight) + , externalSymbolLeft(externalSymbolLeft) + , externalSymbolRight(externalSymbolRight) + , genericMatchedPairs(nullptr) + , genericTpMatchedPairs(nullptr) + { + } + + bool isProvenEqual(TypeId left, TypeId right) const; + bool isAssumedEqual(TypeId left, TypeId right) const; + void recordProvenEqual(TypeId left, TypeId right); + void pushVisiting(TypeId left, TypeId right); + void popVisiting(); + std::vector>::const_reverse_iterator visitingBegin() const; + std::vector>::const_reverse_iterator visitingEnd() const; + std::string getDevFixFriendlyNameLeft() const; + std::string getDevFixFriendlyNameRight() const; + +private: + // TODO: consider using DenseHashSet + std::unordered_set, TypeIdPairHash> provenEqual; + // Ancestors of current types + std::unordered_set, TypeIdPairHash> visiting; + std::vector> visitingStack; +}; +DifferResult diff(TypeId ty1, TypeId ty2); +DifferResult diffWithSymbols(TypeId ty1, TypeId ty2, std::optional symbol1, std::optional symbol2); + +/** + * True if ty is a "simple" type, i.e. cannot contain types. + * string, number, boolean are simple types. + * function and table are not simple types. + */ +bool isSimple(TypeId ty); + +} // namespace Luau diff --git a/third_party/luau/Analysis/include/Luau/Error.h b/third_party/luau/Analysis/include/Luau/Error.h index 8571430b..baf3318c 100644 --- a/third_party/luau/Analysis/include/Luau/Error.h +++ b/third_party/luau/Analysis/include/Luau/Error.h @@ -2,8 +2,12 @@ #pragma once #include "Luau/Location.h" +#include "Luau/NotNull.h" #include "Luau/Type.h" #include "Luau/Variant.h" +#include "Luau/Ast.h" + +#include namespace Luau { @@ -190,6 +194,11 @@ struct InternalError bool operator==(const InternalError& rhs) const; }; +struct ConstraintSolvingIncompleteError +{ + bool operator==(const ConstraintSolvingIncompleteError& rhs) const; +}; + struct CannotCallNonFunction { TypeId ty; @@ -318,6 +327,7 @@ struct TypePackMismatch { TypePackId wantedTp; TypePackId givenTp; + std::string reason; bool operator==(const TypePackMismatch& rhs) const; }; @@ -329,12 +339,164 @@ struct DynamicPropertyLookupOnClassesUnsafe bool operator==(const DynamicPropertyLookupOnClassesUnsafe& rhs) const; }; -using TypeErrorData = Variant; +struct UninhabitedTypeFunction +{ + TypeId ty; + + bool operator==(const UninhabitedTypeFunction& rhs) const; +}; + +struct ExplicitFunctionAnnotationRecommended +{ + std::vector> recommendedArgs; + TypeId recommendedReturn; + bool operator==(const ExplicitFunctionAnnotationRecommended& rhs) const; +}; + +struct UninhabitedTypePackFunction +{ + TypePackId tp; + + bool operator==(const UninhabitedTypePackFunction& rhs) const; +}; + +struct WhereClauseNeeded +{ + TypeId ty; + + bool operator==(const WhereClauseNeeded& rhs) const; +}; + +struct PackWhereClauseNeeded +{ + TypePackId tp; + + bool operator==(const PackWhereClauseNeeded& rhs) const; +}; + +struct CheckedFunctionCallError +{ + TypeId expected; + TypeId passed; + std::string checkedFunctionName; + // TODO: make this a vector + size_t argumentIndex; + bool operator==(const CheckedFunctionCallError& rhs) const; +}; + +struct NonStrictFunctionDefinitionError +{ + std::string functionName; + std::string argument; + TypeId argumentType; + bool operator==(const NonStrictFunctionDefinitionError& rhs) const; +}; + +struct PropertyAccessViolation +{ + TypeId table; + Name key; + + enum + { + CannotRead, + CannotWrite + } context; + + bool operator==(const PropertyAccessViolation& rhs) const; +}; + +struct CheckedFunctionIncorrectArgs +{ + std::string functionName; + size_t expected; + size_t actual; + bool operator==(const CheckedFunctionIncorrectArgs& rhs) const; +}; + +struct CannotAssignToNever +{ + // type of the rvalue being assigned + TypeId rhsType; + + // Originating type. + std::vector cause; + + enum class Reason + { + // when assigning to a property in a union of tables, the properties type + // is narrowed to the intersection of its type in each variant. + PropertyNarrowed, + }; + + Reason reason; + + bool operator==(const CannotAssignToNever& rhs) const; +}; + +struct UnexpectedTypeInSubtyping +{ + TypeId ty; + + bool operator==(const UnexpectedTypeInSubtyping& rhs) const; +}; + +struct UnexpectedTypePackInSubtyping +{ + TypePackId tp; + + bool operator==(const UnexpectedTypePackInSubtyping& rhs) const; +}; + +using TypeErrorData = Variant< + TypeMismatch, + UnknownSymbol, + UnknownProperty, + NotATable, + CannotExtendTable, + OnlyTablesCanHaveMethods, + DuplicateTypeDefinition, + CountMismatch, + FunctionDoesNotTakeSelf, + FunctionRequiresSelf, + OccursCheckFailed, + UnknownRequire, + IncorrectGenericParameterCount, + SyntaxError, + CodeTooComplex, + UnificationTooComplex, + UnknownPropButFoundLikeProp, + GenericError, + InternalError, + ConstraintSolvingIncompleteError, + CannotCallNonFunction, + ExtraInformation, + DeprecatedApiUsed, + ModuleHasCyclicDependency, + IllegalRequire, + FunctionExitsWithoutReturning, + DuplicateGenericParameter, + CannotAssignToNever, + CannotInferBinaryOperation, + MissingProperties, + SwappedGenericTypeParameter, + OptionalValueAccess, + MissingUnionProperty, + TypesAreUnrelated, + NormalizationTooComplex, + TypePackMismatch, + DynamicPropertyLookupOnClassesUnsafe, + UninhabitedTypeFunction, + UninhabitedTypePackFunction, + WhereClauseNeeded, + PackWhereClauseNeeded, + CheckedFunctionCallError, + NonStrictFunctionDefinitionError, + PropertyAccessViolation, + CheckedFunctionIncorrectArgs, + UnexpectedTypeInSubtyping, + UnexpectedTypePackInSubtyping, + ExplicitFunctionAnnotationRecommended>; struct TypeErrorSummary { @@ -403,7 +565,7 @@ std::string toString(const TypeError& error, TypeErrorToStringOptions options); bool containsParseErrorName(const TypeError& error); // Copy any types named in the error into destArena. -void copyErrors(ErrorVec& errors, struct TypeArena& destArena); +void copyErrors(ErrorVec& errors, struct TypeArena& destArena, NotNull builtinTypes); // Internal Compiler Error struct InternalErrorReporter diff --git a/third_party/luau/Analysis/include/Luau/Frontend.h b/third_party/luau/Analysis/include/Luau/Frontend.h index 67e840ee..f476b582 100644 --- a/third_party/luau/Analysis/include/Luau/Frontend.h +++ b/third_party/luau/Analysis/include/Luau/Frontend.h @@ -2,12 +2,14 @@ #pragma once #include "Luau/Config.h" +#include "Luau/GlobalTypes.h" #include "Luau/Module.h" #include "Luau/ModuleResolver.h" #include "Luau/RequireTracer.h" #include "Luau/Scope.h" -#include "Luau/TypeInfer.h" +#include "Luau/TypeCheckLimits.h" #include "Luau/Variant.h" +#include "Luau/AnyTypeSummary.h" #include #include @@ -28,6 +30,9 @@ struct FileResolver; struct ModuleResolver; struct ParseResult; struct HotComment; +struct BuildQueueItem; +struct FrontendCancellationToken; +struct AnyTypeSummary; struct LoadDefinitionFileResult { @@ -68,7 +73,7 @@ struct SourceNode ModuleName name; std::string humanReadableName; - std::unordered_set requireSet; + DenseHashSet requireSet{{}}; std::vector> requireLocations; bool dirtySourceModule = true; bool dirtyModule = true; @@ -95,6 +100,14 @@ struct FrontendOptions std::optional randomizeConstraintResolutionSeed; std::optional enabledLintWarnings; + + std::shared_ptr cancellationToken; + + // Time limit for typechecking a single module + std::optional moduleTimeLimitSec; + + // When true, some internal complexity limits will be scaled down for modules that miss the limit set by moduleTimeLimitSec + bool applyInternalLimitScaling = false; }; struct CheckResult @@ -143,6 +156,10 @@ struct Frontend Frontend(FileResolver* fileResolver, ConfigResolver* configResolver, const FrontendOptions& options = {}); + // Parse module graph and prepare SourceNode/SourceModule data, including required dependencies without running typechecking + void parse(const ModuleName& name); + + // Parse and typecheck module graph CheckResult check(const ModuleName& name, std::optional optionOverride = {}); // new shininess bool isDirty(const ModuleName& name, bool forAutocomplete = false) const; @@ -168,24 +185,58 @@ struct Frontend void registerBuiltinDefinition(const std::string& name, std::function); void applyBuiltinDefinitionToEnvironment(const std::string& environmentName, const std::string& definitionName); - LoadDefinitionFileResult loadDefinitionFile(GlobalTypes& globals, ScopePtr targetScope, std::string_view source, const std::string& packageName, - bool captureComments, bool typeCheckForAutocomplete = false); + LoadDefinitionFileResult loadDefinitionFile( + GlobalTypes& globals, + ScopePtr targetScope, + std::string_view source, + const std::string& packageName, + bool captureComments, + bool typeCheckForAutocomplete = false + ); + + // Batch module checking. Queue modules and check them together, retrieve results with 'getCheckResult' + // If provided, 'executeTask' function is allowed to call the 'task' function on any thread and return without waiting for 'task' to complete + void queueModuleCheck(const std::vector& names); + void queueModuleCheck(const ModuleName& name); + std::vector checkQueuedModules( + std::optional optionOverride = {}, + std::function task)> executeTask = {}, + std::function progress = {} + ); + + std::optional getCheckResult(const ModuleName& name, bool accumulateNested, bool forAutocomplete = false); private: - struct TypeCheckLimits - { - std::optional finishTime; - std::optional instantiationChildLimit; - std::optional unifierIterationLimit; - }; - - ModulePtr check(const SourceModule& sourceModule, Mode mode, std::vector requireCycles, std::optional environmentScope, - bool forAutocomplete, bool recordJsonLog, TypeCheckLimits typeCheckLimits); + ModulePtr check( + const SourceModule& sourceModule, + Mode mode, + std::vector requireCycles, + std::optional environmentScope, + bool forAutocomplete, + bool recordJsonLog, + TypeCheckLimits typeCheckLimits + ); std::pair getSourceNode(const ModuleName& name); SourceModule parse(const ModuleName& name, std::string_view src, const ParseOptions& parseOptions); - bool parseGraph(std::vector& buildQueue, const ModuleName& root, bool forAutocomplete); + bool parseGraph( + std::vector& buildQueue, + const ModuleName& root, + bool forAutocomplete, + std::function canSkip = {} + ); + + void addBuildQueueItems( + std::vector& items, + std::vector& buildQueue, + bool cycleDetected, + DenseHashSet& seen, + const FrontendOptions& frontendOptions + ); + void checkBuildQueueItem(BuildQueueItem& item); + void checkBuildQueueItems(std::vector& items); + void recordItemResult(const BuildQueueItem& item); static LintResult classifyLints(const std::vector& warnings, const Config& config); @@ -211,21 +262,45 @@ struct Frontend FrontendOptions options; InternalErrorReporter iceHandler; std::function prepareModuleScope; + std::function writeJsonLog = {}; - std::unordered_map sourceNodes; - std::unordered_map sourceModules; + std::unordered_map> sourceNodes; + std::unordered_map> sourceModules; std::unordered_map requireTrace; Stats stats = {}; -}; -ModulePtr check(const SourceModule& sourceModule, const std::vector& requireCycles, NotNull builtinTypes, - NotNull iceHandler, NotNull moduleResolver, NotNull fileResolver, - const ScopePtr& globalScope, std::function prepareModuleScope, FrontendOptions options); + std::vector moduleQueue; +}; -ModulePtr check(const SourceModule& sourceModule, const std::vector& requireCycles, NotNull builtinTypes, - NotNull iceHandler, NotNull moduleResolver, NotNull fileResolver, - const ScopePtr& globalScope, std::function prepareModuleScope, FrontendOptions options, - bool recordJsonLog); +ModulePtr check( + const SourceModule& sourceModule, + Mode mode, + const std::vector& requireCycles, + NotNull builtinTypes, + NotNull iceHandler, + NotNull moduleResolver, + NotNull fileResolver, + const ScopePtr& globalScope, + std::function prepareModuleScope, + FrontendOptions options, + TypeCheckLimits limits +); + +ModulePtr check( + const SourceModule& sourceModule, + Mode mode, + const std::vector& requireCycles, + NotNull builtinTypes, + NotNull iceHandler, + NotNull moduleResolver, + NotNull fileResolver, + const ScopePtr& globalScope, + std::function prepareModuleScope, + FrontendOptions options, + TypeCheckLimits limits, + bool recordJsonLog, + std::function writeJsonLog +); } // namespace Luau diff --git a/third_party/luau/Analysis/include/Luau/Generalization.h b/third_party/luau/Analysis/include/Luau/Generalization.h new file mode 100644 index 00000000..18d5b678 --- /dev/null +++ b/third_party/luau/Analysis/include/Luau/Generalization.h @@ -0,0 +1,19 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Scope.h" +#include "Luau/NotNull.h" +#include "Luau/TypeFwd.h" + +namespace Luau +{ + +std::optional generalize( + NotNull arena, + NotNull builtinTypes, + NotNull scope, + NotNull> bakedTypes, + TypeId ty, + /* avoid sealing tables*/ bool avoidSealingTables = false +); +} diff --git a/third_party/luau/Analysis/include/Luau/GlobalTypes.h b/third_party/luau/Analysis/include/Luau/GlobalTypes.h new file mode 100644 index 00000000..55a6d6c7 --- /dev/null +++ b/third_party/luau/Analysis/include/Luau/GlobalTypes.h @@ -0,0 +1,25 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#pragma once + +#include "Luau/Module.h" +#include "Luau/NotNull.h" +#include "Luau/Scope.h" +#include "Luau/TypeArena.h" +#include "Luau/TypeFwd.h" + +namespace Luau +{ + +struct GlobalTypes +{ + explicit GlobalTypes(NotNull builtinTypes); + + NotNull builtinTypes; // Global types are based on builtin types + + TypeArena globalTypes; + SourceModule globalNames; // names for symbols entered into globalScope + ScopePtr globalScope; // shared by all modules +}; + +} // namespace Luau diff --git a/third_party/luau/Analysis/include/Luau/InsertionOrderedMap.h b/third_party/luau/Analysis/include/Luau/InsertionOrderedMap.h new file mode 100644 index 00000000..2937dcda --- /dev/null +++ b/third_party/luau/Analysis/include/Luau/InsertionOrderedMap.h @@ -0,0 +1,134 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Common.h" + +#include +#include +#include +#include + +namespace Luau +{ + +template +struct InsertionOrderedMap +{ + static_assert(std::is_trivially_copyable_v, "key must be trivially copyable"); + +private: + using vec = std::vector>; + +public: + using iterator = typename vec::iterator; + using const_iterator = typename vec::const_iterator; + + void insert(K k, V v) + { + if (indices.count(k) != 0) + return; + + pairs.push_back(std::make_pair(k, std::move(v))); + indices[k] = pairs.size() - 1; + } + + void clear() + { + pairs.clear(); + indices.clear(); + } + + size_t size() const + { + LUAU_ASSERT(pairs.size() == indices.size()); + return pairs.size(); + } + + bool contains(const K& k) const + { + return indices.count(k) > 0; + } + + const V* get(const K& k) const + { + auto it = indices.find(k); + if (it == indices.end()) + return nullptr; + else + return &pairs.at(it->second).second; + } + + V* get(const K& k) + { + auto it = indices.find(k); + if (it == indices.end()) + return nullptr; + else + return &pairs.at(it->second).second; + } + + const_iterator begin() const + { + return pairs.begin(); + } + + const_iterator end() const + { + return pairs.end(); + } + + iterator begin() + { + return pairs.begin(); + } + + iterator end() + { + return pairs.end(); + } + + const_iterator find(K k) const + { + auto indicesIt = indices.find(k); + if (indicesIt == indices.end()) + return end(); + else + return begin() + indicesIt->second; + } + + iterator find(K k) + { + auto indicesIt = indices.find(k); + if (indicesIt == indices.end()) + return end(); + else + return begin() + indicesIt->second; + } + + void erase(iterator it) + { + if (it == pairs.end()) + return; + + K k = it->first; + auto indexIt = indices.find(k); + if (indexIt == indices.end()) + return; + + size_t removed = indexIt->second; + indices.erase(indexIt); + pairs.erase(it); + + for (auto& [_, index] : indices) + { + if (index > removed) + --index; + } + } + +private: + vec pairs; + std::unordered_map indices; +}; + +} // namespace Luau diff --git a/third_party/luau/Analysis/include/Luau/Instantiation.h b/third_party/luau/Analysis/include/Luau/Instantiation.h index c916f953..0fd2817a 100644 --- a/third_party/luau/Analysis/include/Luau/Instantiation.h +++ b/third_party/luau/Analysis/include/Luau/Instantiation.h @@ -1,22 +1,33 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once +#include "Luau/NotNull.h" #include "Luau/Substitution.h" -#include "Luau/Type.h" +#include "Luau/TypeFwd.h" #include "Luau/Unifiable.h" +#include "Luau/VisitType.h" namespace Luau { -struct TypeArena; struct TxnLog; +struct TypeArena; +struct TypeCheckLimits; // A substitution which replaces generic types in a given set by free types. struct ReplaceGenerics : Substitution { - ReplaceGenerics(const TxnLog* log, TypeArena* arena, TypeLevel level, Scope* scope, const std::vector& generics, - const std::vector& genericPacks) + ReplaceGenerics( + const TxnLog* log, + TypeArena* arena, + NotNull builtinTypes, + TypeLevel level, + Scope* scope, + const std::vector& generics, + const std::vector& genericPacks + ) : Substitution(log, arena) + , builtinTypes(builtinTypes) , level(level) , scope(scope) , generics(generics) @@ -24,10 +35,23 @@ struct ReplaceGenerics : Substitution { } + void resetState( + const TxnLog* log, + TypeArena* arena, + NotNull builtinTypes, + TypeLevel level, + Scope* scope, + const std::vector& generics, + const std::vector& genericPacks + ); + + NotNull builtinTypes; + TypeLevel level; Scope* scope; std::vector generics; std::vector genericPacks; + bool ignoreChildren(TypeId ty) override; bool isDirty(TypeId ty) override; bool isDirty(TypePackId tp) override; @@ -38,15 +62,24 @@ struct ReplaceGenerics : Substitution // A substitution which replaces generic functions by monomorphic functions struct Instantiation : Substitution { - Instantiation(const TxnLog* log, TypeArena* arena, TypeLevel level, Scope* scope) + Instantiation(const TxnLog* log, TypeArena* arena, NotNull builtinTypes, TypeLevel level, Scope* scope) : Substitution(log, arena) + , builtinTypes(builtinTypes) , level(level) , scope(scope) + , reusableReplaceGenerics(log, arena, builtinTypes, level, scope, {}, {}) { } + void resetState(const TxnLog* log, TypeArena* arena, NotNull builtinTypes, TypeLevel level, Scope* scope); + + NotNull builtinTypes; + TypeLevel level; Scope* scope; + + ReplaceGenerics reusableReplaceGenerics; + bool ignoreChildren(TypeId ty) override; bool isDirty(TypeId ty) override; bool isDirty(TypePackId tp) override; @@ -54,4 +87,79 @@ struct Instantiation : Substitution TypePackId clean(TypePackId tp) override; }; +// Used to find if a FunctionType requires generic type cleanup during instantiation +struct GenericTypeFinder : TypeOnceVisitor +{ + bool found = false; + + bool visit(TypeId ty) override + { + return !found; + } + + bool visit(TypePackId ty) override + { + return !found; + } + + bool visit(TypeId ty, const Luau::FunctionType& ftv) override + { + if (ftv.hasNoFreeOrGenericTypes) + return false; + + if (!ftv.generics.empty() || !ftv.genericPacks.empty()) + found = true; + + return !found; + } + + bool visit(TypeId ty, const Luau::TableType& ttv) override + { + if (ttv.state == Luau::TableState::Generic) + found = true; + + return !found; + } + + bool visit(TypeId ty, const Luau::GenericType&) override + { + found = true; + return false; + } + + bool visit(TypePackId ty, const Luau::GenericTypePack&) override + { + found = true; + return false; + } + + bool visit(TypeId ty, const Luau::ClassType&) override + { + // During function instantiation, classes are not traversed even if they have generics + return false; + } +}; + +/** Attempt to instantiate a type. Only used under local type inference. + * + * When given a generic function type, instantiate() will return a copy with the + * generics replaced by fresh types. Instantiation will return the same TypeId + * back if the function does not have any generics. + * + * All higher order generics are left as-is. For example, instantiation of + * ((Y) -> (X, Y)) -> (X, Y) is ((Y) -> ('x, Y)) -> ('x, Y) + * + * We substitute the generic X for the free 'x, but leave the generic Y alone. + * + * Instantiation fails only when processing the type causes internal recursion + * limits to be exceeded. + */ +std::optional instantiate( + NotNull builtinTypes, + NotNull arena, + NotNull limits, + NotNull scope, + TypeId ty +); + } // namespace Luau diff --git a/third_party/luau/Analysis/include/Luau/Instantiation2.h b/third_party/luau/Analysis/include/Luau/Instantiation2.h new file mode 100644 index 00000000..c9215fad --- /dev/null +++ b/third_party/luau/Analysis/include/Luau/Instantiation2.h @@ -0,0 +1,90 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/NotNull.h" +#include "Luau/Substitution.h" +#include "Luau/TxnLog.h" +#include "Luau/TypeFwd.h" +#include "Luau/Unifiable.h" + +namespace Luau +{ + +struct TypeArena; +struct TypeCheckLimits; + +struct Replacer : Substitution +{ + DenseHashMap replacements; + DenseHashMap replacementPacks; + + Replacer(NotNull arena, DenseHashMap replacements, DenseHashMap replacementPacks) + : Substitution(TxnLog::empty(), arena) + , replacements(std::move(replacements)) + , replacementPacks(std::move(replacementPacks)) + { + } + + bool isDirty(TypeId ty) override + { + return replacements.find(ty) != nullptr; + } + + bool isDirty(TypePackId tp) override + { + return replacementPacks.find(tp) != nullptr; + } + + TypeId clean(TypeId ty) override + { + TypeId res = replacements[ty]; + LUAU_ASSERT(res); + dontTraverseInto(res); + return res; + } + + TypePackId clean(TypePackId tp) override + { + TypePackId res = replacementPacks[tp]; + LUAU_ASSERT(res); + dontTraverseInto(res); + return res; + } +}; + +// A substitution which replaces generic functions by monomorphic functions +struct Instantiation2 : Substitution +{ + // Mapping from generic types to free types to be used in instantiation. + DenseHashMap genericSubstitutions{nullptr}; + // Mapping from generic type packs to `TypePack`s of free types to be used in instantiation. + DenseHashMap genericPackSubstitutions{nullptr}; + + Instantiation2(TypeArena* arena, DenseHashMap genericSubstitutions, DenseHashMap genericPackSubstitutions) + : Substitution(TxnLog::empty(), arena) + , genericSubstitutions(std::move(genericSubstitutions)) + , genericPackSubstitutions(std::move(genericPackSubstitutions)) + { + } + + bool ignoreChildren(TypeId ty) override; + bool isDirty(TypeId ty) override; + bool isDirty(TypePackId tp) override; + TypeId clean(TypeId ty) override; + TypePackId clean(TypePackId tp) override; +}; + +std::optional instantiate2( + TypeArena* arena, + DenseHashMap genericSubstitutions, + DenseHashMap genericPackSubstitutions, + TypeId ty +); +std::optional instantiate2( + TypeArena* arena, + DenseHashMap genericSubstitutions, + DenseHashMap genericPackSubstitutions, + TypePackId tp +); + +} // namespace Luau diff --git a/third_party/luau/Analysis/include/Luau/IostreamHelpers.h b/third_party/luau/Analysis/include/Luau/IostreamHelpers.h index 42b362be..a16455df 100644 --- a/third_party/luau/Analysis/include/Luau/IostreamHelpers.h +++ b/third_party/luau/Analysis/include/Luau/IostreamHelpers.h @@ -5,6 +5,7 @@ #include "Luau/Location.h" #include "Luau/Type.h" #include "Luau/Ast.h" +#include "Luau/TypePath.h" #include @@ -48,4 +49,14 @@ std::ostream& operator<<(std::ostream& lhs, const TypePackVar& tv); std::ostream& operator<<(std::ostream& lhs, const TypeErrorData& ted); +std::ostream& operator<<(std::ostream& lhs, TypeId ty); +std::ostream& operator<<(std::ostream& lhs, TypePackId tp); + +namespace TypePath +{ + +std::ostream& operator<<(std::ostream& lhs, const Path& path); + +}; // namespace TypePath + } // namespace Luau diff --git a/third_party/luau/Analysis/include/Luau/LValue.h b/third_party/luau/Analysis/include/Luau/LValue.h index 9a8b863b..e20d9901 100644 --- a/third_party/luau/Analysis/include/Luau/LValue.h +++ b/third_party/luau/Analysis/include/Luau/LValue.h @@ -3,6 +3,7 @@ #include "Luau/Variant.h" #include "Luau/Symbol.h" +#include "Luau/TypeFwd.h" #include #include @@ -10,9 +11,6 @@ namespace Luau { -struct Type; -using TypeId = const Type*; - struct Field; // Deprecated. Do not use in new work. diff --git a/third_party/luau/Analysis/include/Luau/Linter.h b/third_party/luau/Analysis/include/Luau/Linter.h index 6bbc3d66..f911a652 100644 --- a/third_party/luau/Analysis/include/Luau/Linter.h +++ b/third_party/luau/Analysis/include/Luau/Linter.h @@ -1,6 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once +#include "Luau/LinterConfig.h" #include "Luau/Location.h" #include @@ -15,88 +16,23 @@ class AstStat; class AstNameTable; struct TypeChecker; struct Module; -struct HotComment; using ScopePtr = std::shared_ptr; -struct LintWarning -{ - // Make sure any new lint codes are documented here: https://luau-lang.org/lint - // Note that in Studio, the active set of lint warnings is determined by FStringStudioLuauLints - enum Code - { - Code_Unknown = 0, - - Code_UnknownGlobal = 1, // superseded by type checker - Code_DeprecatedGlobal = 2, - Code_GlobalUsedAsLocal = 3, - Code_LocalShadow = 4, // disabled in Studio - Code_SameLineStatement = 5, // disabled in Studio - Code_MultiLineStatement = 6, - Code_LocalUnused = 7, // disabled in Studio - Code_FunctionUnused = 8, // disabled in Studio - Code_ImportUnused = 9, // disabled in Studio - Code_BuiltinGlobalWrite = 10, - Code_PlaceholderRead = 11, - Code_UnreachableCode = 12, - Code_UnknownType = 13, - Code_ForRange = 14, - Code_UnbalancedAssignment = 15, - Code_ImplicitReturn = 16, // disabled in Studio, superseded by type checker in strict mode - Code_DuplicateLocal = 17, - Code_FormatString = 18, - Code_TableLiteral = 19, - Code_UninitializedLocal = 20, - Code_DuplicateFunction = 21, - Code_DeprecatedApi = 22, - Code_TableOperations = 23, - Code_DuplicateCondition = 24, - Code_MisleadingAndOr = 25, - Code_CommentDirective = 26, - Code_IntegerParsing = 27, - Code_ComparisonPrecedence = 28, - - Code__Count - }; - - Code code; - Location location; - std::string text; - - static const char* getName(Code code); - static Code parseName(const char* name); - static uint64_t parseMask(const std::vector& hotcomments); -}; - struct LintResult { std::vector errors; std::vector warnings; }; -struct LintOptions -{ - uint64_t warningMask = 0; - - void enableWarning(LintWarning::Code code) - { - warningMask |= 1ull << code; - } - void disableWarning(LintWarning::Code code) - { - warningMask &= ~(1ull << code); - } - - bool isEnabled(LintWarning::Code code) const - { - return 0 != (warningMask & (1ull << code)); - } - - void setDefaults(); -}; - -std::vector lint(AstStat* root, const AstNameTable& names, const ScopePtr& env, const Module* module, - const std::vector& hotcomments, const LintOptions& options); +std::vector lint( + AstStat* root, + const AstNameTable& names, + const ScopePtr& env, + const Module* module, + const std::vector& hotcomments, + const LintOptions& options +); std::vector getDeprecatedGlobals(const AstNameTable& names); diff --git a/third_party/luau/Analysis/include/Luau/Metamethods.h b/third_party/luau/Analysis/include/Luau/Metamethods.h index 84b0092f..747b7201 100644 --- a/third_party/luau/Analysis/include/Luau/Metamethods.h +++ b/third_party/luau/Analysis/include/Luau/Metamethods.h @@ -19,6 +19,7 @@ static const std::unordered_map kBinaryOpMetamet {AstExprBinary::Op::Sub, "__sub"}, {AstExprBinary::Op::Mul, "__mul"}, {AstExprBinary::Op::Div, "__div"}, + {AstExprBinary::Op::FloorDiv, "__idiv"}, {AstExprBinary::Op::Pow, "__pow"}, {AstExprBinary::Op::Mod, "__mod"}, {AstExprBinary::Op::Concat, "__concat"}, diff --git a/third_party/luau/Analysis/include/Luau/Module.h b/third_party/luau/Analysis/include/Luau/Module.h index b9be8205..f909deb8 100644 --- a/third_party/luau/Analysis/include/Luau/Module.h +++ b/third_party/luau/Analysis/include/Luau/Module.h @@ -8,6 +8,7 @@ #include "Luau/ParseResult.h" #include "Luau/Scope.h" #include "Luau/TypeArena.h" +#include "Luau/AnyTypeSummary.h" #include #include @@ -18,6 +19,7 @@ namespace Luau { struct Module; +struct AnyTypeSummary; using ScopePtr = std::shared_ptr; using ModulePtr = std::shared_ptr; @@ -71,6 +73,10 @@ struct Module TypeArena interfaceTypes; TypeArena internalTypes; + // Summary of Ast Nodes that either contain + // user annotated anys or typechecker inferred anys + AnyTypeSummary ats{}; + // Scopes and AST types refer to parse data, so we need to keep that alive std::shared_ptr allocator; std::shared_ptr names; @@ -81,24 +87,47 @@ struct Module DenseHashMap astTypePacks{nullptr}; DenseHashMap astExpectedTypes{nullptr}; + // For AST nodes that are function calls, this map provides the + // unspecialized type of the function that was called. If a function call + // resolves to a __call metamethod application, this map will point at that + // metamethod. + // + // This is useful for type checking and Signature Help. DenseHashMap astOriginalCallTypes{nullptr}; + + // The specialization of a function that was selected. If the function is + // generic, those generic type parameters will be replaced with the actual + // types that were passed. If the function is an overload, this map will + // point at the specific overloads that were selected. DenseHashMap astOverloadResolvedTypes{nullptr}; + // Only used with for...in loops. The computed type of the next() function + // is kept here for type checking. + DenseHashMap astForInNextTypes{nullptr}; + DenseHashMap astResolvedTypes{nullptr}; - DenseHashMap astOriginalResolvedTypes{nullptr}; DenseHashMap astResolvedTypePacks{nullptr}; - // Map AST nodes to the scope they create. Cannot be NotNull because we need a sentinel value for the map. - DenseHashMap astScopes{nullptr}; + // The computed result type of a compound assignment. (eg foo += 1) + // + // Type checking uses this to check that the result of such an operation is + // actually compatible with the left-side operand. + DenseHashMap astCompoundAssignResultTypes{nullptr}; + + DenseHashMap>> upperBoundContributors{nullptr}; - std::unique_ptr reduction; + // Map AST nodes to the scope they create. Cannot be NotNull because + // we need a sentinel value for the map. + DenseHashMap astScopes{nullptr}; std::unordered_map declaredGlobals; ErrorVec errors; LintResult lintResult; Mode mode; SourceCode::Type type; + double checkDurationSec = 0.0; bool timeout = false; + bool cancelled = false; TypePackId returnType = nullptr; std::unordered_map exportedTypeBindings; @@ -106,8 +135,9 @@ struct Module bool hasModuleScope() const; ScopePtr getModuleScope() const; - // Once a module has been typechecked, we clone its public interface into a separate arena. - // This helps us to force Type ownership into a DAG rather than a DCG. + // Once a module has been typechecked, we clone its public interface into a + // separate arena. This helps us to force Type ownership into a DAG rather + // than a DCG. void clonePublicInterface(NotNull builtinTypes, InternalErrorReporter& ice); }; diff --git a/third_party/luau/Analysis/include/Luau/NonStrictTypeChecker.h b/third_party/luau/Analysis/include/Luau/NonStrictTypeChecker.h new file mode 100644 index 00000000..8e80c762 --- /dev/null +++ b/third_party/luau/Analysis/include/Luau/NonStrictTypeChecker.h @@ -0,0 +1,26 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Module.h" +#include "Luau/NotNull.h" +#include "Luau/DataFlowGraph.h" + +namespace Luau +{ + +struct BuiltinTypes; +struct UnifierSharedState; +struct TypeCheckLimits; + +void checkNonStrict( + NotNull builtinTypes, + NotNull ice, + NotNull unifierState, + NotNull dfg, + NotNull limits, + const SourceModule& sourceModule, + Module* module +); + + +} // namespace Luau diff --git a/third_party/luau/Analysis/include/Luau/Normalize.h b/third_party/luau/Analysis/include/Luau/Normalize.h index 6c808286..d844d211 100644 --- a/third_party/luau/Analysis/include/Luau/Normalize.h +++ b/third_party/luau/Analysis/include/Luau/Normalize.h @@ -2,10 +2,15 @@ #pragma once #include "Luau/NotNull.h" -#include "Luau/Type.h" +#include "Luau/Set.h" +#include "Luau/TypeFwd.h" #include "Luau/UnifierSharedState.h" +#include +#include #include +#include +#include namespace Luau { @@ -13,7 +18,6 @@ namespace Luau struct InternalErrorReporter; struct Module; struct Scope; -struct BuiltinTypes; using ModulePtr = std::shared_ptr; @@ -25,7 +29,7 @@ bool isConsistentSubtype(TypePackId subTy, TypePackId superTy, NotNull sc class TypeIds { private: - std::unordered_set types; + DenseHashMap types{nullptr}; std::vector order; std::size_t hash = 0; @@ -33,10 +37,15 @@ class TypeIds using iterator = std::vector::iterator; using const_iterator = std::vector::const_iterator; - TypeIds(const TypeIds&) = default; - TypeIds(TypeIds&&) = default; TypeIds() = default; ~TypeIds() = default; + + TypeIds(std::initializer_list tys); + + TypeIds(const TypeIds&) = default; + TypeIds& operator=(const TypeIds&) = default; + + TypeIds(TypeIds&&) = default; TypeIds& operator=(TypeIds&&) = default; void insert(TypeId ty); @@ -50,6 +59,7 @@ class TypeIds const_iterator begin() const; const_iterator end() const; iterator erase(const_iterator it); + void erase(TypeId ty); size_t size() const; bool empty() const; @@ -64,6 +74,7 @@ class TypeIds bool operator==(const TypeIds& there) const; size_t getHash() const; + bool isNever() const; }; } // namespace Luau @@ -206,7 +217,19 @@ struct NormalizedFunctionType struct NormalizedType; using NormalizedTyvars = std::unordered_map>; -bool isInhabited_DEPRECATED(const NormalizedType& norm); +// Operations provided by `Normalizer` can have ternary results: +// 1. The operation returned true. +// 2. The operation returned false. +// 3. They can hit resource limitations, which invalidates _all normalized types_. +enum class NormalizationResult +{ + // The operation returned true or succeeded. + True, + // The operation returned false or failed. + False, + // Resource limits were hit, invalidating all normalized types. + HitLimits, +}; // A normalized type is either any, unknown, or one of the form P | T | F | G where // * P is a union of primitive types (including singletons, classes and the error type) @@ -226,10 +249,6 @@ struct NormalizedType NormalizedClassType classes; - // The class part of the type. - // Each element of this set is a class, and none of the classes are subclasses of each other. - TypeIds DEPRECATED_classes; - // The error part of the type. // This type is either never or the error type. TypeId errors; @@ -250,6 +269,10 @@ struct NormalizedType // This type is either never or thread. TypeId threads; + // The buffer part of the type. + // This type is either never or buffer. + TypeId buffers; + // The (meta)table part of the type. // Each element of this set is a (meta)table type, or the top `table` type. // An empty set denotes never. @@ -261,6 +284,11 @@ struct NormalizedType // The generic/free part of the type. NormalizedTyvars tyvars; + // Free types, blocked types, and certain other types change shape as type + // inference is done. If we were to cache the normalization of these types, + // we'd be reusing bad, stale data. + bool isCacheable = true; + NormalizedType(NotNull builtinTypes); NormalizedType() = delete; @@ -271,22 +299,63 @@ struct NormalizedType NormalizedType(NormalizedType&&) = default; NormalizedType& operator=(NormalizedType&&) = default; + + // IsType functions + bool isUnknown() const; + /// Returns true if the type is exactly a number. Behaves like Type::isNumber() + bool isExactlyNumber() const; + + /// Returns true if the type is a subtype of string(it could be a singleton). Behaves like Type::isString() + bool isSubtypeOfString() const; + + /// Returns true if the type is a subtype of boolean(it could be a singleton). Behaves like Type::isBoolean() + bool isSubtypeOfBooleans() const; + + /// Returns true if this type should result in error suppressing behavior. + bool shouldSuppressErrors() const; + + /// Returns true if this type contains the primitve top table type, `table`. + bool hasTopTable() const; + + // Helpers that improve readability of the above (they just say if the component is present) + bool hasTops() const; + bool hasBooleans() const; + bool hasClasses() const; + bool hasErrors() const; + bool hasNils() const; + bool hasNumbers() const; + bool hasStrings() const; + bool hasThreads() const; + bool hasBuffers() const; + bool hasTables() const; + bool hasFunctions() const; + bool hasTyvars() const; + + bool isFalsy() const; + bool isTruthy() const; }; + + class Normalizer { - std::unordered_map> cachedNormals; + std::unordered_map> cachedNormals; std::unordered_map cachedIntersections; std::unordered_map cachedUnions; std::unordered_map> cachedTypeIds; + + DenseHashMap cachedIsInhabited{nullptr}; + DenseHashMap, bool, TypeIdPairHash> cachedIsInhabitedIntersection{{nullptr, nullptr}}; + bool withinResourceLimits(); public: TypeArena* arena; NotNull builtinTypes; NotNull sharedState; + bool cacheInhabitance = false; - Normalizer(TypeArena* arena, NotNull builtinTypes, NotNull sharedState); + Normalizer(TypeArena* arena, NotNull builtinTypes, NotNull sharedState, bool cacheInhabitance = false); Normalizer(const Normalizer&) = delete; Normalizer(Normalizer&&) = delete; Normalizer() = delete; @@ -295,7 +364,7 @@ class Normalizer Normalizer& operator=(Normalizer&) = delete; // If this returns null, the typechecker should emit a "too complex" error - const NormalizedType* normalize(TypeId ty); + std::shared_ptr normalize(TypeId ty); void clearNormal(NormalizedType& norm); // ------- Cached TypeIds @@ -320,8 +389,8 @@ class Normalizer void unionFunctions(NormalizedFunctionType& heress, const NormalizedFunctionType& theress); void unionTablesWithTable(TypeIds& heres, TypeId there); void unionTables(TypeIds& heres, const TypeIds& theres); - bool unionNormals(NormalizedType& here, const NormalizedType& there, int ignoreSmallerTyvars = -1); - bool unionNormalWithTy(NormalizedType& here, TypeId there, int ignoreSmallerTyvars = -1); + NormalizationResult unionNormals(NormalizedType& here, const NormalizedType& there, int ignoreSmallerTyvars = -1); + NormalizationResult unionNormalWithTy(NormalizedType& here, TypeId there, Set& seenSetTypes, int ignoreSmallerTyvars = -1); // ------- Negations std::optional negateNormal(const NormalizedType& here); @@ -329,32 +398,35 @@ class Normalizer TypeId negate(TypeId there); void subtractPrimitive(NormalizedType& here, TypeId ty); void subtractSingleton(NormalizedType& here, TypeId ty); + NormalizationResult intersectNormalWithNegationTy(TypeId toNegate, NormalizedType& intersect); // ------- Normalizing intersections TypeId intersectionOfTops(TypeId here, TypeId there); TypeId intersectionOfBools(TypeId here, TypeId there); - void DEPRECATED_intersectClasses(TypeIds& heres, const TypeIds& theres); - void DEPRECATED_intersectClassesWithClass(TypeIds& heres, TypeId there); void intersectClasses(NormalizedClassType& heres, const NormalizedClassType& theres); void intersectClassesWithClass(NormalizedClassType& heres, TypeId there); void intersectStrings(NormalizedStringType& here, const NormalizedStringType& there); std::optional intersectionOfTypePacks(TypePackId here, TypePackId there); - std::optional intersectionOfTables(TypeId here, TypeId there); - void intersectTablesWithTable(TypeIds& heres, TypeId there); + std::optional intersectionOfTables(TypeId here, TypeId there, Set& seenSet); + void intersectTablesWithTable(TypeIds& heres, TypeId there, Set& seenSetTypes); void intersectTables(TypeIds& heres, const TypeIds& theres); std::optional intersectionOfFunctions(TypeId here, TypeId there); void intersectFunctionsWithFunction(NormalizedFunctionType& heress, TypeId there); void intersectFunctions(NormalizedFunctionType& heress, const NormalizedFunctionType& theress); - bool intersectTyvarsWithTy(NormalizedTyvars& here, TypeId there); - bool intersectNormals(NormalizedType& here, const NormalizedType& there, int ignoreSmallerTyvars = -1); - bool intersectNormalWithTy(NormalizedType& here, TypeId there); - bool normalizeIntersections(const std::vector& intersections, NormalizedType& outType); + NormalizationResult intersectTyvarsWithTy(NormalizedTyvars& here, TypeId there, Set& seenSetTypes); + NormalizationResult intersectNormals(NormalizedType& here, const NormalizedType& there, int ignoreSmallerTyvars = -1); + NormalizationResult intersectNormalWithTy(NormalizedType& here, TypeId there, Set& seenSetTypes); + NormalizationResult normalizeIntersections(const std::vector& intersections, NormalizedType& outType, Set& seenSet); // Check for inhabitance - bool isInhabited(TypeId ty, std::unordered_set seen = {}); - bool isInhabited(const NormalizedType* norm, std::unordered_set seen = {}); + NormalizationResult isInhabited(TypeId ty); + NormalizationResult isInhabited(TypeId ty, Set& seen); + NormalizationResult isInhabited(const NormalizedType* norm); + NormalizationResult isInhabited(const NormalizedType* norm, Set& seen); + // Check for intersections being inhabited - bool isIntersectionInhabited(TypeId left, TypeId right); + NormalizationResult isIntersectionInhabited(TypeId left, TypeId right); + NormalizationResult isIntersectionInhabited(TypeId left, TypeId right, Set& seenSet); // -------- Convert back from a normalized type to a type TypeId typeFromNormal(const NormalizedType& norm); diff --git a/third_party/luau/Analysis/include/Luau/OverloadResolution.h b/third_party/luau/Analysis/include/Luau/OverloadResolution.h new file mode 100644 index 00000000..9a2974a5 --- /dev/null +++ b/third_party/luau/Analysis/include/Luau/OverloadResolution.h @@ -0,0 +1,120 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Ast.h" +#include "Luau/InsertionOrderedMap.h" +#include "Luau/NotNull.h" +#include "Luau/TypeFwd.h" +#include "Luau/Location.h" +#include "Luau/Error.h" +#include "Luau/Subtyping.h" + +namespace Luau +{ + +struct BuiltinTypes; +struct TypeArena; +struct Scope; +struct InternalErrorReporter; +struct TypeCheckLimits; +struct Subtyping; + +class Normalizer; + +struct OverloadResolver +{ + enum Analysis + { + Ok, + TypeIsNotAFunction, + ArityMismatch, + OverloadIsNonviable, // Arguments were incompatible with the overloads parameters but were otherwise compatible by arity + }; + + OverloadResolver( + NotNull builtinTypes, + NotNull arena, + NotNull normalizer, + NotNull scope, + NotNull reporter, + NotNull limits, + Location callLocation + ); + + NotNull builtinTypes; + NotNull arena; + NotNull normalizer; + NotNull scope; + NotNull ice; + NotNull limits; + Subtyping subtyping; + Location callLoc; + + // Resolver results + std::vector ok; + std::vector nonFunctions; + std::vector> arityMismatches; + std::vector> nonviableOverloads; + InsertionOrderedMap> resolution; + + + std::pair selectOverload(TypeId ty, TypePackId args); + void resolve(TypeId fnTy, const TypePack* args, AstExpr* selfExpr, const std::vector* argExprs); + +private: + std::optional testIsSubtype(const Location& location, TypeId subTy, TypeId superTy); + std::optional testIsSubtype(const Location& location, TypePackId subTy, TypePackId superTy); + std::pair checkOverload( + TypeId fnTy, + const TypePack* args, + AstExpr* fnLoc, + const std::vector* argExprs, + bool callMetamethodOk = true + ); + static bool isLiteral(AstExpr* expr); + LUAU_NOINLINE + std::pair checkOverload_( + TypeId fnTy, + const FunctionType* fn, + const TypePack* args, + AstExpr* fnExpr, + const std::vector* argExprs + ); + size_t indexof(Analysis analysis); + void add(Analysis analysis, TypeId ty, ErrorVec&& errors); +}; + +struct SolveResult +{ + enum OverloadCallResult + { + Ok, + CodeTooComplex, + OccursCheckFailed, + NoMatchingOverload, + }; + + OverloadCallResult result; + std::optional typePackId; // nullopt if result != Ok + + TypeId overloadToUse = nullptr; + TypeId inferredTy = nullptr; + DenseHashMap> expandedFreeTypes{nullptr}; +}; + +// Helper utility, presently used for binary operator type functions. +// +// Given a function and a set of arguments, select a suitable overload. +SolveResult solveFunctionCall( + NotNull arena, + NotNull builtinTypes, + NotNull normalizer, + NotNull iceReporter, + NotNull limits, + NotNull scope, + const Location& location, + TypeId fn, + TypePackId argsPack +); + +} // namespace Luau diff --git a/third_party/luau/Analysis/include/Luau/Predicate.h b/third_party/luau/Analysis/include/Luau/Predicate.h index 50fd7edd..52ee1f29 100644 --- a/third_party/luau/Analysis/include/Luau/Predicate.h +++ b/third_party/luau/Analysis/include/Luau/Predicate.h @@ -4,15 +4,13 @@ #include "Luau/Location.h" #include "Luau/LValue.h" #include "Luau/Variant.h" +#include "Luau/TypeFwd.h" #include namespace Luau { -struct Type; -using TypeId = const Type*; - struct TruthyPredicate; struct IsAPredicate; struct TypeGuardPredicate; diff --git a/third_party/luau/Analysis/include/Luau/Quantify.h b/third_party/luau/Analysis/include/Luau/Quantify.h index c86512f1..bae3751d 100644 --- a/third_party/luau/Analysis/include/Luau/Quantify.h +++ b/third_party/luau/Analysis/include/Luau/Quantify.h @@ -1,7 +1,12 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once -#include "Luau/Type.h" +#include "Luau/TypeFwd.h" +#include "Luau/DenseHash.h" +#include "Luau/Unifiable.h" + +#include +#include namespace Luau { @@ -10,6 +15,29 @@ struct TypeArena; struct Scope; void quantify(TypeId ty, TypeLevel level); -std::optional quantify(TypeArena* arena, TypeId ty, Scope* scope); + +// TODO: This is eerily similar to the pattern that NormalizedClassType +// implements. We could, and perhaps should, merge them together. +template +struct OrderedMap +{ + std::vector keys; + DenseHashMap pairings{nullptr}; + + void push(K k, V v) + { + keys.push_back(k); + pairings[k] = v; + } +}; + +struct QuantifierResult +{ + TypeId result; + OrderedMap insertedGenerics; + OrderedMap insertedGenericPacks; +}; + +std::optional quantify(TypeArena* arena, TypeId ty, Scope* scope); } // namespace Luau diff --git a/third_party/luau/Analysis/include/Luau/Refinement.h b/third_party/luau/Analysis/include/Luau/Refinement.h index fecf459a..3fea7868 100644 --- a/third_party/luau/Analysis/include/Luau/Refinement.h +++ b/third_party/luau/Analysis/include/Luau/Refinement.h @@ -4,14 +4,13 @@ #include "Luau/NotNull.h" #include "Luau/TypedAllocator.h" #include "Luau/Variant.h" +#include "Luau/TypeFwd.h" namespace Luau { -using BreadcrumbId = NotNull; - -struct Type; -using TypeId = const Type*; +struct RefinementKey; +using DefId = NotNull; struct Variadic; struct Negation; @@ -52,7 +51,7 @@ struct Equivalence struct Proposition { - BreadcrumbId breadcrumb; + const RefinementKey* key; TypeId discriminantTy; }; @@ -69,7 +68,7 @@ struct RefinementArena RefinementId conjunction(RefinementId lhs, RefinementId rhs); RefinementId disjunction(RefinementId lhs, RefinementId rhs); RefinementId equivalence(RefinementId lhs, RefinementId rhs); - RefinementId proposition(BreadcrumbId breadcrumb, TypeId discriminantTy); + RefinementId proposition(const RefinementKey* key, TypeId discriminantTy); private: TypedAllocator allocator; diff --git a/third_party/luau/Analysis/include/Luau/Scope.h b/third_party/luau/Analysis/include/Luau/Scope.h index c3038fac..0e6eff56 100644 --- a/third_party/luau/Analysis/include/Luau/Scope.h +++ b/third_party/luau/Analysis/include/Luau/Scope.h @@ -2,9 +2,13 @@ #pragma once #include "Luau/Def.h" +#include "Luau/LValue.h" #include "Luau/Location.h" #include "Luau/NotNull.h" #include "Luau/Type.h" +#include "Luau/DenseHash.h" +#include "Luau/Symbol.h" +#include "Luau/Unifiable.h" #include #include @@ -41,6 +45,8 @@ struct Scope TypeLevel level; + Location location; // the spanning location associated with this scope + std::unordered_map exportedTypeBindings; std::unordered_map privateTypeBindings; std::unordered_map typeAliasLocations; @@ -52,7 +58,9 @@ struct Scope void addBuiltinTypeBinding(const Name& name, const TypeFun& tyFun); std::optional lookup(Symbol sym) const; + std::optional lookupUnrefinedType(DefId def) const; std::optional lookup(DefId def) const; + std::optional> lookupEx(DefId def); std::optional> lookupEx(Symbol sym); std::optional lookupType(const Name& name) const; @@ -65,7 +73,16 @@ struct Scope std::optional linearSearchForBinding(const std::string& name, bool traverseScopeChain = true) const; RefinementMap refinements; - DenseHashMap dcrRefinements{nullptr}; + + // This can be viewed as the "unrefined" type of each binding. + DenseHashMap lvalueTypes{nullptr}; + + // Luau values are routinely refined more narrowly than their actual + // inferred type through control flow statements. We retain those refined + // types here. + DenseHashMap rvalueRefinements{nullptr}; + + void inheritAssignments(const ScopePtr& childScope); void inheritRefinements(const ScopePtr& childScope); // For mutually recursive type aliases, it's important that @@ -85,4 +102,12 @@ bool subsumesStrict(Scope* left, Scope* right); // outermost-possible scope. bool subsumes(Scope* left, Scope* right); +inline Scope* max(Scope* left, Scope* right) +{ + if (subsumes(left, right)) + return right; + else + return left; +} + } // namespace Luau diff --git a/third_party/luau/Analysis/include/Luau/Set.h b/third_party/luau/Analysis/include/Luau/Set.h new file mode 100644 index 00000000..274375cf --- /dev/null +++ b/third_party/luau/Analysis/include/Luau/Set.h @@ -0,0 +1,194 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Common.h" +#include "Luau/DenseHash.h" + +LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) + +namespace Luau +{ + +template +using SetHashDefault = std::conditional_t, DenseHashPointer, std::hash>; + +// This is an implementation of `unordered_set` using `DenseHashMap` to support erasure. +// This lets us work around `DenseHashSet` limitations and get a more traditional set interface. +template> +class Set +{ +private: + using Impl = DenseHashMap; + Impl mapping; + size_t entryCount = 0; + +public: + class const_iterator; + using iterator = const_iterator; + + Set(const T& empty_key) + : mapping{empty_key} + { + } + + bool insert(const T& element) + { + bool& entry = mapping[element]; + bool fresh = !entry; + + if (fresh) + { + entry = true; + entryCount++; + } + + return fresh; + } + + template + void insert(Iterator begin, Iterator end) + { + for (Iterator it = begin; it != end; ++it) + insert(*it); + } + + void erase(T&& element) + { + bool& entry = mapping[element]; + + if (entry) + { + entry = false; + entryCount--; + } + } + + void erase(const T& element) + { + bool& entry = mapping[element]; + + if (entry) + { + entry = false; + entryCount--; + } + } + + void clear() + { + mapping.clear(); + entryCount = 0; + } + + size_t size() const + { + return entryCount; + } + + bool empty() const + { + return entryCount == 0; + } + + size_t count(const T& element) const + { + const bool* entry = mapping.find(element); + return (entry && *entry) ? 1 : 0; + } + + bool contains(const T& element) const + { + return count(element) != 0; + } + + const_iterator begin() const + { + return const_iterator(mapping.begin(), mapping.end()); + } + + const_iterator end() const + { + return const_iterator(mapping.end(), mapping.end()); + } + + bool operator==(const Set& there) const + { + // if the sets are unequal sizes, then they cannot possibly be equal. + if (size() != there.size()) + return false; + + // otherwise, we'll need to check that every element we have here is in `there`. + for (auto [elem, present] : mapping) + { + // if it's not, we'll return `false` + if (present && there.contains(elem)) + return false; + } + + // otherwise, we've proven the two equal! + return true; + } + + class const_iterator + { + public: + using value_type = T; + using reference = T&; + using pointer = T*; + using difference_type = ptrdiff_t; + using iterator_category = std::forward_iterator_tag; + + const_iterator(typename Impl::const_iterator impl_, typename Impl::const_iterator end_) + : impl(impl_) + , end(end_) + { + while (impl != end && impl->second == false) + ++impl; + } + + const T& operator*() const + { + return impl->first; + } + + const T* operator->() const + { + return &impl->first; + } + + bool operator==(const const_iterator& other) const + { + return impl == other.impl; + } + + bool operator!=(const const_iterator& other) const + { + return impl != other.impl; + } + + + const_iterator& operator++() + { + do + { + impl++; + } while (impl != end && impl->second == false); + // keep iterating past pairs where the value is `false` + + return *this; + } + + const_iterator operator++(int) + { + const_iterator res = *this; + ++*this; + return res; + } + + private: + typename Impl::const_iterator impl; + typename Impl::const_iterator end; + }; +}; + +} // namespace Luau diff --git a/third_party/luau/Analysis/include/Luau/Simplify.h b/third_party/luau/Analysis/include/Luau/Simplify.h new file mode 100644 index 00000000..5b363e96 --- /dev/null +++ b/third_party/luau/Analysis/include/Luau/Simplify.h @@ -0,0 +1,38 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#pragma once + +#include "Luau/DenseHash.h" +#include "Luau/NotNull.h" +#include "Luau/TypeFwd.h" +#include + +namespace Luau +{ + +struct TypeArena; + +struct SimplifyResult +{ + TypeId result; + + DenseHashSet blockedTypes; +}; + +SimplifyResult simplifyIntersection(NotNull builtinTypes, NotNull arena, TypeId ty, TypeId discriminant); +SimplifyResult simplifyIntersection(NotNull builtinTypes, NotNull arena, std::set parts); + +SimplifyResult simplifyUnion(NotNull builtinTypes, NotNull arena, TypeId ty, TypeId discriminant); + +enum class Relation +{ + Disjoint, // No A is a B or vice versa + Coincident, // Every A is in B and vice versa + Intersects, // Some As are in B and some Bs are in A. ex (number | string) <-> (string | boolean) + Subset, // Every A is in B + Superset, // Every B is in A +}; + +Relation relate(TypeId left, TypeId right); + +} // namespace Luau diff --git a/third_party/luau/Analysis/include/Luau/Substitution.h b/third_party/luau/Analysis/include/Luau/Substitution.h index 2efca2df..28ebc93d 100644 --- a/third_party/luau/Analysis/include/Luau/Substitution.h +++ b/third_party/luau/Analysis/include/Luau/Substitution.h @@ -2,8 +2,7 @@ #pragma once #include "Luau/TypeArena.h" -#include "Luau/TypePack.h" -#include "Luau/Type.h" +#include "Luau/TypeFwd.h" #include "Luau/DenseHash.h" // We provide an implementation of substitution on types, @@ -69,24 +68,34 @@ struct TarjanWorklistVertex int lastEdge; }; +struct TarjanNode +{ + TypeId ty; + TypePackId tp; + + bool onStack; + bool dirty; + + // Tarjan calculates the lowlink for each vertex, + // which is the lowest ancestor index reachable from the vertex. + int lowlink; +}; + // Tarjan's algorithm for finding the SCCs in a cyclic structure. // https://en.wikipedia.org/wiki/Tarjan%27s_strongly_connected_components_algorithm struct Tarjan { + Tarjan(); + // Vertices (types and type packs) are indexed, using pre-order traversal. DenseHashMap typeToIndex{nullptr}; DenseHashMap packToIndex{nullptr}; - std::vector indexToType; - std::vector indexToPack; + + std::vector nodes; // Tarjan keeps a stack of vertices where we're still in the process // of finding their SCC. std::vector stack; - std::vector onStack; - - // Tarjan calculates the lowlink for each vertex, - // which is the lowest ancestor index reachable from the vertex. - std::vector lowlink; int childCount = 0; int childLimit = 0; @@ -98,6 +107,7 @@ struct Tarjan std::vector edgesTy; std::vector edgesTp; std::vector worklist; + // This is hot code, so we optimize recursion to a stack. TarjanResult loop(); @@ -113,45 +123,53 @@ struct Tarjan void visitChild(TypeId ty); void visitChild(TypePackId ty); + template + void visitChild(std::optional ty) + { + if (ty) + visitChild(*ty); + } + // Visit the root vertex. TarjanResult visitRoot(TypeId ty); TarjanResult visitRoot(TypePackId ty); - // Each subclass gets called back once for each edge, - // and once for each SCC. - virtual void visitEdge(int index, int parentIndex) {} - virtual void visitSCC(int index) {} + // Used to reuse the object for a new operation + void clearTarjan(const TxnLog* log); + + // Get/set the dirty bit for an index (grows the vector if needed) + bool getDirty(int index); + void setDirty(int index, bool d); + + // Find all the dirty vertices reachable from `t`. + TarjanResult findDirty(TypeId t); + TarjanResult findDirty(TypePackId t); + + // We find dirty vertices using Tarjan + void visitEdge(int index, int parentIndex); + void visitSCC(int index); // Each subclass can decide to ignore some nodes. virtual bool ignoreChildren(TypeId ty) { return false; } + virtual bool ignoreChildren(TypePackId ty) { return false; } -}; - -// We use Tarjan to calculate dirty bits. We set `dirty[i]` true -// if the vertex with index `i` can reach a dirty vertex. -struct FindDirty : Tarjan -{ - std::vector dirty; - void clearTarjan(); - - // Get/set the dirty bit for an index (grows the vector if needed) - bool getDirty(int index); - void setDirty(int index, bool d); - - // Find all the dirty vertices reachable from `t`. - TarjanResult findDirty(TypeId t); - TarjanResult findDirty(TypePackId t); + // Some subclasses might ignore children visit, but not other actions like replacing the children + virtual bool ignoreChildrenVisit(TypeId ty) + { + return ignoreChildren(ty); + } - // We find dirty vertices using Tarjan - void visitEdge(int index, int parentIndex) override; - void visitSCC(int index) override; + virtual bool ignoreChildrenVisit(TypePackId ty) + { + return ignoreChildren(ty); + } // Subclasses should say which vertices are dirty, // and what to do with dirty vertices. @@ -163,16 +181,24 @@ struct FindDirty : Tarjan // And finally substitution, which finds all the reachable dirty vertices // and replaces them with clean ones. -struct Substitution : FindDirty +struct Substitution : Tarjan { protected: - Substitution(const TxnLog* log_, TypeArena* arena) - : arena(arena) - { - log = log_; - LUAU_ASSERT(log); - LUAU_ASSERT(arena); - } + Substitution(const TxnLog* log_, TypeArena* arena); + + /* + * By default, Substitution assumes that the types produced by clean() are + * freshly allocated types that are safe to mutate. + * + * If your clean() implementation produces a type that is not safe to + * mutate, you must call dontTraverseInto on this type (or type pack) to + * prevent Substitution from attempting to perform substitutions within the + * cleaned type. + * + * See the test weird_cyclic_instantiation for an example. + */ + void dontTraverseInto(TypeId ty); + void dontTraverseInto(TypePackId tp); public: TypeArena* arena; @@ -181,13 +207,20 @@ struct Substitution : FindDirty DenseHashSet replacedTypes{nullptr}; DenseHashSet replacedTypePacks{nullptr}; + DenseHashSet noTraverseTypes{nullptr}; + DenseHashSet noTraverseTypePacks{nullptr}; + std::optional substitute(TypeId ty); std::optional substitute(TypePackId tp); + void resetState(const TxnLog* log, TypeArena* arena); + TypeId replace(TypeId ty); TypePackId replace(TypePackId tp); + void replaceChildren(TypeId ty); void replaceChildren(TypePackId tp); + TypeId clone(TypeId ty); TypePackId clone(TypePackId tp); @@ -211,6 +244,16 @@ struct Substitution : FindDirty { return arena->addTypePack(TypePackVar{tp}); } + +private: + template + std::optional replace(std::optional ty) + { + if (ty) + return replace(*ty); + else + return std::nullopt; + } }; } // namespace Luau diff --git a/third_party/luau/Analysis/include/Luau/Subtyping.h b/third_party/luau/Analysis/include/Luau/Subtyping.h new file mode 100644 index 00000000..a3f199f3 --- /dev/null +++ b/third_party/luau/Analysis/include/Luau/Subtyping.h @@ -0,0 +1,248 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Set.h" +#include "Luau/TypeFwd.h" +#include "Luau/TypePairHash.h" +#include "Luau/TypePath.h" +#include "Luau/TypeFunction.h" +#include "Luau/TypeCheckLimits.h" +#include "Luau/DenseHash.h" + +#include +#include + +namespace Luau +{ + +template +struct TryPair; +struct InternalErrorReporter; + +class TypeIds; +class Normalizer; +struct NormalizedClassType; +struct NormalizedFunctionType; +struct NormalizedStringType; +struct NormalizedType; +struct Property; +struct Scope; +struct TableIndexer; +struct TypeArena; +struct TypeCheckLimits; + +enum class SubtypingVariance +{ + // Used for an empty key. Should never appear in actual code. + Invalid, + Covariant, + // This is used to identify cases where we have a covariant + a + // contravariant reason and we need to merge them. + Contravariant, + Invariant, +}; + +struct SubtypingReasoning +{ + // The path, relative to the _root subtype_, where subtyping failed. + Path subPath; + // The path, relative to the _root supertype_, where subtyping failed. + Path superPath; + SubtypingVariance variance = SubtypingVariance::Covariant; + + bool operator==(const SubtypingReasoning& other) const; +}; + +struct SubtypingReasoningHash +{ + size_t operator()(const SubtypingReasoning& r) const; +}; + +using SubtypingReasonings = DenseHashSet; +static const SubtypingReasoning kEmptyReasoning = SubtypingReasoning{TypePath::kEmpty, TypePath::kEmpty, SubtypingVariance::Invalid}; + +struct SubtypingResult +{ + bool isSubtype = false; + bool normalizationTooComplex = false; + bool isCacheable = true; + ErrorVec errors; + /// The reason for isSubtype to be false. May not be present even if + /// isSubtype is false, depending on the input types. + SubtypingReasonings reasoning{kEmptyReasoning}; + + SubtypingResult& andAlso(const SubtypingResult& other); + SubtypingResult& orElse(const SubtypingResult& other); + SubtypingResult& withBothComponent(TypePath::Component component); + SubtypingResult& withSuperComponent(TypePath::Component component); + SubtypingResult& withSubComponent(TypePath::Component component); + SubtypingResult& withBothPath(TypePath::Path path); + SubtypingResult& withSubPath(TypePath::Path path); + SubtypingResult& withSuperPath(TypePath::Path path); + SubtypingResult& withErrors(ErrorVec& err); + SubtypingResult& withError(TypeError err); + + // Only negates the `isSubtype`. + static SubtypingResult negate(const SubtypingResult& result); + static SubtypingResult all(const std::vector& results); + static SubtypingResult any(const std::vector& results); +}; + +struct SubtypingEnvironment +{ + struct GenericBounds + { + DenseHashSet lowerBound{nullptr}; + DenseHashSet upperBound{nullptr}; + }; + + /* + * When we encounter a generic over the course of a subtyping test, we need + * to tentatively map that generic onto a type on the other side. + */ + DenseHashMap mappedGenerics{nullptr}; + DenseHashMap mappedGenericPacks{nullptr}; + + /* + * See the test cyclic_tables_are_assumed_to_be_compatible_with_classes for + * details. + * + * An empty value is equivalent to a nonexistent key. + */ + DenseHashMap substitutions{nullptr}; + + DenseHashMap, SubtypingResult, TypePairHash> ephemeralCache{{}}; + + /// Applies `mappedGenerics` to the given type. + /// This is used specifically to substitute for generics in type function instances. + std::optional applyMappedGenerics(NotNull builtinTypes, NotNull arena, TypeId ty); +}; + +struct Subtyping +{ + NotNull builtinTypes; + NotNull arena; + NotNull normalizer; + NotNull iceReporter; + + NotNull scope; + TypeCheckLimits limits; + + enum class Variance + { + Covariant, + Contravariant + }; + + Variance variance = Variance::Covariant; + + using SeenSet = Set, TypePairHash>; + + SeenSet seenTypes{{}}; + + Subtyping( + NotNull builtinTypes, + NotNull typeArena, + NotNull normalizer, + NotNull iceReporter, + NotNull scope + ); + + Subtyping(const Subtyping&) = delete; + Subtyping& operator=(const Subtyping&) = delete; + + Subtyping(Subtyping&&) = default; + Subtyping& operator=(Subtyping&&) = default; + + // Only used by unit tests to test that the cache works. + const DenseHashMap, SubtypingResult, TypePairHash>& peekCache() const + { + return resultCache; + } + + // TODO cache + // TODO cyclic types + // TODO recursion limits + + SubtypingResult isSubtype(TypeId subTy, TypeId superTy); + SubtypingResult isSubtype(TypePackId subTy, TypePackId superTy); + +private: + DenseHashMap, SubtypingResult, TypePairHash> resultCache{{}}; + + SubtypingResult cache(SubtypingEnvironment& env, SubtypingResult res, TypeId subTy, TypeId superTy); + + SubtypingResult isCovariantWith(SubtypingEnvironment& env, TypeId subTy, TypeId superTy); + SubtypingResult isCovariantWith(SubtypingEnvironment& env, TypePackId subTy, TypePackId superTy); + + template + SubtypingResult isContravariantWith(SubtypingEnvironment& env, SubTy&& subTy, SuperTy&& superTy); + + template + SubtypingResult isInvariantWith(SubtypingEnvironment& env, SubTy&& subTy, SuperTy&& superTy); + + template + SubtypingResult isCovariantWith(SubtypingEnvironment& env, const TryPair& pair); + + template + SubtypingResult isContravariantWith(SubtypingEnvironment& env, const TryPair& pair); + + template + SubtypingResult isInvariantWith(SubtypingEnvironment& env, const TryPair& pair); + + SubtypingResult isCovariantWith(SubtypingEnvironment& env, TypeId subTy, const UnionType* superUnion); + SubtypingResult isCovariantWith(SubtypingEnvironment& env, const UnionType* subUnion, TypeId superTy); + SubtypingResult isCovariantWith(SubtypingEnvironment& env, TypeId subTy, const IntersectionType* superIntersection); + SubtypingResult isCovariantWith(SubtypingEnvironment& env, const IntersectionType* subIntersection, TypeId superTy); + + SubtypingResult isCovariantWith(SubtypingEnvironment& env, const NegationType* subNegation, TypeId superTy); + SubtypingResult isCovariantWith(SubtypingEnvironment& env, const TypeId subTy, const NegationType* superNegation); + + SubtypingResult isCovariantWith(SubtypingEnvironment& env, const PrimitiveType* subPrim, const PrimitiveType* superPrim); + SubtypingResult isCovariantWith(SubtypingEnvironment& env, const SingletonType* subSingleton, const PrimitiveType* superPrim); + SubtypingResult isCovariantWith(SubtypingEnvironment& env, const SingletonType* subSingleton, const SingletonType* superSingleton); + SubtypingResult isCovariantWith(SubtypingEnvironment& env, const TableType* subTable, const TableType* superTable); + SubtypingResult isCovariantWith(SubtypingEnvironment& env, const MetatableType* subMt, const MetatableType* superMt); + SubtypingResult isCovariantWith(SubtypingEnvironment& env, const MetatableType* subMt, const TableType* superTable); + SubtypingResult isCovariantWith(SubtypingEnvironment& env, const ClassType* subClass, const ClassType* superClass); + SubtypingResult isCovariantWith(SubtypingEnvironment& env, TypeId subTy, const ClassType* subClass, TypeId superTy, const TableType* superTable); + SubtypingResult isCovariantWith(SubtypingEnvironment& env, const FunctionType* subFunction, const FunctionType* superFunction); + SubtypingResult isCovariantWith(SubtypingEnvironment& env, const PrimitiveType* subPrim, const TableType* superTable); + SubtypingResult isCovariantWith(SubtypingEnvironment& env, const SingletonType* subSingleton, const TableType* superTable); + + SubtypingResult isCovariantWith(SubtypingEnvironment& env, const TableIndexer& subIndexer, const TableIndexer& superIndexer); + SubtypingResult isCovariantWith(SubtypingEnvironment& env, const Property& subProperty, const Property& superProperty, const std::string& name); + + SubtypingResult isCovariantWith( + SubtypingEnvironment& env, + const std::shared_ptr& subNorm, + const std::shared_ptr& superNorm + ); + SubtypingResult isCovariantWith(SubtypingEnvironment& env, const NormalizedClassType& subClass, const NormalizedClassType& superClass); + SubtypingResult isCovariantWith(SubtypingEnvironment& env, const NormalizedClassType& subClass, const TypeIds& superTables); + SubtypingResult isCovariantWith(SubtypingEnvironment& env, const NormalizedStringType& subString, const NormalizedStringType& superString); + SubtypingResult isCovariantWith(SubtypingEnvironment& env, const NormalizedStringType& subString, const TypeIds& superTables); + SubtypingResult isCovariantWith( + SubtypingEnvironment& env, + const NormalizedFunctionType& subFunction, + const NormalizedFunctionType& superFunction + ); + SubtypingResult isCovariantWith(SubtypingEnvironment& env, const TypeIds& subTypes, const TypeIds& superTypes); + + SubtypingResult isCovariantWith(SubtypingEnvironment& env, const VariadicTypePack* subVariadic, const VariadicTypePack* superVariadic); + SubtypingResult isCovariantWith(SubtypingEnvironment& env, const TypeFunctionInstanceType* subFunctionInstance, const TypeId superTy); + SubtypingResult isCovariantWith(SubtypingEnvironment& env, const TypeId subTy, const TypeFunctionInstanceType* superFunctionInstance); + + bool bindGeneric(SubtypingEnvironment& env, TypeId subTp, TypeId superTp); + bool bindGeneric(SubtypingEnvironment& env, TypePackId subTp, TypePackId superTp); + + template + TypeId makeAggregateType(const Container& container, TypeId orElse); + + std::pair handleTypeFunctionReductionResult(const TypeFunctionInstanceType* functionInstance); + + [[noreturn]] void unexpected(TypeId ty); + [[noreturn]] void unexpected(TypePackId tp); +}; + +} // namespace Luau diff --git a/third_party/luau/Analysis/include/Luau/Symbol.h b/third_party/luau/Analysis/include/Luau/Symbol.h index b47554e0..337e2a9f 100644 --- a/third_party/luau/Analysis/include/Luau/Symbol.h +++ b/third_party/luau/Analysis/include/Luau/Symbol.h @@ -6,8 +6,6 @@ #include -LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) - namespace Luau { @@ -42,17 +40,7 @@ struct Symbol return local != nullptr || global.value != nullptr; } - bool operator==(const Symbol& rhs) const - { - if (local) - return local == rhs.local; - else if (global.value) - return rhs.global.value && global == rhs.global.value; // Subtlety: AstName::operator==(const char*) uses strcmp, not pointer identity. - else if (FFlag::DebugLuauDeferredConstraintResolution) - return !rhs.local && !rhs.global.value; // Reflexivity: we already know `this` Symbol is empty, so check that rhs is. - else - return false; - } + bool operator==(const Symbol& rhs) const; bool operator!=(const Symbol& rhs) const { diff --git a/third_party/luau/Analysis/include/Luau/TableLiteralInference.h b/third_party/luau/Analysis/include/Luau/TableLiteralInference.h new file mode 100644 index 00000000..dd9ecf97 --- /dev/null +++ b/third_party/luau/Analysis/include/Luau/TableLiteralInference.h @@ -0,0 +1,28 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#pragma once + +#include "Luau/DenseHash.h" +#include "Luau/NotNull.h" +#include "Luau/TypeFwd.h" + +namespace Luau +{ + +struct TypeArena; +struct BuiltinTypes; +struct Unifier2; +class AstExpr; + +TypeId matchLiteralType( + NotNull> astTypes, + NotNull> astExpectedTypes, + NotNull builtinTypes, + NotNull arena, + NotNull unifier, + TypeId expectedType, + TypeId exprType, + const AstExpr* expr, + std::vector& toBlock +); +} // namespace Luau diff --git a/third_party/luau/Analysis/include/Luau/ToDot.h b/third_party/luau/Analysis/include/Luau/ToDot.h index 1a9c2811..6fa99ec3 100644 --- a/third_party/luau/Analysis/include/Luau/ToDot.h +++ b/third_party/luau/Analysis/include/Luau/ToDot.h @@ -2,16 +2,12 @@ #pragma once #include "Luau/Common.h" +#include "Luau/TypeFwd.h" #include namespace Luau { -struct Type; -using TypeId = const Type*; - -struct TypePackVar; -using TypePackId = const TypePackVar*; struct ToDotOptions { diff --git a/third_party/luau/Analysis/include/Luau/ToString.h b/third_party/luau/Analysis/include/Luau/ToString.h index 7758e8f9..f8001e08 100644 --- a/third_party/luau/Analysis/include/Luau/ToString.h +++ b/third_party/luau/Analysis/include/Luau/ToString.h @@ -2,6 +2,7 @@ #pragma once #include "Luau/Common.h" +#include "Luau/TypeFwd.h" #include #include @@ -19,13 +20,6 @@ class AstExpr; struct Scope; -struct Type; -using TypeId = const Type*; - -struct TypePackVar; -using TypePackId = const TypePackVar*; - -struct FunctionType; struct Constraint; struct Position; @@ -39,6 +33,11 @@ struct ToStringNameMap struct ToStringOptions { + ToStringOptions(bool exhaustive = false) + : exhaustive(exhaustive) + { + } + bool exhaustive = false; // If true, we produce complete output rather than comprehensible output bool useLineBreaks = false; // If true, we insert new lines to separate long results such as table entries/metatable. bool functionTypeArguments = false; // If true, output function type argument names when they are available @@ -47,6 +46,7 @@ struct ToStringOptions bool hideFunctionSelfArgument = false; // If true, `self: X` will be omitted from the function signature if the function has self size_t maxTableLength = size_t(FInt::LuauTableTypeMaximumStringifierLength); // Only applied to TableTypes size_t maxTypeLength = size_t(FInt::LuauTypeMaximumStringifierLength); + size_t compositeTypesSingleLineLimit = 5; // The number of type elements permitted on a single line when printing type unions/intersections ToStringNameMap nameMap; std::shared_ptr scope; // If present, module names will be added and types that are not available in scope will be marked as 'invalid' std::vector namedFunctionOverrideArgNames; // If present, named function argument names will be overridden @@ -99,10 +99,7 @@ inline std::string toString(const Constraint& c, ToStringOptions&& opts) return toString(c, opts); } -inline std::string toString(const Constraint& c) -{ - return toString(c, ToStringOptions{}); -} +std::string toString(const Constraint& c); std::string toString(const Type& tv, ToStringOptions& opts); std::string toString(const TypePackVar& tp, ToStringOptions& opts); @@ -142,6 +139,16 @@ std::string dump(const std::shared_ptr& scope, const char* name); std::string generateName(size_t n); std::string toString(const Position& position); -std::string toString(const Location& location); +std::string toString(const Location& location, int offset = 0, bool useBegin = true); + +std::string toString(const TypeOrPack& tyOrTp, ToStringOptions& opts); + +inline std::string toString(const TypeOrPack& tyOrTp) +{ + ToStringOptions opts{}; + return toString(tyOrTp, opts); +} + +std::string dump(const TypeOrPack& tyOrTp); } // namespace Luau diff --git a/third_party/luau/Analysis/include/Luau/TxnLog.h b/third_party/luau/Analysis/include/Luau/TxnLog.h index 0ed8a49a..951f89ee 100644 --- a/third_party/luau/Analysis/include/Luau/TxnLog.h +++ b/third_party/luau/Analysis/include/Luau/TxnLog.h @@ -19,6 +19,10 @@ struct PendingType // The pending Type state. Type pending; + // On very rare occasions, we need to delete an entry from the TxnLog. + // DenseHashMap does not afford that so we note its deadness here. + bool dead = false; + explicit PendingType(Type state) : pending(std::move(state)) { @@ -61,10 +65,11 @@ T* getMutable(PendingTypePack* pending) // Log of what TypeIds we are rebinding, to be committed later. struct TxnLog { - TxnLog() + explicit TxnLog(bool useScopes = false) : typeVarChanges(nullptr) , typePackChanges(nullptr) , ownedSeen() + , useScopes(useScopes) , sharedSeen(&ownedSeen) { } @@ -297,6 +302,18 @@ struct TxnLog void popSeen(TypeOrPackId lhs, TypeOrPackId rhs); public: + // There is one spot in the code where TxnLog has to reconcile collisions + // between parallel logs. In that codepath, we have to work out which of two + // FreeTypes subsumes the other. If useScopes is false, the TypeLevel is + // used. Else we use the embedded Scope*. + bool useScopes = false; + + // It is sometimes the case under DCR that we speculatively rebind + // GenericTypes to other types as though they were free. We mark logs that + // contain these kinds of substitutions as radioactive so that we know that + // we must never commit one. + bool radioactive = false; + // Used to avoid infinite recursion when types are cyclic. // Shared with all the descendent TxnLogs. std::vector>* sharedSeen; diff --git a/third_party/luau/Analysis/include/Luau/Type.h b/third_party/luau/Analysis/include/Luau/Type.h index 5d92cbd0..585c2493 100644 --- a/third_party/luau/Analysis/include/Luau/Type.h +++ b/third_party/luau/Analysis/include/Luau/Type.h @@ -1,6 +1,8 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once +#include "Luau/TypeFwd.h" + #include "Luau/Ast.h" #include "Luau/Common.h" #include "Luau/Refinement.h" @@ -9,16 +11,15 @@ #include "Luau/Predicate.h" #include "Luau/Unifiable.h" #include "Luau/Variant.h" +#include "Luau/VecDeque.h" #include -#include #include #include #include #include #include #include -#include #include LUAU_FASTINT(LuauTableTypeMaximumStringifierLength) @@ -31,6 +32,9 @@ struct TypeArena; struct Scope; using ScopePtr = std::shared_ptr; +struct TypeFunction; +struct Constraint; + /** * There are three kinds of type variables: * - `Free` variables are metavariables, which stand for unconstrained types. @@ -57,31 +61,17 @@ using ScopePtr = std::shared_ptr; * ``` */ -// So... why `const T*` here rather than `T*`? -// It's because we've had problems caused by the type graph being mutated -// in ways it shouldn't be, for example mutating types from other modules. -// To try to control this, we make the use of types immutable by default, -// then provide explicit mutable access via getMutable and asMutable. -// This means we can grep for all the places we're mutating the type graph, -// and it makes it possible to provide other APIs (e.g. the txn log) -// which control mutable access to the type graph. -struct TypePackVar; -using TypePackId = const TypePackVar*; - -struct Type; - -// Should never be null -using TypeId = const Type*; - using Name = std::string; -// A free type var is one whose exact shape has yet to be fully determined. +// A free type is one whose exact shape has yet to be fully determined. struct FreeType { explicit FreeType(TypeLevel level); explicit FreeType(Scope* scope); FreeType(Scope* scope, TypeLevel level); + FreeType(Scope* scope, TypeId lowerBound, TypeId upperBound); + int index; TypeLevel level; Scope* scope = nullptr; @@ -90,6 +80,10 @@ struct FreeType // recursive type alias whose definitions haven't been // resolved yet. bool forwardedTypeAlias = false; + + // Only used under local type inference + TypeId lowerBound = nullptr; + TypeId upperBound = nullptr; }; struct GenericType @@ -134,7 +128,14 @@ struct BlockedType BlockedType(); int index; - static int DEPRECATED_nextIndex; + Constraint* getOwner() const; + void setOwner(Constraint* newOwner); + void replaceOwner(Constraint* newOwner); + +private: + // The constraint that is intended to unblock this type. Other constraints + // should block on this constraint if present. + Constraint* owner = nullptr; }; struct PrimitiveType @@ -148,6 +149,7 @@ struct PrimitiveType Thread, Function, Table, + Buffer, }; Type type; @@ -165,7 +167,7 @@ struct PrimitiveType } }; -// Singleton types https://github.com/Roblox/luau/blob/master/rfcs/syntax-singleton-types.md +// Singleton types https://github.com/luau-lang/rfcs/blob/master/docs/syntax-singleton-types.md // Types for true and false struct BooleanSingleton { @@ -238,22 +240,6 @@ const T* get(const SingletonType* stv) return nullptr; } -struct GenericTypeDefinition -{ - TypeId ty; - std::optional defaultValue; - - bool operator==(const GenericTypeDefinition& rhs) const; -}; - -struct GenericTypePackDefinition -{ - TypePackId tp; - std::optional defaultValue; - - bool operator==(const GenericTypePackDefinition& rhs) const; -}; - struct FunctionArgument { Name name; @@ -290,12 +276,13 @@ struct WithPredicate } }; -using MagicFunction = std::function>( - struct TypeChecker&, const std::shared_ptr&, const class AstExprCall&, WithPredicate)>; +using MagicFunction = std::function>(struct TypeChecker&, const std::shared_ptr&, const class AstExprCall&, WithPredicate)>; struct MagicFunctionCallContext { NotNull solver; + NotNull constraint; const class AstExprCall* callSite; TypePackId arguments; TypePackId result; @@ -318,19 +305,46 @@ struct FunctionType FunctionType(TypePackId argTypes, TypePackId retTypes, std::optional defn = {}, bool hasSelf = false); // Global polymorphic function - FunctionType(std::vector generics, std::vector genericPacks, TypePackId argTypes, TypePackId retTypes, - std::optional defn = {}, bool hasSelf = false); + FunctionType( + std::vector generics, + std::vector genericPacks, + TypePackId argTypes, + TypePackId retTypes, + std::optional defn = {}, + bool hasSelf = false + ); // Local monomorphic function FunctionType(TypeLevel level, TypePackId argTypes, TypePackId retTypes, std::optional defn = {}, bool hasSelf = false); FunctionType( - TypeLevel level, Scope* scope, TypePackId argTypes, TypePackId retTypes, std::optional defn = {}, bool hasSelf = false); + TypeLevel level, + Scope* scope, + TypePackId argTypes, + TypePackId retTypes, + std::optional defn = {}, + bool hasSelf = false + ); // Local polymorphic function - FunctionType(TypeLevel level, std::vector generics, std::vector genericPacks, TypePackId argTypes, TypePackId retTypes, - std::optional defn = {}, bool hasSelf = false); - FunctionType(TypeLevel level, Scope* scope, std::vector generics, std::vector genericPacks, TypePackId argTypes, - TypePackId retTypes, std::optional defn = {}, bool hasSelf = false); + FunctionType( + TypeLevel level, + std::vector generics, + std::vector genericPacks, + TypePackId argTypes, + TypePackId retTypes, + std::optional defn = {}, + bool hasSelf = false + ); + FunctionType( + TypeLevel level, + Scope* scope, + std::vector generics, + std::vector genericPacks, + TypePackId argTypes, + TypePackId retTypes, + std::optional defn = {}, + bool hasSelf = false + ); std::optional definition; /// These should all be generic @@ -346,7 +360,10 @@ struct FunctionType DcrMagicFunction dcrMagicFunction = nullptr; DcrMagicRefinement dcrMagicRefinement = nullptr; bool hasSelf; - bool hasNoGenerics = false; + // `hasNoFreeOrGenericTypes` should be true if and only if the type does not have any free or generic types present inside it. + // this flag is used as an optimization to exit early from procedures that manipulate free or generic types. + bool hasNoFreeOrGenericTypes = false; + bool isCheckedFunction = false; }; enum class TableState @@ -383,33 +400,54 @@ struct Property static Property writeonly(TypeId ty); static Property rw(TypeId ty); // Shared read-write type. static Property rw(TypeId read, TypeId write); // Separate read-write type. - static std::optional create(std::optional read, std::optional write); + + // Invariant: at least one of the two optionals are not nullopt! + // If the read type is not nullopt, but the write type is, then the property is readonly. + // If the read type is nullopt, but the write type is not, then the property is writeonly. + // If the read and write types are not nullopt, then the property is read and write. + // Otherwise, an assertion where read and write types are both nullopt will be tripped. + static Property create(std::optional read, std::optional write); bool deprecated = false; std::string deprecatedSuggestion; + + // If this property was inferred from an expression, this field will be + // populated with the source location of the corresponding table property. std::optional location = std::nullopt; + + // If this property was built from an explicit type annotation, this field + // will be populated with the source location of that table property. + std::optional typeLocation = std::nullopt; + Tags tags; std::optional documentationSymbol; // DEPRECATED // TODO: Kill all constructors in favor of `Property::rw(TypeId read, TypeId write)` and friends. Property(); - Property(TypeId readTy, bool deprecated = false, const std::string& deprecatedSuggestion = "", std::optional location = std::nullopt, - const Tags& tags = {}, const std::optional& documentationSymbol = std::nullopt); + Property( + TypeId readTy, + bool deprecated = false, + const std::string& deprecatedSuggestion = "", + std::optional location = std::nullopt, + const Tags& tags = {}, + const std::optional& documentationSymbol = std::nullopt, + std::optional typeLocation = std::nullopt + ); // DEPRECATED: Should only be called in non-RWP! We assert that the `readTy` is not nullopt. // TODO: Kill once we don't have non-RWP. TypeId type() const; void setType(TypeId ty); - // Should only be called in RWP! - // We do not assert that `readTy` nor `writeTy` are nullopt or not. - // The invariant is that at least one of them mustn't be nullopt, which we do assert here. - // TODO: Kill this in favor of exposing `readTy`/`writeTy` directly? If we do, we'll lose the asserts which will be useful while debugging. - std::optional readType() const; - std::optional writeType() const; + // Sets the write type of this property to the read type. + void makeShared(); + + bool isShared() const; + bool isReadOnly() const; + bool isWriteOnly() const; + bool isReadWrite() const; -private: std::optional readTy; std::optional writeTy; }; @@ -449,6 +487,11 @@ struct TableType // Methods of this table that have an untyped self will use the same shared self type. std::optional selfTy; + + // We track the number of as-yet-unadded properties to unsealed tables. + // Some constraints will use this information to decide whether or not they + // are able to dispatch. + size_t remainingProps = 0; }; // Represents a metatable attached to a table type. Somewhat analogous to a bound type. @@ -489,9 +532,41 @@ struct ClassType Tags tags; std::shared_ptr userData; ModuleName definitionModuleName; + std::optional definitionLocation; + std::optional indexer; + + ClassType( + Name name, + Props props, + std::optional parent, + std::optional metatable, + Tags tags, + std::shared_ptr userData, + ModuleName definitionModuleName, + std::optional definitionLocation + ) + : name(name) + , props(props) + , parent(parent) + , metatable(metatable) + , tags(tags) + , userData(userData) + , definitionModuleName(definitionModuleName) + , definitionLocation(definitionLocation) + { + } - ClassType(Name name, Props props, std::optional parent, std::optional metatable, Tags tags, - std::shared_ptr userData, ModuleName definitionModuleName) + ClassType( + Name name, + Props props, + std::optional parent, + std::optional metatable, + Tags tags, + std::shared_ptr userData, + ModuleName definitionModuleName, + std::optional definitionLocation, + std::optional indexer + ) : name(name) , props(props) , parent(parent) @@ -499,44 +574,46 @@ struct ClassType , tags(tags) , userData(userData) , definitionModuleName(definitionModuleName) + , definitionLocation(definitionLocation) + , indexer(indexer) { } }; -struct TypeFun +/** + * An instance of a type function that has not yet been reduced to a more concrete + * type. The constraint solver receives a constraint to reduce each + * TypeFunctionInstanceType to a concrete type. A design detail is important to + * note here: the parameters for this instantiation of the type function are + * contained within this type, so that they can be substituted. + */ +struct TypeFunctionInstanceType { - // These should all be generic - std::vector typeParams; - std::vector typePackParams; + NotNull function; - /** The underlying type. - * - * WARNING! This is not safe to use as a type if typeParams is not empty!! - * You must first use TypeChecker::instantiateTypeFun to turn it into a real type. - */ - TypeId type; - - TypeFun() = default; + std::vector typeArguments; + std::vector packArguments; - explicit TypeFun(TypeId ty) - : type(ty) + TypeFunctionInstanceType(NotNull function, std::vector typeArguments, std::vector packArguments) + : function(function) + , typeArguments(typeArguments) + , packArguments(packArguments) { } - TypeFun(std::vector typeParams, TypeId type) - : typeParams(std::move(typeParams)) - , type(type) + TypeFunctionInstanceType(const TypeFunction& function, std::vector typeArguments) + : function{&function} + , typeArguments(typeArguments) + , packArguments{} { } - TypeFun(std::vector typeParams, std::vector typePackParams, TypeId type) - : typeParams(std::move(typeParams)) - , typePackParams(std::move(typePackParams)) - , type(type) + TypeFunctionInstanceType(const TypeFunction& function, std::vector typeArguments, std::vector packArguments) + : function{&function} + , typeArguments(typeArguments) + , packArguments(packArguments) { } - - bool operator==(const TypeFun& rhs) const; }; /** Represents a pending type alias instantiation. @@ -579,30 +656,26 @@ struct IntersectionType struct LazyType { LazyType() = default; - LazyType(std::function thunk_DEPRECATED, std::function unwrap) - : thunk_DEPRECATED(thunk_DEPRECATED) - , unwrap(unwrap) + LazyType(std::function unwrap) + : unwrap(unwrap) { } // std::atomic is sad and requires a manual copy LazyType(const LazyType& rhs) - : thunk_DEPRECATED(rhs.thunk_DEPRECATED) - , unwrap(rhs.unwrap) + : unwrap(rhs.unwrap) , unwrapped(rhs.unwrapped.load()) { } LazyType(LazyType&& rhs) noexcept - : thunk_DEPRECATED(std::move(rhs.thunk_DEPRECATED)) - , unwrap(std::move(rhs.unwrap)) + : unwrap(std::move(rhs.unwrap)) , unwrapped(rhs.unwrapped.load()) { } LazyType& operator=(const LazyType& rhs) { - thunk_DEPRECATED = rhs.thunk_DEPRECATED; unwrap = rhs.unwrap; unwrapped = rhs.unwrapped.load(); @@ -611,15 +684,12 @@ struct LazyType LazyType& operator=(LazyType&& rhs) noexcept { - thunk_DEPRECATED = std::move(rhs.thunk_DEPRECATED); unwrap = std::move(rhs.unwrap); unwrapped = rhs.unwrapped.load(); return *this; } - std::function thunk_DEPRECATED; - std::function unwrap; std::atomic unwrapped = nullptr; }; @@ -640,8 +710,26 @@ struct NegationType using ErrorType = Unifiable::Error; -using TypeVariant = Unifiable::Variant; +using TypeVariant = Unifiable::Variant< + TypeId, + FreeType, + GenericType, + PrimitiveType, + SingletonType, + BlockedType, + PendingExpansionType, + FunctionType, + TableType, + MetatableType, + ClassType, + AnyType, + UnionType, + IntersectionType, + LazyType, + UnknownType, + NeverType, + NegationType, + TypeFunctionInstanceType>; struct Type final { @@ -689,12 +777,72 @@ struct Type final Type& operator=(const Type& rhs); }; +struct GenericTypeDefinition +{ + TypeId ty; + std::optional defaultValue; + + bool operator==(const GenericTypeDefinition& rhs) const; +}; + +struct GenericTypePackDefinition +{ + TypePackId tp; + std::optional defaultValue; + + bool operator==(const GenericTypePackDefinition& rhs) const; +}; + +struct TypeFun +{ + // These should all be generic + std::vector typeParams; + std::vector typePackParams; + + /** The underlying type. + * + * WARNING! This is not safe to use as a type if typeParams is not empty!! + * You must first use TypeChecker::instantiateTypeFun to turn it into a real type. + */ + TypeId type; + + TypeFun() = default; + + explicit TypeFun(TypeId ty) + : type(ty) + { + } + + TypeFun(std::vector typeParams, TypeId type) + : typeParams(std::move(typeParams)) + , type(type) + { + } + + TypeFun(std::vector typeParams, std::vector typePackParams, TypeId type) + : typeParams(std::move(typeParams)) + , typePackParams(std::move(typePackParams)) + , type(type) + { + } + + bool operator==(const TypeFun& rhs) const; +}; + using SeenSet = std::set>; bool areEqual(SeenSet& seen, const Type& lhs, const Type& rhs); +enum class FollowOption +{ + Normal, + DisableLazyTypeThunks, +}; + // Follow BoundTypes until we get to something real TypeId follow(TypeId t); -TypeId follow(TypeId t, std::function mapper); +TypeId follow(TypeId t, FollowOption followOption); +TypeId follow(TypeId t, const void* context, TypeId (*mapper)(const void*, TypeId)); +TypeId follow(TypeId t, FollowOption followOption, const void* context, TypeId (*mapper)(const void*, TypeId)); std::vector flattenIntersection(TypeId ty); @@ -704,8 +852,10 @@ bool isBoolean(TypeId ty); bool isNumber(TypeId ty); bool isString(TypeId ty); bool isThread(TypeId ty); +bool isBuffer(TypeId ty); bool isOptional(TypeId ty); bool isTableIntersection(TypeId ty); +bool isTableUnion(TypeId ty); bool isOverloadedFunction(TypeId ty); // True when string is a subtype of ty @@ -749,18 +899,20 @@ struct BuiltinTypes TypeId errorRecoveryType() const; TypePackId errorRecoveryTypePack() const; + friend TypeId makeStringMetatable(NotNull builtinTypes); + friend struct GlobalTypes; + private: std::unique_ptr arena; bool debugFreezeArena = false; - TypeId makeStringMetatable(); - public: const TypeId nilType; const TypeId numberType; const TypeId stringType; const TypeId booleanType; const TypeId threadType; + const TypeId bufferType; const TypeId functionType; const TypeId classType; const TypeId tableType; @@ -777,7 +929,9 @@ struct BuiltinTypes const TypeId optionalNumberType; const TypeId optionalStringType; + const TypePackId emptyTypePack; const TypePackId anyTypePack; + const TypePackId unknownTypePack; const TypePackId neverTypePack; const TypePackId uninhabitableTypePack; const TypePackId errorTypePack; @@ -798,6 +952,18 @@ bool isSubclass(const ClassType* cls, const ClassType* parent); Type* asMutable(TypeId ty); +template +bool is(T&& tv) +{ + if (!tv) + return false; + + if constexpr (std::is_same_v && !(std::is_same_v || ...)) + LUAU_ASSERT(get_if(&tv->ty) == nullptr); + + return (get(tv) || ...); +} + template const T* get(TypeId tv) { @@ -915,8 +1081,8 @@ struct TypeIterator // (T* t, size_t currentIndex) using SavedIterInfo = std::pair; - std::deque stack; - std::unordered_set seen; // Only needed to protect the iterator from hanging the thread. + VecDeque stack; + DenseHashSet seen{nullptr}; // Only needed to protect the iterator from hanging the thread. void advance() { @@ -943,7 +1109,7 @@ struct TypeIterator { // If we're about to descend into a cyclic type, we should skip over this. // Ideally this should never happen, but alas it does from time to time. :( - if (seen.find(inner) != seen.end()) + if (seen.contains(inner)) advance(); else { @@ -959,6 +1125,8 @@ struct TypeIterator } }; +TypeId freshType(NotNull arena, NotNull builtinTypes, Scope* scope); + using TypeIdPredicate = std::function(TypeId)>; std::vector filterMap(TypeId type, TypeIdPredicate predicate); @@ -993,4 +1161,7 @@ LUAU_NOINLINE T* emplaceType(Type* ty, Args&&... args) return &ty->ty.emplace(std::forward(args)...); } +template<> +LUAU_NOINLINE Unifiable::Bound* emplaceType(Type* ty, TypeId& tyArg); + } // namespace Luau diff --git a/third_party/luau/Analysis/include/Luau/TypeArena.h b/third_party/luau/Analysis/include/Luau/TypeArena.h index 0e69bb4a..4f8aea87 100644 --- a/third_party/luau/Analysis/include/Luau/TypeArena.h +++ b/third_party/luau/Analysis/include/Luau/TypeArena.h @@ -9,12 +9,16 @@ namespace Luau { +struct Module; struct TypeArena { TypedAllocator types; TypedAllocator typePacks; + // Owning module, if any + Module* owningModule = nullptr; + void clear(); template @@ -44,6 +48,11 @@ struct TypeArena { return addTypePack(TypePackVar(std::move(tp))); } + + TypeId addTypeFunction(const TypeFunction& function, std::initializer_list types); + TypeId addTypeFunction(const TypeFunction& function, std::vector typeArguments, std::vector packArguments = {}); + TypePackId addTypePackFunction(const TypePackFunction& function, std::initializer_list types); + TypePackId addTypePackFunction(const TypePackFunction& function, std::vector typeArguments, std::vector packArguments = {}); }; void freeze(TypeArena& arena); diff --git a/third_party/luau/Analysis/include/Luau/TypeCheckLimits.h b/third_party/luau/Analysis/include/Luau/TypeCheckLimits.h new file mode 100644 index 00000000..9eabe0ff --- /dev/null +++ b/third_party/luau/Analysis/include/Luau/TypeCheckLimits.h @@ -0,0 +1,41 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Cancellation.h" +#include "Luau/Error.h" + +#include +#include +#include + +namespace Luau +{ + +class TimeLimitError : public InternalCompilerError +{ +public: + explicit TimeLimitError(const std::string& moduleName) + : InternalCompilerError("Typeinfer failed to complete in allotted time", moduleName) + { + } +}; + +class UserCancelError : public InternalCompilerError +{ +public: + explicit UserCancelError(const std::string& moduleName) + : InternalCompilerError("Analysis has been cancelled by user", moduleName) + { + } +}; + +struct TypeCheckLimits +{ + std::optional finishTime; + std::optional instantiationChildLimit; + std::optional unifierIterationLimit; + + std::shared_ptr cancellationToken; +}; + +} // namespace Luau diff --git a/third_party/luau/Analysis/include/Luau/TypeChecker2.h b/third_party/luau/Analysis/include/Luau/TypeChecker2.h index def00a44..981fdfe6 100644 --- a/third_party/luau/Analysis/include/Luau/TypeChecker2.h +++ b/third_party/luau/Analysis/include/Luau/TypeChecker2.h @@ -2,16 +2,25 @@ #pragma once -#include "Luau/Ast.h" -#include "Luau/Module.h" #include "Luau/NotNull.h" namespace Luau { -struct DcrLogger; struct BuiltinTypes; +struct DcrLogger; +struct TypeCheckLimits; +struct UnifierSharedState; +struct SourceModule; +struct Module; -void check(NotNull builtinTypes, NotNull sharedState, DcrLogger* logger, const SourceModule& sourceModule, Module* module); +void check( + NotNull builtinTypes, + NotNull sharedState, + NotNull limits, + DcrLogger* logger, + const SourceModule& sourceModule, + Module* module +); } // namespace Luau diff --git a/third_party/luau/Analysis/include/Luau/TypeFunction.h b/third_party/luau/Analysis/include/Luau/TypeFunction.h new file mode 100644 index 00000000..ad7b92ef --- /dev/null +++ b/third_party/luau/Analysis/include/Luau/TypeFunction.h @@ -0,0 +1,195 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/ConstraintSolver.h" +#include "Luau/Error.h" +#include "Luau/NotNull.h" +#include "Luau/TypeCheckLimits.h" +#include "Luau/TypeFwd.h" + +#include +#include +#include + +namespace Luau +{ + +struct TypeArena; +struct TxnLog; +class Normalizer; + +struct TypeFunctionContext +{ + NotNull arena; + NotNull builtins; + NotNull scope; + NotNull normalizer; + NotNull ice; + NotNull limits; + + // nullptr if the type function is being reduced outside of the constraint solver. + ConstraintSolver* solver; + // The constraint being reduced in this run of the reduction + const Constraint* constraint; + + TypeFunctionContext(NotNull cs, NotNull scope, NotNull constraint) + : arena(cs->arena) + , builtins(cs->builtinTypes) + , scope(scope) + , normalizer(cs->normalizer) + , ice(NotNull{&cs->iceReporter}) + , limits(NotNull{&cs->limits}) + , solver(cs.get()) + , constraint(constraint.get()) + { + } + + TypeFunctionContext( + NotNull arena, + NotNull builtins, + NotNull scope, + NotNull normalizer, + NotNull ice, + NotNull limits + ) + : arena(arena) + , builtins(builtins) + , scope(scope) + , normalizer(normalizer) + , ice(ice) + , limits(limits) + , solver(nullptr) + , constraint(nullptr) + { + } + + NotNull pushConstraint(ConstraintV&& c); +}; + +/// Represents a reduction result, which may have successfully reduced the type, +/// may have concretely failed to reduce the type, or may simply be stuck +/// without more information. +template +struct TypeFunctionReductionResult +{ + /// The result of the reduction, if any. If this is nullopt, the type function + /// could not be reduced. + std::optional result; + /// Whether the result is uninhabited: whether we know, unambiguously and + /// permanently, whether this type function reduction results in an + /// uninhabitable type. This will trigger an error to be reported. + bool uninhabited; + /// Any types that need to be progressed or mutated before the reduction may + /// proceed. + std::vector blockedTypes; + /// Any type packs that need to be progressed or mutated before the + /// reduction may proceed. + std::vector blockedPacks; +}; + +template +using ReducerFunction = + std::function(T, const std::vector&, const std::vector&, NotNull)>; + +/// Represents a type function that may be applied to map a series of types and +/// type packs to a single output type. +struct TypeFunction +{ + /// The human-readable name of the type function. Used to stringify instance + /// types. + std::string name; + + /// The reducer function for the type function. + ReducerFunction reducer; +}; + +/// Represents a type function that may be applied to map a series of types and +/// type packs to a single output type pack. +struct TypePackFunction +{ + /// The human-readable name of the type pack function. Used to stringify + /// instance packs. + std::string name; + + /// The reducer function for the type pack function. + ReducerFunction reducer; +}; + +struct FunctionGraphReductionResult +{ + ErrorVec errors; + DenseHashSet blockedTypes{nullptr}; + DenseHashSet blockedPacks{nullptr}; + DenseHashSet reducedTypes{nullptr}; + DenseHashSet reducedPacks{nullptr}; +}; + +/** + * Attempt to reduce all instances of any type or type pack functions in the type + * graph provided. + * + * @param entrypoint the entry point to the type graph. + * @param location the location the reduction is occurring at; used to populate + * type errors. + * @param arena an arena to allocate types into. + * @param builtins the built-in types. + * @param normalizer the normalizer to use when normalizing types + * @param ice the internal error reporter to use for ICEs + */ +FunctionGraphReductionResult reduceTypeFunctions(TypeId entrypoint, Location location, TypeFunctionContext, bool force = false); + +/** + * Attempt to reduce all instances of any type or type pack functions in the type + * graph provided. + * + * @param entrypoint the entry point to the type graph. + * @param location the location the reduction is occurring at; used to populate + * type errors. + * @param arena an arena to allocate types into. + * @param builtins the built-in types. + * @param normalizer the normalizer to use when normalizing types + * @param ice the internal error reporter to use for ICEs + */ +FunctionGraphReductionResult reduceTypeFunctions(TypePackId entrypoint, Location location, TypeFunctionContext, bool force = false); + +struct BuiltinTypeFunctions +{ + BuiltinTypeFunctions(); + + TypeFunction notFunc; + TypeFunction lenFunc; + TypeFunction unmFunc; + + TypeFunction addFunc; + TypeFunction subFunc; + TypeFunction mulFunc; + TypeFunction divFunc; + TypeFunction idivFunc; + TypeFunction powFunc; + TypeFunction modFunc; + + TypeFunction concatFunc; + + TypeFunction andFunc; + TypeFunction orFunc; + + TypeFunction ltFunc; + TypeFunction leFunc; + TypeFunction eqFunc; + + TypeFunction refineFunc; + TypeFunction singletonFunc; + TypeFunction unionFunc; + TypeFunction intersectFunc; + + TypeFunction keyofFunc; + TypeFunction rawkeyofFunc; + TypeFunction indexFunc; + TypeFunction rawgetFunc; + + void addToScope(NotNull arena, NotNull scope) const; +}; + +const BuiltinTypeFunctions& builtinTypeFunctions(); + +} // namespace Luau diff --git a/third_party/luau/Analysis/include/Luau/TypeFunctionReductionGuesser.h b/third_party/luau/Analysis/include/Luau/TypeFunctionReductionGuesser.h new file mode 100644 index 00000000..b6d4a74c --- /dev/null +++ b/third_party/luau/Analysis/include/Luau/TypeFunctionReductionGuesser.h @@ -0,0 +1,85 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#pragma once + +#include "Luau/Ast.h" +#include "Luau/VecDeque.h" +#include "Luau/DenseHash.h" +#include "Luau/TypeFunction.h" +#include "Luau/Type.h" +#include "Luau/TypePack.h" +#include "Luau/TypeUtils.h" +#include "Luau/Normalize.h" +#include "Luau/TypeFwd.h" +#include "Luau/VisitType.h" +#include "Luau/NotNull.h" +#include "TypeArena.h" + +namespace Luau +{ + +struct TypeFunctionReductionGuessResult +{ + std::vector> guessedFunctionAnnotations; + TypeId guessedReturnType; + bool shouldRecommendAnnotation = true; +}; + +// An Inference result for a type function is a list of types corresponding to the guessed argument types, followed by a type for the result +struct TypeFunctionInferenceResult +{ + std::vector operandInference; + TypeId functionResultInference; +}; + +struct TypeFunctionReductionGuesser +{ + // Tracks our hypothesis about what a type function reduces to + DenseHashMap functionReducesTo{nullptr}; + // Tracks our constraints on type function operands + DenseHashMap substitutable{nullptr}; + // List of instances to try progress + VecDeque toInfer; + DenseHashSet cyclicInstances{nullptr}; + + // Utilities + NotNull arena; + NotNull builtins; + NotNull normalizer; + + TypeFunctionReductionGuesser(NotNull arena, NotNull builtins, NotNull normalizer); + + std::optional guess(TypeId typ); + std::optional guess(TypePackId typ); + TypeFunctionReductionGuessResult guessTypeFunctionReductionForFunctionExpr(const AstExprFunction& expr, const FunctionType* ftv, TypeId retTy); + +private: + std::optional guessType(TypeId arg); + void dumpGuesses(); + + bool isNumericBinopFunction(const TypeFunctionInstanceType& instance); + bool isComparisonFunction(const TypeFunctionInstanceType& instance); + bool isOrAndFunction(const TypeFunctionInstanceType& instance); + bool isNotFunction(const TypeFunctionInstanceType& instance); + bool isLenFunction(const TypeFunctionInstanceType& instance); + bool isUnaryMinus(const TypeFunctionInstanceType& instance); + + // Operand is assignable if it looks like a cyclic type function instance, or a generic type + bool operandIsAssignable(TypeId ty); + std::optional tryAssignOperandType(TypeId ty); + + std::shared_ptr normalize(TypeId ty); + void step(); + void infer(); + bool done(); + + bool isFunctionGenericsSaturated(const FunctionType& ftv, DenseHashSet& instanceArgs); + void inferTypeFunctionSubstitutions(TypeId ty, const TypeFunctionInstanceType* instance); + TypeFunctionInferenceResult inferNumericBinopFunction(const TypeFunctionInstanceType* instance); + TypeFunctionInferenceResult inferComparisonFunction(const TypeFunctionInstanceType* instance); + TypeFunctionInferenceResult inferOrAndFunction(const TypeFunctionInstanceType* instance); + TypeFunctionInferenceResult inferNotFunction(const TypeFunctionInstanceType* instance); + TypeFunctionInferenceResult inferLenFunction(const TypeFunctionInstanceType* instance); + TypeFunctionInferenceResult inferUnaryMinusFunction(const TypeFunctionInstanceType* instance); +}; +} // namespace Luau diff --git a/third_party/luau/Analysis/include/Luau/TypeFwd.h b/third_party/luau/Analysis/include/Luau/TypeFwd.h new file mode 100644 index 00000000..42d582fe --- /dev/null +++ b/third_party/luau/Analysis/include/Luau/TypeFwd.h @@ -0,0 +1,59 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Variant.h" + +#include + +namespace Luau +{ + +// So... why `const T*` here rather than `T*`? +// It's because we've had problems caused by the type graph being mutated +// in ways it shouldn't be, for example mutating types from other modules. +// To try to control this, we make the use of types immutable by default, +// then provide explicit mutable access via getMutable and asMutable. +// This means we can grep for all the places we're mutating the type graph, +// and it makes it possible to provide other APIs (e.g. the txn log) +// which control mutable access to the type graph. + +struct Type; +using TypeId = const Type*; + +struct FreeType; +struct GenericType; +struct PrimitiveType; +struct BlockedType; +struct PendingExpansionType; +struct SingletonType; +struct FunctionType; +struct TableType; +struct MetatableType; +struct ClassType; +struct AnyType; +struct UnionType; +struct IntersectionType; +struct LazyType; +struct UnknownType; +struct NeverType; +struct NegationType; +struct TypeFunctionInstanceType; + +struct TypePackVar; +using TypePackId = const TypePackVar*; + +struct FreeTypePack; +struct GenericTypePack; +struct TypePack; +struct VariadicTypePack; +struct BlockedTypePack; +struct TypeFunctionInstanceTypePack; + +using Name = std::string; +using ModuleName = std::string; + +struct BuiltinTypes; + +using TypeOrPack = Variant; + +} // namespace Luau diff --git a/third_party/luau/Analysis/include/Luau/TypeInfer.h b/third_party/luau/Analysis/include/Luau/TypeInfer.h index cceff0db..7f2e29b5 100644 --- a/third_party/luau/Analysis/include/Luau/TypeInfer.h +++ b/third_party/luau/Analysis/include/Luau/TypeInfer.h @@ -4,13 +4,14 @@ #include "Luau/Anyification.h" #include "Luau/ControlFlow.h" #include "Luau/Error.h" +#include "Luau/Instantiation.h" #include "Luau/Module.h" #include "Luau/Predicate.h" #include "Luau/Substitution.h" #include "Luau/Symbol.h" #include "Luau/TxnLog.h" -#include "Luau/Type.h" -#include "Luau/TypePack.h" +#include "Luau/TypeFwd.h" +#include "Luau/TypeCheckLimits.h" #include "Luau/TypeUtils.h" #include "Luau/Unifier.h" #include "Luau/UnifierSharedState.h" @@ -19,18 +20,24 @@ #include #include -LUAU_FASTFLAG(LuauClassTypeVarsInSubstitution) - namespace Luau { struct Scope; struct TypeChecker; struct ModuleResolver; +struct FrontendCancellationToken; using Name = std::string; using ScopePtr = std::shared_ptr; -using OverloadErrorEntry = std::tuple, std::vector, const FunctionType*>; + +struct OverloadErrorEntry +{ + TxnLog log; + ErrorVec errors; + std::vector arguments; + const FunctionType* fnTy; +}; bool doesCallError(const AstExprCall* call); bool hasBreak(AstStat* node); @@ -50,32 +57,16 @@ struct HashBoolNamePair size_t operator()(const std::pair& pair) const; }; -class TimeLimitError : public InternalCompilerError -{ -public: - explicit TimeLimitError(const std::string& moduleName) - : InternalCompilerError("Typeinfer failed to complete in allotted time", moduleName) - { - } -}; - -struct GlobalTypes -{ - GlobalTypes(NotNull builtinTypes); - - NotNull builtinTypes; // Global types are based on builtin types - - TypeArena globalTypes; - SourceModule globalNames; // names for symbols entered into globalScope - ScopePtr globalScope; // shared by all modules -}; - // All Types are retained via Environment::types. All TypeIds // within a program are borrowed pointers into this set. struct TypeChecker { explicit TypeChecker( - const ScopePtr& globalScope, ModuleResolver* resolver, NotNull builtinTypes, InternalErrorReporter* iceHandler); + const ScopePtr& globalScope, + ModuleResolver* resolver, + NotNull builtinTypes, + InternalErrorReporter* iceHandler + ); TypeChecker(const TypeChecker&) = delete; TypeChecker& operator=(const TypeChecker&) = delete; @@ -98,6 +89,7 @@ struct TypeChecker ControlFlow check(const ScopePtr& scope, TypeId ty, const ScopePtr& funScope, const AstStatFunction& function); ControlFlow check(const ScopePtr& scope, TypeId ty, const ScopePtr& funScope, const AstStatLocalFunction& function); ControlFlow check(const ScopePtr& scope, const AstStatTypeAlias& typealias); + ControlFlow check(const ScopePtr& scope, const AstStatTypeFunction& typefunction); ControlFlow check(const ScopePtr& scope, const AstStatDeclareClass& declaredClass); ControlFlow check(const ScopePtr& scope, const AstStatDeclareFunction& declaredFunction); @@ -109,7 +101,11 @@ struct TypeChecker void checkBlockTypeAliases(const ScopePtr& scope, std::vector& sorted); WithPredicate checkExpr( - const ScopePtr& scope, const AstExpr& expr, std::optional expectedType = std::nullopt, bool forceSingleton = false); + const ScopePtr& scope, + const AstExpr& expr, + std::optional expectedType = std::nullopt, + bool forceSingleton = false + ); WithPredicate checkExpr(const ScopePtr& scope, const AstExprLocal& expr); WithPredicate checkExpr(const ScopePtr& scope, const AstExprGlobal& expr); WithPredicate checkExpr(const ScopePtr& scope, const AstExprVarargs& expr); @@ -120,17 +116,31 @@ struct TypeChecker WithPredicate checkExpr(const ScopePtr& scope, const AstExprTable& expr, std::optional expectedType = std::nullopt); WithPredicate checkExpr(const ScopePtr& scope, const AstExprUnary& expr); TypeId checkRelationalOperation( - const ScopePtr& scope, const AstExprBinary& expr, TypeId lhsType, TypeId rhsType, const PredicateVec& predicates = {}); + const ScopePtr& scope, + const AstExprBinary& expr, + TypeId lhsType, + TypeId rhsType, + const PredicateVec& predicates = {} + ); TypeId checkBinaryOperation( - const ScopePtr& scope, const AstExprBinary& expr, TypeId lhsType, TypeId rhsType, const PredicateVec& predicates = {}); + const ScopePtr& scope, + const AstExprBinary& expr, + TypeId lhsType, + TypeId rhsType, + const PredicateVec& predicates = {} + ); WithPredicate checkExpr(const ScopePtr& scope, const AstExprBinary& expr, std::optional expectedType = std::nullopt); WithPredicate checkExpr(const ScopePtr& scope, const AstExprTypeAssertion& expr); WithPredicate checkExpr(const ScopePtr& scope, const AstExprError& expr); WithPredicate checkExpr(const ScopePtr& scope, const AstExprIfElse& expr, std::optional expectedType = std::nullopt); WithPredicate checkExpr(const ScopePtr& scope, const AstExprInterpString& expr); - TypeId checkExprTable(const ScopePtr& scope, const AstExprTable& expr, const std::vector>& fieldTypes, - std::optional expectedType); + TypeId checkExprTable( + const ScopePtr& scope, + const AstExprTable& expr, + const std::vector>& fieldTypes, + std::optional expectedType + ); // Returns the type of the lvalue. TypeId checkLValue(const ScopePtr& scope, const AstExpr& expr, ValueContext ctx); @@ -143,34 +153,79 @@ struct TypeChecker TypeId checkLValueBinding(const ScopePtr& scope, const AstExprIndexExpr& expr, ValueContext ctx); TypeId checkFunctionName(const ScopePtr& scope, AstExpr& funName, TypeLevel level); - std::pair checkFunctionSignature(const ScopePtr& scope, int subLevel, const AstExprFunction& expr, - std::optional originalNameLoc, std::optional selfType, std::optional expectedType); + std::pair checkFunctionSignature( + const ScopePtr& scope, + int subLevel, + const AstExprFunction& expr, + std::optional originalNameLoc, + std::optional selfType, + std::optional expectedType + ); void checkFunctionBody(const ScopePtr& scope, TypeId type, const AstExprFunction& function); - void checkArgumentList(const ScopePtr& scope, const AstExpr& funName, Unifier& state, TypePackId paramPack, TypePackId argPack, - const std::vector& argLocations); + void checkArgumentList( + const ScopePtr& scope, + const AstExpr& funName, + Unifier& state, + TypePackId paramPack, + TypePackId argPack, + const std::vector& argLocations + ); WithPredicate checkExprPack(const ScopePtr& scope, const AstExpr& expr); WithPredicate checkExprPackHelper(const ScopePtr& scope, const AstExpr& expr); WithPredicate checkExprPackHelper(const ScopePtr& scope, const AstExprCall& expr); WithPredicate checkExprPackHelper2( - const ScopePtr& scope, const AstExprCall& expr, TypeId selfType, TypeId actualFunctionType, TypeId functionType, TypePackId retPack); + const ScopePtr& scope, + const AstExprCall& expr, + TypeId selfType, + TypeId actualFunctionType, + TypeId functionType, + TypePackId retPack + ); std::vector> getExpectedTypesForCall(const std::vector& overloads, size_t argumentCount, bool selfCall); - std::unique_ptr> checkCallOverload(const ScopePtr& scope, const AstExprCall& expr, TypeId fn, TypePackId retPack, - TypePackId argPack, TypePack* args, const std::vector* argLocations, const WithPredicate& argListResult, - std::vector& overloadsThatMatchArgCount, std::vector& overloadsThatDont, std::vector& errors); - bool handleSelfCallMismatch(const ScopePtr& scope, const AstExprCall& expr, TypePack* args, const std::vector& argLocations, - const std::vector& errors); - void reportOverloadResolutionError(const ScopePtr& scope, const AstExprCall& expr, TypePackId retPack, TypePackId argPack, - const std::vector& argLocations, const std::vector& overloads, const std::vector& overloadsThatMatchArgCount, - const std::vector& errors); - - WithPredicate checkExprList(const ScopePtr& scope, const Location& location, const AstArray& exprs, - bool substituteFreeForNil = false, const std::vector& lhsAnnotations = {}, - const std::vector>& expectedTypes = {}); + std::unique_ptr> checkCallOverload( + const ScopePtr& scope, + const AstExprCall& expr, + TypeId fn, + TypePackId retPack, + TypePackId argPack, + TypePack* args, + const std::vector* argLocations, + const WithPredicate& argListResult, + std::vector& overloadsThatMatchArgCount, + std::vector& overloadsThatDont, + std::vector& errors + ); + bool handleSelfCallMismatch( + const ScopePtr& scope, + const AstExprCall& expr, + TypePack* args, + const std::vector& argLocations, + const std::vector& errors + ); + void reportOverloadResolutionError( + const ScopePtr& scope, + const AstExprCall& expr, + TypePackId retPack, + TypePackId argPack, + const std::vector& argLocations, + const std::vector& overloads, + const std::vector& overloadsThatMatchArgCount, + std::vector& errors + ); + + WithPredicate checkExprList( + const ScopePtr& scope, + const Location& location, + const AstArray& exprs, + bool substituteFreeForNil = false, + const std::vector& lhsAnnotations = {}, + const std::vector>& expectedTypes = {} + ); static std::optional matchRequire(const AstExprCall& call); TypeId checkRequire(const ScopePtr& scope, const ModuleInfo& moduleInfo, const Location& location); @@ -188,8 +243,13 @@ struct TypeChecker */ bool unify(TypeId subTy, TypeId superTy, const ScopePtr& scope, const Location& location); bool unify(TypeId subTy, TypeId superTy, const ScopePtr& scope, const Location& location, const UnifierOptions& options); - bool unify(TypePackId subTy, TypePackId superTy, const ScopePtr& scope, const Location& location, - CountMismatch::Context ctx = CountMismatch::Context::Arg); + bool unify( + TypePackId subTy, + TypePackId superTy, + const ScopePtr& scope, + const Location& location, + CountMismatch::Context ctx = CountMismatch::Context::Arg + ); /** Attempt to unify the types. * If this fails, and the subTy type can be instantiated, do so and try unification again. @@ -257,6 +317,7 @@ struct TypeChecker [[noreturn]] void ice(const std::string& message, const Location& location); [[noreturn]] void ice(const std::string& message); [[noreturn]] void throwTimeLimitError(); + [[noreturn]] void throwUserCancelError(); ScopePtr childFunctionScope(const ScopePtr& parent, const Location& location, int subLevel = 0); ScopePtr childScope(const ScopePtr& parent, const Location& location); @@ -325,12 +386,23 @@ struct TypeChecker TypeId resolveTypeWorker(const ScopePtr& scope, const AstType& annotation); TypePackId resolveTypePack(const ScopePtr& scope, const AstTypeList& types); TypePackId resolveTypePack(const ScopePtr& scope, const AstTypePack& annotation); - TypeId instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, const std::vector& typeParams, - const std::vector& typePackParams, const Location& location); + TypeId instantiateTypeFun( + const ScopePtr& scope, + const TypeFun& tf, + const std::vector& typeParams, + const std::vector& typePackParams, + const Location& location + ); // Note: `scope` must be a fresh scope. - GenericTypeDefinitions createGenericTypes(const ScopePtr& scope, std::optional levelOpt, const AstNode& node, - const AstArray& genericNames, const AstArray& genericPackNames, bool useCache = false); + GenericTypeDefinitions createGenericTypes( + const ScopePtr& scope, + std::optional levelOpt, + const AstNode& node, + const AstArray& genericNames, + const AstArray& genericPackNames, + bool useCache = false + ); public: void resolve(const PredicateVec& predicates, const ScopePtr& scope, bool sense); @@ -375,6 +447,8 @@ struct TypeChecker UnifierSharedState unifierState; Normalizer normalizer; + Instantiation reusableInstantiation; + std::vector requireCycles; // Type inference limits @@ -382,12 +456,15 @@ struct TypeChecker std::optional instantiationChildLimit; std::optional unifierIterationLimit; + std::shared_ptr cancellationToken; + public: const TypeId nilType; const TypeId numberType; const TypeId stringType; const TypeId booleanType; const TypeId threadType; + const TypeId bufferType; const TypeId anyType; const TypeId unknownType; const TypeId neverType; diff --git a/third_party/luau/Analysis/include/Luau/TypeOrPack.h b/third_party/luau/Analysis/include/Luau/TypeOrPack.h new file mode 100644 index 00000000..87001910 --- /dev/null +++ b/third_party/luau/Analysis/include/Luau/TypeOrPack.h @@ -0,0 +1,41 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Type.h" +#include "Luau/TypePack.h" +#include "Luau/Variant.h" + +#include + +namespace Luau +{ + +const void* ptr(TypeOrPack ty); + +template, bool> = true> +const T* get(const TypeOrPack& tyOrTp) +{ + return tyOrTp.get_if(); +} + +template, bool> = true> +const T* get(const TypeOrPack& tyOrTp) +{ + if (const TypeId* ty = get(tyOrTp)) + return get(*ty); + else + return nullptr; +} + +template, bool> = true> +const T* get(const TypeOrPack& tyOrTp) +{ + if (const TypePackId* tp = get(tyOrTp)) + return get(*tp); + else + return nullptr; +} + +TypeOrPack follow(TypeOrPack ty); + +} // namespace Luau diff --git a/third_party/luau/Analysis/include/Luau/TypePack.h b/third_party/luau/Analysis/include/Luau/TypePack.h index 2ae56e5f..1065b947 100644 --- a/third_party/luau/Analysis/include/Luau/TypePack.h +++ b/third_party/luau/Analysis/include/Luau/TypePack.h @@ -1,25 +1,27 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once -#include "Luau/Type.h" #include "Luau/Unifiable.h" #include "Luau/Variant.h" +#include "Luau/TypeFwd.h" +#include "Luau/NotNull.h" +#include "Luau/Common.h" #include #include +#include namespace Luau { struct TypeArena; +struct TypePackFunction; struct TxnLog; struct TypePack; struct VariadicTypePack; struct BlockedTypePack; - -struct TypePackVar; -using TypePackId = const TypePackVar*; +struct TypeFunctionInstanceTypePack; struct FreeTypePack { @@ -50,10 +52,10 @@ struct GenericTypePack }; using BoundTypePack = Unifiable::Bound; - using ErrorTypePack = Unifiable::Error; -using TypePackVariant = Unifiable::Variant; +using TypePackVariant = + Unifiable::Variant; /* A TypePack is a rope-like string of TypeIds. We use this structure to encode * notions like packs of unknown length and packs of any length, as well as more @@ -80,9 +82,22 @@ struct BlockedTypePack BlockedTypePack(); size_t index; + struct Constraint* owner = nullptr; + static size_t nextIndex; }; +/** + * Analogous to a TypeFunctionInstanceType. + */ +struct TypeFunctionInstanceTypePack +{ + NotNull function; + + std::vector typeArguments; + std::vector packArguments; +}; + struct TypePackVar { explicit TypePackVar(const TypePackVariant& ty); @@ -169,7 +184,7 @@ using SeenSet = std::set>; bool areEqual(SeenSet& seen, const TypePackVar& lhs, const TypePackVar& rhs); TypePackId follow(TypePackId tp); -TypePackId follow(TypePackId tp, std::function mapper); +TypePackId follow(TypePackId t, const void* context, TypePackId (*mapper)(const void*, TypePackId)); size_t size(TypePackId tp, TxnLog* log = nullptr); bool finite(TypePackId tp, TxnLog* log = nullptr); @@ -218,4 +233,18 @@ bool isVariadicTail(TypePackId tp, const TxnLog& log, bool includeHiddenVariadic bool containsNever(TypePackId tp); +/* + * Use this to change the kind of a particular type pack. + * + * LUAU_NOINLINE so that the calling frame doesn't have to pay the stack storage for the new variant. + */ +template +LUAU_NOINLINE T* emplaceTypePack(TypePackVar* ty, Args&&... args) +{ + return &ty->ty.emplace(std::forward(args)...); +} + +template<> +LUAU_NOINLINE Unifiable::Bound* emplaceTypePack(TypePackVar* ty, TypePackId& tyArg); + } // namespace Luau diff --git a/third_party/luau/Analysis/include/Luau/TypePairHash.h b/third_party/luau/Analysis/include/Luau/TypePairHash.h new file mode 100644 index 00000000..591f20f1 --- /dev/null +++ b/third_party/luau/Analysis/include/Luau/TypePairHash.h @@ -0,0 +1,35 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/TypeFwd.h" + +#include +#include + +namespace Luau +{ + +struct TypePairHash +{ + size_t hashOne(TypeId key) const + { + return (uintptr_t(key) >> 4) ^ (uintptr_t(key) >> 9); + } + + size_t hashOne(TypePackId key) const + { + return (uintptr_t(key) >> 4) ^ (uintptr_t(key) >> 9); + } + + size_t operator()(const std::pair& x) const + { + return hashOne(x.first) ^ (hashOne(x.second) << 1); + } + + size_t operator()(const std::pair& x) const + { + return hashOne(x.first) ^ (hashOne(x.second) << 1); + } +}; + +} // namespace Luau diff --git a/third_party/luau/Analysis/include/Luau/TypePath.h b/third_party/luau/Analysis/include/Luau/TypePath.h new file mode 100644 index 00000000..50c75da4 --- /dev/null +++ b/third_party/luau/Analysis/include/Luau/TypePath.h @@ -0,0 +1,237 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/TypeFwd.h" +#include "Luau/Variant.h" +#include "Luau/NotNull.h" + +#include +#include +#include + +namespace Luau +{ + +namespace TypePath +{ + +/// Represents a property of a class, table, or anything else with a concept of +/// a named property. +struct Property +{ + /// The name of the property. + std::string name; + /// Whether to look at the read or the write type. + bool isRead = true; + + explicit Property(std::string name); + Property(std::string name, bool read) + : name(std::move(name)) + , isRead(read) + { + } + + static Property read(std::string name); + static Property write(std::string name); + + bool operator==(const Property& other) const; +}; + +/// Represents an index into a type or a pack. For a type, this indexes into a +/// union or intersection's list. For a pack, this indexes into the pack's nth +/// element. +struct Index +{ + /// The 0-based index to use for the lookup. + size_t index; + + bool operator==(const Index& other) const; +}; + +/// Represents fields of a type or pack that contain a type. +enum class TypeField +{ + /// The metatable of a type. This could be a metatable type, a primitive + /// type, a class type, or perhaps even a string singleton type. + Metatable, + /// The lower bound of this type, if one is present. + LowerBound, + /// The upper bound of this type, if present. + UpperBound, + /// The index type. + IndexLookup, + /// The indexer result type. + IndexResult, + /// The negated type, for negations. + Negated, + /// The variadic type for a type pack. + Variadic, +}; + +/// Represents fields of a type or type pack that contain a type pack. +enum class PackField +{ + /// What arguments this type accepts. + Arguments, + /// What this type returns when called. + Returns, + /// The tail of a type pack. + Tail, +}; + +/// Component that represents the result of a reduction +/// `resultType` is `never` if the reduction could not proceed +struct Reduction +{ + TypeId resultType; + + bool operator==(const Reduction& other) const; +}; + +/// A single component of a path, representing one inner type or type pack to +/// traverse into. +using Component = Luau::Variant; + +/// A path through a type or type pack accessing a particular type or type pack +/// contained within. +/// +/// Paths are always relative; to make use of a Path, you need to specify an +/// entry point. They are not canonicalized; two Paths may not compare equal but +/// may point to the same result, depending on the layout of the entry point. +/// +/// Paths always descend through an entry point. This doesn't mean that they +/// cannot reach "upwards" in the actual type hierarchy in some cases, but it +/// does mean that there is no equivalent to `../` in file system paths. This is +/// intentional and unavoidable, because types and type packs don't have a +/// concept of a parent - they are a directed cyclic graph, with no hierarchy +/// that actually holds in all cases. +struct Path +{ + /// The Components of this Path. + std::vector components; + + /// Creates a new empty Path. + Path() {} + + /// Creates a new Path from a list of components. + explicit Path(std::vector components) + : components(std::move(components)) + { + } + + /// Creates a new single-component Path. + explicit Path(Component component) + : components({component}) + { + } + + /// Creates a new Path by appending another Path to this one. + /// @param suffix the Path to append + /// @return a new Path representing `this + suffix` + Path append(const Path& suffix) const; + + /// Creates a new Path by appending a Component to this Path. + /// @param component the Component to append + /// @return a new Path with `component` appended to it. + Path push(Component component) const; + + /// Creates a new Path by prepending a Component to this Path. + /// @param component the Component to prepend + /// @return a new Path with `component` prepended to it. + Path push_front(Component component) const; + + /// Creates a new Path by removing the last Component of this Path. + /// If the Path is empty, this is a no-op. + /// @return a Path with the last component removed. + Path pop() const; + + /// Returns the last Component of this Path, if present. + std::optional last() const; + + /// Returns whether this Path is empty, meaning it has no components at all. + /// Traversing an empty Path results in the type you started with. + bool empty() const; + + bool operator==(const Path& other) const; + bool operator!=(const Path& other) const + { + return !(*this == other); + } +}; + +struct PathHash +{ + size_t operator()(const Property& prop) const; + size_t operator()(const Index& idx) const; + size_t operator()(const TypeField& field) const; + size_t operator()(const PackField& field) const; + size_t operator()(const Reduction& reduction) const; + size_t operator()(const Component& component) const; + size_t operator()(const Path& path) const; +}; + +/// The canonical "empty" Path, meaning a Path with no components. +static const Path kEmpty{}; + +struct PathBuilder +{ + std::vector components; + + Path build(); + + PathBuilder& readProp(std::string name); + PathBuilder& writeProp(std::string name); + PathBuilder& prop(std::string name); + PathBuilder& index(size_t i); + PathBuilder& mt(); + PathBuilder& lb(); + PathBuilder& ub(); + PathBuilder& indexKey(); + PathBuilder& indexValue(); + PathBuilder& negated(); + PathBuilder& variadic(); + PathBuilder& args(); + PathBuilder& rets(); + PathBuilder& tail(); +}; + +} // namespace TypePath + +using Path = TypePath::Path; + +/// Converts a Path to a string for debugging purposes. This output may not be +/// terribly clear to end users of the Luau type system. +std::string toString(const TypePath::Path& path, bool prefixDot = false); + +std::optional traverse(TypeId root, const Path& path, NotNull builtinTypes); +std::optional traverse(TypePackId root, const Path& path, NotNull builtinTypes); + +/// Traverses a path from a type to its end point, which must be a type. +/// @param root the entry point of the traversal +/// @param path the path to traverse +/// @param builtinTypes the built-in types in use (used to acquire the string metatable) +/// @returns the TypeId at the end of the path, or nullopt if the traversal failed. +std::optional traverseForType(TypeId root, const Path& path, NotNull builtinTypes); + +/// Traverses a path from a type pack to its end point, which must be a type. +/// @param root the entry point of the traversal +/// @param path the path to traverse +/// @param builtinTypes the built-in types in use (used to acquire the string metatable) +/// @returns the TypeId at the end of the path, or nullopt if the traversal failed. +std::optional traverseForType(TypePackId root, const Path& path, NotNull builtinTypes); + +/// Traverses a path from a type to its end point, which must be a type pack. +/// @param root the entry point of the traversal +/// @param path the path to traverse +/// @param builtinTypes the built-in types in use (used to acquire the string metatable) +/// @returns the TypePackId at the end of the path, or nullopt if the traversal failed. +std::optional traverseForPack(TypeId root, const Path& path, NotNull builtinTypes); + +/// Traverses a path from a type pack to its end point, which must be a type pack. +/// @param root the entry point of the traversal +/// @param path the path to traverse +/// @param builtinTypes the built-in types in use (used to acquire the string metatable) +/// @returns the TypePackId at the end of the path, or nullopt if the traversal failed. +std::optional traverseForPack(TypePackId root, const Path& path, NotNull builtinTypes); + +} // namespace Luau diff --git a/third_party/luau/Analysis/include/Luau/TypeReduction.h b/third_party/luau/Analysis/include/Luau/TypeReduction.h deleted file mode 100644 index 3f64870a..00000000 --- a/third_party/luau/Analysis/include/Luau/TypeReduction.h +++ /dev/null @@ -1,85 +0,0 @@ -// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -#pragma once - -#include "Luau/Type.h" -#include "Luau/TypeArena.h" -#include "Luau/TypePack.h" -#include "Luau/Variant.h" - -namespace Luau -{ - -namespace detail -{ -template -struct ReductionEdge -{ - T type = nullptr; - bool irreducible = false; -}; - -struct TypeReductionMemoization -{ - TypeReductionMemoization() = default; - - TypeReductionMemoization(const TypeReductionMemoization&) = delete; - TypeReductionMemoization& operator=(const TypeReductionMemoization&) = delete; - - TypeReductionMemoization(TypeReductionMemoization&&) = default; - TypeReductionMemoization& operator=(TypeReductionMemoization&&) = default; - - DenseHashMap> types{nullptr}; - DenseHashMap> typePacks{nullptr}; - - bool isIrreducible(TypeId ty); - bool isIrreducible(TypePackId tp); - - TypeId memoize(TypeId ty, TypeId reducedTy); - TypePackId memoize(TypePackId tp, TypePackId reducedTp); - - // Reducing A into B may have a non-irreducible edge A to B for which B is not irreducible, which means B could be reduced into C. - // Because reduction should always be transitive, A should point to C if A points to B and B points to C. - std::optional> memoizedof(TypeId ty) const; - std::optional> memoizedof(TypePackId tp) const; -}; -} // namespace detail - -struct TypeReductionOptions -{ - /// If it's desirable for type reduction to allocate into a different arena than the TypeReduction instance you have, you will need - /// to create a temporary TypeReduction in that case, and set [`TypeReductionOptions::allowTypeReductionsFromOtherArenas`] to true. - /// This is because TypeReduction caches the reduced type. - bool allowTypeReductionsFromOtherArenas = false; -}; - -struct TypeReduction -{ - explicit TypeReduction(NotNull arena, NotNull builtinTypes, NotNull handle, - const TypeReductionOptions& opts = {}); - - TypeReduction(const TypeReduction&) = delete; - TypeReduction& operator=(const TypeReduction&) = delete; - - TypeReduction(TypeReduction&&) = default; - TypeReduction& operator=(TypeReduction&&) = default; - - std::optional reduce(TypeId ty); - std::optional reduce(TypePackId tp); - std::optional reduce(const TypeFun& fun); - -private: - NotNull arena; - NotNull builtinTypes; - NotNull handle; - - TypeReductionOptions options; - detail::TypeReductionMemoization memoization; - - // Computes an *estimated length* of the cartesian product of the given type. - size_t cartesianProductSize(TypeId ty) const; - - bool hasExceededCartesianProductLimit(TypeId ty) const; - bool hasExceededCartesianProductLimit(TypePackId tp) const; -}; - -} // namespace Luau diff --git a/third_party/luau/Analysis/include/Luau/TypeUtils.h b/third_party/luau/Analysis/include/Luau/TypeUtils.h index 86f20f38..92be19d1 100644 --- a/third_party/luau/Analysis/include/Luau/TypeUtils.h +++ b/third_party/luau/Analysis/include/Luau/TypeUtils.h @@ -14,6 +14,7 @@ namespace Luau struct TxnLog; struct TypeArena; +class Normalizer; enum class ValueContext { @@ -21,12 +22,71 @@ enum class ValueContext RValue }; +/// the current context of the type checker +enum class TypeContext +{ + /// the default context + Default, + /// inside of a condition + Condition, +}; + +bool inConditional(const TypeContext& context); + +// sets the given type context to `Condition` and restores it to its original +// value when the struct drops out of scope +struct InConditionalContext +{ + TypeContext* typeContext; + TypeContext oldValue; + + InConditionalContext(TypeContext* c) + : typeContext(c) + , oldValue(*c) + { + *typeContext = TypeContext::Condition; + } + + ~InConditionalContext() + { + *typeContext = oldValue; + } +}; + using ScopePtr = std::shared_ptr; +std::optional findTableProperty( + NotNull builtinTypes, + ErrorVec& errors, + TypeId ty, + const std::string& name, + Location location +); + std::optional findMetatableEntry( - NotNull builtinTypes, ErrorVec& errors, TypeId type, const std::string& entry, Location location); + NotNull builtinTypes, + ErrorVec& errors, + TypeId type, + const std::string& entry, + Location location +); +std::optional findTablePropertyRespectingMeta( + NotNull builtinTypes, + ErrorVec& errors, + TypeId ty, + const std::string& name, + Location location +); std::optional findTablePropertyRespectingMeta( - NotNull builtinTypes, ErrorVec& errors, TypeId ty, const std::string& name, Location location); + NotNull builtinTypes, + ErrorVec& errors, + TypeId ty, + const std::string& name, + ValueContext context, + Location location +); + +bool occursCheck(TypeId needle, TypeId haystack); // Returns the minimum and maximum number of types the argument list can accept. std::pair> getParameterExtents(const TxnLog* log, TypePackId tp, bool includeHiddenVariadics = false); @@ -34,7 +94,12 @@ std::pair> getParameterExtents(const TxnLog* log, // Extend the provided pack to at least `length` types. // Returns a temporary TypePack that contains those types plus a tail. TypePack extendTypePack( - TypeArena& arena, NotNull builtinTypes, TypePackId pack, size_t length, std::vector> overrides = {}); + TypeArena& arena, + NotNull builtinTypes, + TypePackId pack, + size_t length, + std::vector> overrides = {} +); /** * Reduces a union by decomposing to the any/error type if it appears in the @@ -55,4 +120,132 @@ std::vector reduceUnion(const std::vector& types); */ TypeId stripNil(NotNull builtinTypes, TypeArena& arena, TypeId ty); +struct ErrorSuppression +{ + enum Value + { + Suppress, + DoNotSuppress, + NormalizationFailed, + }; + + ErrorSuppression() = default; + constexpr ErrorSuppression(Value enumValue) + : value(enumValue) + { + } + + constexpr operator Value() const + { + return value; + } + explicit operator bool() const = delete; + + ErrorSuppression orElse(const ErrorSuppression& other) const + { + switch (value) + { + case DoNotSuppress: + return other; + default: + return *this; + } + } + +private: + Value value; +}; + +/** + * Normalizes the given type using the normalizer to determine if the type + * should suppress any errors that would be reported involving it. + * @param normalizer the normalizer to use + * @param ty the type to check for error suppression + * @returns an enum indicating whether or not to suppress the error or to signal a normalization failure + */ +ErrorSuppression shouldSuppressErrors(NotNull normalizer, TypeId ty); + +/** + * Flattens and normalizes the given typepack using the normalizer to determine if the type + * should suppress any errors that would be reported involving it. + * @param normalizer the normalizer to use + * @param tp the typepack to check for error suppression + * @returns an enum indicating whether or not to suppress the error or to signal a normalization failure + */ +ErrorSuppression shouldSuppressErrors(NotNull normalizer, TypePackId tp); + +/** + * Normalizes the two given type using the normalizer to determine if either type + * should suppress any errors that would be reported involving it. + * @param normalizer the normalizer to use + * @param ty1 the first type to check for error suppression + * @param ty2 the second type to check for error suppression + * @returns an enum indicating whether or not to suppress the error or to signal a normalization failure + */ +ErrorSuppression shouldSuppressErrors(NotNull normalizer, TypeId ty1, TypeId ty2); + +/** + * Flattens and normalizes the two given typepacks using the normalizer to determine if either type + * should suppress any errors that would be reported involving it. + * @param normalizer the normalizer to use + * @param tp1 the first typepack to check for error suppression + * @param tp2 the second typepack to check for error suppression + * @returns an enum indicating whether or not to suppress the error or to signal a normalization failure + */ +ErrorSuppression shouldSuppressErrors(NotNull normalizer, TypePackId tp1, TypePackId tp2); + +// Similar to `std::optional>`, but whose `sizeof()` is the same as `std::pair` +// and cooperates with C++'s `if (auto p = ...)` syntax without the extra fatness of `std::optional`. +template +struct TryPair +{ + A first; + B second; + + explicit operator bool() const + { + return bool(first) && bool(second); + } +}; + +template +TryPair get2(Ty one, Ty two) +{ + static_assert(std::is_pointer_v, "argument must be a pointer type"); + + const A* a = get(one); + const B* b = get(two); + if (a && b) + return {a, b}; + else + return {nullptr, nullptr}; +} + +template +const T* get(std::optional ty) +{ + if (ty) + return get(*ty); + else + return nullptr; +} + +template +T* getMutable(std::optional ty) +{ + if (ty) + return getMutable(*ty); + else + return nullptr; +} + +template +std::optional follow(std::optional ty) +{ + if (ty) + return follow(*ty); + else + return std::nullopt; +} + } // namespace Luau diff --git a/third_party/luau/Analysis/include/Luau/Unifier.h b/third_party/luau/Analysis/include/Luau/Unifier.h index e3b0a878..b0a855d3 100644 --- a/third_party/luau/Analysis/include/Luau/Unifier.h +++ b/third_party/luau/Analysis/include/Luau/Unifier.h @@ -9,7 +9,7 @@ #include "Luau/TxnLog.h" #include "Luau/TypeArena.h" #include "Luau/UnifierSharedState.h" -#include "Normalize.h" +#include "Luau/Normalize.h" #include @@ -43,6 +43,21 @@ struct Widen : Substitution TypePackId operator()(TypePackId ty); }; +/** + * Normally, when we unify table properties, we must do so invariantly, but we + * can introduce a special exception: If the table property in the subtype + * position arises from a literal expression, it is safe to instead perform a + * covariant check. + * + * This is very useful for typechecking cases where table literals (and trees of + * table literals) are passed directly to functions. + * + * In this case, we know that the property has no other name referring to it and + * so it is perfectly safe for the function to mutate the table any way it + * wishes. + */ +using LiteralProperties = DenseHashSet; + // TODO: Use this more widely. struct UnifierOptions { @@ -54,7 +69,6 @@ struct Unifier TypeArena* const types; NotNull builtinTypes; NotNull normalizer; - Mode mode; NotNull scope; // const Scope maybe TxnLog log; @@ -64,9 +78,11 @@ struct Unifier Variance variance = Covariant; bool normalize = true; // Normalize unions and intersections if necessary bool checkInhabited = true; // Normalize types to check if they are inhabited - bool useScopes = false; // If true, we use the scope hierarchy rather than TypeLevels CountMismatch::Context ctx = CountMismatch::Arg; + // If true, generics act as free types when unifying. + bool hideousFixMeGenericsAreActuallyFree = false; + UnifierSharedState& sharedState; // When the Unifier is forced to unify two blocked types (or packs), they @@ -75,8 +91,11 @@ struct Unifier std::vector blockedTypes; std::vector blockedTypePacks; - Unifier( - NotNull normalizer, Mode mode, NotNull scope, const Location& location, Variance variance, TxnLog* parentLog = nullptr); + Unifier(NotNull normalizer, NotNull scope, const Location& location, Variance variance, TxnLog* parentLog = nullptr); + + // Configure the Unifier to test for scope subsumption via embedded Scope + // pointers rather than TypeLevels. + void enableNewSolver(); // Test whether the two type vars unify. Never commits the result. ErrorVec canUnify(TypeId subTy, TypeId superTy); @@ -86,10 +105,22 @@ struct Unifier * Populate the vector errors with any type errors that may arise. * Populate the transaction log with the set of TypeIds that need to be reset to undo the unification attempt. */ - void tryUnify(TypeId subTy, TypeId superTy, bool isFunctionCall = false, bool isIntersection = false); + void tryUnify( + TypeId subTy, + TypeId superTy, + bool isFunctionCall = false, + bool isIntersection = false, + const LiteralProperties* aliasableMap = nullptr + ); private: - void tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall = false, bool isIntersection = false); + void tryUnify_( + TypeId subTy, + TypeId superTy, + bool isFunctionCall = false, + bool isIntersection = false, + const LiteralProperties* aliasableMap = nullptr + ); void tryUnifyUnionWithType(TypeId subTy, const UnionType* uv, TypeId superTy); // Traverse the two types provided and block on any BlockedTypes we find. @@ -99,12 +130,18 @@ struct Unifier void tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionType* uv, bool cacheEnabled, bool isFunctionCall); void tryUnifyTypeWithIntersection(TypeId subTy, TypeId superTy, const IntersectionType* uv); void tryUnifyIntersectionWithType(TypeId subTy, const IntersectionType* uv, TypeId superTy, bool cacheEnabled, bool isFunctionCall); - void tryUnifyNormalizedTypes(TypeId subTy, TypeId superTy, const NormalizedType& subNorm, const NormalizedType& superNorm, std::string reason, - std::optional error = std::nullopt); + void tryUnifyNormalizedTypes( + TypeId subTy, + TypeId superTy, + const NormalizedType& subNorm, + const NormalizedType& superNorm, + std::string reason, + std::optional error = std::nullopt + ); void tryUnifyPrimitives(TypeId subTy, TypeId superTy); void tryUnifySingletons(TypeId subTy, TypeId superTy); void tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCall = false); - void tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection = false); + void tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection = false, const LiteralProperties* aliasableMap = nullptr); void tryUnifyScalarShape(TypeId subTy, TypeId superTy, bool reversed); void tryUnifyWithMetatable(TypeId subTy, TypeId superTy, bool reversed); void tryUnifyWithClass(TypeId subTy, TypeId superTy, bool reversed); @@ -148,7 +185,6 @@ struct Unifier LUAU_NOINLINE void reportError(Location location, TypeErrorData data); private: - bool isNonstrictMode() const; TypeMismatch::Context mismatchContext(); void checkChildUnifierTypeMismatch(const ErrorVec& innerErrors, TypeId wantedType, TypeId givenType); @@ -159,9 +195,15 @@ struct Unifier // Available after regular type pack unification errors std::optional firstPackErrorPos; + + // If true, we do a bunch of small things differently to work better with + // the new type inference engine. Most notably, we use the Scope hierarchy + // directly rather than using TypeLevels. + bool useNewSolver = false; }; void promoteTypeLevels(TxnLog& log, const TypeArena* arena, TypeLevel minLevel, Scope* outerScope, bool useScope, TypePackId tp); std::optional hasUnificationTooComplex(const ErrorVec& errors); +std::optional hasCountMismatch(const ErrorVec& errors); } // namespace Luau diff --git a/third_party/luau/Analysis/include/Luau/Unifier2.h b/third_party/luau/Analysis/include/Luau/Unifier2.h new file mode 100644 index 00000000..8734aeec --- /dev/null +++ b/third_party/luau/Analysis/include/Luau/Unifier2.h @@ -0,0 +1,115 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#pragma once + +#include "Luau/Constraint.h" +#include "Luau/DenseHash.h" +#include "Luau/NotNull.h" +#include "Luau/TypeCheckLimits.h" +#include "Luau/TypeFwd.h" +#include "Luau/TypePairHash.h" + +#include +#include +#include + +namespace Luau +{ + +struct InternalErrorReporter; +struct Scope; +struct TypeArena; + +enum class OccursCheckResult +{ + Pass, + Fail +}; + +struct Unifier2 +{ + NotNull arena; + NotNull builtinTypes; + NotNull scope; + NotNull ice; + TypeCheckLimits limits; + + DenseHashSet, TypePairHash> seenTypePairings{{nullptr, nullptr}}; + DenseHashSet, TypePairHash> seenTypePackPairings{{nullptr, nullptr}}; + + DenseHashMap> expandedFreeTypes{nullptr}; + + // Mapping from generic types to free types to be used in instantiation. + DenseHashMap genericSubstitutions{nullptr}; + // Mapping from generic type packs to `TypePack`s of free types to be used in instantiation. + DenseHashMap genericPackSubstitutions{nullptr}; + + int recursionCount = 0; + int recursionLimit = 0; + + std::vector incompleteSubtypes; + // null if not in a constraint solving context + DenseHashSet* uninhabitedTypeFunctions; + + Unifier2(NotNull arena, NotNull builtinTypes, NotNull scope, NotNull ice); + Unifier2( + NotNull arena, + NotNull builtinTypes, + NotNull scope, + NotNull ice, + DenseHashSet* uninhabitedTypeFunctions + ); + + /** Attempt to commit the subtype relation subTy <: superTy to the type + * graph. + * + * @returns true if successful. + * + * Note that incoherent types can and will successfully be unified. We stop + * when we *cannot know* how to relate the provided types, not when doing so + * would narrow something down to never or broaden it to unknown. + * + * Presently, the only way unification can fail is if we attempt to bind one + * free TypePack to another and encounter an occurs check violation. + */ + bool unify(TypeId subTy, TypeId superTy); + bool unifyFreeWithType(TypeId subTy, TypeId superTy); + bool unify(TypeId subTy, const FunctionType* superFn); + bool unify(const UnionType* subUnion, TypeId superTy); + bool unify(TypeId subTy, const UnionType* superUnion); + bool unify(const IntersectionType* subIntersection, TypeId superTy); + bool unify(TypeId subTy, const IntersectionType* superIntersection); + bool unify(TableType* subTable, const TableType* superTable); + bool unify(const MetatableType* subMetatable, const MetatableType* superMetatable); + + bool unify(const AnyType* subAny, const FunctionType* superFn); + bool unify(const FunctionType* subFn, const AnyType* superAny); + bool unify(const AnyType* subAny, const TableType* superTable); + bool unify(const TableType* subTable, const AnyType* superAny); + + // TODO think about this one carefully. We don't do unions or intersections of type packs + bool unify(TypePackId subTp, TypePackId superTp); + + std::optional generalize(TypeId ty); + +private: + /** + * @returns simplify(left | right) + */ + TypeId mkUnion(TypeId left, TypeId right); + + /** + * @returns simplify(left & right) + */ + TypeId mkIntersection(TypeId left, TypeId right); + + // Returns true if needle occurs within haystack already. ie if we bound + // needle to haystack, would a cyclic type result? + OccursCheckResult occursCheck(DenseHashSet& seen, TypeId needle, TypeId haystack); + + // Returns true if needle occurs within haystack already. ie if we bound + // needle to haystack, would a cyclic TypePack result? + OccursCheckResult occursCheck(DenseHashSet& seen, TypePackId needle, TypePackId haystack); +}; + +} // namespace Luau diff --git a/third_party/luau/Analysis/include/Luau/UnifierSharedState.h b/third_party/luau/Analysis/include/Luau/UnifierSharedState.h index ada56ec5..de69c17c 100644 --- a/third_party/luau/Analysis/include/Luau/UnifierSharedState.h +++ b/third_party/luau/Analysis/include/Luau/UnifierSharedState.h @@ -3,8 +3,7 @@ #include "Luau/DenseHash.h" #include "Luau/Error.h" -#include "Luau/Type.h" -#include "Luau/TypePack.h" +#include "Luau/TypeFwd.h" #include diff --git a/third_party/luau/Analysis/include/Luau/VisitType.h b/third_party/luau/Analysis/include/Luau/VisitType.h index 663627d5..e588d06b 100644 --- a/third_party/luau/Analysis/include/Luau/VisitType.h +++ b/third_party/luau/Analysis/include/Luau/VisitType.h @@ -7,9 +7,11 @@ #include "Luau/RecursionCounter.h" #include "Luau/TypePack.h" #include "Luau/Type.h" +#include "Type.h" LUAU_FASTINT(LuauVisitRecursionLimit) LUAU_FASTFLAG(LuauBoundLazyTypes2) +LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) namespace Luau { @@ -63,6 +65,9 @@ inline void unsee(DenseHashSet& seen, const void* tv) } // namespace visit_detail +// recursion counter is equivalent here, but we'd like a better name to express the intent. +using TypeFunctionDepthCounter = RecursionCounter; + template struct GenericTypeVisitor { @@ -71,6 +76,7 @@ struct GenericTypeVisitor Set seen; bool skipBoundTypes = false; int recursionCounter = 0; + int typeFunctionDepth = 0; GenericTypeVisitor() = default; @@ -159,6 +165,10 @@ struct GenericTypeVisitor { return visit(ty); } + virtual bool visit(TypeId ty, const TypeFunctionInstanceType& tfit) + { + return visit(ty); + } virtual bool visit(TypePackId tp) { @@ -192,6 +202,10 @@ struct GenericTypeVisitor { return visit(tp); } + virtual bool visit(TypePackId tp, const TypeFunctionInstanceTypePack& tfitp) + { + return visit(tp); + } void traverse(TypeId ty) { @@ -211,7 +225,26 @@ struct GenericTypeVisitor traverse(btv->boundTo); } else if (auto ftv = get(ty)) - visit(ty, *ftv); + { + if (FFlag::DebugLuauDeferredConstraintResolution) + { + if (visit(ty, *ftv)) + { + // TODO: Replace these if statements with assert()s when we + // delete FFlag::DebugLuauDeferredConstraintResolution. + // + // When the old solver is used, these pointers are always + // unused. When the new solver is used, they are never null. + if (ftv->lowerBound) + traverse(ftv->lowerBound); + + if (ftv->upperBound) + traverse(ftv->upperBound); + } + } + else + visit(ty, *ftv); + } else if (auto gtv = get(ty)) visit(ty, *gtv); else if (auto etv = get(ty)) @@ -242,7 +275,22 @@ struct GenericTypeVisitor else { for (auto& [_name, prop] : ttv->props) - traverse(prop.type()); + { + if (FFlag::DebugLuauDeferredConstraintResolution) + { + if (auto ty = prop.readTy) + traverse(*ty); + + // In the case that the readType and the writeType + // are the same pointer, just traverse once. + // Traversing each property twice has pretty + // significant performance consequences. + if (auto ty = prop.writeTy; ty && !prop.isShared()) + traverse(*ty); + } + else + traverse(prop.type()); + } if (ttv->indexer) { @@ -265,13 +313,34 @@ struct GenericTypeVisitor if (visit(ty, *ctv)) { for (const auto& [name, prop] : ctv->props) - traverse(prop.type()); + { + if (FFlag::DebugLuauDeferredConstraintResolution) + { + if (auto ty = prop.readTy) + traverse(*ty); + + // In the case that the readType and the writeType are + // the same pointer, just traverse once. Traversing each + // property twice would have pretty significant + // performance consequences. + if (auto ty = prop.writeTy; ty && !prop.isShared()) + traverse(*ty); + } + else + traverse(prop.type()); + } if (ctv->parent) traverse(*ctv->parent); if (ctv->metatable) traverse(*ctv->metatable); + + if (ctv->indexer) + { + traverse(ctv->indexer->indexType); + traverse(ctv->indexer->indexResultType); + } } } else if (auto atv = get(ty)) @@ -280,25 +349,45 @@ struct GenericTypeVisitor { if (visit(ty, *utv)) { + bool unionChanged = false; for (TypeId optTy : utv->options) + { traverse(optTy); + if (!get(follow(ty))) + { + unionChanged = true; + break; + } + } + + if (unionChanged) + traverse(ty); } } else if (auto itv = get(ty)) { if (visit(ty, *itv)) { + bool intersectionChanged = false; for (TypeId partTy : itv->parts) + { traverse(partTy); + if (!get(follow(ty))) + { + intersectionChanged = true; + break; + } + } + + if (intersectionChanged) + traverse(ty); } } else if (auto ltv = get(ty)) { - if (FFlag::LuauBoundLazyTypes2) - { - if (TypeId unwrapped = ltv->unwrapped) - traverse(unwrapped); - } + if (TypeId unwrapped = ltv->unwrapped) + traverse(unwrapped); + // Visiting into LazyType that hasn't been unwrapped may necessarily cause infinite expansion, so we don't do that on purpose. // Asserting also makes no sense, because the type _will_ happen here, most likely as a property of some ClassType // that doesn't need to be expanded. @@ -327,6 +416,19 @@ struct GenericTypeVisitor if (visit(ty, *ntv)) traverse(ntv->ty); } + else if (auto tfit = get(ty)) + { + TypeFunctionDepthCounter tfdc{&typeFunctionDepth}; + + if (visit(ty, *tfit)) + { + for (TypeId p : tfit->typeArguments) + traverse(p); + + for (TypePackId p : tfit->packArguments) + traverse(p); + } + } else LUAU_ASSERT(!"GenericTypeVisitor::traverse(TypeId) is not exhaustive!"); @@ -376,6 +478,19 @@ struct GenericTypeVisitor } else if (auto btp = get(tp)) visit(tp, *btp); + else if (auto tfitp = get(tp)) + { + TypeFunctionDepthCounter tfdc{&typeFunctionDepth}; + + if (visit(tp, *tfitp)) + { + for (TypeId t : tfitp->typeArguments) + traverse(t); + + for (TypePackId t : tfitp->packArguments) + traverse(t); + } + } else LUAU_ASSERT(!"GenericTypeVisitor::traverse(TypePackId) is not exhaustive!"); diff --git a/third_party/luau/Analysis/src/AnyTypeSummary.cpp b/third_party/luau/Analysis/src/AnyTypeSummary.cpp new file mode 100644 index 00000000..c5e8e8c6 --- /dev/null +++ b/third_party/luau/Analysis/src/AnyTypeSummary.cpp @@ -0,0 +1,883 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/AnyTypeSummary.h" + +#include "Luau/BuiltinDefinitions.h" +#include "Luau/Clone.h" +#include "Luau/Common.h" +#include "Luau/Config.h" +#include "Luau/ConstraintGenerator.h" +#include "Luau/ConstraintSolver.h" +#include "Luau/DataFlowGraph.h" +#include "Luau/DcrLogger.h" +#include "Luau/Module.h" +#include "Luau/Parser.h" +#include "Luau/Scope.h" +#include "Luau/StringUtils.h" +#include "Luau/TimeTrace.h" +#include "Luau/ToString.h" +#include "Luau/Transpiler.h" +#include "Luau/TypeArena.h" +#include "Luau/TypeChecker2.h" +#include "Luau/NonStrictTypeChecker.h" +#include "Luau/TypeInfer.h" +#include "Luau/Variant.h" +#include "Luau/VisitType.h" +#include "Luau/TypePack.h" +#include "Luau/TypeOrPack.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +#include + +LUAU_FASTFLAGVARIABLE(StudioReportLuauAny, false); +LUAU_FASTINTVARIABLE(LuauAnySummaryRecursionLimit, 300); + +LUAU_FASTFLAG(DebugLuauMagicTypes); + +namespace Luau +{ + +void AnyTypeSummary::traverse(const Module* module, AstStat* src, NotNull builtinTypes) +{ + visit(findInnerMostScope(src->location, module), src, module, builtinTypes); +} + +void AnyTypeSummary::visit(const Scope* scope, AstStat* stat, const Module* module, NotNull builtinTypes) +{ + RecursionLimiter limiter{&recursionCount, FInt::LuauAnySummaryRecursionLimit}; + + if (auto s = stat->as()) + return visit(scope, s, module, builtinTypes); + else if (auto i = stat->as()) + return visit(scope, i, module, builtinTypes); + else if (auto s = stat->as()) + return visit(scope, s, module, builtinTypes); + else if (auto s = stat->as()) + return visit(scope, s, module, builtinTypes); + else if (auto r = stat->as()) + return visit(scope, r, module, builtinTypes); + else if (auto e = stat->as()) + return visit(scope, e, module, builtinTypes); + else if (auto s = stat->as()) + return visit(scope, s, module, builtinTypes); + else if (auto s = stat->as()) + return visit(scope, s, module, builtinTypes); + else if (auto s = stat->as()) + return visit(scope, s, module, builtinTypes); + else if (auto a = stat->as()) + return visit(scope, a, module, builtinTypes); + else if (auto a = stat->as()) + return visit(scope, a, module, builtinTypes); + else if (auto f = stat->as()) + return visit(scope, f, module, builtinTypes); + else if (auto f = stat->as()) + return visit(scope, f, module, builtinTypes); + else if (auto a = stat->as()) + return visit(scope, a, module, builtinTypes); + else if (auto s = stat->as()) + return visit(scope, s, module, builtinTypes); + else if (auto s = stat->as()) + return visit(scope, s, module, builtinTypes); + else if (auto s = stat->as()) + return visit(scope, s, module, builtinTypes); + else if (auto s = stat->as()) + return visit(scope, s, module, builtinTypes); +} + +void AnyTypeSummary::visit(const Scope* scope, AstStatBlock* block, const Module* module, NotNull builtinTypes) +{ + RecursionCounter counter{&recursionCount}; + + if (recursionCount >= FInt::LuauAnySummaryRecursionLimit) + return; // don't report + + for (AstStat* stat : block->body) + visit(scope, stat, module, builtinTypes); +} + +void AnyTypeSummary::visit(const Scope* scope, AstStatIf* ifStatement, const Module* module, NotNull builtinTypes) +{ + if (ifStatement->thenbody) + { + const Scope* thenScope = findInnerMostScope(ifStatement->thenbody->location, module); + visit(thenScope, ifStatement->thenbody, module, builtinTypes); + } + + if (ifStatement->elsebody) + { + const Scope* elseScope = findInnerMostScope(ifStatement->elsebody->location, module); + visit(elseScope, ifStatement->elsebody, module, builtinTypes); + } +} + +void AnyTypeSummary::visit(const Scope* scope, AstStatWhile* while_, const Module* module, NotNull builtinTypes) +{ + const Scope* whileScope = findInnerMostScope(while_->location, module); + visit(whileScope, while_->body, module, builtinTypes); +} + +void AnyTypeSummary::visit(const Scope* scope, AstStatRepeat* repeat, const Module* module, NotNull builtinTypes) +{ + const Scope* repeatScope = findInnerMostScope(repeat->location, module); + visit(repeatScope, repeat->body, module, builtinTypes); +} + +void AnyTypeSummary::visit(const Scope* scope, AstStatReturn* ret, const Module* module, NotNull builtinTypes) +{ + const Scope* retScope = findInnerMostScope(ret->location, module); + + auto ctxNode = getNode(rootSrc, ret); + + for (auto val : ret->list) + { + if (isAnyCall(retScope, val, module, builtinTypes)) + { + TelemetryTypePair types; + types.inferredType = toString(lookupType(val, module, builtinTypes)); + TypeInfo ti{Pattern::FuncApp, toString(ctxNode), types}; + typeInfo.push_back(ti); + } + + if (isAnyCast(retScope, val, module, builtinTypes)) + { + if (auto cast = val->as()) + { + TelemetryTypePair types; + + types.annotatedType = toString(lookupAnnotation(cast->annotation, module, builtinTypes)); + types.inferredType = toString(lookupType(cast->expr, module, builtinTypes)); + + TypeInfo ti{Pattern::Casts, toString(ctxNode), types}; + typeInfo.push_back(ti); + } + } + } +} + +void AnyTypeSummary::visit(const Scope* scope, AstStatLocal* local, const Module* module, NotNull builtinTypes) +{ + auto ctxNode = getNode(rootSrc, local); + + TypePackId values = reconstructTypePack(local->values, module, builtinTypes); + auto [head, tail] = flatten(values); + + size_t posn = 0; + for (AstLocal* loc : local->vars) + { + if (local->vars.data[0] == loc && posn < local->values.size) + { + if (loc->annotation) + { + auto annot = lookupAnnotation(loc->annotation, module, builtinTypes); + if (containsAny(annot)) + { + TelemetryTypePair types; + + types.annotatedType = toString(annot); + types.inferredType = toString(lookupType(local->values.data[posn], module, builtinTypes)); + + TypeInfo ti{Pattern::VarAnnot, toString(ctxNode), types}; + typeInfo.push_back(ti); + } + } + + const AstExprTypeAssertion* maybeRequire = local->values.data[posn]->as(); + if (!maybeRequire) + continue; + + if (isAnyCast(scope, local->values.data[posn], module, builtinTypes)) + { + TelemetryTypePair types; + + types.inferredType = toString(head[std::min(local->values.size - 1, posn)]); + + TypeInfo ti{Pattern::Casts, toString(ctxNode), types}; + typeInfo.push_back(ti); + } + } + else + { + + if (std::min(local->values.size - 1, posn) < head.size()) + { + if (loc->annotation) + { + auto annot = lookupAnnotation(loc->annotation, module, builtinTypes); + if (containsAny(annot)) + { + TelemetryTypePair types; + + types.annotatedType = toString(annot); + types.inferredType = toString(head[std::min(local->values.size - 1, posn)]); + + TypeInfo ti{Pattern::VarAnnot, toString(ctxNode), types}; + typeInfo.push_back(ti); + } + } + } + else + { + if (tail) + { + if (containsAny(*tail)) + { + TelemetryTypePair types; + + types.inferredType = toString(*tail); + + TypeInfo ti{Pattern::VarAny, toString(ctxNode), types}; + typeInfo.push_back(ti); + } + } + } + } + + ++posn; + } +} + +void AnyTypeSummary::visit(const Scope* scope, AstStatFor* for_, const Module* module, NotNull builtinTypes) +{ + const Scope* forScope = findInnerMostScope(for_->location, module); + visit(forScope, for_->body, module, builtinTypes); +} + +void AnyTypeSummary::visit(const Scope* scope, AstStatForIn* forIn, const Module* module, NotNull builtinTypes) +{ + const Scope* loopScope = findInnerMostScope(forIn->location, module); + visit(loopScope, forIn->body, module, builtinTypes); +} + +void AnyTypeSummary::visit(const Scope* scope, AstStatAssign* assign, const Module* module, NotNull builtinTypes) +{ + auto ctxNode = getNode(rootSrc, assign); + + TypePackId values = reconstructTypePack(assign->values, module, builtinTypes); + auto [head, tail] = flatten(values); + + size_t posn = 0; + for (AstExpr* var : assign->vars) + { + TypeId tp = lookupType(var, module, builtinTypes); + if (containsAny(tp)) + { + TelemetryTypePair types; + + types.annotatedType = toString(tp); + + auto loc = std::min(assign->vars.size - 1, posn); + if (head.size() >= assign->vars.size) + { + types.inferredType = toString(head[posn]); + } + else if (loc < head.size()) + types.inferredType = toString(head[loc]); + else + types.inferredType = toString(builtinTypes->nilType); + + TypeInfo ti{Pattern::Assign, toString(ctxNode), types}; + typeInfo.push_back(ti); + } + ++posn; + } + + for (AstExpr* val : assign->values) + { + if (isAnyCall(scope, val, module, builtinTypes)) + { + TelemetryTypePair types; + + types.inferredType = toString(lookupType(val, module, builtinTypes)); + + TypeInfo ti{Pattern::FuncApp, toString(ctxNode), types}; + typeInfo.push_back(ti); + } + + if (isAnyCast(scope, val, module, builtinTypes)) + { + if (auto cast = val->as()) + { + TelemetryTypePair types; + + types.annotatedType = toString(lookupAnnotation(cast->annotation, module, builtinTypes)); + types.inferredType = toString(lookupType(val, module, builtinTypes)); + + TypeInfo ti{Pattern::Casts, toString(ctxNode), types}; + typeInfo.push_back(ti); + } + } + } + + if (tail) + { + if (containsAny(*tail)) + { + TelemetryTypePair types; + + types.inferredType = toString(*tail); + + TypeInfo ti{Pattern::Assign, toString(ctxNode), types}; + typeInfo.push_back(ti); + } + } +} + +void AnyTypeSummary::visit(const Scope* scope, AstStatCompoundAssign* assign, const Module* module, NotNull builtinTypes) +{ + auto ctxNode = getNode(rootSrc, assign); + + TelemetryTypePair types; + + types.inferredType = toString(lookupType(assign->value, module, builtinTypes)); + types.annotatedType = toString(lookupType(assign->var, module, builtinTypes)); + + if (module->astTypes.contains(assign->var)) + { + if (containsAny(*module->astTypes.find(assign->var))) + { + TypeInfo ti{Pattern::Assign, toString(ctxNode), types}; + typeInfo.push_back(ti); + } + } + else if (module->astTypePacks.contains(assign->var)) + { + if (containsAny(*module->astTypePacks.find(assign->var))) + { + TypeInfo ti{Pattern::Assign, toString(ctxNode), types}; + typeInfo.push_back(ti); + } + } + + if (isAnyCall(scope, assign->value, module, builtinTypes)) + { + TypeInfo ti{Pattern::FuncApp, toString(ctxNode), types}; + typeInfo.push_back(ti); + } + + if (isAnyCast(scope, assign->value, module, builtinTypes)) + { + if (auto cast = assign->value->as()) + { + types.annotatedType = toString(lookupAnnotation(cast->annotation, module, builtinTypes)); + types.inferredType = toString(lookupType(cast->expr, module, builtinTypes)); + + TypeInfo ti{Pattern::Casts, toString(ctxNode), types}; + typeInfo.push_back(ti); + } + } +} + +void AnyTypeSummary::visit(const Scope* scope, AstStatFunction* function, const Module* module, NotNull builtinTypes) +{ + TelemetryTypePair types; + types.inferredType = toString(lookupType(function->func, module, builtinTypes)); + + if (hasVariadicAnys(scope, function->func, module, builtinTypes)) + { + TypeInfo ti{Pattern::VarAny, toString(function), types}; + typeInfo.push_back(ti); + } + + if (hasArgAnys(scope, function->func, module, builtinTypes)) + { + TypeInfo ti{Pattern::FuncArg, toString(function), types}; + typeInfo.push_back(ti); + } + + if (hasAnyReturns(scope, function->func, module, builtinTypes)) + { + TypeInfo ti{Pattern::FuncRet, toString(function), types}; + typeInfo.push_back(ti); + } + + if (function->func->body->body.size > 0) + visit(scope, function->func->body, module, builtinTypes); +} + +void AnyTypeSummary::visit(const Scope* scope, AstStatLocalFunction* function, const Module* module, NotNull builtinTypes) +{ + TelemetryTypePair types; + + if (hasVariadicAnys(scope, function->func, module, builtinTypes)) + { + types.inferredType = toString(lookupType(function->func, module, builtinTypes)); + TypeInfo ti{Pattern::VarAny, toString(function), types}; + typeInfo.push_back(ti); + } + + if (hasArgAnys(scope, function->func, module, builtinTypes)) + { + types.inferredType = toString(lookupType(function->func, module, builtinTypes)); + TypeInfo ti{Pattern::FuncArg, toString(function), types}; + typeInfo.push_back(ti); + } + + if (hasAnyReturns(scope, function->func, module, builtinTypes)) + { + types.inferredType = toString(lookupType(function->func, module, builtinTypes)); + TypeInfo ti{Pattern::FuncRet, toString(function), types}; + typeInfo.push_back(ti); + } + + if (function->func->body->body.size > 0) + visit(scope, function->func->body, module, builtinTypes); +} + +void AnyTypeSummary::visit(const Scope* scope, AstStatTypeAlias* alias, const Module* module, NotNull builtinTypes) +{ + auto ctxNode = getNode(rootSrc, alias); + + auto annot = lookupAnnotation(alias->type, module, builtinTypes); + if (containsAny(annot)) + { + // no expr => no inference for aliases + TelemetryTypePair types; + + types.annotatedType = toString(annot); + TypeInfo ti{Pattern::Alias, toString(ctxNode), types}; + typeInfo.push_back(ti); + } +} + +void AnyTypeSummary::visit(const Scope* scope, AstStatExpr* expr, const Module* module, NotNull builtinTypes) +{ + auto ctxNode = getNode(rootSrc, expr); + + if (isAnyCall(scope, expr->expr, module, builtinTypes)) + { + TelemetryTypePair types; + types.inferredType = toString(lookupType(expr->expr, module, builtinTypes)); + + TypeInfo ti{Pattern::FuncApp, toString(ctxNode), types}; + typeInfo.push_back(ti); + } +} + +void AnyTypeSummary::visit(const Scope* scope, AstStatDeclareGlobal* declareGlobal, const Module* module, NotNull builtinTypes) {} + +void AnyTypeSummary::visit(const Scope* scope, AstStatDeclareClass* declareClass, const Module* module, NotNull builtinTypes) {} + +void AnyTypeSummary::visit(const Scope* scope, AstStatDeclareFunction* declareFunction, const Module* module, NotNull builtinTypes) {} + +void AnyTypeSummary::visit(const Scope* scope, AstStatError* error, const Module* module, NotNull builtinTypes) {} + +TypeId AnyTypeSummary::checkForFamilyInhabitance(const TypeId instance, const Location location) +{ + if (seenTypeFamilyInstances.find(instance)) + return instance; + + seenTypeFamilyInstances.insert(instance); + return instance; +} + +TypeId AnyTypeSummary::lookupType(const AstExpr* expr, const Module* module, NotNull builtinTypes) +{ + const TypeId* ty = module->astTypes.find(expr); + if (ty) + return checkForFamilyInhabitance(follow(*ty), expr->location); + + const TypePackId* tp = module->astTypePacks.find(expr); + if (tp) + { + if (auto fst = first(*tp, /*ignoreHiddenVariadics*/ false)) + return checkForFamilyInhabitance(*fst, expr->location); + else if (finite(*tp) && size(*tp) == 0) + return checkForFamilyInhabitance(builtinTypes->nilType, expr->location); + } + + return builtinTypes->errorRecoveryType(); +} + +TypePackId AnyTypeSummary::reconstructTypePack(AstArray exprs, const Module* module, NotNull builtinTypes) +{ + if (exprs.size == 0) + return arena.addTypePack(TypePack{{}, std::nullopt}); + + std::vector head; + + for (size_t i = 0; i < exprs.size - 1; ++i) + { + head.push_back(lookupType(exprs.data[i], module, builtinTypes)); + } + + const TypePackId* tail = module->astTypePacks.find(exprs.data[exprs.size - 1]); + if (tail) + return arena.addTypePack(TypePack{std::move(head), follow(*tail)}); + else + return arena.addTypePack(TypePack{std::move(head), builtinTypes->errorRecoveryTypePack()}); +} + +bool AnyTypeSummary::isAnyCall(const Scope* scope, AstExpr* expr, const Module* module, NotNull builtinTypes) +{ + if (auto call = expr->as()) + { + TypePackId args = reconstructTypePack(call->args, module, builtinTypes); + if (containsAny(args)) + return true; + + TypeId func = lookupType(call->func, module, builtinTypes); + if (containsAny(func)) + return true; + } + return false; +} + +bool AnyTypeSummary::hasVariadicAnys(const Scope* scope, AstExprFunction* expr, const Module* module, NotNull builtinTypes) +{ + if (expr->vararg && expr->varargAnnotation) + { + auto annot = lookupPackAnnotation(expr->varargAnnotation, module); + if (annot && containsAny(*annot)) + { + return true; + } + } + return false; +} + +bool AnyTypeSummary::hasArgAnys(const Scope* scope, AstExprFunction* expr, const Module* module, NotNull builtinTypes) +{ + if (expr->args.size > 0) + { + for (const AstLocal* arg : expr->args) + { + if (arg->annotation) + { + auto annot = lookupAnnotation(arg->annotation, module, builtinTypes); + if (containsAny(annot)) + { + return true; + } + } + } + } + return false; +} + +bool AnyTypeSummary::hasAnyReturns(const Scope* scope, AstExprFunction* expr, const Module* module, NotNull builtinTypes) +{ + if (!expr->returnAnnotation) + { + return false; + } + + for (AstType* ret : expr->returnAnnotation->types) + { + if (containsAny(lookupAnnotation(ret, module, builtinTypes))) + { + return true; + } + } + + if (expr->returnAnnotation->tailType) + { + auto annot = lookupPackAnnotation(expr->returnAnnotation->tailType, module); + if (annot && containsAny(*annot)) + { + return true; + } + } + + return false; +} + +bool AnyTypeSummary::isAnyCast(const Scope* scope, AstExpr* expr, const Module* module, NotNull builtinTypes) +{ + if (auto cast = expr->as()) + { + auto annot = lookupAnnotation(cast->annotation, module, builtinTypes); + if (containsAny(annot)) + { + return true; + } + } + return false; +} + +TypeId AnyTypeSummary::lookupAnnotation(AstType* annotation, const Module* module, NotNull builtintypes) +{ + if (FFlag::DebugLuauMagicTypes) + { + if (auto ref = annotation->as(); ref && ref->parameters.size > 0) + { + if (auto ann = ref->parameters.data[0].type) + { + TypeId argTy = lookupAnnotation(ref->parameters.data[0].type, module, builtintypes); + return follow(argTy); + } + } + } + + const TypeId* ty = module->astResolvedTypes.find(annotation); + if (ty) + return checkForTypeFunctionInhabitance(follow(*ty), annotation->location); + else + return checkForTypeFunctionInhabitance(builtintypes->errorRecoveryType(), annotation->location); +} + +TypeId AnyTypeSummary::checkForTypeFunctionInhabitance(const TypeId instance, const Location location) +{ + if (seenTypeFunctionInstances.find(instance)) + return instance; + seenTypeFunctionInstances.insert(instance); + + return instance; +} + +std::optional AnyTypeSummary::lookupPackAnnotation(AstTypePack* annotation, const Module* module) +{ + const TypePackId* tp = module->astResolvedTypePacks.find(annotation); + if (tp != nullptr) + return {follow(*tp)}; + return {}; +} + +bool AnyTypeSummary::containsAny(TypeId typ) +{ + typ = follow(typ); + + if (auto t = seen.find(typ); t && !*t) + { + return false; + } + + seen[typ] = false; + + RecursionCounter counter{&recursionCount}; + if (recursionCount >= FInt::LuauAnySummaryRecursionLimit) + { + return false; + } + + bool found = false; + + if (auto ty = get(typ)) + { + found = true; + } + else if (auto ty = get(typ)) + { + found = true; + } + else if (auto ty = get(typ)) + { + for (auto& [_name, prop] : ty->props) + { + if (FFlag::DebugLuauDeferredConstraintResolution) + { + if (auto newT = follow(prop.readTy)) + { + if (containsAny(*newT)) + found = true; + } + else if (auto newT = follow(prop.writeTy)) + { + if (containsAny(*newT)) + found = true; + } + } + else + { + if (containsAny(prop.type())) + found = true; + } + } + } + else if (auto ty = get(typ)) + { + for (auto part : ty->parts) + { + if (containsAny(part)) + { + found = true; + } + } + } + else if (auto ty = get(typ)) + { + for (auto option : ty->options) + { + if (containsAny(option)) + { + found = true; + } + } + } + else if (auto ty = get(typ)) + { + if (containsAny(ty->argTypes)) + found = true; + else if (containsAny(ty->retTypes)) + found = true; + } + + seen[typ] = found; + + return found; +} + +bool AnyTypeSummary::containsAny(TypePackId typ) +{ + typ = follow(typ); + + if (auto t = seen.find(typ); t && !*t) + { + return false; + } + + seen[typ] = false; + + auto [head, tail] = flatten(typ); + bool found = false; + + for (auto tp : head) + { + if (containsAny(tp)) + found = true; + } + + if (tail) + { + if (auto vtp = get(tail)) + { + if (auto ty = get(follow(vtp->ty))) + { + found = true; + } + } + else if (auto tftp = get(tail)) + { + + for (TypePackId tp : tftp->packArguments) + { + if (containsAny(tp)) + { + found = true; + } + } + + for (TypeId t : tftp->typeArguments) + { + if (containsAny(t)) + { + found = true; + } + } + } + } + + seen[typ] = found; + + return found; +} + +const Scope* AnyTypeSummary::findInnerMostScope(const Location location, const Module* module) +{ + const Scope* bestScope = module->getModuleScope().get(); + + bool didNarrow = false; + do + { + didNarrow = false; + for (auto scope : bestScope->children) + { + if (scope->location.encloses(location)) + { + bestScope = scope.get(); + didNarrow = true; + break; + } + } + } while (didNarrow && bestScope->children.size() > 0); + + return bestScope; +} + +std::optional AnyTypeSummary::matchRequire(const AstExprCall& call) +{ + const char* require = "require"; + + if (call.args.size != 1) + return std::nullopt; + + const AstExprGlobal* funcAsGlobal = call.func->as(); + if (!funcAsGlobal || funcAsGlobal->name != require) + return std::nullopt; + + if (call.args.size != 1) + return std::nullopt; + + return call.args.data[0]; +} + +AstNode* AnyTypeSummary::getNode(AstStatBlock* root, AstNode* node) +{ + FindReturnAncestry finder(node, root->location.end); + root->visit(&finder); + + if (!finder.currNode) + finder.currNode = node; + + LUAU_ASSERT(finder.found && finder.currNode); + return finder.currNode; +} + +bool AnyTypeSummary::FindReturnAncestry::visit(AstStatLocalFunction* node) +{ + currNode = node; + return !found; +} + +bool AnyTypeSummary::FindReturnAncestry::visit(AstStatFunction* node) +{ + currNode = node; + return !found; +} + +bool AnyTypeSummary::FindReturnAncestry::visit(AstType* node) +{ + return !found; +} + +bool AnyTypeSummary::FindReturnAncestry::visit(AstNode* node) +{ + if (node == stat) + { + found = true; + } + + if (node->location.end == rootEnd && stat->location.end >= rootEnd) + { + currNode = node; + found = true; + } + + return !found; +} + + +AnyTypeSummary::TypeInfo::TypeInfo(Pattern code, std::string node, TelemetryTypePair type) + : code(code) + , node(node) + , type(type) +{ +} + +AnyTypeSummary::FindReturnAncestry::FindReturnAncestry(AstNode* stat, Position rootEnd) + : stat(stat) + , rootEnd(rootEnd) +{ +} + +AnyTypeSummary::AnyTypeSummary() {} + +} // namespace Luau \ No newline at end of file diff --git a/third_party/luau/Analysis/src/Anyification.cpp b/third_party/luau/Analysis/src/Anyification.cpp index 15dd25cc..4bacec03 100644 --- a/third_party/luau/Analysis/src/Anyification.cpp +++ b/third_party/luau/Analysis/src/Anyification.cpp @@ -6,13 +6,17 @@ #include "Luau/Normalize.h" #include "Luau/TxnLog.h" -LUAU_FASTFLAG(LuauClassTypeVarsInSubstitution) - namespace Luau { -Anyification::Anyification(TypeArena* arena, NotNull scope, NotNull builtinTypes, InternalErrorReporter* iceHandler, - TypeId anyType, TypePackId anyTypePack) +Anyification::Anyification( + TypeArena* arena, + NotNull scope, + NotNull builtinTypes, + InternalErrorReporter* iceHandler, + TypeId anyType, + TypePackId anyTypePack +) : Substitution(TxnLog::empty(), arena) , scope(scope) , builtinTypes(builtinTypes) @@ -22,8 +26,14 @@ Anyification::Anyification(TypeArena* arena, NotNull scope, NotNull builtinTypes, InternalErrorReporter* iceHandler, - TypeId anyType, TypePackId anyTypePack) +Anyification::Anyification( + TypeArena* arena, + const ScopePtr& scope, + NotNull builtinTypes, + InternalErrorReporter* iceHandler, + TypeId anyType, + TypePackId anyTypePack +) : Anyification(arena, NotNull{scope.get()}, builtinTypes, iceHandler, anyType, anyTypePack) { } @@ -78,7 +88,7 @@ TypePackId Anyification::clean(TypePackId tp) bool Anyification::ignoreChildren(TypeId ty) { - if (FFlag::LuauClassTypeVarsInSubstitution && get(ty)) + if (get(ty)) return true; return ty->persistent; diff --git a/third_party/luau/Analysis/src/ApplyTypeFunction.cpp b/third_party/luau/Analysis/src/ApplyTypeFunction.cpp index fe8cc8ac..025e8f6d 100644 --- a/third_party/luau/Analysis/src/ApplyTypeFunction.cpp +++ b/third_party/luau/Analysis/src/ApplyTypeFunction.cpp @@ -2,8 +2,6 @@ #include "Luau/ApplyTypeFunction.h" -LUAU_FASTFLAG(LuauClassTypeVarsInSubstitution) - namespace Luau { @@ -33,7 +31,7 @@ bool ApplyTypeFunction::ignoreChildren(TypeId ty) { if (get(ty)) return true; - else if (FFlag::LuauClassTypeVarsInSubstitution && get(ty)) + else if (get(ty)) return true; else return false; diff --git a/third_party/luau/Analysis/src/AstJsonEncoder.cpp b/third_party/luau/Analysis/src/AstJsonEncoder.cpp index a964c785..ceeee73c 100644 --- a/third_party/luau/Analysis/src/AstJsonEncoder.cpp +++ b/third_party/luau/Analysis/src/AstJsonEncoder.cpp @@ -8,6 +8,8 @@ #include +LUAU_FASTFLAG(LuauDeclarationExtraPropData) + namespace Luau { @@ -198,6 +200,23 @@ struct AstJsonEncoder : public AstVisitor { writeString(name.value ? name.value : ""); } + void write(std::optional name) + { + if (name) + write(*name); + else + writeRaw("null"); + } + void write(AstArgumentName name) + { + writeRaw("{"); + bool c = pushComma(); + writeType("AstArgumentName"); + write("name", name.first); + write("location", name.second); + popComma(c); + writeRaw("}"); + } void write(const Position& position) { @@ -254,9 +273,14 @@ struct AstJsonEncoder : public AstVisitor void write(class AstExprGroup* node) { - writeNode(node, "AstExprGroup", [&]() { - write("expr", node->expr); - }); + writeNode( + node, + "AstExprGroup", + [&]() + { + write("expr", node->expr); + } + ); } void write(class AstExprConstantNil* node) @@ -266,37 +290,62 @@ struct AstJsonEncoder : public AstVisitor void write(class AstExprConstantBool* node) { - writeNode(node, "AstExprConstantBool", [&]() { - write("value", node->value); - }); + writeNode( + node, + "AstExprConstantBool", + [&]() + { + write("value", node->value); + } + ); } void write(class AstExprConstantNumber* node) { - writeNode(node, "AstExprConstantNumber", [&]() { - write("value", node->value); - }); + writeNode( + node, + "AstExprConstantNumber", + [&]() + { + write("value", node->value); + } + ); } void write(class AstExprConstantString* node) { - writeNode(node, "AstExprConstantString", [&]() { - write("value", node->value); - }); + writeNode( + node, + "AstExprConstantString", + [&]() + { + write("value", node->value); + } + ); } void write(class AstExprLocal* node) { - writeNode(node, "AstExprLocal", [&]() { - write("local", node->local); - }); + writeNode( + node, + "AstExprLocal", + [&]() + { + write("local", node->local); + } + ); } void write(class AstExprGlobal* node) { - writeNode(node, "AstExprGlobal", [&]() { - write("global", node->name); - }); + writeNode( + node, + "AstExprGlobal", + [&]() + { + write("global", node->name); + } + ); } void write(class AstExprVarargs* node) @@ -330,52 +379,71 @@ struct AstJsonEncoder : public AstVisitor void write(class AstExprCall* node) { - writeNode(node, "AstExprCall", [&]() { - PROP(func); - PROP(args); - PROP(self); - PROP(argLocation); - }); + writeNode( + node, + "AstExprCall", + [&]() + { + PROP(func); + PROP(args); + PROP(self); + PROP(argLocation); + } + ); } void write(class AstExprIndexName* node) { - writeNode(node, "AstExprIndexName", [&]() { - PROP(expr); - PROP(index); - PROP(indexLocation); - PROP(op); - }); + writeNode( + node, + "AstExprIndexName", + [&]() + { + PROP(expr); + PROP(index); + PROP(indexLocation); + PROP(op); + } + ); } void write(class AstExprIndexExpr* node) { - writeNode(node, "AstExprIndexExpr", [&]() { - PROP(expr); - PROP(index); - }); + writeNode( + node, + "AstExprIndexExpr", + [&]() + { + PROP(expr); + PROP(index); + } + ); } void write(class AstExprFunction* node) { - writeNode(node, "AstExprFunction", [&]() { - PROP(generics); - PROP(genericPacks); - if (node->self) - PROP(self); - PROP(args); - if (node->returnAnnotation) - PROP(returnAnnotation); - PROP(vararg); - PROP(varargLocation); - if (node->varargAnnotation) - PROP(varargAnnotation); - - PROP(body); - PROP(functionDepth); - PROP(debugname); - PROP(hasEnd); - }); + writeNode( + node, + "AstExprFunction", + [&]() + { + PROP(generics); + PROP(genericPacks); + if (node->self) + PROP(self); + PROP(args); + if (node->returnAnnotation) + PROP(returnAnnotation); + PROP(vararg); + PROP(varargLocation); + if (node->varargAnnotation) + PROP(varargAnnotation); + + PROP(body); + PROP(functionDepth); + PROP(debugname); + } + ); } void write(const std::optional& typeList) @@ -457,28 +525,43 @@ struct AstJsonEncoder : public AstVisitor void write(class AstExprIfElse* node) { - writeNode(node, "AstExprIfElse", [&]() { - PROP(condition); - PROP(hasThen); - PROP(trueExpr); - PROP(hasElse); - PROP(falseExpr); - }); + writeNode( + node, + "AstExprIfElse", + [&]() + { + PROP(condition); + PROP(hasThen); + PROP(trueExpr); + PROP(hasElse); + PROP(falseExpr); + } + ); } void write(class AstExprInterpString* node) { - writeNode(node, "AstExprInterpString", [&]() { - PROP(strings); - PROP(expressions); - }); + writeNode( + node, + "AstExprInterpString", + [&]() + { + PROP(strings); + PROP(expressions); + } + ); } void write(class AstExprTable* node) { - writeNode(node, "AstExprTable", [&]() { - PROP(items); - }); + writeNode( + node, + "AstExprTable", + [&]() + { + PROP(items); + } + ); } void write(AstExprUnary::Op op) @@ -496,10 +579,15 @@ struct AstJsonEncoder : public AstVisitor void write(class AstExprUnary* node) { - writeNode(node, "AstExprUnary", [&]() { - PROP(op); - PROP(expr); - }); + writeNode( + node, + "AstExprUnary", + [&]() + { + PROP(op); + PROP(expr); + } + ); } void write(AstExprBinary::Op op) @@ -514,6 +602,8 @@ struct AstJsonEncoder : public AstVisitor return writeString("Mul"); case AstExprBinary::Div: return writeString("Div"); + case AstExprBinary::FloorDiv: + return writeString("FloorDiv"); case AstExprBinary::Mod: return writeString("Mod"); case AstExprBinary::Pow: @@ -536,81 +626,117 @@ struct AstJsonEncoder : public AstVisitor return writeString("And"); case AstExprBinary::Or: return writeString("Or"); + default: + LUAU_ASSERT(!"Unknown Op"); } } void write(class AstExprBinary* node) { - writeNode(node, "AstExprBinary", [&]() { - PROP(op); - PROP(left); - PROP(right); - }); + writeNode( + node, + "AstExprBinary", + [&]() + { + PROP(op); + PROP(left); + PROP(right); + } + ); } void write(class AstExprTypeAssertion* node) { - writeNode(node, "AstExprTypeAssertion", [&]() { - PROP(expr); - PROP(annotation); - }); + writeNode( + node, + "AstExprTypeAssertion", + [&]() + { + PROP(expr); + PROP(annotation); + } + ); } void write(class AstExprError* node) { - writeNode(node, "AstExprError", [&]() { - PROP(expressions); - PROP(messageIndex); - }); + writeNode( + node, + "AstExprError", + [&]() + { + PROP(expressions); + PROP(messageIndex); + } + ); } void write(class AstStatBlock* node) { - writeNode(node, "AstStatBlock", [&]() { - writeRaw(",\"body\":["); - bool comma = false; - for (AstStat* stat : node->body) + writeNode( + node, + "AstStatBlock", + [&]() { - if (comma) - writeRaw(","); - else - comma = true; - - write(stat); + writeRaw(",\"hasEnd\":"); + write(node->hasEnd); + writeRaw(",\"body\":["); + bool comma = false; + for (AstStat* stat : node->body) + { + if (comma) + writeRaw(","); + else + comma = true; + + write(stat); + } + writeRaw("]"); } - writeRaw("]"); - }); + ); } void write(class AstStatIf* node) { - writeNode(node, "AstStatIf", [&]() { - PROP(condition); - PROP(thenbody); - if (node->elsebody) - PROP(elsebody); - write("hasThen", node->thenLocation.has_value()); - PROP(hasEnd); - }); + writeNode( + node, + "AstStatIf", + [&]() + { + PROP(condition); + PROP(thenbody); + if (node->elsebody) + PROP(elsebody); + write("hasThen", node->thenLocation.has_value()); + } + ); } void write(class AstStatWhile* node) { - writeNode(node, "AstStatWhile", [&]() { - PROP(condition); - PROP(body); - PROP(hasDo); - PROP(hasEnd); - }); + writeNode( + node, + "AstStatWhile", + [&]() + { + PROP(condition); + PROP(body); + PROP(hasDo); + } + ); } void write(class AstStatRepeat* node) { - writeNode(node, "AstStatRepeat", [&]() { - PROP(condition); - PROP(body); - PROP(hasUntil); - }); + writeNode( + node, + "AstStatRepeat", + [&]() + { + PROP(condition); + PROP(body); + } + ); } void write(class AstStatBreak* node) @@ -625,113 +751,188 @@ struct AstJsonEncoder : public AstVisitor void write(class AstStatReturn* node) { - writeNode(node, "AstStatReturn", [&]() { - PROP(list); - }); + writeNode( + node, + "AstStatReturn", + [&]() + { + PROP(list); + } + ); } void write(class AstStatExpr* node) { - writeNode(node, "AstStatExpr", [&]() { - PROP(expr); - }); + writeNode( + node, + "AstStatExpr", + [&]() + { + PROP(expr); + } + ); } void write(class AstStatLocal* node) { - writeNode(node, "AstStatLocal", [&]() { - PROP(vars); - PROP(values); - }); + writeNode( + node, + "AstStatLocal", + [&]() + { + PROP(vars); + PROP(values); + } + ); } void write(class AstStatFor* node) { - writeNode(node, "AstStatFor", [&]() { - PROP(var); - PROP(from); - PROP(to); - if (node->step) - PROP(step); - PROP(body); - PROP(hasDo); - PROP(hasEnd); - }); + writeNode( + node, + "AstStatFor", + [&]() + { + PROP(var); + PROP(from); + PROP(to); + if (node->step) + PROP(step); + PROP(body); + PROP(hasDo); + } + ); } void write(class AstStatForIn* node) { - writeNode(node, "AstStatForIn", [&]() { - PROP(vars); - PROP(values); - PROP(body); - PROP(hasIn); - PROP(hasDo); - PROP(hasEnd); - }); + writeNode( + node, + "AstStatForIn", + [&]() + { + PROP(vars); + PROP(values); + PROP(body); + PROP(hasIn); + PROP(hasDo); + } + ); } void write(class AstStatAssign* node) { - writeNode(node, "AstStatAssign", [&]() { - PROP(vars); - PROP(values); - }); + writeNode( + node, + "AstStatAssign", + [&]() + { + PROP(vars); + PROP(values); + } + ); } void write(class AstStatCompoundAssign* node) { - writeNode(node, "AstStatCompoundAssign", [&]() { - PROP(op); - PROP(var); - PROP(value); - }); + writeNode( + node, + "AstStatCompoundAssign", + [&]() + { + PROP(op); + PROP(var); + PROP(value); + } + ); } void write(class AstStatFunction* node) { - writeNode(node, "AstStatFunction", [&]() { - PROP(name); - PROP(func); - }); + writeNode( + node, + "AstStatFunction", + [&]() + { + PROP(name); + PROP(func); + } + ); } void write(class AstStatLocalFunction* node) { - writeNode(node, "AstStatLocalFunction", [&]() { - PROP(name); - PROP(func); - }); + writeNode( + node, + "AstStatLocalFunction", + [&]() + { + PROP(name); + PROP(func); + } + ); } void write(class AstStatTypeAlias* node) { - writeNode(node, "AstStatTypeAlias", [&]() { - PROP(name); - PROP(generics); - PROP(genericPacks); - PROP(type); - PROP(exported); - }); + writeNode( + node, + "AstStatTypeAlias", + [&]() + { + PROP(name); + PROP(generics); + PROP(genericPacks); + PROP(type); + PROP(exported); + } + ); } void write(class AstStatDeclareFunction* node) { - writeNode(node, "AstStatDeclareFunction", [&]() { - PROP(name); - PROP(params); - PROP(retTypes); - PROP(generics); - PROP(genericPacks); - }); + writeNode( + node, + "AstStatDeclareFunction", + [&]() + { + // TODO: attributes + PROP(name); + + if (FFlag::LuauDeclarationExtraPropData) + PROP(nameLocation); + + PROP(params); + + if (FFlag::LuauDeclarationExtraPropData) + { + PROP(paramNames); + PROP(vararg); + PROP(varargLocation); + } + + PROP(retTypes); + PROP(generics); + PROP(genericPacks); + } + ); } void write(class AstStatDeclareGlobal* node) { - writeNode(node, "AstStatDeclareGlobal", [&]() { - PROP(name); - PROP(type); - }); + writeNode( + node, + "AstStatDeclareGlobal", + [&]() + { + PROP(name); + + if (FFlag::LuauDeclarationExtraPropData) + PROP(nameLocation); + + PROP(type); + } + ); } void write(const AstDeclaredClassProp& prop) @@ -739,28 +940,47 @@ struct AstJsonEncoder : public AstVisitor writeRaw("{"); bool c = pushComma(); write("name", prop.name); + + if (FFlag::LuauDeclarationExtraPropData) + write("nameLocation", prop.nameLocation); + writeType("AstDeclaredClassProp"); write("luauType", prop.ty); + + if (FFlag::LuauDeclarationExtraPropData) + write("location", prop.location); + popComma(c); writeRaw("}"); } void write(class AstStatDeclareClass* node) { - writeNode(node, "AstStatDeclareClass", [&]() { - PROP(name); - if (node->superName) - write("superName", *node->superName); - PROP(props); - }); + writeNode( + node, + "AstStatDeclareClass", + [&]() + { + PROP(name); + if (node->superName) + write("superName", *node->superName); + PROP(props); + PROP(indexer); + } + ); } void write(class AstStatError* node) { - writeNode(node, "AstStatError", [&]() { - PROP(expressions); - PROP(statements); - }); + writeNode( + node, + "AstStatError", + [&]() + { + PROP(expressions); + PROP(statements); + } + ); } void write(struct AstTypeOrPack node) @@ -773,15 +993,20 @@ struct AstJsonEncoder : public AstVisitor void write(class AstTypeReference* node) { - writeNode(node, "AstTypeReference", [&]() { - if (node->prefix) - PROP(prefix); - if (node->prefixLocation) - write("prefixLocation", *node->prefixLocation); - PROP(name); - PROP(nameLocation); - PROP(parameters); - }); + writeNode( + node, + "AstTypeReference", + [&]() + { + if (node->prefix) + PROP(prefix); + if (node->prefixLocation) + write("prefixLocation", *node->prefixLocation); + PROP(name); + PROP(nameLocation); + PROP(parameters); + } + ); } void write(const AstTableProp& prop) @@ -800,10 +1025,15 @@ struct AstJsonEncoder : public AstVisitor void write(class AstTypeTable* node) { - writeNode(node, "AstTypeTable", [&]() { - PROP(props); - PROP(indexer); - }); + writeNode( + node, + "AstTypeTable", + [&]() + { + PROP(props); + PROP(indexer); + } + ); } void write(struct AstTableIndexer* indexer) @@ -826,62 +1056,129 @@ struct AstJsonEncoder : public AstVisitor void write(class AstTypeFunction* node) { - writeNode(node, "AstTypeFunction", [&]() { - PROP(generics); - PROP(genericPacks); - PROP(argTypes); - PROP(returnTypes); - }); + writeNode( + node, + "AstTypeFunction", + [&]() + { + PROP(generics); + PROP(genericPacks); + PROP(argTypes); + PROP(argNames); + PROP(returnTypes); + } + ); } void write(class AstTypeTypeof* node) { - writeNode(node, "AstTypeTypeof", [&]() { - PROP(expr); - }); + writeNode( + node, + "AstTypeTypeof", + [&]() + { + PROP(expr); + } + ); } void write(class AstTypeUnion* node) { - writeNode(node, "AstTypeUnion", [&]() { - PROP(types); - }); + writeNode( + node, + "AstTypeUnion", + [&]() + { + PROP(types); + } + ); } void write(class AstTypeIntersection* node) { - writeNode(node, "AstTypeIntersection", [&]() { - PROP(types); - }); + writeNode( + node, + "AstTypeIntersection", + [&]() + { + PROP(types); + } + ); } void write(class AstTypeError* node) { - writeNode(node, "AstTypeError", [&]() { - PROP(types); - PROP(messageIndex); - }); + writeNode( + node, + "AstTypeError", + [&]() + { + PROP(types); + PROP(messageIndex); + } + ); } void write(class AstTypePackExplicit* node) { - writeNode(node, "AstTypePackExplicit", [&]() { - PROP(typeList); - }); + writeNode( + node, + "AstTypePackExplicit", + [&]() + { + PROP(typeList); + } + ); } void write(class AstTypePackVariadic* node) { - writeNode(node, "AstTypePackVariadic", [&]() { - PROP(variadicType); - }); + writeNode( + node, + "AstTypePackVariadic", + [&]() + { + PROP(variadicType); + } + ); } void write(class AstTypePackGeneric* node) { - writeNode(node, "AstTypePackGeneric", [&]() { - PROP(genericName); - }); + writeNode( + node, + "AstTypePackGeneric", + [&]() + { + PROP(genericName); + } + ); + } + + bool visit(class AstTypeSingletonBool* node) override + { + writeNode( + node, + "AstTypeSingletonBool", + [&]() + { + write("value", node->value); + } + ); + return false; + } + + bool visit(class AstTypeSingletonString* node) override + { + writeNode( + node, + "AstTypeSingletonString", + [&]() + { + write("value", node->value); + } + ); + return false; } bool visit(class AstExprGroup* node) override diff --git a/third_party/luau/Analysis/src/AstQuery.cpp b/third_party/luau/Analysis/src/AstQuery.cpp index 38f3bdf5..243834f8 100644 --- a/third_party/luau/Analysis/src/AstQuery.cpp +++ b/third_party/luau/Analysis/src/AstQuery.cpp @@ -11,6 +11,9 @@ #include +LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); +LUAU_FASTFLAGVARIABLE(LuauFixBindingForGlobalPos, false); + namespace Luau { @@ -146,6 +149,16 @@ struct FindNode : public AstVisitor return false; } + bool visit(AstStatFunction* node) override + { + visit(static_cast(node)); + if (node->name->location.contains(pos)) + node->name->visit(this); + else if (node->func->location.contains(pos)) + node->func->visit(this); + return false; + } + bool visit(AstStatBlock* block) override { visit(static_cast(block)); @@ -164,50 +177,52 @@ struct FindNode : public AstVisitor } }; -struct FindFullAncestry final : public AstVisitor +} // namespace + +FindFullAncestry::FindFullAncestry(Position pos, Position documentEnd, bool includeTypes) + : pos(pos) + , documentEnd(documentEnd) + , includeTypes(includeTypes) { - std::vector nodes; - Position pos; - Position documentEnd; - bool includeTypes = false; +} - explicit FindFullAncestry(Position pos, Position documentEnd, bool includeTypes = false) - : pos(pos) - , documentEnd(documentEnd) - , includeTypes(includeTypes) - { - } +bool FindFullAncestry::visit(AstType* type) +{ + if (includeTypes) + return visit(static_cast(type)); + else + return false; +} - bool visit(AstType* type) override - { - if (includeTypes) - return visit(static_cast(type)); - else - return false; - } +bool FindFullAncestry::visit(AstStatFunction* node) +{ + visit(static_cast(node)); + if (node->name->location.contains(pos)) + node->name->visit(this); + else if (node->func->location.contains(pos)) + node->func->visit(this); + return false; +} - bool visit(AstNode* node) override +bool FindFullAncestry::visit(AstNode* node) +{ + if (node->location.contains(pos)) { - if (node->location.contains(pos)) - { - nodes.push_back(node); - return true; - } - - // Edge case: If we ask for the node at the position that is the very end of the document - // return the innermost AST element that ends at that position. + nodes.push_back(node); + return true; + } - if (node->location.end == documentEnd && pos >= documentEnd) - { - nodes.push_back(node); - return true; - } + // Edge case: If we ask for the node at the position that is the very end of the document + // return the innermost AST element that ends at that position. - return false; + if (node->location.end == documentEnd && pos >= documentEnd) + { + nodes.push_back(node); + return true; } -}; -} // namespace + return false; +} std::vector findAncestryAtPositionForAutocomplete(const SourceModule& source, Position pos) { @@ -310,10 +325,20 @@ std::optional findExpectedTypeAtPosition(const Module& module, const Sou static std::optional findBindingLocalStatement(const SourceModule& source, const Binding& binding) { + // Bindings coming from global sources (e.g., definition files) have a zero position. + // They cannot be defined from a local statement + if (FFlag::LuauFixBindingForGlobalPos && binding.location == Location{{0, 0}, {0, 0}}) + return std::nullopt; + std::vector nodes = findAstAncestryOfPosition(source, binding.location.begin); - auto iter = std::find_if(nodes.rbegin(), nodes.rend(), [](AstNode* node) { - return node->is(); - }); + auto iter = std::find_if( + nodes.rbegin(), + nodes.rend(), + [](AstNode* node) + { + return node->is(); + } + ); return iter != nodes.rend() ? std::make_optional((*iter)->as()) : std::nullopt; } @@ -452,7 +477,11 @@ ExprOrLocal findExprOrLocalAtPosition(const SourceModule& source, Position pos) } static std::optional checkOverloadedDocumentationSymbol( - const Module& module, const TypeId ty, const AstExpr* parentExpr, const std::optional documentationSymbol) + const Module& module, + const TypeId ty, + const AstExpr* parentExpr, + const std::optional documentationSymbol +) { if (!documentationSymbol) return std::nullopt; @@ -501,12 +530,28 @@ std::optional getDocumentationSymbolAtPosition(const Source if (const TableType* ttv = get(parentTy)) { if (auto propIt = ttv->props.find(indexName->index.value); propIt != ttv->props.end()) - return checkOverloadedDocumentationSymbol(module, propIt->second.type(), parentExpr, propIt->second.documentationSymbol); + { + if (FFlag::DebugLuauDeferredConstraintResolution) + { + if (auto ty = propIt->second.readTy) + return checkOverloadedDocumentationSymbol(module, *ty, parentExpr, propIt->second.documentationSymbol); + } + else + return checkOverloadedDocumentationSymbol(module, propIt->second.type(), parentExpr, propIt->second.documentationSymbol); + } } else if (const ClassType* ctv = get(parentTy)) { if (auto propIt = ctv->props.find(indexName->index.value); propIt != ctv->props.end()) - return checkOverloadedDocumentationSymbol(module, propIt->second.type(), parentExpr, propIt->second.documentationSymbol); + { + if (FFlag::DebugLuauDeferredConstraintResolution) + { + if (auto ty = propIt->second.readTy) + return checkOverloadedDocumentationSymbol(module, *ty, parentExpr, propIt->second.documentationSymbol); + } + else + return checkOverloadedDocumentationSymbol(module, propIt->second.type(), parentExpr, propIt->second.documentationSymbol); + } } } } diff --git a/third_party/luau/Analysis/src/Autocomplete.cpp b/third_party/luau/Analysis/src/Autocomplete.cpp index 4b66568b..a4acfb85 100644 --- a/third_party/luau/Analysis/src/Autocomplete.cpp +++ b/third_party/luau/Analysis/src/Autocomplete.cpp @@ -5,16 +5,18 @@ #include "Luau/BuiltinDefinitions.h" #include "Luau/Frontend.h" #include "Luau/ToString.h" +#include "Luau/Subtyping.h" #include "Luau/TypeInfer.h" #include "Luau/TypePack.h" -#include "Luau/TypeReduction.h" #include #include #include -static const std::unordered_set kStatementStartingKeywords = { - "while", "if", "local", "repeat", "function", "do", "for", "return", "break", "continue", "type", "export"}; +LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); + +static const std::unordered_set kStatementStartingKeywords = + {"while", "if", "local", "repeat", "function", "do", "for", "return", "break", "continue", "type", "export"}; namespace Luau { @@ -139,17 +141,33 @@ static bool checkTypeMatch(TypeId subTy, TypeId superTy, NotNull scope, T InternalErrorReporter iceReporter; UnifierSharedState unifierState(&iceReporter); Normalizer normalizer{typeArena, builtinTypes, NotNull{&unifierState}}; - Unifier unifier(NotNull{&normalizer}, Mode::Strict, scope, Location(), Variance::Covariant); - // Cost of normalization can be too high for autocomplete response time requirements - unifier.normalize = false; - unifier.checkInhabited = false; + if (FFlag::DebugLuauDeferredConstraintResolution) + { + Subtyping subtyping{builtinTypes, NotNull{typeArena}, NotNull{&normalizer}, NotNull{&iceReporter}, scope}; - return unifier.canUnify(subTy, superTy).empty(); + return subtyping.isSubtype(subTy, superTy).isSubtype; + } + else + { + Unifier unifier(NotNull{&normalizer}, scope, Location(), Variance::Covariant); + + // Cost of normalization can be too high for autocomplete response time requirements + unifier.normalize = false; + unifier.checkInhabited = false; + + return unifier.canUnify(subTy, superTy).empty(); + } } static TypeCorrectKind checkTypeCorrectKind( - const Module& module, TypeArena* typeArena, NotNull builtinTypes, AstNode* node, Position position, TypeId ty) + const Module& module, + TypeArena* typeArena, + NotNull builtinTypes, + AstNode* node, + Position position, + TypeId ty +) { ty = follow(ty); @@ -164,7 +182,8 @@ static TypeCorrectKind checkTypeCorrectKind( TypeId expectedType = follow(*typeAtPosition); - auto checkFunctionType = [typeArena, builtinTypes, moduleScope, &expectedType](const FunctionType* ftv) { + auto checkFunctionType = [typeArena, builtinTypes, moduleScope, &expectedType](const FunctionType* ftv) + { if (std::optional firstRetTy = first(ftv->retTypes)) return checkTypeMatch(*firstRetTy, expectedType, moduleScope, typeArena, builtinTypes); @@ -197,9 +216,18 @@ enum class PropIndexType Key, }; -static void autocompleteProps(const Module& module, TypeArena* typeArena, NotNull builtinTypes, TypeId rootTy, TypeId ty, - PropIndexType indexType, const std::vector& nodes, AutocompleteEntryMap& result, std::unordered_set& seen, - std::optional containingClass = std::nullopt) +static void autocompleteProps( + const Module& module, + TypeArena* typeArena, + NotNull builtinTypes, + TypeId rootTy, + TypeId ty, + PropIndexType indexType, + const std::vector& nodes, + AutocompleteEntryMap& result, + std::unordered_set& seen, + std::optional containingClass = std::nullopt +) { rootTy = follow(rootTy); ty = follow(ty); @@ -208,13 +236,15 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, NotNul return; seen.insert(ty); - auto isWrongIndexer = [typeArena, builtinTypes, &module, rootTy, indexType](Luau::TypeId type) { + auto isWrongIndexer = [typeArena, builtinTypes, &module, rootTy, indexType](Luau::TypeId type) + { if (indexType == PropIndexType::Key) return false; bool calledWithSelf = indexType == PropIndexType::Colon; - auto isCompatibleCall = [typeArena, builtinTypes, &module, rootTy, calledWithSelf](const FunctionType* ftv) { + auto isCompatibleCall = [typeArena, builtinTypes, &module, rootTy, calledWithSelf](const FunctionType* ftv) + { // Strong match with definition is a success if (calledWithSelf == ftv->hasSelf) return true; @@ -253,17 +283,30 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, NotNul return calledWithSelf; }; - auto fillProps = [&](const ClassType::Props& props) { + auto fillProps = [&](const ClassType::Props& props) + { for (const auto& [name, prop] : props) { // We are walking up the class hierarchy, so if we encounter a property that we have // already populated, it takes precedence over the property we found just now. if (result.count(name) == 0 && name != kParseNameError) { - Luau::TypeId type = Luau::follow(prop.type()); + Luau::TypeId type; + + if (FFlag::DebugLuauDeferredConstraintResolution) + { + if (auto ty = prop.readTy) + type = follow(*ty); + else + continue; + } + else + type = follow(prop.type()); + TypeCorrectKind typeCorrect = indexType == PropIndexType::Key ? TypeCorrectKind::Correct : checkTypeCorrectKind(module, typeArena, builtinTypes, nodes.back(), {{}, {}}, type); + ParenthesesRecommendation parens = indexType == PropIndexType::Key ? ParenthesesRecommendation::None : getParenRecommendation(type, nodes, typeCorrect); @@ -278,12 +321,15 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, NotNul prop.documentationSymbol, {}, parens, + {}, + indexType == PropIndexType::Colon }; } } }; - auto fillMetatableProps = [&](const TableType* mtable) { + auto fillMetatableProps = [&](const TableType* mtable) + { auto indexIt = mtable->props.find("__index"); if (indexIt != mtable->props.end()) { @@ -395,7 +441,11 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, NotNul } static void autocompleteKeywords( - const SourceModule& sourceModule, const std::vector& ancestry, Position position, AutocompleteEntryMap& result) + const SourceModule& sourceModule, + const std::vector& ancestry, + Position position, + AutocompleteEntryMap& result +) { LUAU_ASSERT(!ancestry.empty()); @@ -415,15 +465,28 @@ static void autocompleteKeywords( } } -static void autocompleteProps(const Module& module, TypeArena* typeArena, NotNull builtinTypes, TypeId ty, PropIndexType indexType, - const std::vector& nodes, AutocompleteEntryMap& result) +static void autocompleteProps( + const Module& module, + TypeArena* typeArena, + NotNull builtinTypes, + TypeId ty, + PropIndexType indexType, + const std::vector& nodes, + AutocompleteEntryMap& result +) { std::unordered_set seen; autocompleteProps(module, typeArena, builtinTypes, ty, ty, indexType, nodes, result, seen); } -AutocompleteEntryMap autocompleteProps(const Module& module, TypeArena* typeArena, NotNull builtinTypes, TypeId ty, - PropIndexType indexType, const std::vector& nodes) +AutocompleteEntryMap autocompleteProps( + const Module& module, + TypeArena* typeArena, + NotNull builtinTypes, + TypeId ty, + PropIndexType indexType, + const std::vector& nodes +) { AutocompleteEntryMap result; autocompleteProps(module, typeArena, builtinTypes, ty, indexType, nodes, result); @@ -448,9 +511,18 @@ AutocompleteEntryMap autocompleteModuleTypes(const Module& module, Position posi return result; } -static void autocompleteStringSingleton(TypeId ty, bool addQuotes, AutocompleteEntryMap& result) +static void autocompleteStringSingleton(TypeId ty, bool addQuotes, AstNode* node, Position position, AutocompleteEntryMap& result) { - auto formatKey = [addQuotes](const std::string& key) { + if (position == node->location.begin || position == node->location.end) + { + if (auto str = node->as(); str && str->quoteStyle == AstExprConstantString::Quoted) + return; + else if (node->is()) + return; + } + + auto formatKey = [addQuotes](const std::string& key) + { if (addQuotes) return "\"" + escape(key) + "\""; @@ -578,14 +650,13 @@ std::optional getLocalTypeInScopeAt(const Module& module, Position posit return {}; } -static std::optional tryGetTypeNameInScope(ScopePtr scope, TypeId ty) +template +static std::optional tryToStringDetailed(const ScopePtr& scope, T ty, bool functionTypeArguments) { - if (!canSuggestInferredType(scope, ty)) - return std::nullopt; - ToStringOptions opts; opts.useLineBreaks = false; opts.hideTableKind = true; + opts.functionTypeArguments = functionTypeArguments; opts.scope = scope; ToStringResult name = toStringDetailed(ty, opts); @@ -595,6 +666,14 @@ static std::optional tryGetTypeNameInScope(ScopePtr scope, TypeId ty) return name.name; } +static std::optional tryGetTypeNameInScope(ScopePtr scope, TypeId ty, bool functionTypeArguments = false) +{ + if (!canSuggestInferredType(scope, ty)) + return std::nullopt; + + return tryToStringDetailed(scope, ty, functionTypeArguments); +} + static bool tryAddTypeCorrectSuggestion(AutocompleteEntryMap& result, ScopePtr scope, AstType* topType, TypeId inferredType, Position position) { std::optional ty; @@ -676,9 +755,14 @@ static std::optional functionIsExpectedAt(const Module& module, AstNode* n if (const IntersectionType* itv = get(expectedType)) { - return std::all_of(begin(itv->parts), end(itv->parts), [](auto&& ty) { - return get(Luau::follow(ty)) != nullptr; - }); + return std::all_of( + begin(itv->parts), + end(itv->parts), + [](auto&& ty) + { + return get(Luau::follow(ty)) != nullptr; + } + ); } if (const UnionType* utv = get(expectedType)) @@ -698,15 +782,31 @@ AutocompleteEntryMap autocompleteTypeNames(const Module& module, Position positi for (const auto& [name, ty] : scope->exportedTypeBindings) { if (!result.count(name)) - result[name] = AutocompleteEntry{AutocompleteEntryKind::Type, ty.type, false, false, TypeCorrectKind::None, std::nullopt, - std::nullopt, ty.type->documentationSymbol}; + result[name] = AutocompleteEntry{ + AutocompleteEntryKind::Type, + ty.type, + false, + false, + TypeCorrectKind::None, + std::nullopt, + std::nullopt, + ty.type->documentationSymbol + }; } for (const auto& [name, ty] : scope->privateTypeBindings) { if (!result.count(name)) - result[name] = AutocompleteEntry{AutocompleteEntryKind::Type, ty.type, false, false, TypeCorrectKind::None, std::nullopt, - std::nullopt, ty.type->documentationSymbol}; + result[name] = AutocompleteEntry{ + AutocompleteEntryKind::Type, + ty.type, + false, + false, + TypeCorrectKind::None, + std::nullopt, + std::nullopt, + ty.type->documentationSymbol + }; } for (const auto& [name, _] : scope->importedTypeBindings) @@ -796,7 +896,8 @@ AutocompleteEntryMap autocompleteTypeNames(const Module& module, Position positi else if (AstExprFunction* node = parent->as()) { // For lookup inside expected function type if that's available - auto tryGetExpectedFunctionType = [](const Module& module, AstExpr* expr) -> const FunctionType* { + auto tryGetExpectedFunctionType = [](const Module& module, AstExpr* expr) -> const FunctionType* + { auto it = module.astExpectedTypes.find(expr); if (!it) @@ -1000,7 +1101,11 @@ static bool isBindingLegalAtCurrentPosition(const Symbol& symbol, const Binding& } static AutocompleteEntryMap autocompleteStatement( - const SourceModule& sourceModule, const Module& module, const std::vector& ancestry, Position position) + const SourceModule& sourceModule, + const Module& module, + const std::vector& ancestry, + Position position +) { // This is inefficient. :( ScopePtr scope = findScopeAtPosition(module, position); @@ -1022,8 +1127,18 @@ static AutocompleteEntryMap autocompleteStatement( std::string n = toString(name); if (!result.count(n)) - result[n] = {AutocompleteEntryKind::Binding, binding.typeId, binding.deprecated, false, TypeCorrectKind::None, std::nullopt, - std::nullopt, binding.documentationSymbol, {}, getParenRecommendation(binding.typeId, ancestry, TypeCorrectKind::None)}; + result[n] = { + AutocompleteEntryKind::Binding, + binding.typeId, + binding.deprecated, + false, + TypeCorrectKind::None, + std::nullopt, + std::nullopt, + binding.documentationSymbol, + {}, + getParenRecommendation(binding.typeId, ancestry, TypeCorrectKind::None) + }; } scope = scope->parent; @@ -1034,15 +1149,27 @@ static AutocompleteEntryMap autocompleteStatement( for (auto it = ancestry.rbegin(); it != ancestry.rend(); ++it) { - if (AstStatForIn* statForIn = (*it)->as(); statForIn && !statForIn->hasEnd) + if (AstStatForIn* statForIn = (*it)->as(); statForIn && !statForIn->body->hasEnd) result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword}); - if (AstStatFor* statFor = (*it)->as(); statFor && !statFor->hasEnd) + else if (AstStatFor* statFor = (*it)->as(); statFor && !statFor->body->hasEnd) result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword}); - if (AstStatIf* statIf = (*it)->as(); statIf && !statIf->hasEnd) + else if (AstStatIf* statIf = (*it)->as()) + { + bool hasEnd = statIf->thenbody->hasEnd; + if (statIf->elsebody) + { + if (AstStatBlock* elseBlock = statIf->elsebody->as()) + hasEnd = elseBlock->hasEnd; + } + + if (!hasEnd) + result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword}); + } + else if (AstStatWhile* statWhile = (*it)->as(); statWhile && !statWhile->body->hasEnd) result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword}); - if (AstStatWhile* statWhile = (*it)->as(); statWhile && !statWhile->hasEnd) + else if (AstExprFunction* exprFunction = (*it)->as(); exprFunction && !exprFunction->body->hasEnd) result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword}); - if (AstExprFunction* exprFunction = (*it)->as(); exprFunction && !exprFunction->hasEnd) + if (AstStatBlock* exprBlock = (*it)->as(); exprBlock && !exprBlock->hasEnd) result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword}); } @@ -1058,7 +1185,7 @@ static AutocompleteEntryMap autocompleteStatement( } } - if (AstStatRepeat* statRepeat = parent->as(); statRepeat && !statRepeat->hasUntil) + if (AstStatRepeat* statRepeat = parent->as(); statRepeat && !statRepeat->body->hasEnd) result.emplace("until", AutocompleteEntry{AutocompleteEntryKind::Keyword}); } @@ -1073,7 +1200,7 @@ static AutocompleteEntryMap autocompleteStatement( } } - if (AstStatRepeat* statRepeat = extractStat(ancestry); statRepeat && !statRepeat->hasUntil) + if (AstStatRepeat* statRepeat = extractStat(ancestry); statRepeat && !statRepeat->body->hasEnd) result.emplace("until", AutocompleteEntry{AutocompleteEntryKind::Keyword}); return result; @@ -1081,7 +1208,11 @@ static AutocompleteEntryMap autocompleteStatement( // Returns true iff `node` was handled by this function (completions, if any, are returned in `outResult`) static bool autocompleteIfElseExpression( - const AstNode* node, const std::vector& ancestry, const Position& position, AutocompleteEntryMap& outResult) + const AstNode* node, + const std::vector& ancestry, + const Position& position, + AutocompleteEntryMap& outResult +) { AstNode* parent = ancestry.size() >= 2 ? ancestry.rbegin()[1] : nullptr; if (!parent) @@ -1120,8 +1251,15 @@ static bool autocompleteIfElseExpression( } } -static AutocompleteContext autocompleteExpression(const SourceModule& sourceModule, const Module& module, NotNull builtinTypes, - TypeArena* typeArena, const std::vector& ancestry, Position position, AutocompleteEntryMap& result) +static AutocompleteContext autocompleteExpression( + const SourceModule& sourceModule, + const Module& module, + NotNull builtinTypes, + TypeArena* typeArena, + const std::vector& ancestry, + Position position, + AutocompleteEntryMap& result +) { LUAU_ASSERT(!ancestry.empty()); @@ -1156,8 +1294,18 @@ static AutocompleteContext autocompleteExpression(const SourceModule& sourceModu { TypeCorrectKind typeCorrect = checkTypeCorrectKind(module, typeArena, builtinTypes, node, position, binding.typeId); - result[n] = {AutocompleteEntryKind::Binding, binding.typeId, binding.deprecated, false, typeCorrect, std::nullopt, std::nullopt, - binding.documentationSymbol, {}, getParenRecommendation(binding.typeId, ancestry, typeCorrect)}; + result[n] = { + AutocompleteEntryKind::Binding, + binding.typeId, + binding.deprecated, + false, + typeCorrect, + std::nullopt, + std::nullopt, + binding.documentationSymbol, + {}, + getParenRecommendation(binding.typeId, ancestry, typeCorrect) + }; } } @@ -1178,14 +1326,20 @@ static AutocompleteContext autocompleteExpression(const SourceModule& sourceModu result["function"] = {AutocompleteEntryKind::Keyword, std::nullopt, false, false, correctForFunction}; if (auto ty = findExpectedTypeAt(module, node, position)) - autocompleteStringSingleton(*ty, true, result); + autocompleteStringSingleton(*ty, true, node, position, result); } return AutocompleteContext::Expression; } -static AutocompleteResult autocompleteExpression(const SourceModule& sourceModule, const Module& module, NotNull builtinTypes, - TypeArena* typeArena, const std::vector& ancestry, Position position) +static AutocompleteResult autocompleteExpression( + const SourceModule& sourceModule, + const Module& module, + NotNull builtinTypes, + TypeArena* typeArena, + const std::vector& ancestry, + Position position +) { AutocompleteEntryMap result; AutocompleteContext context = autocompleteExpression(sourceModule, module, builtinTypes, typeArena, ancestry, position, result); @@ -1271,8 +1425,13 @@ static std::optional getStringContents(const AstNode* node) } } -static std::optional autocompleteStringParams(const SourceModule& sourceModule, const ModulePtr& module, - const std::vector& nodes, Position position, StringCompletionCallback callback) +static std::optional autocompleteStringParams( + const SourceModule& sourceModule, + const ModulePtr& module, + const std::vector& nodes, + Position position, + StringCompletionCallback callback +) { if (nodes.size() < 2) { @@ -1284,6 +1443,14 @@ static std::optional autocompleteStringParams(const Source return std::nullopt; } + if (!nodes.back()->is()) + { + if (nodes.back()->location.end == position || nodes.back()->location.begin == position) + { + return std::nullopt; + } + } + AstExprCall* candidate = nodes.at(nodes.size() - 2)->as(); if (!candidate) { @@ -1305,7 +1472,8 @@ static std::optional autocompleteStringParams(const Source std::optional candidateString = getStringContents(nodes.back()); - auto performCallback = [&](const FunctionType* funcType) -> std::optional { + auto performCallback = [&](const FunctionType* funcType) -> std::optional + { for (const std::string& tag : funcType->tags) { if (std::optional ret = callback(tag, getMethodContainingClass(module, candidate->func), candidateString)) @@ -1348,8 +1516,152 @@ static AutocompleteResult autocompleteWhileLoopKeywords(std::vector an return {std::move(ret), std::move(ancestry), AutocompleteContext::Keyword}; } -static AutocompleteResult autocomplete(const SourceModule& sourceModule, const ModulePtr& module, NotNull builtinTypes, - TypeArena* typeArena, Scope* globalScope, Position position, StringCompletionCallback callback) +static std::string makeAnonymous(const ScopePtr& scope, const FunctionType& funcTy) +{ + std::string result = "function("; + + auto [args, tail] = Luau::flatten(funcTy.argTypes); + + bool first = true; + // Skip the implicit 'self' argument if call is indexed with ':' + for (size_t argIdx = 0; argIdx < args.size(); ++argIdx) + { + if (!first) + result += ", "; + else + first = false; + + std::string name; + if (argIdx < funcTy.argNames.size() && funcTy.argNames[argIdx]) + name = funcTy.argNames[argIdx]->name; + else + name = "a" + std::to_string(argIdx); + + if (std::optional type = tryGetTypeNameInScope(scope, args[argIdx], true)) + result += name + ": " + *type; + else + result += name; + } + + if (tail && (Luau::isVariadic(*tail) || Luau::get(Luau::follow(*tail)))) + { + if (!first) + result += ", "; + + std::optional varArgType; + if (const VariadicTypePack* pack = get(follow(*tail))) + { + if (std::optional res = tryToStringDetailed(scope, pack->ty, true)) + varArgType = std::move(res); + } + + if (varArgType) + result += "...: " + *varArgType; + else + result += "..."; + } + + result += ")"; + + auto [rets, retTail] = Luau::flatten(funcTy.retTypes); + if (const size_t totalRetSize = rets.size() + (retTail ? 1 : 0); totalRetSize > 0) + { + if (std::optional returnTypes = tryToStringDetailed(scope, funcTy.retTypes, true)) + { + result += ": "; + bool wrap = totalRetSize != 1; + if (wrap) + result += "("; + result += *returnTypes; + if (wrap) + result += ")"; + } + } + result += " end"; + return result; +} + +static std::optional makeAnonymousAutofilled( + const ModulePtr& module, + Position position, + const AstNode* node, + const std::vector& ancestry +) +{ + const AstExprCall* call = node->as(); + if (!call && ancestry.size() > 1) + call = ancestry[ancestry.size() - 2]->as(); + + if (!call) + return std::nullopt; + + if (!call->location.containsClosed(position) || call->func->location.containsClosed(position)) + return std::nullopt; + + TypeId* typeIter = module->astTypes.find(call->func); + if (!typeIter) + return std::nullopt; + + const FunctionType* outerFunction = get(follow(*typeIter)); + if (!outerFunction) + return std::nullopt; + + size_t argument = 0; + for (size_t i = 0; i < call->args.size; ++i) + { + if (call->args.data[i]->location.containsClosed(position)) + { + argument = i; + break; + } + } + + if (call->self) + argument++; + + std::optional argType; + auto [args, tail] = flatten(outerFunction->argTypes); + if (argument < args.size()) + argType = args[argument]; + + if (!argType) + return std::nullopt; + + TypeId followed = follow(*argType); + const FunctionType* type = get(followed); + if (!type) + { + if (const UnionType* unionType = get(followed)) + { + if (std::optional nonnullFunction = returnFirstNonnullOptionOfType(unionType)) + type = *nonnullFunction; + } + } + + if (!type) + return std::nullopt; + + const ScopePtr scope = findScopeAtPosition(*module, position); + if (!scope) + return std::nullopt; + + AutocompleteEntry entry; + entry.kind = AutocompleteEntryKind::GeneratedFunction; + entry.typeCorrect = TypeCorrectKind::Correct; + entry.type = argType; + entry.insertText = makeAnonymous(scope, *type); + return std::make_optional(std::move(entry)); +} + +static AutocompleteResult autocomplete( + const SourceModule& sourceModule, + const ModulePtr& module, + NotNull builtinTypes, + TypeArena* typeArena, + Scope* globalScope, + Position position, + StringCompletionCallback callback +) { if (isWithinComment(sourceModule, position)) return {}; @@ -1474,14 +1786,17 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M else if (AstStatWhile* statWhile = extractStat(ancestry); (statWhile && (!statWhile->hasDo || statWhile->doLocation.containsClosed(position)) && statWhile->condition && - !statWhile->condition->location.containsClosed(position))) + !statWhile->condition->location.containsClosed(position))) { return autocompleteWhileLoopKeywords(ancestry); } else if (AstStatIf* statIf = node->as(); statIf && !statIf->elseLocation.has_value()) { - return {{{"else", AutocompleteEntry{AutocompleteEntryKind::Keyword}}, {"elseif", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, - ancestry, AutocompleteContext::Keyword}; + return { + {{"else", AutocompleteEntry{AutocompleteEntryKind::Keyword}}, {"elseif", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, + ancestry, + AutocompleteContext::Keyword + }; } else if (AstStatIf* statIf = parent->as(); statIf && node->is()) { @@ -1517,7 +1832,7 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M auto result = autocompleteProps(*module, typeArena, builtinTypes, *it, PropIndexType::Key, ancestry); if (auto nodeIt = module->astExpectedTypes.find(node->asExpr())) - autocompleteStringSingleton(*nodeIt, !node->is(), result); + autocompleteStringSingleton(*nodeIt, !node->is(), node, position, result); if (!key) { @@ -1529,7 +1844,7 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M // suggest those too. if (auto ttv = get(follow(*it)); ttv && ttv->indexer) { - autocompleteStringSingleton(ttv->indexer->indexType, false, result); + autocompleteStringSingleton(ttv->indexer->indexType, false, node, position, result); } } @@ -1554,6 +1869,37 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M } } } + else if (AstExprTable* exprTable = node->as()) + { + AutocompleteEntryMap result; + + if (auto it = module->astExpectedTypes.find(exprTable)) + { + result = autocompleteProps(*module, typeArena, builtinTypes, *it, PropIndexType::Key, ancestry); + + // If the key type is a union of singleton strings, + // suggest those too. + if (auto ttv = get(follow(*it)); ttv && ttv->indexer) + { + autocompleteStringSingleton(ttv->indexer->indexType, false, node, position, result); + } + + // Remove keys that are already completed + for (const auto& item : exprTable->items) + { + if (!item.key) + continue; + + if (auto stringKey = item.key->as()) + result.erase(std::string(stringKey->value.data, stringKey->value.size)); + } + } + + // Also offer general expression suggestions + autocompleteExpression(sourceModule, *module, builtinTypes, typeArena, ancestry, position, result); + + return {result, ancestry, AutocompleteContext::Property}; + } else if (isIdentifier(node) && (parent->is() || parent->is())) return {autocompleteStatement(sourceModule, *module, ancestry, position), ancestry, AutocompleteContext::Statement}; @@ -1566,7 +1912,7 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M AutocompleteEntryMap result; if (auto it = module->astExpectedTypes.find(node->asExpr())) - autocompleteStringSingleton(*it, false, result); + autocompleteStringSingleton(*it, false, node, position, result); if (ancestry.size() >= 2) { @@ -1580,7 +1926,7 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M if (binExpr->op == AstExprBinary::CompareEq || binExpr->op == AstExprBinary::CompareNe) { if (auto it = module->astTypes.find(node == binExpr->left ? binExpr->right : binExpr->left)) - autocompleteStringSingleton(*it, false, result); + autocompleteStringSingleton(*it, false, node, position, result); } } } @@ -1599,7 +1945,12 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M return {}; if (node->asExpr()) - return autocompleteExpression(sourceModule, *module, builtinTypes, typeArena, ancestry, position); + { + AutocompleteResult ret = autocompleteExpression(sourceModule, *module, builtinTypes, typeArena, ancestry, position); + if (std::optional generated = makeAnonymousAutofilled(module, position, node, ancestry)) + ret.entryMap[kGeneratedAnonymousFunctionEntryName] = std::move(*generated); + return ret; + } else if (node->asStat()) return {autocompleteStatement(sourceModule, *module, ancestry, position), ancestry, AutocompleteContext::Statement}; @@ -1608,22 +1959,25 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M AutocompleteResult autocomplete(Frontend& frontend, const ModuleName& moduleName, Position position, StringCompletionCallback callback) { - // FIXME: We can improve performance here by parsing without checking. - // The old type graph is probably fine. (famous last words!) - FrontendOptions opts; - opts.forAutocomplete = true; - frontend.check(moduleName, opts); - const SourceModule* sourceModule = frontend.getSourceModule(moduleName); if (!sourceModule) return {}; - ModulePtr module = frontend.moduleResolverForAutocomplete.getModule(moduleName); + ModulePtr module; + if (FFlag::DebugLuauDeferredConstraintResolution) + module = frontend.moduleResolver.getModule(moduleName); + else + module = frontend.moduleResolverForAutocomplete.getModule(moduleName); + if (!module) return {}; NotNull builtinTypes = frontend.builtinTypes; - Scope* globalScope = frontend.globalsForAutocomplete.globalScope.get(); + Scope* globalScope; + if (FFlag::DebugLuauDeferredConstraintResolution) + globalScope = frontend.globals.globalScope.get(); + else + globalScope = frontend.globalsForAutocomplete.globalScope.get(); TypeArena typeArena; return autocomplete(*sourceModule, module, builtinTypes, &typeArena, globalScope, position, callback); diff --git a/third_party/luau/Analysis/src/BuiltinDefinitions.cpp b/third_party/luau/Analysis/src/BuiltinDefinitions.cpp index c55a88eb..082d79e8 100644 --- a/third_party/luau/Analysis/src/BuiltinDefinitions.cpp +++ b/third_party/luau/Analysis/src/BuiltinDefinitions.cpp @@ -7,8 +7,10 @@ #include "Luau/Common.h" #include "Luau/ToString.h" #include "Luau/ConstraintSolver.h" -#include "Luau/ConstraintGraphBuilder.h" +#include "Luau/ConstraintGenerator.h" +#include "Luau/NotNull.h" #include "Luau/TypeInfer.h" +#include "Luau/TypeFunction.h" #include "Luau/TypePack.h" #include "Luau/Type.h" #include "Luau/TypeUtils.h" @@ -21,19 +23,41 @@ * about a function that takes any number of values, but where each value must have some specific type. */ +LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); + namespace Luau { static std::optional> magicFunctionSelect( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate); + TypeChecker& typechecker, + const ScopePtr& scope, + const AstExprCall& expr, + WithPredicate withPredicate +); static std::optional> magicFunctionSetMetaTable( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate); + TypeChecker& typechecker, + const ScopePtr& scope, + const AstExprCall& expr, + WithPredicate withPredicate +); static std::optional> magicFunctionAssert( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate); + TypeChecker& typechecker, + const ScopePtr& scope, + const AstExprCall& expr, + WithPredicate withPredicate +); static std::optional> magicFunctionPack( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate); + TypeChecker& typechecker, + const ScopePtr& scope, + const AstExprCall& expr, + WithPredicate withPredicate +); static std::optional> magicFunctionRequire( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate); + TypeChecker& typechecker, + const ScopePtr& scope, + const AstExprCall& expr, + WithPredicate withPredicate +); static bool dcrMagicFunctionSelect(MagicFunctionCallContext context); @@ -57,26 +81,51 @@ TypeId makeOption(NotNull builtinTypes, TypeArena& arena, TypeId t } TypeId makeFunction( - TypeArena& arena, std::optional selfType, std::initializer_list paramTypes, std::initializer_list retTypes) + TypeArena& arena, + std::optional selfType, + std::initializer_list paramTypes, + std::initializer_list retTypes, + bool checked +) { - return makeFunction(arena, selfType, {}, {}, paramTypes, {}, retTypes); + return makeFunction(arena, selfType, {}, {}, paramTypes, {}, retTypes, checked); } -TypeId makeFunction(TypeArena& arena, std::optional selfType, std::initializer_list generics, - std::initializer_list genericPacks, std::initializer_list paramTypes, std::initializer_list retTypes) +TypeId makeFunction( + TypeArena& arena, + std::optional selfType, + std::initializer_list generics, + std::initializer_list genericPacks, + std::initializer_list paramTypes, + std::initializer_list retTypes, + bool checked +) { - return makeFunction(arena, selfType, generics, genericPacks, paramTypes, {}, retTypes); + return makeFunction(arena, selfType, generics, genericPacks, paramTypes, {}, retTypes, checked); } -TypeId makeFunction(TypeArena& arena, std::optional selfType, std::initializer_list paramTypes, - std::initializer_list paramNames, std::initializer_list retTypes) +TypeId makeFunction( + TypeArena& arena, + std::optional selfType, + std::initializer_list paramTypes, + std::initializer_list paramNames, + std::initializer_list retTypes, + bool checked +) { - return makeFunction(arena, selfType, {}, {}, paramTypes, paramNames, retTypes); + return makeFunction(arena, selfType, {}, {}, paramTypes, paramNames, retTypes, checked); } -TypeId makeFunction(TypeArena& arena, std::optional selfType, std::initializer_list generics, - std::initializer_list genericPacks, std::initializer_list paramTypes, std::initializer_list paramNames, - std::initializer_list retTypes) +TypeId makeFunction( + TypeArena& arena, + std::optional selfType, + std::initializer_list generics, + std::initializer_list genericPacks, + std::initializer_list paramTypes, + std::initializer_list paramNames, + std::initializer_list retTypes, + bool checked +) { std::vector params; if (selfType) @@ -103,6 +152,8 @@ TypeId makeFunction(TypeArena& arena, std::optional selfType, std::initi ftv.argNames.push_back(std::nullopt); } + ftv.isCheckedFunction = checked; + return arena.addType(std::move(ftv)); } @@ -201,18 +252,6 @@ void assignPropDocumentationSymbols(TableType::Props& props, const std::string& } } -void registerBuiltinTypes(GlobalTypes& globals) -{ - globals.globalScope->addBuiltinTypeBinding("any", TypeFun{{}, globals.builtinTypes->anyType}); - globals.globalScope->addBuiltinTypeBinding("nil", TypeFun{{}, globals.builtinTypes->nilType}); - globals.globalScope->addBuiltinTypeBinding("number", TypeFun{{}, globals.builtinTypes->numberType}); - globals.globalScope->addBuiltinTypeBinding("string", TypeFun{{}, globals.builtinTypes->stringType}); - globals.globalScope->addBuiltinTypeBinding("boolean", TypeFun{{}, globals.builtinTypes->booleanType}); - globals.globalScope->addBuiltinTypeBinding("thread", TypeFun{{}, globals.builtinTypes->threadType}); - globals.globalScope->addBuiltinTypeBinding("unknown", TypeFun{{}, globals.builtinTypes->unknownType}); - globals.globalScope->addBuiltinTypeBinding("never", TypeFun{{}, globals.builtinTypes->neverType}); -} - void registerBuiltinGlobals(Frontend& frontend, GlobalTypes& globals, bool typeCheckForAutocomplete) { LUAU_ASSERT(!globals.globalTypes.types.isFrozen()); @@ -221,8 +260,12 @@ void registerBuiltinGlobals(Frontend& frontend, GlobalTypes& globals, bool typeC TypeArena& arena = globals.globalTypes; NotNull builtinTypes = globals.builtinTypes; + if (FFlag::DebugLuauDeferredConstraintResolution) + builtinTypeFunctions().addToScope(NotNull{&arena}, NotNull{globals.globalScope.get()}); + LoadDefinitionFileResult loadResult = frontend.loadDefinitionFile( - globals, globals.globalScope, getBuiltinDefinitionSource(), "@luau", /* captureComments */ false, typeCheckForAutocomplete); + globals, globals.globalScope, getBuiltinDefinitionSource(), "@luau", /* captureComments */ false, typeCheckForAutocomplete + ); LUAU_ASSERT(loadResult.success); TypeId genericK = arena.addType(GenericType{"K"}); @@ -259,21 +302,44 @@ void registerBuiltinGlobals(Frontend& frontend, GlobalTypes& globals, bool typeC TypeId tableMetaMT = arena.addType(MetatableType{tabTy, genericMT}); + // getmetatable : ({ @metatable MT, {+ +} }) -> MT addGlobalBinding(globals, "getmetatable", makeFunction(arena, std::nullopt, {genericMT}, {}, {tableMetaMT}, {genericMT}), "@luau"); - // clang-format off - // setmetatable(T, MT) -> { @metatable MT, T } - addGlobalBinding(globals, "setmetatable", - arena.addType( - FunctionType{ - {genericMT}, - {}, - arena.addTypePack(TypePack{{tabTy, genericMT}}), - arena.addTypePack(TypePack{{tableMetaMT}}) - } - ), "@luau" - ); - // clang-format on + if (FFlag::DebugLuauDeferredConstraintResolution) + { + TypeId genericT = arena.addType(GenericType{"T"}); + TypeId tMetaMT = arena.addType(MetatableType{genericT, genericMT}); + + // clang-format off + // setmetatable(T, MT) -> { @metatable MT, T } + addGlobalBinding(globals, "setmetatable", + arena.addType( + FunctionType{ + {genericT, genericMT}, + {}, + arena.addTypePack(TypePack{{genericT, genericMT}}), + arena.addTypePack(TypePack{{tMetaMT}}) + } + ), "@luau" + ); + // clang-format on + } + else + { + // clang-format off + // setmetatable(T, MT) -> { @metatable MT, T } + addGlobalBinding(globals, "setmetatable", + arena.addType( + FunctionType{ + {genericMT}, + {}, + arena.addTypePack(TypePack{{tabTy, genericMT}}), + arena.addTypePack(TypePack{{tableMetaMT}}) + } + ), "@luau" + ); + // clang-format on + } for (const auto& pair : globals.globalScope->bindings) { @@ -287,6 +353,21 @@ void registerBuiltinGlobals(Frontend& frontend, GlobalTypes& globals, bool typeC } attachMagicFunction(getGlobalBinding(globals, "assert"), magicFunctionAssert); + + if (FFlag::DebugLuauDeferredConstraintResolution) + { + // declare function assert(value: T, errorMessage: string?): intersect + TypeId genericT = arena.addType(GenericType{"T"}); + TypeId refinedTy = arena.addType(TypeFunctionInstanceType{ + NotNull{&builtinTypeFunctions().intersectFunc}, {genericT, arena.addType(NegationType{builtinTypes->falsyType})}, {} + }); + + TypeId assertTy = arena.addType(FunctionType{ + {genericT}, {}, arena.addTypePack(TypePack{{genericT, builtinTypes->optionalStringType}}), arena.addTypePack(TypePack{{refinedTy}}) + }); + addGlobalBinding(globals, "assert", assertTy, "@luau"); + } + attachMagicFunction(getGlobalBinding(globals, "setmetatable"), magicFunctionSetMetaTable); attachMagicFunction(getGlobalBinding(globals, "select"), magicFunctionSelect); attachDcrMagicFunction(getGlobalBinding(globals, "select"), dcrMagicFunctionSelect); @@ -310,8 +391,577 @@ void registerBuiltinGlobals(Frontend& frontend, GlobalTypes& globals, bool typeC attachDcrMagicFunction(getGlobalBinding(globals, "require"), dcrMagicFunctionRequire); } +static std::vector parseFormatString(NotNull builtinTypes, const char* data, size_t size) +{ + const char* options = "cdiouxXeEfgGqs*"; + + std::vector result; + + for (size_t i = 0; i < size; ++i) + { + if (data[i] == '%') + { + i++; + + if (i < size && data[i] == '%') + continue; + + // we just ignore all characters (including flags/precision) up until first alphabetic character + while (i < size && !(data[i] > 0 && (isalpha(data[i]) || data[i] == '*'))) + i++; + + if (i == size) + break; + + if (data[i] == 'q' || data[i] == 's') + result.push_back(builtinTypes->stringType); + else if (data[i] == '*') + result.push_back(builtinTypes->unknownType); + else if (strchr(options, data[i])) + result.push_back(builtinTypes->numberType); + else + result.push_back(builtinTypes->errorRecoveryType(builtinTypes->anyType)); + } + } + + return result; +} + +std::optional> magicFunctionFormat( + TypeChecker& typechecker, + const ScopePtr& scope, + const AstExprCall& expr, + WithPredicate withPredicate +) +{ + auto [paramPack, _predicates] = withPredicate; + + TypeArena& arena = typechecker.currentModule->internalTypes; + + AstExprConstantString* fmt = nullptr; + if (auto index = expr.func->as(); index && expr.self) + { + if (auto group = index->expr->as()) + fmt = group->expr->as(); + else + fmt = index->expr->as(); + } + + if (!expr.self && expr.args.size > 0) + fmt = expr.args.data[0]->as(); + + if (!fmt) + return std::nullopt; + + std::vector expected = parseFormatString(typechecker.builtinTypes, fmt->value.data, fmt->value.size); + const auto& [params, tail] = flatten(paramPack); + + size_t paramOffset = 1; + size_t dataOffset = expr.self ? 0 : 1; + + // unify the prefix one argument at a time + for (size_t i = 0; i < expected.size() && i + paramOffset < params.size(); ++i) + { + Location location = expr.args.data[std::min(i + dataOffset, expr.args.size - 1)]->location; + + typechecker.unify(params[i + paramOffset], expected[i], scope, location); + } + + // if we know the argument count or if we have too many arguments for sure, we can issue an error + size_t numActualParams = params.size(); + size_t numExpectedParams = expected.size() + 1; // + 1 for the format string + + if (numExpectedParams != numActualParams && (!tail || numExpectedParams < numActualParams)) + typechecker.reportError(TypeError{expr.location, CountMismatch{numExpectedParams, std::nullopt, numActualParams}}); + + return WithPredicate{arena.addTypePack({typechecker.stringType})}; +} + +static bool dcrMagicFunctionFormat(MagicFunctionCallContext context) +{ + TypeArena* arena = context.solver->arena; + + AstExprConstantString* fmt = nullptr; + if (auto index = context.callSite->func->as(); index && context.callSite->self) + { + if (auto group = index->expr->as()) + fmt = group->expr->as(); + else + fmt = index->expr->as(); + } + + if (!context.callSite->self && context.callSite->args.size > 0) + fmt = context.callSite->args.data[0]->as(); + + if (!fmt) + return false; + + std::vector expected = parseFormatString(context.solver->builtinTypes, fmt->value.data, fmt->value.size); + const auto& [params, tail] = flatten(context.arguments); + + size_t paramOffset = 1; + + // unify the prefix one argument at a time + for (size_t i = 0; i < expected.size() && i + paramOffset < params.size(); ++i) + { + context.solver->unify(context.constraint, params[i + paramOffset], expected[i]); + } + + // if we know the argument count or if we have too many arguments for sure, we can issue an error + size_t numActualParams = params.size(); + size_t numExpectedParams = expected.size() + 1; // + 1 for the format string + + if (numExpectedParams != numActualParams && (!tail || numExpectedParams < numActualParams)) + context.solver->reportError(TypeError{context.callSite->location, CountMismatch{numExpectedParams, std::nullopt, numActualParams}}); + + TypePackId resultPack = arena->addTypePack({context.solver->builtinTypes->stringType}); + asMutable(context.result)->ty.emplace(resultPack); + + return true; +} + +static std::vector parsePatternString(NotNull builtinTypes, const char* data, size_t size) +{ + std::vector result; + int depth = 0; + bool parsingSet = false; + + for (size_t i = 0; i < size; ++i) + { + if (data[i] == '%') + { + ++i; + if (!parsingSet && i < size && data[i] == 'b') + i += 2; + } + else if (!parsingSet && data[i] == '[') + { + parsingSet = true; + if (i + 1 < size && data[i + 1] == ']') + i += 1; + } + else if (parsingSet && data[i] == ']') + { + parsingSet = false; + } + else if (data[i] == '(') + { + if (parsingSet) + continue; + + if (i + 1 < size && data[i + 1] == ')') + { + i++; + result.push_back(builtinTypes->optionalNumberType); + continue; + } + + ++depth; + result.push_back(builtinTypes->optionalStringType); + } + else if (data[i] == ')') + { + if (parsingSet) + continue; + + --depth; + + if (depth < 0) + break; + } + } + + if (depth != 0 || parsingSet) + return std::vector(); + + if (result.empty()) + result.push_back(builtinTypes->optionalStringType); + + return result; +} + +static std::optional> magicFunctionGmatch( + TypeChecker& typechecker, + const ScopePtr& scope, + const AstExprCall& expr, + WithPredicate withPredicate +) +{ + auto [paramPack, _predicates] = withPredicate; + const auto& [params, tail] = flatten(paramPack); + + if (params.size() != 2) + return std::nullopt; + + TypeArena& arena = typechecker.currentModule->internalTypes; + + AstExprConstantString* pattern = nullptr; + size_t index = expr.self ? 0 : 1; + if (expr.args.size > index) + pattern = expr.args.data[index]->as(); + + if (!pattern) + return std::nullopt; + + std::vector returnTypes = parsePatternString(typechecker.builtinTypes, pattern->value.data, pattern->value.size); + + if (returnTypes.empty()) + return std::nullopt; + + typechecker.unify(params[0], typechecker.stringType, scope, expr.args.data[0]->location); + + const TypePackId emptyPack = arena.addTypePack({}); + const TypePackId returnList = arena.addTypePack(returnTypes); + const TypeId iteratorType = arena.addType(FunctionType{emptyPack, returnList}); + return WithPredicate{arena.addTypePack({iteratorType})}; +} + +static bool dcrMagicFunctionGmatch(MagicFunctionCallContext context) +{ + const auto& [params, tail] = flatten(context.arguments); + + if (params.size() != 2) + return false; + + TypeArena* arena = context.solver->arena; + + AstExprConstantString* pattern = nullptr; + size_t index = context.callSite->self ? 0 : 1; + if (context.callSite->args.size > index) + pattern = context.callSite->args.data[index]->as(); + + if (!pattern) + return false; + + std::vector returnTypes = parsePatternString(context.solver->builtinTypes, pattern->value.data, pattern->value.size); + + if (returnTypes.empty()) + return false; + + context.solver->unify(context.constraint, params[0], context.solver->builtinTypes->stringType); + + const TypePackId emptyPack = arena->addTypePack({}); + const TypePackId returnList = arena->addTypePack(returnTypes); + const TypeId iteratorType = arena->addType(FunctionType{emptyPack, returnList}); + const TypePackId resTypePack = arena->addTypePack({iteratorType}); + asMutable(context.result)->ty.emplace(resTypePack); + + return true; +} + +static std::optional> magicFunctionMatch( + TypeChecker& typechecker, + const ScopePtr& scope, + const AstExprCall& expr, + WithPredicate withPredicate +) +{ + auto [paramPack, _predicates] = withPredicate; + const auto& [params, tail] = flatten(paramPack); + + if (params.size() < 2 || params.size() > 3) + return std::nullopt; + + TypeArena& arena = typechecker.currentModule->internalTypes; + + AstExprConstantString* pattern = nullptr; + size_t patternIndex = expr.self ? 0 : 1; + if (expr.args.size > patternIndex) + pattern = expr.args.data[patternIndex]->as(); + + if (!pattern) + return std::nullopt; + + std::vector returnTypes = parsePatternString(typechecker.builtinTypes, pattern->value.data, pattern->value.size); + + if (returnTypes.empty()) + return std::nullopt; + + typechecker.unify(params[0], typechecker.stringType, scope, expr.args.data[0]->location); + + const TypeId optionalNumber = arena.addType(UnionType{{typechecker.nilType, typechecker.numberType}}); + + size_t initIndex = expr.self ? 1 : 2; + if (params.size() == 3 && expr.args.size > initIndex) + typechecker.unify(params[2], optionalNumber, scope, expr.args.data[initIndex]->location); + + const TypePackId returnList = arena.addTypePack(returnTypes); + return WithPredicate{returnList}; +} + +static bool dcrMagicFunctionMatch(MagicFunctionCallContext context) +{ + const auto& [params, tail] = flatten(context.arguments); + + if (params.size() < 2 || params.size() > 3) + return false; + + TypeArena* arena = context.solver->arena; + + AstExprConstantString* pattern = nullptr; + size_t patternIndex = context.callSite->self ? 0 : 1; + if (context.callSite->args.size > patternIndex) + pattern = context.callSite->args.data[patternIndex]->as(); + + if (!pattern) + return false; + + std::vector returnTypes = parsePatternString(context.solver->builtinTypes, pattern->value.data, pattern->value.size); + + if (returnTypes.empty()) + return false; + + context.solver->unify(context.constraint, params[0], context.solver->builtinTypes->stringType); + + const TypeId optionalNumber = arena->addType(UnionType{{context.solver->builtinTypes->nilType, context.solver->builtinTypes->numberType}}); + + size_t initIndex = context.callSite->self ? 1 : 2; + if (params.size() == 3 && context.callSite->args.size > initIndex) + context.solver->unify(context.constraint, params[2], optionalNumber); + + const TypePackId returnList = arena->addTypePack(returnTypes); + asMutable(context.result)->ty.emplace(returnList); + + return true; +} + +static std::optional> magicFunctionFind( + TypeChecker& typechecker, + const ScopePtr& scope, + const AstExprCall& expr, + WithPredicate withPredicate +) +{ + auto [paramPack, _predicates] = withPredicate; + const auto& [params, tail] = flatten(paramPack); + + if (params.size() < 2 || params.size() > 4) + return std::nullopt; + + TypeArena& arena = typechecker.currentModule->internalTypes; + + AstExprConstantString* pattern = nullptr; + size_t patternIndex = expr.self ? 0 : 1; + if (expr.args.size > patternIndex) + pattern = expr.args.data[patternIndex]->as(); + + if (!pattern) + return std::nullopt; + + bool plain = false; + size_t plainIndex = expr.self ? 2 : 3; + if (expr.args.size > plainIndex) + { + AstExprConstantBool* p = expr.args.data[plainIndex]->as(); + plain = p && p->value; + } + + std::vector returnTypes; + if (!plain) + { + returnTypes = parsePatternString(typechecker.builtinTypes, pattern->value.data, pattern->value.size); + + if (returnTypes.empty()) + return std::nullopt; + } + + typechecker.unify(params[0], typechecker.stringType, scope, expr.args.data[0]->location); + + const TypeId optionalNumber = arena.addType(UnionType{{typechecker.nilType, typechecker.numberType}}); + const TypeId optionalBoolean = arena.addType(UnionType{{typechecker.nilType, typechecker.booleanType}}); + + size_t initIndex = expr.self ? 1 : 2; + if (params.size() >= 3 && expr.args.size > initIndex) + typechecker.unify(params[2], optionalNumber, scope, expr.args.data[initIndex]->location); + + if (params.size() == 4 && expr.args.size > plainIndex) + typechecker.unify(params[3], optionalBoolean, scope, expr.args.data[plainIndex]->location); + + returnTypes.insert(returnTypes.begin(), {optionalNumber, optionalNumber}); + + const TypePackId returnList = arena.addTypePack(returnTypes); + return WithPredicate{returnList}; +} + +static bool dcrMagicFunctionFind(MagicFunctionCallContext context) +{ + const auto& [params, tail] = flatten(context.arguments); + + if (params.size() < 2 || params.size() > 4) + return false; + + TypeArena* arena = context.solver->arena; + NotNull builtinTypes = context.solver->builtinTypes; + + AstExprConstantString* pattern = nullptr; + size_t patternIndex = context.callSite->self ? 0 : 1; + if (context.callSite->args.size > patternIndex) + pattern = context.callSite->args.data[patternIndex]->as(); + + if (!pattern) + return false; + + bool plain = false; + size_t plainIndex = context.callSite->self ? 2 : 3; + if (context.callSite->args.size > plainIndex) + { + AstExprConstantBool* p = context.callSite->args.data[plainIndex]->as(); + plain = p && p->value; + } + + std::vector returnTypes; + if (!plain) + { + returnTypes = parsePatternString(builtinTypes, pattern->value.data, pattern->value.size); + + if (returnTypes.empty()) + return false; + } + + context.solver->unify(context.constraint, params[0], builtinTypes->stringType); + + const TypeId optionalNumber = arena->addType(UnionType{{builtinTypes->nilType, builtinTypes->numberType}}); + const TypeId optionalBoolean = arena->addType(UnionType{{builtinTypes->nilType, builtinTypes->booleanType}}); + + size_t initIndex = context.callSite->self ? 1 : 2; + if (params.size() >= 3 && context.callSite->args.size > initIndex) + context.solver->unify(context.constraint, params[2], optionalNumber); + + if (params.size() == 4 && context.callSite->args.size > plainIndex) + context.solver->unify(context.constraint, params[3], optionalBoolean); + + returnTypes.insert(returnTypes.begin(), {optionalNumber, optionalNumber}); + + const TypePackId returnList = arena->addTypePack(returnTypes); + asMutable(context.result)->ty.emplace(returnList); + return true; +} + +TypeId makeStringMetatable(NotNull builtinTypes) +{ + NotNull arena{builtinTypes->arena.get()}; + + const TypeId nilType = builtinTypes->nilType; + const TypeId numberType = builtinTypes->numberType; + const TypeId booleanType = builtinTypes->booleanType; + const TypeId stringType = builtinTypes->stringType; + + const TypeId optionalNumber = arena->addType(UnionType{{nilType, numberType}}); + const TypeId optionalString = arena->addType(UnionType{{nilType, stringType}}); + const TypeId optionalBoolean = arena->addType(UnionType{{nilType, booleanType}}); + + const TypePackId oneStringPack = arena->addTypePack({stringType}); + const TypePackId anyTypePack = builtinTypes->anyTypePack; + + const TypePackId variadicTailPack = FFlag::DebugLuauDeferredConstraintResolution ? builtinTypes->unknownTypePack : anyTypePack; + const TypePackId emptyPack = arena->addTypePack({}); + const TypePackId stringVariadicList = arena->addTypePack(TypePackVar{VariadicTypePack{stringType}}); + const TypePackId numberVariadicList = arena->addTypePack(TypePackVar{VariadicTypePack{numberType}}); + + + FunctionType formatFTV{arena->addTypePack(TypePack{{stringType}, variadicTailPack}), oneStringPack}; + formatFTV.magicFunction = &magicFunctionFormat; + formatFTV.isCheckedFunction = true; + const TypeId formatFn = arena->addType(formatFTV); + attachDcrMagicFunction(formatFn, dcrMagicFunctionFormat); + + + const TypeId stringToStringType = makeFunction(*arena, std::nullopt, {}, {}, {stringType}, {}, {stringType}, /* checked */ true); + + const TypeId replArgType = arena->addType(UnionType{ + {stringType, + arena->addType(TableType({}, TableIndexer(stringType, stringType), TypeLevel{}, TableState::Generic)), + makeFunction(*arena, std::nullopt, {}, {}, {stringType}, {}, {stringType}, /* checked */ false)} + }); + const TypeId gsubFunc = + makeFunction(*arena, stringType, {}, {}, {stringType, replArgType, optionalNumber}, {}, {stringType, numberType}, /* checked */ false); + const TypeId gmatchFunc = + makeFunction(*arena, stringType, {}, {}, {stringType}, {}, {arena->addType(FunctionType{emptyPack, stringVariadicList})}, /* checked */ true); + attachMagicFunction(gmatchFunc, magicFunctionGmatch); + attachDcrMagicFunction(gmatchFunc, dcrMagicFunctionGmatch); + + FunctionType matchFuncTy{ + arena->addTypePack({stringType, stringType, optionalNumber}), arena->addTypePack(TypePackVar{VariadicTypePack{stringType}}) + }; + matchFuncTy.isCheckedFunction = true; + const TypeId matchFunc = arena->addType(matchFuncTy); + attachMagicFunction(matchFunc, magicFunctionMatch); + attachDcrMagicFunction(matchFunc, dcrMagicFunctionMatch); + + FunctionType findFuncTy{ + arena->addTypePack({stringType, stringType, optionalNumber, optionalBoolean}), + arena->addTypePack(TypePack{{optionalNumber, optionalNumber}, stringVariadicList}) + }; + findFuncTy.isCheckedFunction = true; + const TypeId findFunc = arena->addType(findFuncTy); + attachMagicFunction(findFunc, magicFunctionFind); + attachDcrMagicFunction(findFunc, dcrMagicFunctionFind); + + // string.byte : string -> number? -> number? -> ...number + FunctionType stringDotByte{arena->addTypePack({stringType, optionalNumber, optionalNumber}), numberVariadicList}; + stringDotByte.isCheckedFunction = true; + + // string.char : .... number -> string + FunctionType stringDotChar{numberVariadicList, arena->addTypePack({stringType})}; + stringDotChar.isCheckedFunction = true; + + // string.unpack : string -> string -> number? -> ...any + FunctionType stringDotUnpack{ + arena->addTypePack(TypePack{{stringType, stringType, optionalNumber}}), + variadicTailPack, + }; + stringDotUnpack.isCheckedFunction = true; + + TableType::Props stringLib = { + {"byte", {arena->addType(stringDotByte)}}, + {"char", {arena->addType(stringDotChar)}}, + {"find", {findFunc}}, + {"format", {formatFn}}, // FIXME + {"gmatch", {gmatchFunc}}, + {"gsub", {gsubFunc}}, + {"len", {makeFunction(*arena, stringType, {}, {}, {}, {}, {numberType}, /* checked */ true)}}, + {"lower", {stringToStringType}}, + {"match", {matchFunc}}, + {"rep", {makeFunction(*arena, stringType, {}, {}, {numberType}, {}, {stringType}, /* checked */ true)}}, + {"reverse", {stringToStringType}}, + {"sub", {makeFunction(*arena, stringType, {}, {}, {numberType, optionalNumber}, {}, {stringType}, /* checked */ true)}}, + {"upper", {stringToStringType}}, + {"split", + {makeFunction( + *arena, + stringType, + {}, + {}, + {optionalString}, + {}, + {arena->addType(TableType{{}, TableIndexer{numberType, stringType}, TypeLevel{}, TableState::Sealed})}, + /* checked */ true + )}}, + {"pack", + {arena->addType(FunctionType{ + arena->addTypePack(TypePack{{stringType}, variadicTailPack}), + oneStringPack, + })}}, + {"packsize", {makeFunction(*arena, stringType, {}, {}, {}, {}, {numberType}, /* checked */ true)}}, + {"unpack", {arena->addType(stringDotUnpack)}}, + }; + + assignPropDocumentationSymbols(stringLib, "@luau/global/string"); + + TypeId tableType = arena->addType(TableType{std::move(stringLib), std::nullopt, TypeLevel{}, TableState::Sealed}); + + if (TableType* ttv = getMutable(tableType)) + ttv->name = "typeof(string)"; + + return arena->addType(TableType{{{{"__index", {tableType}}}}, std::nullopt, TypeLevel{}, TableState::Sealed}); +} + static std::optional> magicFunctionSelect( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate) + TypeChecker& typechecker, + const ScopePtr& scope, + const AstExprCall& expr, + WithPredicate withPredicate +) { auto [paramPack, _predicates] = withPredicate; @@ -397,7 +1047,11 @@ static bool dcrMagicFunctionSelect(MagicFunctionCallContext context) } static std::optional> magicFunctionSetMetaTable( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate) + TypeChecker& typechecker, + const ScopePtr& scope, + const AstExprCall& expr, + WithPredicate withPredicate +) { auto [paramPack, _predicates] = withPredicate; @@ -454,6 +1108,18 @@ static std::optional> magicFunctionSetMetaTable( else if (get(target) || get(target) || isTableIntersection(target)) { } + else if (isTableUnion(target)) + { + const UnionType* ut = get(target); + LUAU_ASSERT(ut); + + std::vector resultParts; + + for (TypeId ty : ut) + resultParts.push_back(arena.addType(MetatableType{ty, mt})); + + return WithPredicate{arena.addTypePack({arena.addType(UnionType{std::move(resultParts)})})}; + } else { typechecker.reportError(TypeError{expr.location, GenericError{"setmetatable should take a table"}}); @@ -463,7 +1129,11 @@ static std::optional> magicFunctionSetMetaTable( } static std::optional> magicFunctionAssert( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate) + TypeChecker& typechecker, + const ScopePtr& scope, + const AstExprCall& expr, + WithPredicate withPredicate +) { auto [paramPack, predicates] = withPredicate; @@ -493,7 +1163,11 @@ static std::optional> magicFunctionAssert( } static std::optional> magicFunctionPack( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate) + TypeChecker& typechecker, + const ScopePtr& scope, + const AstExprCall& expr, + WithPredicate withPredicate +) { auto [paramPack, _predicates] = withPredicate; @@ -594,7 +1268,11 @@ static bool checkRequirePath(TypeChecker& typechecker, AstExpr* expr) } static std::optional> magicFunctionRequire( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate) + TypeChecker& typechecker, + const ScopePtr& scope, + const AstExprCall& expr, + WithPredicate withPredicate +) { TypeArena& arena = typechecker.currentModule->internalTypes; diff --git a/third_party/luau/Analysis/src/Clone.cpp b/third_party/luau/Analysis/src/Clone.cpp index 450b84af..7446846a 100644 --- a/third_party/luau/Analysis/src/Clone.cpp +++ b/third_party/luau/Analysis/src/Clone.cpp @@ -1,16 +1,15 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Clone.h" -#include "Luau/RecursionCounter.h" -#include "Luau/TxnLog.h" +#include "Luau/NotNull.h" +#include "Luau/Type.h" #include "Luau/TypePack.h" #include "Luau/Unifiable.h" -LUAU_FASTFLAG(DebugLuauCopyBeforeNormalizing) -LUAU_FASTFLAG(LuauClonePublicInterfaceLess2) -LUAU_FASTFLAG(DebugLuauReadWriteProperties) +LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) -LUAU_FASTINTVARIABLE(LuauTypeCloneRecursionLimit, 300) +// For each `Luau::clone` call, we will clone only up to N amount of types _and_ packs, as controlled by this limit. +LUAU_FASTINTVARIABLE(LuauTypeCloneIterationLimit, 100'000) namespace Luau { @@ -18,395 +17,444 @@ namespace Luau namespace { -Property clone(const Property& prop, TypeArena& dest, CloneState& cloneState) +using Kind = Variant; + +template +const T* get(const Kind& kind) { - if (FFlag::DebugLuauReadWriteProperties) - { - std::optional cloneReadTy; - if (auto ty = prop.readType()) - cloneReadTy = clone(*ty, dest, cloneState); - - std::optional cloneWriteTy; - if (auto ty = prop.writeType()) - cloneWriteTy = clone(*ty, dest, cloneState); - - std::optional cloned = Property::create(cloneReadTy, cloneWriteTy); - LUAU_ASSERT(cloned); - cloned->deprecated = prop.deprecated; - cloned->deprecatedSuggestion = prop.deprecatedSuggestion; - cloned->location = prop.location; - cloned->tags = prop.tags; - cloned->documentationSymbol = prop.documentationSymbol; - return *cloned; - } - else - { - return Property{ - clone(prop.type(), dest, cloneState), - prop.deprecated, - prop.deprecatedSuggestion, - prop.location, - prop.tags, - prop.documentationSymbol, - }; - } + return get_if(&kind); } -struct TypePackCloner; +class TypeCloner +{ + NotNull arena; + NotNull builtinTypes; -/* - * Both TypeCloner and TypePackCloner work by depositing the requested type variable into the appropriate 'seen' set. - * They do not return anything because their sole consumer (the deepClone function) already has a pointer into this storage. - */ + // A queue of kinds where we cloned it, but whose interior types hasn't + // been updated to point to new clones. Once all of its interior types + // has been updated, it gets removed from the queue. + std::vector queue; -struct TypeCloner -{ - TypeCloner(TypeArena& dest, TypeId typeId, CloneState& cloneState) - : dest(dest) - , typeId(typeId) - , seenTypes(cloneState.seenTypes) - , seenTypePacks(cloneState.seenTypePacks) - , cloneState(cloneState) - { - } - - TypeArena& dest; - TypeId typeId; - SeenTypes& seenTypes; - SeenTypePacks& seenTypePacks; - CloneState& cloneState; - - template - void defaultClone(const T& t); - - void operator()(const FreeType& t); - void operator()(const GenericType& t); - void operator()(const BoundType& t); - void operator()(const ErrorType& t); - void operator()(const BlockedType& t); - void operator()(const PendingExpansionType& t); - void operator()(const PrimitiveType& t); - void operator()(const SingletonType& t); - void operator()(const FunctionType& t); - void operator()(const TableType& t); - void operator()(const MetatableType& t); - void operator()(const ClassType& t); - void operator()(const AnyType& t); - void operator()(const UnionType& t); - void operator()(const IntersectionType& t); - void operator()(const LazyType& t); - void operator()(const UnknownType& t); - void operator()(const NeverType& t); - void operator()(const NegationType& t); -}; + NotNull types; + NotNull packs; -struct TypePackCloner -{ - TypeArena& dest; - TypePackId typePackId; - SeenTypes& seenTypes; - SeenTypePacks& seenTypePacks; - CloneState& cloneState; + int steps = 0; - TypePackCloner(TypeArena& dest, TypePackId typePackId, CloneState& cloneState) - : dest(dest) - , typePackId(typePackId) - , seenTypes(cloneState.seenTypes) - , seenTypePacks(cloneState.seenTypePacks) - , cloneState(cloneState) +public: + TypeCloner(NotNull arena, NotNull builtinTypes, NotNull types, NotNull packs) + : arena(arena) + , builtinTypes(builtinTypes) + , types(types) + , packs(packs) { } - template - void defaultClone(const T& t) + TypeId clone(TypeId ty) { - TypePackId cloned = dest.addTypePack(TypePackVar{t}); - seenTypePacks[typePackId] = cloned; + shallowClone(ty); + run(); + + if (hasExceededIterationLimit()) + { + TypeId error = builtinTypes->errorRecoveryType(); + (*types)[ty] = error; + return error; + } + + return find(ty).value_or(builtinTypes->errorRecoveryType()); } - void operator()(const FreeTypePack& t) + TypePackId clone(TypePackId tp) { - defaultClone(t); + shallowClone(tp); + run(); + + if (hasExceededIterationLimit()) + { + TypePackId error = builtinTypes->errorRecoveryTypePack(); + (*packs)[tp] = error; + return error; + } + + return find(tp).value_or(builtinTypes->errorRecoveryTypePack()); } - void operator()(const GenericTypePack& t) + +private: + bool hasExceededIterationLimit() const { - defaultClone(t); + if (FInt::LuauTypeCloneIterationLimit == 0) + return false; + + return steps + queue.size() >= size_t(FInt::LuauTypeCloneIterationLimit); } - void operator()(const ErrorTypePack& t) + + void run() { - defaultClone(t); + while (!queue.empty()) + { + ++steps; + + if (hasExceededIterationLimit()) + break; + + Kind kind = queue.back(); + queue.pop_back(); + + if (find(kind)) + continue; + + cloneChildren(kind); + } } - void operator()(const BlockedTypePack& t) + std::optional find(TypeId ty) const { - defaultClone(t); + ty = follow(ty, FollowOption::DisableLazyTypeThunks); + if (auto it = types->find(ty); it != types->end()) + return it->second; + else if (ty->persistent) + return ty; + return std::nullopt; } - // While we are a-cloning, we can flatten out bound Types and make things a bit tighter. - // We just need to be sure that we rewrite pointers both to the binder and the bindee to the same pointer. - void operator()(const Unifiable::Bound& t) + std::optional find(TypePackId tp) const { - TypePackId cloned = clone(t.boundTo, dest, cloneState); - if (FFlag::DebugLuauCopyBeforeNormalizing) - cloned = dest.addTypePack(TypePackVar{BoundTypePack{cloned}}); - seenTypePacks[typePackId] = cloned; + tp = follow(tp); + if (auto it = packs->find(tp); it != packs->end()) + return it->second; + else if (tp->persistent) + return tp; + return std::nullopt; } - void operator()(const VariadicTypePack& t) + std::optional find(Kind kind) const { - TypePackId cloned = dest.addTypePack(TypePackVar{VariadicTypePack{clone(t.ty, dest, cloneState), /*hidden*/ t.hidden}}); - seenTypePacks[typePackId] = cloned; + if (auto ty = get(kind)) + return find(*ty); + else if (auto tp = get(kind)) + return find(*tp); + else + { + LUAU_ASSERT(!"Unknown kind?"); + return std::nullopt; + } } - void operator()(const TypePack& t) +private: + TypeId shallowClone(TypeId ty) { - TypePackId cloned = dest.addTypePack(TypePack{}); - TypePack* destTp = getMutable(cloned); - LUAU_ASSERT(destTp != nullptr); - seenTypePacks[typePackId] = cloned; + // We want to [`Luau::follow`] but without forcing the expansion of [`LazyType`]s. + ty = follow(ty, FollowOption::DisableLazyTypeThunks); + + if (auto clone = find(ty)) + return *clone; + else if (ty->persistent) + return ty; + + TypeId target = arena->addType(ty->ty); + asMutable(target)->documentationSymbol = ty->documentationSymbol; + + if (auto generic = getMutable(target)) + generic->scope = nullptr; + else if (auto free = getMutable(target)) + free->scope = nullptr; + else if (auto fn = getMutable(target)) + fn->scope = nullptr; + else if (auto table = getMutable(target)) + table->scope = nullptr; + + (*types)[ty] = target; + queue.push_back(target); + return target; + } - for (TypeId ty : t.head) - destTp->head.push_back(clone(ty, dest, cloneState)); + TypePackId shallowClone(TypePackId tp) + { + tp = follow(tp); - if (t.tail) - destTp->tail = clone(*t.tail, dest, cloneState); - } -}; + if (auto clone = find(tp)) + return *clone; + else if (tp->persistent) + return tp; -template -void TypeCloner::defaultClone(const T& t) -{ - TypeId cloned = dest.addType(t); - seenTypes[typeId] = cloned; -} + TypePackId target = arena->addTypePack(tp->ty); -void TypeCloner::operator()(const FreeType& t) -{ - defaultClone(t); -} + if (auto generic = getMutable(target)) + generic->scope = nullptr; + else if (auto free = getMutable(target)) + free->scope = nullptr; -void TypeCloner::operator()(const GenericType& t) -{ - defaultClone(t); -} + (*packs)[tp] = target; + queue.push_back(target); + return target; + } -void TypeCloner::operator()(const Unifiable::Bound& t) -{ - TypeId boundTo = clone(t.boundTo, dest, cloneState); - if (FFlag::DebugLuauCopyBeforeNormalizing) - boundTo = dest.addType(BoundType{boundTo}); - seenTypes[typeId] = boundTo; -} + Property shallowClone(const Property& p) + { + if (FFlag::DebugLuauDeferredConstraintResolution) + { + std::optional cloneReadTy; + if (auto ty = p.readTy) + cloneReadTy = shallowClone(*ty); + + std::optional cloneWriteTy; + if (auto ty = p.writeTy) + cloneWriteTy = shallowClone(*ty); + + Property cloned = Property::create(cloneReadTy, cloneWriteTy); + cloned.deprecated = p.deprecated; + cloned.deprecatedSuggestion = p.deprecatedSuggestion; + cloned.location = p.location; + cloned.tags = p.tags; + cloned.documentationSymbol = p.documentationSymbol; + cloned.typeLocation = p.typeLocation; + return cloned; + } + else + { + return Property{ + shallowClone(p.type()), + p.deprecated, + p.deprecatedSuggestion, + p.location, + p.tags, + p.documentationSymbol, + p.typeLocation, + }; + } + } -void TypeCloner::operator()(const Unifiable::Error& t) -{ - defaultClone(t); -} + void cloneChildren(TypeId ty) + { + return visit( + [&](auto&& t) + { + return cloneChildren(&t); + }, + asMutable(ty)->ty + ); + } -void TypeCloner::operator()(const BlockedType& t) -{ - defaultClone(t); -} + void cloneChildren(TypePackId tp) + { + return visit( + [&](auto&& t) + { + return cloneChildren(&t); + }, + asMutable(tp)->ty + ); + } -void TypeCloner::operator()(const PendingExpansionType& t) -{ - TypeId res = dest.addType(PendingExpansionType{t.prefix, t.name, t.typeArguments, t.packArguments}); - PendingExpansionType* petv = getMutable(res); - LUAU_ASSERT(petv); + void cloneChildren(Kind kind) + { + if (auto ty = get(kind)) + return cloneChildren(*ty); + else if (auto tp = get(kind)) + return cloneChildren(*tp); + else + LUAU_ASSERT(!"Item holds neither TypeId nor TypePackId when enqueuing its children?"); + } - seenTypes[typeId] = res; + // ErrorType and ErrorTypePack is an alias to this type. + void cloneChildren(Unifiable::Error* t) + { + // noop. + } - std::vector typeArguments; - for (TypeId arg : t.typeArguments) - typeArguments.push_back(clone(arg, dest, cloneState)); + void cloneChildren(BoundType* t) + { + t->boundTo = shallowClone(t->boundTo); + } - std::vector packArguments; - for (TypePackId arg : t.packArguments) - packArguments.push_back(clone(arg, dest, cloneState)); + void cloneChildren(FreeType* t) + { + if (t->lowerBound) + t->lowerBound = shallowClone(t->lowerBound); + if (t->upperBound) + t->upperBound = shallowClone(t->upperBound); + } - petv->typeArguments = std::move(typeArguments); - petv->packArguments = std::move(packArguments); -} + void cloneChildren(GenericType* t) + { + // TOOD: clone upper bounds. + } -void TypeCloner::operator()(const PrimitiveType& t) -{ - defaultClone(t); -} + void cloneChildren(PrimitiveType* t) + { + // noop. + } -void TypeCloner::operator()(const SingletonType& t) -{ - defaultClone(t); -} + void cloneChildren(BlockedType* t) + { + // TODO: In the new solver, we should ice. + } -void TypeCloner::operator()(const FunctionType& t) -{ - // FISHY: We always erase the scope when we clone things. clone() was - // originally written so that we could copy a module's type surface into an - // export arena. This probably dates to that. - TypeId result = dest.addType(FunctionType{TypeLevel{0, 0}, {}, {}, nullptr, nullptr, t.definition, t.hasSelf}); - FunctionType* ftv = getMutable(result); - LUAU_ASSERT(ftv != nullptr); - - seenTypes[typeId] = result; - - for (TypeId generic : t.generics) - ftv->generics.push_back(clone(generic, dest, cloneState)); - - for (TypePackId genericPack : t.genericPacks) - ftv->genericPacks.push_back(clone(genericPack, dest, cloneState)); - - ftv->tags = t.tags; - ftv->argTypes = clone(t.argTypes, dest, cloneState); - ftv->argNames = t.argNames; - ftv->retTypes = clone(t.retTypes, dest, cloneState); - ftv->hasNoGenerics = t.hasNoGenerics; -} + void cloneChildren(PendingExpansionType* t) + { + // TODO: In the new solver, we should ice. + } -void TypeCloner::operator()(const TableType& t) -{ - // If table is now bound to another one, we ignore the content of the original - if (!FFlag::DebugLuauCopyBeforeNormalizing && t.boundTo) + void cloneChildren(SingletonType* t) { - TypeId boundTo = clone(*t.boundTo, dest, cloneState); - seenTypes[typeId] = boundTo; - return; + // noop. } - TypeId result = dest.addType(TableType{}); - TableType* ttv = getMutable(result); - LUAU_ASSERT(ttv != nullptr); + void cloneChildren(FunctionType* t) + { + for (TypeId& g : t->generics) + g = shallowClone(g); - *ttv = t; + for (TypePackId& gp : t->genericPacks) + gp = shallowClone(gp); - seenTypes[typeId] = result; + t->argTypes = shallowClone(t->argTypes); + t->retTypes = shallowClone(t->retTypes); + } - ttv->level = TypeLevel{0, 0}; + void cloneChildren(TableType* t) + { + if (t->indexer) + { + t->indexer->indexType = shallowClone(t->indexer->indexType); + t->indexer->indexResultType = shallowClone(t->indexer->indexResultType); + } - if (FFlag::DebugLuauCopyBeforeNormalizing && t.boundTo) - ttv->boundTo = clone(*t.boundTo, dest, cloneState); + for (auto& [_, p] : t->props) + p = shallowClone(p); - for (const auto& [name, prop] : t.props) - ttv->props[name] = clone(prop, dest, cloneState); + for (TypeId& ty : t->instantiatedTypeParams) + ty = shallowClone(ty); - if (t.indexer) - ttv->indexer = TableIndexer{clone(t.indexer->indexType, dest, cloneState), clone(t.indexer->indexResultType, dest, cloneState)}; + for (TypePackId& tp : t->instantiatedTypePackParams) + tp = shallowClone(tp); + } - for (TypeId& arg : ttv->instantiatedTypeParams) - arg = clone(arg, dest, cloneState); + void cloneChildren(MetatableType* t) + { + t->table = shallowClone(t->table); + t->metatable = shallowClone(t->metatable); + } - for (TypePackId& arg : ttv->instantiatedTypePackParams) - arg = clone(arg, dest, cloneState); + void cloneChildren(ClassType* t) + { + for (auto& [_, p] : t->props) + p = shallowClone(p); - ttv->definitionModuleName = t.definitionModuleName; - ttv->definitionLocation = t.definitionLocation; - ttv->tags = t.tags; -} + if (t->parent) + t->parent = shallowClone(*t->parent); -void TypeCloner::operator()(const MetatableType& t) -{ - TypeId result = dest.addType(MetatableType{}); - MetatableType* mtv = getMutable(result); - seenTypes[typeId] = result; + if (t->metatable) + t->metatable = shallowClone(*t->metatable); - mtv->table = clone(t.table, dest, cloneState); - mtv->metatable = clone(t.metatable, dest, cloneState); -} + if (t->indexer) + { + t->indexer->indexType = shallowClone(t->indexer->indexType); + t->indexer->indexResultType = shallowClone(t->indexer->indexResultType); + } + } -void TypeCloner::operator()(const ClassType& t) -{ - TypeId result = dest.addType(ClassType{t.name, {}, std::nullopt, std::nullopt, t.tags, t.userData, t.definitionModuleName}); - ClassType* ctv = getMutable(result); + void cloneChildren(AnyType* t) + { + // noop. + } - seenTypes[typeId] = result; + void cloneChildren(UnionType* t) + { + for (TypeId& ty : t->options) + ty = shallowClone(ty); + } - for (const auto& [name, prop] : t.props) - ctv->props[name] = clone(prop, dest, cloneState); + void cloneChildren(IntersectionType* t) + { + for (TypeId& ty : t->parts) + ty = shallowClone(ty); + } - if (t.parent) - ctv->parent = clone(*t.parent, dest, cloneState); + void cloneChildren(LazyType* t) + { + if (auto unwrapped = t->unwrapped.load()) + t->unwrapped.store(shallowClone(unwrapped)); + } - if (t.metatable) - ctv->metatable = clone(*t.metatable, dest, cloneState); -} + void cloneChildren(UnknownType* t) + { + // noop. + } -void TypeCloner::operator()(const AnyType& t) -{ - defaultClone(t); -} + void cloneChildren(NeverType* t) + { + // noop. + } -void TypeCloner::operator()(const UnionType& t) -{ - std::vector options; - options.reserve(t.options.size()); + void cloneChildren(NegationType* t) + { + t->ty = shallowClone(t->ty); + } - for (TypeId ty : t.options) - options.push_back(clone(ty, dest, cloneState)); + void cloneChildren(TypeFunctionInstanceType* t) + { + for (TypeId& ty : t->typeArguments) + ty = shallowClone(ty); - TypeId result = dest.addType(UnionType{std::move(options)}); - seenTypes[typeId] = result; -} + for (TypePackId& tp : t->packArguments) + tp = shallowClone(tp); + } -void TypeCloner::operator()(const IntersectionType& t) -{ - TypeId result = dest.addType(IntersectionType{}); - seenTypes[typeId] = result; + void cloneChildren(FreeTypePack* t) + { + // TODO: clone lower and upper bounds. + // TODO: In the new solver, we should ice. + } - IntersectionType* option = getMutable(result); - LUAU_ASSERT(option != nullptr); + void cloneChildren(GenericTypePack* t) + { + // TOOD: clone upper bounds. + } - for (TypeId ty : t.parts) - option->parts.push_back(clone(ty, dest, cloneState)); -} + void cloneChildren(BlockedTypePack* t) + { + // TODO: In the new solver, we should ice. + } -void TypeCloner::operator()(const LazyType& t) -{ - if (TypeId unwrapped = t.unwrapped.load()) + void cloneChildren(BoundTypePack* t) { - seenTypes[typeId] = clone(unwrapped, dest, cloneState); + t->boundTo = shallowClone(t->boundTo); } - else + + void cloneChildren(VariadicTypePack* t) { - defaultClone(t); + t->ty = shallowClone(t->ty); } -} -void TypeCloner::operator()(const UnknownType& t) -{ - defaultClone(t); -} + void cloneChildren(TypePack* t) + { + for (TypeId& ty : t->head) + ty = shallowClone(ty); -void TypeCloner::operator()(const NeverType& t) -{ - defaultClone(t); -} + if (t->tail) + t->tail = shallowClone(*t->tail); + } -void TypeCloner::operator()(const NegationType& t) -{ - TypeId result = dest.addType(AnyType{}); - seenTypes[typeId] = result; + void cloneChildren(TypeFunctionInstanceTypePack* t) + { + for (TypeId& ty : t->typeArguments) + ty = shallowClone(ty); - TypeId ty = clone(t.ty, dest, cloneState); - asMutable(result)->ty = NegationType{ty}; -} + for (TypePackId& tp : t->packArguments) + tp = shallowClone(tp); + } +}; -} // anonymous namespace +} // namespace TypePackId clone(TypePackId tp, TypeArena& dest, CloneState& cloneState) { if (tp->persistent) return tp; - RecursionLimiter _ra(&cloneState.recursionCount, FInt::LuauTypeCloneRecursionLimit); - - TypePackId& res = cloneState.seenTypePacks[tp]; - - if (res == nullptr) - { - TypePackCloner cloner{dest, tp, cloneState}; - Luau::visit(cloner, tp->ty); // Mutates the storage that 'res' points into. - } - - return res; + TypeCloner cloner{NotNull{&dest}, cloneState.builtinTypes, NotNull{&cloneState.seenTypes}, NotNull{&cloneState.seenTypePacks}}; + return cloner.clone(tp); } TypeId clone(TypeId typeId, TypeArena& dest, CloneState& cloneState) @@ -414,54 +462,35 @@ TypeId clone(TypeId typeId, TypeArena& dest, CloneState& cloneState) if (typeId->persistent) return typeId; - RecursionLimiter _ra(&cloneState.recursionCount, FInt::LuauTypeCloneRecursionLimit); - - TypeId& res = cloneState.seenTypes[typeId]; - - if (res == nullptr) - { - TypeCloner cloner{dest, typeId, cloneState}; - Luau::visit(cloner, typeId->ty); // Mutates the storage that 'res' points into. - - // Persistent types are not being cloned and we get the original type back which might be read-only - if (!res->persistent) - { - asMutable(res)->documentationSymbol = typeId->documentationSymbol; - } - } - - return res; + TypeCloner cloner{NotNull{&dest}, cloneState.builtinTypes, NotNull{&cloneState.seenTypes}, NotNull{&cloneState.seenTypePacks}}; + return cloner.clone(typeId); } TypeFun clone(const TypeFun& typeFun, TypeArena& dest, CloneState& cloneState) { - TypeFun result; + TypeCloner cloner{NotNull{&dest}, cloneState.builtinTypes, NotNull{&cloneState.seenTypes}, NotNull{&cloneState.seenTypePacks}}; + + TypeFun copy = typeFun; - for (auto param : typeFun.typeParams) + for (auto& param : copy.typeParams) { - TypeId ty = clone(param.ty, dest, cloneState); - std::optional defaultValue; + param.ty = cloner.clone(param.ty); if (param.defaultValue) - defaultValue = clone(*param.defaultValue, dest, cloneState); - - result.typeParams.push_back({ty, defaultValue}); + param.defaultValue = cloner.clone(*param.defaultValue); } - for (auto param : typeFun.typePackParams) + for (auto& param : copy.typePackParams) { - TypePackId tp = clone(param.tp, dest, cloneState); - std::optional defaultValue; + param.tp = cloner.clone(param.tp); if (param.defaultValue) - defaultValue = clone(*param.defaultValue, dest, cloneState); - - result.typePackParams.push_back({tp, defaultValue}); + param.defaultValue = cloner.clone(*param.defaultValue); } - result.type = clone(typeFun.type, dest, cloneState); + copy.type = cloner.clone(copy.type); - return result; + return copy; } } // namespace Luau diff --git a/third_party/luau/Analysis/src/Constraint.cpp b/third_party/luau/Analysis/src/Constraint.cpp index 3a6417dc..a62879fa 100644 --- a/third_party/luau/Analysis/src/Constraint.cpp +++ b/third_party/luau/Analysis/src/Constraint.cpp @@ -1,6 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Constraint.h" +#include "Luau/VisitType.h" namespace Luau { @@ -12,4 +13,127 @@ Constraint::Constraint(NotNull scope, const Location& location, Constrain { } +struct ReferenceCountInitializer : TypeOnceVisitor +{ + + DenseHashSet* result; + + ReferenceCountInitializer(DenseHashSet* result) + : result(result) + { + } + + bool visit(TypeId ty, const FreeType&) override + { + result->insert(ty); + return false; + } + + bool visit(TypeId ty, const BlockedType&) override + { + result->insert(ty); + return false; + } + + bool visit(TypeId ty, const PendingExpansionType&) override + { + result->insert(ty); + return false; + } + + bool visit(TypeId ty, const ClassType&) override + { + // ClassTypes never contain free types. + return false; + } +}; + +bool isReferenceCountedType(const TypeId typ) +{ + // n.b. this should match whatever `ReferenceCountInitializer` includes. + return get(typ) || get(typ) || get(typ); +} + +DenseHashSet Constraint::getMaybeMutatedFreeTypes() const +{ + // For the purpose of this function and reference counting in general, we are only considering + // mutations that affect the _bounds_ of the free type, and not something that may bind the free + // type itself to a new type. As such, `ReduceConstraint` and `GeneralizationConstraint` have no + // contribution to the output set here. + + DenseHashSet types{{}}; + ReferenceCountInitializer rci{&types}; + + if (auto ec = get(*this)) + { + rci.traverse(ec->resultType); + // `EqualityConstraints` should not mutate `assignmentType`. + } + else if (auto sc = get(*this)) + { + rci.traverse(sc->subType); + rci.traverse(sc->superType); + } + else if (auto psc = get(*this)) + { + rci.traverse(psc->subPack); + rci.traverse(psc->superPack); + } + else if (auto itc = get(*this)) + { + for (TypeId ty : itc->variables) + rci.traverse(ty); + // `IterableConstraints` should not mutate `iterator`. + } + else if (auto nc = get(*this)) + { + rci.traverse(nc->namedType); + } + else if (auto taec = get(*this)) + { + rci.traverse(taec->target); + } + else if (auto fchc = get(*this)) + { + rci.traverse(fchc->argsPack); + } + else if (auto ptc = get(*this)) + { + rci.traverse(ptc->freeType); + } + else if (auto hpc = get(*this)) + { + rci.traverse(hpc->resultType); + // `HasPropConstraints` should not mutate `subjectType`. + } + else if (auto hic = get(*this)) + { + rci.traverse(hic->resultType); + // `HasIndexerConstraint` should not mutate `subjectType` or `indexType`. + } + else if (auto apc = get(*this)) + { + rci.traverse(apc->lhsType); + rci.traverse(apc->rhsType); + } + else if (auto aic = get(*this)) + { + rci.traverse(aic->lhsType); + rci.traverse(aic->indexType); + rci.traverse(aic->rhsType); + } + else if (auto uc = get(*this)) + { + for (TypeId ty : uc->resultPack) + rci.traverse(ty); + // `UnpackConstraint` should not mutate `sourcePack`. + } + else if (auto rpc = get(*this)) + { + rci.traverse(rpc->tp); + } + + return types; +} + } // namespace Luau diff --git a/third_party/luau/Analysis/src/ConstraintGenerator.cpp b/third_party/luau/Analysis/src/ConstraintGenerator.cpp new file mode 100644 index 00000000..036f4313 --- /dev/null +++ b/third_party/luau/Analysis/src/ConstraintGenerator.cpp @@ -0,0 +1,3528 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/ConstraintGenerator.h" + +#include "Luau/Ast.h" +#include "Luau/Def.h" +#include "Luau/Common.h" +#include "Luau/Constraint.h" +#include "Luau/ControlFlow.h" +#include "Luau/DcrLogger.h" +#include "Luau/DenseHash.h" +#include "Luau/ModuleResolver.h" +#include "Luau/RecursionCounter.h" +#include "Luau/Refinement.h" +#include "Luau/Scope.h" +#include "Luau/Simplify.h" +#include "Luau/StringUtils.h" +#include "Luau/TableLiteralInference.h" +#include "Luau/TimeTrace.h" +#include "Luau/Type.h" +#include "Luau/TypeFunction.h" +#include "Luau/TypePack.h" +#include "Luau/TypeUtils.h" +#include "Luau/Unifier2.h" +#include "Luau/VisitType.h" + +#include +#include + +LUAU_FASTINT(LuauCheckRecursionLimit); +LUAU_FASTFLAG(DebugLuauLogSolverToJson); +LUAU_FASTFLAG(DebugLuauMagicTypes); +LUAU_FASTFLAG(LuauDeclarationExtraPropData); + +namespace Luau +{ + +bool doesCallError(const AstExprCall* call); // TypeInfer.cpp +const AstStat* getFallthrough(const AstStat* node); // TypeInfer.cpp + +static std::optional matchRequire(const AstExprCall& call) +{ + const char* require = "require"; + + if (call.args.size != 1) + return std::nullopt; + + const AstExprGlobal* funcAsGlobal = call.func->as(); + if (!funcAsGlobal || funcAsGlobal->name != require) + return std::nullopt; + + if (call.args.size != 1) + return std::nullopt; + + return call.args.data[0]; +} + +static bool matchSetmetatable(const AstExprCall& call) +{ + const char* smt = "setmetatable"; + + if (call.args.size != 2) + return false; + + const AstExprGlobal* funcAsGlobal = call.func->as(); + if (!funcAsGlobal || funcAsGlobal->name != smt) + return false; + + return true; +} + +struct TypeGuard +{ + bool isTypeof; + AstExpr* target; + std::string type; +}; + +static std::optional matchTypeGuard(const AstExprBinary* binary) +{ + if (binary->op != AstExprBinary::CompareEq && binary->op != AstExprBinary::CompareNe) + return std::nullopt; + + AstExpr* left = binary->left; + AstExpr* right = binary->right; + if (right->is()) + std::swap(left, right); + + if (!right->is()) + return std::nullopt; + + AstExprCall* call = left->as(); + AstExprConstantString* string = right->as(); + if (!call || !string) + return std::nullopt; + + AstExprGlobal* callee = call->func->as(); + if (!callee) + return std::nullopt; + + if (callee->name != "type" && callee->name != "typeof") + return std::nullopt; + + if (call->args.size != 1) + return std::nullopt; + + return TypeGuard{ + /*isTypeof*/ callee->name == "typeof", + /*target*/ call->args.data[0], + /*type*/ std::string(string->value.data, string->value.size), + }; +} + +static bool matchAssert(const AstExprCall& call) +{ + if (call.args.size < 1) + return false; + + const AstExprGlobal* funcAsGlobal = call.func->as(); + if (!funcAsGlobal || funcAsGlobal->name != "assert") + return false; + + return true; +} + +namespace +{ + +struct Checkpoint +{ + size_t offset; +}; + +Checkpoint checkpoint(const ConstraintGenerator* cg) +{ + return Checkpoint{cg->constraints.size()}; +} + +template +void forEachConstraint(const Checkpoint& start, const Checkpoint& end, const ConstraintGenerator* cg, F f) +{ + for (size_t i = start.offset; i < end.offset; ++i) + f(cg->constraints[i]); +} + +struct HasFreeType : TypeOnceVisitor +{ + bool result = false; + + HasFreeType() {} + + bool visit(TypeId ty) override + { + if (result || ty->persistent) + return false; + return true; + } + + bool visit(TypePackId tp) override + { + if (result) + return false; + return true; + } + + bool visit(TypeId ty, const ClassType&) override + { + return false; + } + + bool visit(TypeId ty, const FreeType&) override + { + result = true; + return false; + } + + bool visit(TypePackId ty, const FreeTypePack&) override + { + result = true; + return false; + } +}; + +bool hasFreeType(TypeId ty) +{ + HasFreeType hft{}; + hft.traverse(ty); + return hft.result; +} + +} // namespace + +ConstraintGenerator::ConstraintGenerator( + ModulePtr module, + NotNull normalizer, + NotNull moduleResolver, + NotNull builtinTypes, + NotNull ice, + const ScopePtr& globalScope, + std::function prepareModuleScope, + DcrLogger* logger, + NotNull dfg, + std::vector requireCycles +) + : module(module) + , builtinTypes(builtinTypes) + , arena(normalizer->arena) + , rootScope(nullptr) + , dfg(dfg) + , normalizer(normalizer) + , moduleResolver(moduleResolver) + , ice(ice) + , globalScope(globalScope) + , prepareModuleScope(std::move(prepareModuleScope)) + , requireCycles(std::move(requireCycles)) + , logger(logger) +{ + LUAU_ASSERT(module); +} + +void ConstraintGenerator::visitModuleRoot(AstStatBlock* block) +{ + LUAU_TIMETRACE_SCOPE("ConstraintGenerator::visitModuleRoot", "Typechecking"); + + LUAU_ASSERT(scopes.empty()); + LUAU_ASSERT(rootScope == nullptr); + ScopePtr scope = std::make_shared(globalScope); + rootScope = scope.get(); + scopes.emplace_back(block->location, scope); + rootScope->location = block->location; + module->astScopes[block] = NotNull{scope.get()}; + + rootScope->returnType = freshTypePack(scope); + + TypeId moduleFnTy = arena->addType(FunctionType{TypeLevel{}, rootScope, builtinTypes->anyTypePack, rootScope->returnType}); + interiorTypes.emplace_back(); + + prepopulateGlobalScope(scope, block); + + Checkpoint start = checkpoint(this); + + ControlFlow cf = visitBlockWithoutChildScope(scope, block); + if (cf == ControlFlow::None) + addConstraint(scope, block->location, PackSubtypeConstraint{builtinTypes->emptyTypePack, rootScope->returnType}); + + Checkpoint end = checkpoint(this); + + TypeId result = arena->addType(BlockedType{}); + NotNull genConstraint = + addConstraint(scope, block->location, GeneralizationConstraint{result, moduleFnTy, std::move(interiorTypes.back())}); + getMutable(result)->setOwner(genConstraint); + forEachConstraint( + start, + end, + this, + [genConstraint](const ConstraintPtr& c) + { + genConstraint->dependencies.push_back(NotNull{c.get()}); + } + ); + + interiorTypes.pop_back(); + + fillInInferredBindings(scope, block); + + if (logger) + logger->captureGenerationModule(module); + + for (const auto& [ty, domain] : localTypes) + { + // FIXME: This isn't the most efficient thing. + TypeId domainTy = builtinTypes->neverType; + for (TypeId d : domain) + { + if (d == ty) + continue; + domainTy = simplifyUnion(builtinTypes, arena, domainTy, d).result; + } + + LUAU_ASSERT(get(ty)); + asMutable(ty)->ty.emplace(domainTy); + } +} + +TypeId ConstraintGenerator::freshType(const ScopePtr& scope) +{ + return Luau::freshType(arena, builtinTypes, scope.get()); +} + +TypePackId ConstraintGenerator::freshTypePack(const ScopePtr& scope) +{ + FreeTypePack f{scope.get()}; + return arena->addTypePack(TypePackVar{std::move(f)}); +} + +TypePackId ConstraintGenerator::addTypePack(std::vector head, std::optional tail) +{ + if (head.empty()) + { + if (tail) + return *tail; + else + return builtinTypes->emptyTypePack; + } + else + return arena->addTypePack(TypePack{std::move(head), tail}); +} + +ScopePtr ConstraintGenerator::childScope(AstNode* node, const ScopePtr& parent) +{ + auto scope = std::make_shared(parent); + scopes.emplace_back(node->location, scope); + scope->location = node->location; + + scope->returnType = parent->returnType; + scope->varargPack = parent->varargPack; + + parent->children.push_back(NotNull{scope.get()}); + module->astScopes[node] = scope.get(); + + return scope; +} + +std::optional ConstraintGenerator::lookup(const ScopePtr& scope, Location location, DefId def, bool prototype) +{ + if (get(def)) + return scope->lookup(def); + if (auto phi = get(def)) + { + if (auto found = scope->lookup(def)) + return *found; + else if (!prototype && phi->operands.size() == 1) + return lookup(scope, location, phi->operands.at(0), prototype); + else if (!prototype) + return std::nullopt; + + TypeId res = builtinTypes->neverType; + + for (DefId operand : phi->operands) + { + // `scope->lookup(operand)` may return nothing because we only bind a type to that operand + // once we've seen that particular `DefId`. In this case, we need to prototype those types + // and use those at a later time. + std::optional ty = lookup(scope, location, operand, /*prototype*/ false); + if (!ty) + { + ty = arena->addType(BlockedType{}); + localTypes.try_insert(*ty, {}); + rootScope->lvalueTypes[operand] = *ty; + } + + res = makeUnion(scope, location, res, *ty); + } + + scope->lvalueTypes[def] = res; + return res; + } + else + ice->ice("ConstraintGenerator::lookup is inexhaustive?"); +} + +NotNull ConstraintGenerator::addConstraint(const ScopePtr& scope, const Location& location, ConstraintV cv) +{ + return NotNull{constraints.emplace_back(new Constraint{NotNull{scope.get()}, location, std::move(cv)}).get()}; +} + +NotNull ConstraintGenerator::addConstraint(const ScopePtr& scope, std::unique_ptr c) +{ + return NotNull{constraints.emplace_back(std::move(c)).get()}; +} + +void ConstraintGenerator::unionRefinements( + const ScopePtr& scope, + Location location, + const RefinementContext& lhs, + const RefinementContext& rhs, + RefinementContext& dest, + std::vector* constraints +) +{ + const auto intersect = [&](const std::vector& types) + { + if (1 == types.size()) + return types[0]; + else if (2 == types.size()) + return makeIntersect(scope, location, types[0], types[1]); + + return arena->addType(IntersectionType{types}); + }; + + for (auto& [def, partition] : lhs) + { + auto rhsIt = rhs.find(def); + if (rhsIt == rhs.end()) + continue; + + LUAU_ASSERT(!partition.discriminantTypes.empty()); + LUAU_ASSERT(!rhsIt->second.discriminantTypes.empty()); + + TypeId leftDiscriminantTy = partition.discriminantTypes.size() == 1 ? partition.discriminantTypes[0] : intersect(partition.discriminantTypes); + + TypeId rightDiscriminantTy = + rhsIt->second.discriminantTypes.size() == 1 ? rhsIt->second.discriminantTypes[0] : intersect(rhsIt->second.discriminantTypes); + + dest.insert(def, {}); + dest.get(def)->discriminantTypes.push_back(makeUnion(scope, location, leftDiscriminantTy, rightDiscriminantTy)); + dest.get(def)->shouldAppendNilType |= partition.shouldAppendNilType || rhsIt->second.shouldAppendNilType; + } +} + +void ConstraintGenerator::computeRefinement( + const ScopePtr& scope, + Location location, + RefinementId refinement, + RefinementContext* refis, + bool sense, + bool eq, + std::vector* constraints +) +{ + if (!refinement) + return; + else if (auto variadic = get(refinement)) + { + for (RefinementId refi : variadic->refinements) + computeRefinement(scope, location, refi, refis, sense, eq, constraints); + } + else if (auto negation = get(refinement)) + return computeRefinement(scope, location, negation->refinement, refis, !sense, eq, constraints); + else if (auto conjunction = get(refinement)) + { + RefinementContext lhsRefis; + RefinementContext rhsRefis; + + computeRefinement(scope, location, conjunction->lhs, sense ? refis : &lhsRefis, sense, eq, constraints); + computeRefinement(scope, location, conjunction->rhs, sense ? refis : &rhsRefis, sense, eq, constraints); + + if (!sense) + unionRefinements(scope, location, lhsRefis, rhsRefis, *refis, constraints); + } + else if (auto disjunction = get(refinement)) + { + RefinementContext lhsRefis; + RefinementContext rhsRefis; + + computeRefinement(scope, location, disjunction->lhs, sense ? &lhsRefis : refis, sense, eq, constraints); + computeRefinement(scope, location, disjunction->rhs, sense ? &rhsRefis : refis, sense, eq, constraints); + + if (sense) + unionRefinements(scope, location, lhsRefis, rhsRefis, *refis, constraints); + } + else if (auto equivalence = get(refinement)) + { + computeRefinement(scope, location, equivalence->lhs, refis, sense, true, constraints); + computeRefinement(scope, location, equivalence->rhs, refis, sense, true, constraints); + } + else if (auto proposition = get(refinement)) + { + TypeId discriminantTy = proposition->discriminantTy; + + // if we have a negative sense, then we need to negate the discriminant + if (!sense) + discriminantTy = arena->addType(NegationType{discriminantTy}); + + if (eq) + discriminantTy = createTypeFunctionInstance(builtinTypeFunctions().singletonFunc, {discriminantTy}, {}, scope, location); + + for (const RefinementKey* key = proposition->key; key; key = key->parent) + { + refis->insert(key->def, {}); + refis->get(key->def)->discriminantTypes.push_back(discriminantTy); + + // Reached leaf node + if (!key->propName) + break; + + TypeId nextDiscriminantTy = arena->addType(TableType{}); + NotNull table{getMutable(nextDiscriminantTy)}; + // When we fully support read-write properties (i.e. when we allow properties with + // completely disparate read and write types), then the following property can be + // set to read-only since refinements only tell us about what we read. This cannot + // be allowed yet though because it causes read and write types to diverge. + table->props[*key->propName] = Property::rw(discriminantTy); + table->scope = scope.get(); + table->state = TableState::Sealed; + + discriminantTy = nextDiscriminantTy; + } + + // When the top-level expression is `t[x]`, we want to refine it into `nil`, not `never`. + LUAU_ASSERT(refis->get(proposition->key->def)); + refis->get(proposition->key->def)->shouldAppendNilType = (sense || !eq) && containsSubscriptedDefinition(proposition->key->def); + } +} + +namespace +{ + +/* + * Constraint generation may be called upon to simplify an intersection or union + * of types that are not sufficiently solved yet. We use + * FindSimplificationBlockers to recognize these types and defer the + * simplification until constraint solution. + */ +struct FindSimplificationBlockers : TypeOnceVisitor +{ + bool found = false; + + bool visit(TypeId) override + { + return !found; + } + + bool visit(TypeId, const BlockedType&) override + { + found = true; + return false; + } + + bool visit(TypeId, const FreeType&) override + { + found = true; + return false; + } + + bool visit(TypeId, const PendingExpansionType&) override + { + found = true; + return false; + } + + // We do not need to know anything at all about a function's argument or + // return types in order to simplify it in an intersection or union. + bool visit(TypeId, const FunctionType&) override + { + return false; + } + + bool visit(TypeId, const ClassType&) override + { + return false; + } +}; + +bool mustDeferIntersection(TypeId ty) +{ + FindSimplificationBlockers bts; + bts.traverse(ty); + return bts.found; +} +} // namespace + +void ConstraintGenerator::applyRefinements(const ScopePtr& scope, Location location, RefinementId refinement) +{ + if (!refinement) + return; + + RefinementContext refinements; + std::vector constraints; + computeRefinement(scope, location, refinement, &refinements, /*sense*/ true, /*eq*/ false, &constraints); + + for (auto& [def, partition] : refinements) + { + if (std::optional defTy = lookup(scope, location, def)) + { + TypeId ty = *defTy; + if (partition.shouldAppendNilType) + ty = arena->addType(UnionType{{ty, builtinTypes->nilType}}); + + // Intersect ty with every discriminant type. If either type is not + // sufficiently solved, we queue the intersection up via an + // IntersectConstraint. + + for (TypeId dt : partition.discriminantTypes) + { + if (mustDeferIntersection(ty) || mustDeferIntersection(dt)) + { + TypeId resultType = createTypeFunctionInstance(builtinTypeFunctions().refineFunc, {ty, dt}, {}, scope, location); + + ty = resultType; + } + else + { + switch (shouldSuppressErrors(normalizer, ty)) + { + case ErrorSuppression::DoNotSuppress: + { + if (!get(follow(ty))) + ty = makeIntersect(scope, location, ty, dt); + break; + } + case ErrorSuppression::Suppress: + ty = makeIntersect(scope, location, ty, dt); + ty = makeUnion(scope, location, ty, builtinTypes->errorType); + break; + case ErrorSuppression::NormalizationFailed: + reportError(location, NormalizationTooComplex{}); + ty = makeIntersect(scope, location, ty, dt); + break; + } + } + } + + scope->rvalueRefinements[def] = ty; + } + } + + for (auto& c : constraints) + addConstraint(scope, location, c); +} + +ControlFlow ConstraintGenerator::visitBlockWithoutChildScope(const ScopePtr& scope, AstStatBlock* block) +{ + RecursionCounter counter{&recursionCount}; + + if (recursionCount >= FInt::LuauCheckRecursionLimit) + { + reportCodeTooComplex(block->location); + return ControlFlow::None; + } + + std::unordered_map aliasDefinitionLocations; + + // In order to enable mutually-recursive type aliases, we need to + // populate the type bindings before we actually check any of the + // alias statements. + for (AstStat* stat : block->body) + { + if (auto alias = stat->as()) + { + if (scope->exportedTypeBindings.count(alias->name.value) || scope->privateTypeBindings.count(alias->name.value)) + { + auto it = aliasDefinitionLocations.find(alias->name.value); + LUAU_ASSERT(it != aliasDefinitionLocations.end()); + reportError(alias->location, DuplicateTypeDefinition{alias->name.value, it->second}); + continue; + } + + // A type alias might have no name if the code is syntactically + // illegal. We mustn't prepopulate anything in this case. + if (alias->name == kParseNameError || alias->name == "typeof") + continue; + + ScopePtr defnScope = childScope(alias, scope); + + TypeId initialType = arena->addType(BlockedType{}); + TypeFun initialFun{initialType}; + + for (const auto& [name, gen] : createGenerics(defnScope, alias->generics, /* useCache */ true)) + { + initialFun.typeParams.push_back(gen); + } + + for (const auto& [name, genPack] : createGenericPacks(defnScope, alias->genericPacks, /* useCache */ true)) + { + initialFun.typePackParams.push_back(genPack); + } + + if (alias->exported) + scope->exportedTypeBindings[alias->name.value] = std::move(initialFun); + else + scope->privateTypeBindings[alias->name.value] = std::move(initialFun); + + astTypeAliasDefiningScopes[alias] = defnScope; + aliasDefinitionLocations[alias->name.value] = alias->location; + } + } + + std::optional firstControlFlow; + for (AstStat* stat : block->body) + { + ControlFlow cf = visit(scope, stat); + if (cf != ControlFlow::None && !firstControlFlow) + firstControlFlow = cf; + } + + return firstControlFlow.value_or(ControlFlow::None); +} + +ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStat* stat) +{ + RecursionLimiter limiter{&recursionCount, FInt::LuauCheckRecursionLimit}; + + if (auto s = stat->as()) + return visit(scope, s); + else if (auto i = stat->as()) + return visit(scope, i); + else if (auto s = stat->as()) + return visit(scope, s); + else if (auto s = stat->as()) + return visit(scope, s); + else if (stat->is()) + return ControlFlow::Breaks; + else if (stat->is()) + return ControlFlow::Continues; + else if (auto r = stat->as()) + return visit(scope, r); + else if (auto e = stat->as()) + { + checkPack(scope, e->expr); + + if (auto call = e->expr->as(); call && doesCallError(call)) + return ControlFlow::Throws; + + return ControlFlow::None; + } + else if (auto s = stat->as()) + return visit(scope, s); + else if (auto s = stat->as()) + return visit(scope, s); + else if (auto s = stat->as()) + return visit(scope, s); + else if (auto a = stat->as()) + return visit(scope, a); + else if (auto a = stat->as()) + return visit(scope, a); + else if (auto f = stat->as()) + return visit(scope, f); + else if (auto f = stat->as()) + return visit(scope, f); + else if (auto a = stat->as()) + return visit(scope, a); + else if (auto f = stat->as()) + return visit(scope, f); + else if (auto s = stat->as()) + return visit(scope, s); + else if (auto s = stat->as()) + return visit(scope, s); + else if (auto s = stat->as()) + return visit(scope, s); + else if (auto s = stat->as()) + return visit(scope, s); + else + { + LUAU_ASSERT(0 && "Internal error: Unknown AstStat type"); + return ControlFlow::None; + } +} + +ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatLocal* statLocal) +{ + std::vector annotatedTypes; + annotatedTypes.reserve(statLocal->vars.size); + bool hasAnnotation = false; + + std::vector> expectedTypes; + expectedTypes.reserve(statLocal->vars.size); + + std::vector assignees; + assignees.reserve(statLocal->vars.size); + + // Used to name the first value type, even if it's not placed in varTypes, + // for the purpose of synthetic name attribution. + std::optional firstValueType; + + for (AstLocal* local : statLocal->vars) + { + const Location location = local->location; + + TypeId assignee = arena->addType(BlockedType{}); + localTypes.try_insert(assignee, {}); + + assignees.push_back(assignee); + + if (!firstValueType) + firstValueType = assignee; + + if (local->annotation) + { + hasAnnotation = true; + TypeId annotationTy = resolveType(scope, local->annotation, /* inTypeArguments */ false); + annotatedTypes.push_back(annotationTy); + expectedTypes.push_back(annotationTy); + + scope->bindings[local] = Binding{annotationTy, location}; + } + else + { + // annotatedTypes must contain one type per local. If a particular + // local has no annotation at, assume the most conservative thing. + annotatedTypes.push_back(builtinTypes->unknownType); + + expectedTypes.push_back(std::nullopt); + scope->bindings[local] = Binding{builtinTypes->unknownType, location}; + + inferredBindings[local] = {scope.get(), location, {assignee}}; + } + + DefId def = dfg->getDef(local); + scope->lvalueTypes[def] = assignee; + } + + Checkpoint start = checkpoint(this); + TypePackId rvaluePack = checkPack(scope, statLocal->values, expectedTypes).tp; + Checkpoint end = checkpoint(this); + + if (hasAnnotation) + { + for (size_t i = 0; i < statLocal->vars.size; ++i) + { + LUAU_ASSERT(get(assignees[i])); + TypeIds* localDomain = localTypes.find(assignees[i]); + LUAU_ASSERT(localDomain); + localDomain->insert(annotatedTypes[i]); + } + + TypePackId annotatedPack = arena->addTypePack(std::move(annotatedTypes)); + addConstraint(scope, statLocal->location, PackSubtypeConstraint{rvaluePack, annotatedPack}); + } + else + { + std::vector valueTypes; + valueTypes.reserve(statLocal->vars.size); + + auto [head, tail] = flatten(rvaluePack); + + if (head.size() >= statLocal->vars.size) + { + for (size_t i = 0; i < statLocal->vars.size; ++i) + valueTypes.push_back(head[i]); + } + else + { + for (size_t i = 0; i < statLocal->vars.size; ++i) + valueTypes.push_back(arena->addType(BlockedType{})); + + auto uc = addConstraint(scope, statLocal->location, UnpackConstraint{valueTypes, rvaluePack}); + + forEachConstraint( + start, + end, + this, + [&uc](const ConstraintPtr& runBefore) + { + uc->dependencies.push_back(NotNull{runBefore.get()}); + } + ); + + for (TypeId t : valueTypes) + getMutable(t)->setOwner(uc); + } + + for (size_t i = 0; i < statLocal->vars.size; ++i) + { + LUAU_ASSERT(get(assignees[i])); + TypeIds* localDomain = localTypes.find(assignees[i]); + LUAU_ASSERT(localDomain); + localDomain->insert(valueTypes[i]); + } + } + + if (statLocal->vars.size == 1 && statLocal->values.size == 1 && firstValueType && scope.get() == rootScope && !hasAnnotation) + { + AstLocal* var = statLocal->vars.data[0]; + AstExpr* value = statLocal->values.data[0]; + + if (value->is()) + addConstraint(scope, value->location, NameConstraint{*firstValueType, var->name.value, /*synthetic*/ true}); + else if (const AstExprCall* call = value->as()) + { + if (const AstExprGlobal* global = call->func->as(); global && global->name == "setmetatable") + { + addConstraint(scope, value->location, NameConstraint{*firstValueType, var->name.value, /*synthetic*/ true}); + } + } + } + + if (statLocal->values.size > 0) + { + // To correctly handle 'require', we need to import the exported type bindings into the variable 'namespace'. + for (size_t i = 0; i < statLocal->values.size && i < statLocal->vars.size; ++i) + { + const AstExprCall* call = statLocal->values.data[i]->as(); + if (!call) + continue; + + auto maybeRequire = matchRequire(*call); + if (!maybeRequire) + continue; + + AstExpr* require = *maybeRequire; + + auto moduleInfo = moduleResolver->resolveModuleInfo(module->name, *require); + if (!moduleInfo) + continue; + + ModulePtr module = moduleResolver->getModule(moduleInfo->name); + if (!module) + continue; + + const Name name{statLocal->vars.data[i]->name.value}; + scope->importedTypeBindings[name] = module->exportedTypeBindings; + scope->importedModules[name] = moduleInfo->name; + + // Imported types of requires that transitively refer to current module have to be replaced with 'any' + for (const auto& [location, path] : requireCycles) + { + if (path.empty() || path.front() != moduleInfo->name) + continue; + + for (auto& [name, tf] : scope->importedTypeBindings[name]) + tf = TypeFun{{}, {}, builtinTypes->anyType}; + } + } + } + + return ControlFlow::None; +} + +ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatFor* for_) +{ + TypeId annotationTy = builtinTypes->numberType; + if (for_->var->annotation) + annotationTy = resolveType(scope, for_->var->annotation, /* inTypeArguments */ false); + + auto inferNumber = [&](AstExpr* expr) + { + if (!expr) + return; + + TypeId t = check(scope, expr).ty; + addConstraint(scope, expr->location, SubtypeConstraint{t, builtinTypes->numberType}); + }; + + inferNumber(for_->from); + inferNumber(for_->to); + inferNumber(for_->step); + + ScopePtr forScope = childScope(for_, scope); + forScope->bindings[for_->var] = Binding{annotationTy, for_->var->location}; + + DefId def = dfg->getDef(for_->var); + forScope->lvalueTypes[def] = annotationTy; + forScope->rvalueRefinements[def] = annotationTy; + + visit(forScope, for_->body); + + return ControlFlow::None; +} + +ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatForIn* forIn) +{ + ScopePtr loopScope = childScope(forIn, scope); + TypePackId iterator = checkPack(scope, forIn->values).tp; + + std::vector variableTypes; + variableTypes.reserve(forIn->vars.size); + + for (AstLocal* var : forIn->vars) + { + TypeId assignee = arena->addType(BlockedType{}); + variableTypes.push_back(assignee); + + TypeId loopVar = arena->addType(BlockedType{}); + localTypes[loopVar].insert(assignee); + + if (var->annotation) + { + TypeId annotationTy = resolveType(loopScope, var->annotation, /*inTypeArguments*/ false); + loopScope->bindings[var] = Binding{annotationTy, var->location}; + addConstraint(scope, var->location, SubtypeConstraint{loopVar, annotationTy}); + } + else + loopScope->bindings[var] = Binding{loopVar, var->location}; + + DefId def = dfg->getDef(var); + loopScope->lvalueTypes[def] = loopVar; + } + + auto iterable = addConstraint( + loopScope, getLocation(forIn->values), IterableConstraint{iterator, variableTypes, forIn->values.data[0], &module->astForInNextTypes} + ); + + for (TypeId var : variableTypes) + { + auto bt = getMutable(var); + LUAU_ASSERT(bt); + bt->setOwner(iterable); + } + + Checkpoint start = checkpoint(this); + visit(loopScope, forIn->body); + Checkpoint end = checkpoint(this); + + // This iter constraint must dispatch first. + forEachConstraint( + start, + end, + this, + [&iterable](const ConstraintPtr& runLater) + { + runLater->dependencies.push_back(iterable); + } + ); + + return ControlFlow::None; +} + +ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatWhile* while_) +{ + RefinementId refinement = check(scope, while_->condition).refinement; + + ScopePtr whileScope = childScope(while_, scope); + applyRefinements(whileScope, while_->condition->location, refinement); + + visit(whileScope, while_->body); + + return ControlFlow::None; +} + +ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatRepeat* repeat) +{ + ScopePtr repeatScope = childScope(repeat, scope); + + visitBlockWithoutChildScope(repeatScope, repeat->body); + + check(repeatScope, repeat->condition); + + return ControlFlow::None; +} + +ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatLocalFunction* function) +{ + // Local + // Global + // Dotted path + // Self? + + TypeId functionType = nullptr; + auto ty = scope->lookup(function->name); + LUAU_ASSERT(!ty.has_value()); // The parser ensures that every local function has a distinct Symbol for its name. + + functionType = arena->addType(BlockedType{}); + scope->bindings[function->name] = Binding{functionType, function->name->location}; + + FunctionSignature sig = checkFunctionSignature(scope, function->func, /* expectedType */ std::nullopt, function->name->location); + sig.bodyScope->bindings[function->name] = Binding{sig.signature, function->name->location}; + + bool sigFullyDefined = !hasFreeType(sig.signature); + if (sigFullyDefined) + emplaceType(asMutable(functionType), sig.signature); + + DefId def = dfg->getDef(function->name); + scope->lvalueTypes[def] = functionType; + scope->rvalueRefinements[def] = functionType; + sig.bodyScope->lvalueTypes[def] = sig.signature; + sig.bodyScope->rvalueRefinements[def] = sig.signature; + + Checkpoint start = checkpoint(this); + checkFunctionBody(sig.bodyScope, function->func); + Checkpoint end = checkpoint(this); + + if (!sigFullyDefined) + { + NotNull constraintScope{sig.signatureScope ? sig.signatureScope.get() : sig.bodyScope.get()}; + std::unique_ptr c = + std::make_unique(constraintScope, function->name->location, GeneralizationConstraint{functionType, sig.signature}); + + Constraint* previous = nullptr; + forEachConstraint( + start, + end, + this, + [&c, &previous](const ConstraintPtr& constraint) + { + c->dependencies.push_back(NotNull{constraint.get()}); + + if (auto psc = get(*constraint); psc && psc->returns) + { + if (previous) + constraint->dependencies.push_back(NotNull{previous}); + + previous = constraint.get(); + } + } + ); + + getMutable(functionType)->setOwner(addConstraint(scope, std::move(c))); + module->astTypes[function->func] = functionType; + } + else + module->astTypes[function->func] = sig.signature; + + return ControlFlow::None; +} + +ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatFunction* function) +{ + // Name could be AstStatLocal, AstStatGlobal, AstStatIndexName. + // With or without self + + Checkpoint start = checkpoint(this); + FunctionSignature sig = checkFunctionSignature(scope, function->func, /* expectedType */ std::nullopt, function->name->location); + bool sigFullyDefined = !hasFreeType(sig.signature); + + checkFunctionBody(sig.bodyScope, function->func); + Checkpoint end = checkpoint(this); + + TypeId generalizedType = arena->addType(BlockedType{}); + if (sigFullyDefined) + emplaceType(asMutable(generalizedType), sig.signature); + else + { + const ScopePtr& constraintScope = sig.signatureScope ? sig.signatureScope : sig.bodyScope; + + NotNull c = addConstraint(constraintScope, function->name->location, GeneralizationConstraint{generalizedType, sig.signature}); + getMutable(generalizedType)->setOwner(c); + + Constraint* previous = nullptr; + forEachConstraint( + start, + end, + this, + [&c, &previous](const ConstraintPtr& constraint) + { + c->dependencies.push_back(NotNull{constraint.get()}); + + if (auto psc = get(*constraint); psc && psc->returns) + { + if (previous) + constraint->dependencies.push_back(NotNull{previous}); + + previous = constraint.get(); + } + } + ); + } + + DefId def = dfg->getDef(function->name); + std::optional existingFunctionTy = follow(lookup(scope, function->name->location, def)); + + if (AstExprLocal* localName = function->name->as()) + { + visitLValue(scope, localName, generalizedType); + + scope->bindings[localName->local] = Binding{sig.signature, localName->location}; + scope->lvalueTypes[def] = sig.signature; + } + else if (AstExprGlobal* globalName = function->name->as()) + { + if (!existingFunctionTy) + ice->ice("prepopulateGlobalScope did not populate a global name", globalName->location); + + // Sketchy: We're specifically looking for BlockedTypes that were + // initially created by ConstraintGenerator::prepopulateGlobalScope. + if (auto bt = get(*existingFunctionTy); bt && nullptr == bt->getOwner()) + emplaceType(asMutable(*existingFunctionTy), generalizedType); + + scope->bindings[globalName->name] = Binding{sig.signature, globalName->location}; + scope->lvalueTypes[def] = sig.signature; + } + else if (AstExprIndexName* indexName = function->name->as()) + { + visitLValue(scope, indexName, generalizedType); + } + else if (AstExprError* err = function->name->as()) + { + generalizedType = builtinTypes->errorRecoveryType(); + } + + if (generalizedType == nullptr) + ice->ice("generalizedType == nullptr", function->location); + + scope->rvalueRefinements[def] = generalizedType; + + return ControlFlow::None; +} + +ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatReturn* ret) +{ + // At this point, the only way scope->returnType should have anything + // interesting in it is if the function has an explicit return annotation. + // If this is the case, then we can expect that the return expression + // conforms to that. + std::vector> expectedTypes; + for (TypeId ty : scope->returnType) + expectedTypes.push_back(ty); + + TypePackId exprTypes = checkPack(scope, ret->list, expectedTypes).tp; + addConstraint(scope, ret->location, PackSubtypeConstraint{exprTypes, scope->returnType, /*returns*/ true}); + + return ControlFlow::Returns; +} + +ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatBlock* block) +{ + ScopePtr innerScope = childScope(block, scope); + + ControlFlow flow = visitBlockWithoutChildScope(innerScope, block); + + // An AstStatBlock has linear control flow, i.e. one entry and one exit, so we can inherit + // all the changes to the environment occurred by the statements in that block. + scope->inheritRefinements(innerScope); + scope->inheritAssignments(innerScope); + + return flow; +} + +// TODO Clip? +static void bindFreeType(TypeId a, TypeId b) +{ + FreeType* af = getMutable(a); + FreeType* bf = getMutable(b); + + LUAU_ASSERT(af || bf); + + if (!bf) + emplaceType(asMutable(a), b); + else if (!af) + emplaceType(asMutable(b), a); + else if (subsumes(bf->scope, af->scope)) + emplaceType(asMutable(a), b); + else if (subsumes(af->scope, bf->scope)) + emplaceType(asMutable(b), a); +} + +ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatAssign* assign) +{ + TypePackId resultPack = checkPack(scope, assign->values).tp; + + std::vector valueTypes; + valueTypes.reserve(assign->vars.size); + + auto [head, tail] = flatten(resultPack); + if (head.size() >= assign->vars.size) + { + // If the resultPack is definitely long enough for each variable, we can + // skip the UnpackConstraint and use the result types directly. + + for (size_t i = 0; i < assign->vars.size; ++i) + valueTypes.push_back(head[i]); + } + else + { + // We're not sure how many types are produced by the right-side + // expressions. We'll use an UnpackConstraint to defer this until + // later. + for (size_t i = 0; i < assign->vars.size; ++i) + valueTypes.push_back(arena->addType(BlockedType{})); + + auto uc = addConstraint(scope, assign->location, UnpackConstraint{valueTypes, resultPack}); + + for (TypeId t : valueTypes) + getMutable(t)->setOwner(uc); + } + + for (size_t i = 0; i < assign->vars.size; ++i) + { + visitLValue(scope, assign->vars.data[i], valueTypes[i]); + } + + return ControlFlow::None; +} + +ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatCompoundAssign* assign) +{ + AstExprBinary binop = AstExprBinary{assign->location, assign->op, assign->var, assign->value}; + TypeId resultTy = check(scope, &binop).ty; + module->astCompoundAssignResultTypes[assign] = resultTy; + + TypeId lhsType = check(scope, assign->var).ty; + visitLValue(scope, assign->var, lhsType); + + follow(lhsType); + follow(resultTy); + + return ControlFlow::None; +} + +ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatIf* ifStatement) +{ + RefinementId refinement = [&]() + { + InConditionalContext flipper{&typeContext}; + return check(scope, ifStatement->condition, std::nullopt).refinement; + }(); + + ScopePtr thenScope = childScope(ifStatement->thenbody, scope); + applyRefinements(thenScope, ifStatement->condition->location, refinement); + + ScopePtr elseScope = childScope(ifStatement->elsebody ? ifStatement->elsebody : ifStatement, scope); + applyRefinements(elseScope, ifStatement->elseLocation.value_or(ifStatement->condition->location), refinementArena.negation(refinement)); + + ControlFlow thencf = visit(thenScope, ifStatement->thenbody); + ControlFlow elsecf = ControlFlow::None; + if (ifStatement->elsebody) + elsecf = visit(elseScope, ifStatement->elsebody); + + if (thencf != ControlFlow::None && elsecf == ControlFlow::None) + scope->inheritRefinements(elseScope); + else if (thencf == ControlFlow::None && elsecf != ControlFlow::None) + scope->inheritRefinements(thenScope); + + if (thencf == ControlFlow::None) + scope->inheritAssignments(thenScope); + if (elsecf == ControlFlow::None) + scope->inheritAssignments(elseScope); + + if (thencf == elsecf) + return thencf; + else if (matches(thencf, ControlFlow::Returns | ControlFlow::Throws) && matches(elsecf, ControlFlow::Returns | ControlFlow::Throws)) + return ControlFlow::Returns; + else + return ControlFlow::None; +} + +ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatTypeAlias* alias) +{ + if (alias->name == kParseNameError) + return ControlFlow::None; + + if (alias->name == "typeof") + { + reportError(alias->location, GenericError{"Type aliases cannot be named typeof"}); + return ControlFlow::None; + } + + scope->typeAliasLocations[alias->name.value] = alias->location; + scope->typeAliasNameLocations[alias->name.value] = alias->nameLocation; + + ScopePtr* defnScope = astTypeAliasDefiningScopes.find(alias); + + std::unordered_map* typeBindings; + if (alias->exported) + typeBindings = &scope->exportedTypeBindings; + else + typeBindings = &scope->privateTypeBindings; + + // These will be undefined if the alias was a duplicate definition, in which + // case we just skip over it. + auto bindingIt = typeBindings->find(alias->name.value); + if (bindingIt == typeBindings->end() || defnScope == nullptr) + return ControlFlow::None; + + TypeId ty = resolveType(*defnScope, alias->type, /* inTypeArguments */ false, /* replaceErrorWithFresh */ false); + + TypeId aliasTy = bindingIt->second.type; + LUAU_ASSERT(get(aliasTy)); + if (occursCheck(aliasTy, ty)) + { + emplaceType(asMutable(aliasTy), builtinTypes->anyType); + reportError(alias->nameLocation, OccursCheckFailed{}); + } + else + emplaceType(asMutable(aliasTy), ty); + + std::vector typeParams; + for (auto tyParam : createGenerics(*defnScope, alias->generics, /* useCache */ true, /* addTypes */ false)) + typeParams.push_back(tyParam.second.ty); + + std::vector typePackParams; + for (auto tpParam : createGenericPacks(*defnScope, alias->genericPacks, /* useCache */ true, /* addTypes */ false)) + typePackParams.push_back(tpParam.second.tp); + + addConstraint( + scope, + alias->type->location, + NameConstraint{ + ty, + alias->name.value, + /*synthetic=*/false, + std::move(typeParams), + std::move(typePackParams), + } + ); + + return ControlFlow::None; +} + +ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatTypeFunction* function) +{ + return ControlFlow::None; +} + +ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatDeclareGlobal* global) +{ + LUAU_ASSERT(global->type); + + TypeId globalTy = resolveType(scope, global->type, /* inTypeArguments */ false); + Name globalName(global->name.value); + + module->declaredGlobals[globalName] = globalTy; + rootScope->bindings[global->name] = Binding{globalTy, global->location}; + + DefId def = dfg->getDef(global); + rootScope->lvalueTypes[def] = globalTy; + rootScope->rvalueRefinements[def] = globalTy; + + return ControlFlow::None; +} + +static bool isMetamethod(const Name& name) +{ + return name == "__index" || name == "__newindex" || name == "__call" || name == "__concat" || name == "__unm" || name == "__add" || + name == "__sub" || name == "__mul" || name == "__div" || name == "__mod" || name == "__pow" || name == "__tostring" || + name == "__metatable" || name == "__eq" || name == "__lt" || name == "__le" || name == "__mode" || name == "__iter" || name == "__len" || + name == "__idiv"; +} + +ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatDeclareClass* declaredClass) +{ + std::optional superTy = std::make_optional(builtinTypes->classType); + if (declaredClass->superName) + { + Name superName = Name(declaredClass->superName->value); + std::optional lookupType = scope->lookupType(superName); + + if (!lookupType) + { + reportError(declaredClass->location, UnknownSymbol{superName, UnknownSymbol::Type}); + return ControlFlow::None; + } + + // We don't have generic classes, so this assertion _should_ never be hit. + LUAU_ASSERT(lookupType->typeParams.size() == 0 && lookupType->typePackParams.size() == 0); + superTy = lookupType->type; + + if (!get(follow(*superTy))) + { + reportError( + declaredClass->location, + GenericError{format("Cannot use non-class type '%s' as a superclass of class '%s'", superName.c_str(), declaredClass->name.value)} + ); + + return ControlFlow::None; + } + } + + Name className(declaredClass->name.value); + + TypeId classTy = arena->addType(ClassType(className, {}, superTy, std::nullopt, {}, {}, module->name, declaredClass->location)); + ClassType* ctv = getMutable(classTy); + + TypeId metaTy = arena->addType(TableType{TableState::Sealed, scope->level, scope.get()}); + TableType* metatable = getMutable(metaTy); + + ctv->metatable = metaTy; + + scope->exportedTypeBindings[className] = TypeFun{{}, classTy}; + + if (declaredClass->indexer) + { + RecursionCounter counter{&recursionCount}; + + if (recursionCount >= FInt::LuauCheckRecursionLimit) + { + reportCodeTooComplex(declaredClass->indexer->location); + } + else + { + ctv->indexer = TableIndexer{ + resolveType(scope, declaredClass->indexer->indexType, /* inTypeArguments */ false), + resolveType(scope, declaredClass->indexer->resultType, /* inTypeArguments */ false), + }; + } + } + + for (const AstDeclaredClassProp& prop : declaredClass->props) + { + Name propName(prop.name.value); + TypeId propTy = resolveType(scope, prop.ty, /* inTypeArguments */ false); + + bool assignToMetatable = isMetamethod(propName); + + // Function types always take 'self', but this isn't reflected in the + // parsed annotation. Add it here. + if (prop.isMethod) + { + if (FunctionType* ftv = getMutable(propTy)) + { + ftv->argNames.insert(ftv->argNames.begin(), FunctionArgument{"self", {}}); + ftv->argTypes = addTypePack({classTy}, ftv->argTypes); + + ftv->hasSelf = true; + + if (FFlag::LuauDeclarationExtraPropData) + { + FunctionDefinition defn; + + defn.definitionModuleName = module->name; + defn.definitionLocation = prop.location; + // No data is preserved for varargLocation + defn.originalNameLocation = prop.nameLocation; + + ftv->definition = defn; + } + } + } + + TableType::Props& props = assignToMetatable ? metatable->props : ctv->props; + + if (props.count(propName) == 0) + { + if (FFlag::LuauDeclarationExtraPropData) + props[propName] = {propTy, /*deprecated*/ false, /*deprecatedSuggestion*/ "", prop.location}; + else + props[propName] = {propTy}; + } + else if (FFlag::LuauDeclarationExtraPropData) + { + Luau::Property& prop = props[propName]; + TypeId currentTy = prop.type(); + + // We special-case this logic to keep the intersection flat; otherwise we + // would create a ton of nested intersection types. + if (const IntersectionType* itv = get(currentTy)) + { + std::vector options = itv->parts; + options.push_back(propTy); + TypeId newItv = arena->addType(IntersectionType{std::move(options)}); + + prop.readTy = newItv; + prop.writeTy = newItv; + } + else if (get(currentTy)) + { + TypeId intersection = arena->addType(IntersectionType{{currentTy, propTy}}); + + prop.readTy = intersection; + prop.writeTy = intersection; + } + else + { + reportError(declaredClass->location, GenericError{format("Cannot overload non-function class member '%s'", propName.c_str())}); + } + } + else + { + TypeId currentTy = props[propName].type(); + + // We special-case this logic to keep the intersection flat; otherwise we + // would create a ton of nested intersection types. + if (const IntersectionType* itv = get(currentTy)) + { + std::vector options = itv->parts; + options.push_back(propTy); + TypeId newItv = arena->addType(IntersectionType{std::move(options)}); + + props[propName] = {newItv}; + } + else if (get(currentTy)) + { + TypeId intersection = arena->addType(IntersectionType{{currentTy, propTy}}); + + props[propName] = {intersection}; + } + else + { + reportError(declaredClass->location, GenericError{format("Cannot overload non-function class member '%s'", propName.c_str())}); + } + } + } + + return ControlFlow::None; +} + +ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatDeclareFunction* global) +{ + std::vector> generics = createGenerics(scope, global->generics); + std::vector> genericPacks = createGenericPacks(scope, global->genericPacks); + + std::vector genericTys; + genericTys.reserve(generics.size()); + for (auto& [name, generic] : generics) + { + genericTys.push_back(generic.ty); + } + + std::vector genericTps; + genericTps.reserve(genericPacks.size()); + for (auto& [name, generic] : genericPacks) + { + genericTps.push_back(generic.tp); + } + + ScopePtr funScope = scope; + if (!generics.empty() || !genericPacks.empty()) + funScope = childScope(global, scope); + + TypePackId paramPack = resolveTypePack(funScope, global->params, /* inTypeArguments */ false); + TypePackId retPack = resolveTypePack(funScope, global->retTypes, /* inTypeArguments */ false); + + FunctionDefinition defn; + + if (FFlag::LuauDeclarationExtraPropData) + { + defn.definitionModuleName = module->name; + defn.definitionLocation = global->location; + defn.varargLocation = global->vararg ? std::make_optional(global->varargLocation) : std::nullopt; + defn.originalNameLocation = global->nameLocation; + } + + TypeId fnType = arena->addType(FunctionType{TypeLevel{}, funScope.get(), std::move(genericTys), std::move(genericTps), paramPack, retPack, defn}); + FunctionType* ftv = getMutable(fnType); + ftv->isCheckedFunction = global->isCheckedFunction(); + + ftv->argNames.reserve(global->paramNames.size); + for (const auto& el : global->paramNames) + ftv->argNames.push_back(FunctionArgument{el.first.value, el.second}); + + Name fnName(global->name.value); + + module->declaredGlobals[fnName] = fnType; + scope->bindings[global->name] = Binding{fnType, global->location}; + + DefId def = dfg->getDef(global); + rootScope->lvalueTypes[def] = fnType; + rootScope->rvalueRefinements[def] = fnType; + + return ControlFlow::None; +} + +ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatError* error) +{ + for (AstStat* stat : error->statements) + visit(scope, stat); + for (AstExpr* expr : error->expressions) + check(scope, expr); + + return ControlFlow::None; +} + +InferencePack ConstraintGenerator::checkPack(const ScopePtr& scope, AstArray exprs, const std::vector>& expectedTypes) +{ + std::vector head; + std::optional tail; + + for (size_t i = 0; i < exprs.size; ++i) + { + AstExpr* expr = exprs.data[i]; + if (i < exprs.size - 1) + { + std::optional expectedType; + if (i < expectedTypes.size()) + expectedType = expectedTypes[i]; + head.push_back(check(scope, expr, expectedType).ty); + } + else + { + std::vector> expectedTailTypes; + if (i < expectedTypes.size()) + expectedTailTypes.assign(begin(expectedTypes) + i, end(expectedTypes)); + tail = checkPack(scope, expr, expectedTailTypes).tp; + } + } + + return InferencePack{addTypePack(std::move(head), tail)}; +} + +InferencePack ConstraintGenerator::checkPack( + const ScopePtr& scope, + AstExpr* expr, + const std::vector>& expectedTypes, + bool generalize +) +{ + RecursionCounter counter{&recursionCount}; + + if (recursionCount >= FInt::LuauCheckRecursionLimit) + { + reportCodeTooComplex(expr->location); + return InferencePack{builtinTypes->errorRecoveryTypePack()}; + } + + InferencePack result; + + if (AstExprCall* call = expr->as()) + result = checkPack(scope, call); + else if (AstExprVarargs* varargs = expr->as()) + { + if (scope->varargPack) + result = InferencePack{*scope->varargPack}; + else + result = InferencePack{builtinTypes->errorRecoveryTypePack()}; + } + else + { + std::optional expectedType; + if (!expectedTypes.empty()) + expectedType = expectedTypes[0]; + TypeId t = check(scope, expr, expectedType, /*forceSingletons*/ false, generalize).ty; + result = InferencePack{arena->addTypePack({t})}; + } + + LUAU_ASSERT(result.tp); + module->astTypePacks[expr] = result.tp; + return result; +} + +InferencePack ConstraintGenerator::checkPack(const ScopePtr& scope, AstExprCall* call) +{ + std::vector exprArgs; + + std::vector returnRefinements; + std::vector> discriminantTypes; + + if (call->self) + { + AstExprIndexName* indexExpr = call->func->as(); + if (!indexExpr) + ice->ice("method call expression has no 'self'"); + + exprArgs.push_back(indexExpr->expr); + + if (auto key = dfg->getRefinementKey(indexExpr->expr)) + { + TypeId discriminantTy = arena->addType(BlockedType{}); + returnRefinements.push_back(refinementArena.proposition(key, discriminantTy)); + discriminantTypes.push_back(discriminantTy); + } + else + discriminantTypes.push_back(std::nullopt); + } + + for (AstExpr* arg : call->args) + { + exprArgs.push_back(arg); + + if (auto key = dfg->getRefinementKey(arg)) + { + TypeId discriminantTy = arena->addType(BlockedType{}); + returnRefinements.push_back(refinementArena.proposition(key, discriminantTy)); + discriminantTypes.push_back(discriminantTy); + } + else + discriminantTypes.push_back(std::nullopt); + } + + Checkpoint funcBeginCheckpoint = checkpoint(this); + + TypeId fnType = check(scope, call->func).ty; + + Checkpoint funcEndCheckpoint = checkpoint(this); + + std::vector> expectedTypesForCall = getExpectedCallTypesForFunctionOverloads(fnType); + + module->astOriginalCallTypes[call] = fnType; + + Checkpoint argBeginCheckpoint = checkpoint(this); + + std::vector args; + std::optional argTail; + std::vector argumentRefinements; + + for (size_t i = 0; i < exprArgs.size(); ++i) + { + AstExpr* arg = exprArgs[i]; + + if (i == 0 && call->self) + { + // The self type has already been computed as a side effect of + // computing fnType. If computing that did not cause us to exceed a + // recursion limit, we can fetch it from astTypes rather than + // recomputing it. + TypeId* selfTy = module->astTypes.find(exprArgs[0]); + if (selfTy) + args.push_back(*selfTy); + else + args.push_back(freshType(scope)); + } + else if (i < exprArgs.size() - 1 || !(arg->is() || arg->is())) + { + auto [ty, refinement] = check(scope, arg, /*expectedType*/ std::nullopt, /*forceSingleton*/ false, /*generalize*/ false); + args.push_back(ty); + argumentRefinements.push_back(refinement); + } + else + { + auto [tp, refis] = checkPack(scope, arg, {}); + argTail = tp; + argumentRefinements.insert(argumentRefinements.end(), refis.begin(), refis.end()); + } + } + + Checkpoint argEndCheckpoint = checkpoint(this); + + if (matchSetmetatable(*call)) + { + TypePack argTailPack; + if (argTail && args.size() < 2) + argTailPack = extendTypePack(*arena, builtinTypes, *argTail, 2 - args.size()); + + TypeId target = nullptr; + TypeId mt = nullptr; + + if (args.size() + argTailPack.head.size() == 2) + { + target = args.size() > 0 ? args[0] : argTailPack.head[0]; + mt = args.size() > 1 ? args[1] : argTailPack.head[args.size() == 0 ? 1 : 0]; + } + else + { + std::vector unpackedTypes; + if (args.size() > 0) + target = args[0]; + else + { + target = arena->addType(BlockedType{}); + unpackedTypes.emplace_back(target); + } + + mt = arena->addType(BlockedType{}); + unpackedTypes.emplace_back(mt); + + auto c = addConstraint(scope, call->location, UnpackConstraint{unpackedTypes, *argTail}); + getMutable(mt)->setOwner(c); + if (auto b = getMutable(target); b && b->getOwner() == nullptr) + b->setOwner(c); + } + + LUAU_ASSERT(target); + LUAU_ASSERT(mt); + + target = follow(target); + + AstExpr* targetExpr = call->args.data[0]; + + TypeId resultTy = nullptr; + + if (isTableUnion(target)) + { + const UnionType* targetUnion = get(target); + std::vector newParts; + + for (TypeId ty : targetUnion) + newParts.push_back(arena->addType(MetatableType{ty, mt})); + + resultTy = arena->addType(UnionType{std::move(newParts)}); + } + else + resultTy = arena->addType(MetatableType{target, mt}); + + if (AstExprLocal* targetLocal = targetExpr->as()) + { + scope->bindings[targetLocal->local].typeId = resultTy; + + DefId def = dfg->getDef(targetLocal); + scope->lvalueTypes[def] = resultTy; // TODO: typestates: track this as an assignment + scope->rvalueRefinements[def] = resultTy; // TODO: typestates: track this as an assignment + + // HACK: If we have a targetLocal, it has already been added to the + // inferredBindings table. We want to replace it so that we don't + // infer a weird union like tbl | { @metatable something, tbl } + if (InferredBinding* ib = inferredBindings.find(targetLocal->local)) + ib->types.erase(target); + + recordInferredBinding(targetLocal->local, resultTy); + } + + return InferencePack{arena->addTypePack({resultTy}), {refinementArena.variadic(returnRefinements)}}; + } + else + { + if (matchAssert(*call) && !argumentRefinements.empty()) + applyRefinements(scope, call->args.data[0]->location, argumentRefinements[0]); + + // TODO: How do expectedTypes play into this? Do they? + TypePackId rets = arena->addTypePack(BlockedTypePack{}); + TypePackId argPack = addTypePack(std::move(args), argTail); + FunctionType ftv(TypeLevel{}, scope.get(), argPack, rets, std::nullopt, call->self); + + /* + * To make bidirectional type checking work, we need to solve these constraints in a particular order: + * + * 1. Solve the function type + * 2. Propagate type information from the function type to the argument types + * 3. Solve the argument types + * 4. Solve the call + */ + + NotNull checkConstraint = addConstraint( + scope, + call->func->location, + FunctionCheckConstraint{fnType, argPack, call, NotNull{&module->astTypes}, NotNull{&module->astExpectedTypes}} + ); + + forEachConstraint( + funcBeginCheckpoint, + funcEndCheckpoint, + this, + [checkConstraint](const ConstraintPtr& constraint) + { + checkConstraint->dependencies.emplace_back(constraint.get()); + } + ); + + NotNull callConstraint = addConstraint( + scope, + call->func->location, + FunctionCallConstraint{ + fnType, + argPack, + rets, + call, + std::move(discriminantTypes), + &module->astOverloadResolvedTypes, + } + ); + + getMutable(rets)->owner = callConstraint.get(); + + callConstraint->dependencies.push_back(checkConstraint); + + forEachConstraint( + argBeginCheckpoint, + argEndCheckpoint, + this, + [checkConstraint, callConstraint](const ConstraintPtr& constraint) + { + constraint->dependencies.emplace_back(checkConstraint); + + callConstraint->dependencies.emplace_back(constraint.get()); + } + ); + + return InferencePack{rets, {refinementArena.variadic(returnRefinements)}}; + } +} + +Inference ConstraintGenerator::check(const ScopePtr& scope, AstExpr* expr, std::optional expectedType, bool forceSingleton, bool generalize) +{ + RecursionCounter counter{&recursionCount}; + + if (recursionCount >= FInt::LuauCheckRecursionLimit) + { + reportCodeTooComplex(expr->location); + return Inference{builtinTypes->errorRecoveryType()}; + } + + Inference result; + + if (auto group = expr->as()) + result = check(scope, group->expr, expectedType, forceSingleton); + else if (auto stringExpr = expr->as()) + result = check(scope, stringExpr, expectedType, forceSingleton); + else if (expr->is()) + result = Inference{builtinTypes->numberType}; + else if (auto boolExpr = expr->as()) + result = check(scope, boolExpr, expectedType, forceSingleton); + else if (expr->is()) + result = Inference{builtinTypes->nilType}; + else if (auto local = expr->as()) + result = check(scope, local); + else if (auto global = expr->as()) + result = check(scope, global); + else if (expr->is()) + result = flattenPack(scope, expr->location, checkPack(scope, expr)); + else if (auto call = expr->as()) + result = flattenPack(scope, expr->location, checkPack(scope, call)); // TODO: needs predicates too + else if (auto a = expr->as()) + result = check(scope, a, expectedType, generalize); + else if (auto indexName = expr->as()) + result = check(scope, indexName); + else if (auto indexExpr = expr->as()) + result = check(scope, indexExpr); + else if (auto table = expr->as()) + result = check(scope, table, expectedType); + else if (auto unary = expr->as()) + result = check(scope, unary); + else if (auto binary = expr->as()) + result = check(scope, binary, expectedType); + else if (auto ifElse = expr->as()) + result = check(scope, ifElse, expectedType); + else if (auto typeAssert = expr->as()) + result = check(scope, typeAssert); + else if (auto interpString = expr->as()) + result = check(scope, interpString); + else if (auto err = expr->as()) + { + // Open question: Should we traverse into this? + for (AstExpr* subExpr : err->expressions) + check(scope, subExpr); + + result = Inference{builtinTypes->errorRecoveryType()}; + } + else + { + LUAU_ASSERT(0); + result = Inference{freshType(scope)}; + } + + LUAU_ASSERT(result.ty); + module->astTypes[expr] = result.ty; + if (expectedType) + module->astExpectedTypes[expr] = *expectedType; + return result; +} + +Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprConstantString* string, std::optional expectedType, bool forceSingleton) +{ + if (forceSingleton) + return Inference{arena->addType(SingletonType{StringSingleton{std::string{string->value.data, string->value.size}}})}; + + FreeType ft = FreeType{scope.get()}; + ft.lowerBound = arena->addType(SingletonType{StringSingleton{std::string{string->value.data, string->value.size}}}); + ft.upperBound = builtinTypes->stringType; + const TypeId freeTy = arena->addType(ft); + addConstraint(scope, string->location, PrimitiveTypeConstraint{freeTy, expectedType, builtinTypes->stringType}); + return Inference{freeTy}; +} + +Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprConstantBool* boolExpr, std::optional expectedType, bool forceSingleton) +{ + const TypeId singletonType = boolExpr->value ? builtinTypes->trueType : builtinTypes->falseType; + if (forceSingleton) + return Inference{singletonType}; + + FreeType ft = FreeType{scope.get()}; + ft.lowerBound = singletonType; + ft.upperBound = builtinTypes->booleanType; + const TypeId freeTy = arena->addType(ft); + addConstraint(scope, boolExpr->location, PrimitiveTypeConstraint{freeTy, expectedType, builtinTypes->booleanType}); + return Inference{freeTy}; +} + +Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprLocal* local) +{ + const RefinementKey* key = dfg->getRefinementKey(local); + std::optional rvalueDef = dfg->getRValueDefForCompoundAssign(local); + LUAU_ASSERT(key || rvalueDef); + + std::optional maybeTy; + + // if we have a refinement key, we can look up its type. + if (key) + maybeTy = lookup(scope, local->location, key->def); + + // if the current def doesn't have a type, we might be doing a compound assignment + // and therefore might need to look at the rvalue def instead. + if (!maybeTy && rvalueDef) + maybeTy = lookup(scope, local->location, *rvalueDef); + + if (maybeTy) + { + TypeId ty = follow(*maybeTy); + + recordInferredBinding(local->local, ty); + + return Inference{ty, refinementArena.proposition(key, builtinTypes->truthyType)}; + } + else + ice->ice("CG: AstExprLocal came before its declaration?"); +} + +Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprGlobal* global) +{ + const RefinementKey* key = dfg->getRefinementKey(global); + std::optional rvalueDef = dfg->getRValueDefForCompoundAssign(global); + LUAU_ASSERT(key || rvalueDef); + + // we'll use whichever of the two definitions we have here. + DefId def = key ? key->def : *rvalueDef; + + /* prepopulateGlobalScope() has already added all global functions to the environment by this point, so any + * global that is not already in-scope is definitely an unknown symbol. + */ + if (auto ty = lookup(scope, global->location, def, /*prototype=*/false)) + { + rootScope->lvalueTypes[def] = *ty; + return Inference{*ty, refinementArena.proposition(key, builtinTypes->truthyType)}; + } + else + return Inference{builtinTypes->errorRecoveryType()}; +} + +Inference ConstraintGenerator::checkIndexName( + const ScopePtr& scope, + const RefinementKey* key, + AstExpr* indexee, + const std::string& index, + Location indexLocation +) +{ + TypeId obj = check(scope, indexee).ty; + TypeId result = nullptr; + + // We optimize away the HasProp constraint in simple cases so that we can + // reason about updates to unsealed tables more accurately. + + const TableType* tt = getTableType(obj); + + // This is a little bit iffy but I *believe* it is okay because, if the + // local's domain is going to be extended at all, it will be someplace after + // the current lexical position within the script. + if (!tt) + { + if (TypeIds* localDomain = localTypes.find(obj); localDomain && 1 == localDomain->size()) + tt = getTableType(*localDomain->begin()); + } + + if (tt) + { + auto it = tt->props.find(index); + if (it != tt->props.end() && it->second.readTy.has_value()) + result = *it->second.readTy; + } + + if (!result) + { + result = arena->addType(BlockedType{}); + + auto c = addConstraint( + scope, indexee->location, HasPropConstraint{result, obj, std::move(index), ValueContext::RValue, inConditional(typeContext)} + ); + getMutable(result)->setOwner(c); + } + + if (key) + { + if (auto ty = lookup(scope, indexLocation, key->def)) + return Inference{*ty, refinementArena.proposition(key, builtinTypes->truthyType)}; + + scope->rvalueRefinements[key->def] = result; + } + + if (key) + return Inference{result, refinementArena.proposition(key, builtinTypes->truthyType)}; + else + return Inference{result}; +} + +Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprIndexName* indexName) +{ + const RefinementKey* key = dfg->getRefinementKey(indexName); + return checkIndexName(scope, key, indexName->expr, indexName->index.value, indexName->indexLocation); +} + +Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprIndexExpr* indexExpr) +{ + if (auto constantString = indexExpr->index->as()) + { + const RefinementKey* key = dfg->getRefinementKey(indexExpr); + return checkIndexName(scope, key, indexExpr->expr, constantString->value.data, indexExpr->location); + } + + TypeId obj = check(scope, indexExpr->expr).ty; + TypeId indexType = check(scope, indexExpr->index).ty; + + TypeId result = arena->addType(BlockedType{}); + + const RefinementKey* key = dfg->getRefinementKey(indexExpr); + if (key) + { + if (auto ty = lookup(scope, indexExpr->location, key->def)) + return Inference{*ty, refinementArena.proposition(key, builtinTypes->truthyType)}; + + scope->rvalueRefinements[key->def] = result; + } + + auto c = addConstraint(scope, indexExpr->expr->location, HasIndexerConstraint{result, obj, indexType}); + getMutable(result)->setOwner(c); + + if (key) + return Inference{result, refinementArena.proposition(key, builtinTypes->truthyType)}; + else + return Inference{result}; +} + +Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprFunction* func, std::optional expectedType, bool generalize) +{ + Checkpoint startCheckpoint = checkpoint(this); + FunctionSignature sig = checkFunctionSignature(scope, func, expectedType); + + interiorTypes.push_back(std::vector{}); + checkFunctionBody(sig.bodyScope, func); + Checkpoint endCheckpoint = checkpoint(this); + + TypeId generalizedTy = arena->addType(BlockedType{}); + NotNull gc = + addConstraint(sig.signatureScope, func->location, GeneralizationConstraint{generalizedTy, sig.signature, std::move(interiorTypes.back())}); + getMutable(generalizedTy)->setOwner(gc); + interiorTypes.pop_back(); + + Constraint* previous = nullptr; + forEachConstraint( + startCheckpoint, + endCheckpoint, + this, + [gc, &previous](const ConstraintPtr& constraint) + { + gc->dependencies.emplace_back(constraint.get()); + + if (auto psc = get(*constraint); psc && psc->returns) + { + if (previous) + constraint->dependencies.push_back(NotNull{previous}); + + previous = constraint.get(); + } + } + ); + + if (generalize && hasFreeType(sig.signature)) + { + return Inference{generalizedTy}; + } + else + { + return Inference{sig.signature}; + } +} + +Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprUnary* unary) +{ + auto [operandType, refinement] = check(scope, unary->expr); + + switch (unary->op) + { + case AstExprUnary::Op::Not: + { + TypeId resultType = createTypeFunctionInstance(builtinTypeFunctions().notFunc, {operandType}, {}, scope, unary->location); + return Inference{resultType, refinementArena.negation(refinement)}; + } + case AstExprUnary::Op::Len: + { + TypeId resultType = createTypeFunctionInstance(builtinTypeFunctions().lenFunc, {operandType}, {}, scope, unary->location); + return Inference{resultType, refinementArena.negation(refinement)}; + } + case AstExprUnary::Op::Minus: + { + TypeId resultType = createTypeFunctionInstance(builtinTypeFunctions().unmFunc, {operandType}, {}, scope, unary->location); + return Inference{resultType, refinementArena.negation(refinement)}; + } + default: // msvc can't prove that this is exhaustive. + LUAU_UNREACHABLE(); + } +} + +Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprBinary* binary, std::optional expectedType) +{ + auto [leftType, rightType, refinement] = checkBinary(scope, binary, expectedType); + + switch (binary->op) + { + case AstExprBinary::Op::Add: + { + TypeId resultType = createTypeFunctionInstance(builtinTypeFunctions().addFunc, {leftType, rightType}, {}, scope, binary->location); + return Inference{resultType, std::move(refinement)}; + } + case AstExprBinary::Op::Sub: + { + TypeId resultType = createTypeFunctionInstance(builtinTypeFunctions().subFunc, {leftType, rightType}, {}, scope, binary->location); + return Inference{resultType, std::move(refinement)}; + } + case AstExprBinary::Op::Mul: + { + TypeId resultType = createTypeFunctionInstance(builtinTypeFunctions().mulFunc, {leftType, rightType}, {}, scope, binary->location); + return Inference{resultType, std::move(refinement)}; + } + case AstExprBinary::Op::Div: + { + TypeId resultType = createTypeFunctionInstance(builtinTypeFunctions().divFunc, {leftType, rightType}, {}, scope, binary->location); + return Inference{resultType, std::move(refinement)}; + } + case AstExprBinary::Op::FloorDiv: + { + TypeId resultType = createTypeFunctionInstance(builtinTypeFunctions().idivFunc, {leftType, rightType}, {}, scope, binary->location); + return Inference{resultType, std::move(refinement)}; + } + case AstExprBinary::Op::Pow: + { + TypeId resultType = createTypeFunctionInstance(builtinTypeFunctions().powFunc, {leftType, rightType}, {}, scope, binary->location); + return Inference{resultType, std::move(refinement)}; + } + case AstExprBinary::Op::Mod: + { + TypeId resultType = createTypeFunctionInstance(builtinTypeFunctions().modFunc, {leftType, rightType}, {}, scope, binary->location); + return Inference{resultType, std::move(refinement)}; + } + case AstExprBinary::Op::Concat: + { + TypeId resultType = createTypeFunctionInstance(builtinTypeFunctions().concatFunc, {leftType, rightType}, {}, scope, binary->location); + return Inference{resultType, std::move(refinement)}; + } + case AstExprBinary::Op::And: + { + TypeId resultType = createTypeFunctionInstance(builtinTypeFunctions().andFunc, {leftType, rightType}, {}, scope, binary->location); + return Inference{resultType, std::move(refinement)}; + } + case AstExprBinary::Op::Or: + { + TypeId resultType = createTypeFunctionInstance(builtinTypeFunctions().orFunc, {leftType, rightType}, {}, scope, binary->location); + return Inference{resultType, std::move(refinement)}; + } + case AstExprBinary::Op::CompareLt: + { + TypeId resultType = createTypeFunctionInstance(builtinTypeFunctions().ltFunc, {leftType, rightType}, {}, scope, binary->location); + return Inference{resultType, std::move(refinement)}; + } + case AstExprBinary::Op::CompareGe: + { + TypeId resultType = createTypeFunctionInstance( + builtinTypeFunctions().ltFunc, + {rightType, leftType}, // lua decided that `__ge(a, b)` is instead just `__lt(b, a)` + {}, + scope, + binary->location + ); + return Inference{resultType, std::move(refinement)}; + } + case AstExprBinary::Op::CompareLe: + { + TypeId resultType = createTypeFunctionInstance(builtinTypeFunctions().leFunc, {leftType, rightType}, {}, scope, binary->location); + return Inference{resultType, std::move(refinement)}; + } + case AstExprBinary::Op::CompareGt: + { + TypeId resultType = createTypeFunctionInstance( + builtinTypeFunctions().leFunc, + {rightType, leftType}, // lua decided that `__gt(a, b)` is instead just `__le(b, a)` + {}, + scope, + binary->location + ); + return Inference{resultType, std::move(refinement)}; + } + case AstExprBinary::Op::CompareEq: + case AstExprBinary::Op::CompareNe: + { + DefId leftDef = dfg->getDef(binary->left); + DefId rightDef = dfg->getDef(binary->right); + bool leftSubscripted = containsSubscriptedDefinition(leftDef); + bool rightSubscripted = containsSubscriptedDefinition(rightDef); + + if (leftSubscripted && rightSubscripted) + { + // we cannot add nil in this case because then we will blindly accept comparisons that we should not. + } + else if (leftSubscripted) + leftType = makeUnion(scope, binary->location, leftType, builtinTypes->nilType); + else if (rightSubscripted) + rightType = makeUnion(scope, binary->location, rightType, builtinTypes->nilType); + + TypeId resultType = createTypeFunctionInstance(builtinTypeFunctions().eqFunc, {leftType, rightType}, {}, scope, binary->location); + return Inference{resultType, std::move(refinement)}; + } + case AstExprBinary::Op::Op__Count: + ice->ice("Op__Count should never be generated in an AST."); + default: // msvc can't prove that this is exhaustive. + LUAU_UNREACHABLE(); + } +} + +Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprIfElse* ifElse, std::optional expectedType) +{ + RefinementId refinement = [&]() + { + InConditionalContext flipper{&typeContext}; + ScopePtr condScope = childScope(ifElse->condition, scope); + return check(condScope, ifElse->condition).refinement; + }(); + + ScopePtr thenScope = childScope(ifElse->trueExpr, scope); + applyRefinements(thenScope, ifElse->trueExpr->location, refinement); + TypeId thenType = check(thenScope, ifElse->trueExpr, expectedType).ty; + + ScopePtr elseScope = childScope(ifElse->falseExpr, scope); + applyRefinements(elseScope, ifElse->falseExpr->location, refinementArena.negation(refinement)); + TypeId elseType = check(elseScope, ifElse->falseExpr, expectedType).ty; + + return Inference{expectedType ? *expectedType : makeUnion(scope, ifElse->location, thenType, elseType)}; +} + +Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprTypeAssertion* typeAssert) +{ + check(scope, typeAssert->expr, std::nullopt); + return Inference{resolveType(scope, typeAssert->annotation, /* inTypeArguments */ false)}; +} + +Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprInterpString* interpString) +{ + for (AstExpr* expr : interpString->expressions) + check(scope, expr); + + return Inference{builtinTypes->stringType}; +} + +std::tuple ConstraintGenerator::checkBinary( + const ScopePtr& scope, + AstExprBinary* binary, + std::optional expectedType +) +{ + if (binary->op == AstExprBinary::And) + { + std::optional relaxedExpectedLhs; + + if (expectedType) + relaxedExpectedLhs = arena->addType(UnionType{{builtinTypes->falsyType, *expectedType}}); + + auto [leftType, leftRefinement] = check(scope, binary->left, relaxedExpectedLhs); + + ScopePtr rightScope = childScope(binary->right, scope); + applyRefinements(rightScope, binary->right->location, leftRefinement); + auto [rightType, rightRefinement] = check(rightScope, binary->right, expectedType); + + return {leftType, rightType, refinementArena.conjunction(leftRefinement, rightRefinement)}; + } + else if (binary->op == AstExprBinary::Or) + { + std::optional relaxedExpectedLhs; + + if (expectedType) + relaxedExpectedLhs = arena->addType(UnionType{{builtinTypes->falsyType, *expectedType}}); + + auto [leftType, leftRefinement] = check(scope, binary->left, relaxedExpectedLhs); + + ScopePtr rightScope = childScope(binary->right, scope); + applyRefinements(rightScope, binary->right->location, refinementArena.negation(leftRefinement)); + auto [rightType, rightRefinement] = check(rightScope, binary->right, expectedType); + + return {leftType, rightType, refinementArena.disjunction(leftRefinement, rightRefinement)}; + } + else if (auto typeguard = matchTypeGuard(binary)) + { + TypeId leftType = check(scope, binary->left).ty; + TypeId rightType = check(scope, binary->right).ty; + + const RefinementKey* key = dfg->getRefinementKey(typeguard->target); + if (!key) + return {leftType, rightType, nullptr}; + + TypeId discriminantTy = builtinTypes->neverType; + if (typeguard->type == "nil") + discriminantTy = builtinTypes->nilType; + else if (typeguard->type == "string") + discriminantTy = builtinTypes->stringType; + else if (typeguard->type == "number") + discriminantTy = builtinTypes->numberType; + else if (typeguard->type == "boolean") + discriminantTy = builtinTypes->booleanType; + else if (typeguard->type == "thread") + discriminantTy = builtinTypes->threadType; + else if (typeguard->type == "buffer") + discriminantTy = builtinTypes->bufferType; + else if (typeguard->type == "table") + discriminantTy = builtinTypes->tableType; + else if (typeguard->type == "function") + discriminantTy = builtinTypes->functionType; + else if (typeguard->type == "userdata") + { + // For now, we don't really care about being accurate with userdata if the typeguard was using typeof. + discriminantTy = builtinTypes->classType; + } + else if (!typeguard->isTypeof && typeguard->type == "vector") + discriminantTy = builtinTypes->neverType; // TODO: figure out a way to deal with this quirky type + else if (!typeguard->isTypeof) + discriminantTy = builtinTypes->neverType; + else if (auto typeFun = globalScope->lookupType(typeguard->type); typeFun && typeFun->typeParams.empty() && typeFun->typePackParams.empty()) + { + TypeId ty = follow(typeFun->type); + + // We're only interested in the root class of any classes. + if (auto ctv = get(ty); ctv && ctv->parent == builtinTypes->classType) + discriminantTy = ty; + } + + RefinementId proposition = refinementArena.proposition(key, discriminantTy); + if (binary->op == AstExprBinary::CompareEq) + return {leftType, rightType, proposition}; + else if (binary->op == AstExprBinary::CompareNe) + return {leftType, rightType, refinementArena.negation(proposition)}; + else + ice->ice("matchTypeGuard should only return a Some under `==` or `~=`!"); + } + else if (binary->op == AstExprBinary::CompareEq || binary->op == AstExprBinary::CompareNe) + { + // We are checking a binary expression of the form a op b + // Just because a op b is epxected to return a bool, doesn't mean a, b are expected to be bools too + TypeId leftType = check(scope, binary->left, {}, true).ty; + TypeId rightType = check(scope, binary->right, {}, true).ty; + + RefinementId leftRefinement = refinementArena.proposition(dfg->getRefinementKey(binary->left), rightType); + RefinementId rightRefinement = refinementArena.proposition(dfg->getRefinementKey(binary->right), leftType); + + if (binary->op == AstExprBinary::CompareNe) + { + leftRefinement = refinementArena.negation(leftRefinement); + rightRefinement = refinementArena.negation(rightRefinement); + } + + return {leftType, rightType, refinementArena.equivalence(leftRefinement, rightRefinement)}; + } + else + { + TypeId leftType = check(scope, binary->left).ty; + TypeId rightType = check(scope, binary->right).ty; + return {leftType, rightType, nullptr}; + } +} + +void ConstraintGenerator::visitLValue(const ScopePtr& scope, AstExpr* expr, TypeId rhsType) +{ + if (auto e = expr->as()) + visitLValue(scope, e, rhsType); + else if (auto e = expr->as()) + visitLValue(scope, e, rhsType); + else if (auto e = expr->as()) + visitLValue(scope, e, rhsType); + else if (auto e = expr->as()) + visitLValue(scope, e, rhsType); + else if (auto e = expr->as()) + { + // Nothing? + } + else + ice->ice("Unexpected lvalue expression", expr->location); +} + +void ConstraintGenerator::visitLValue(const ScopePtr& scope, AstExprLocal* local, TypeId rhsType) +{ + std::optional annotatedTy = scope->lookup(local->local); + LUAU_ASSERT(annotatedTy); + + const DefId defId = dfg->getDef(local); + std::optional ty = scope->lookupUnrefinedType(defId); + + if (ty) + { + TypeIds* localDomain = localTypes.find(*ty); + if (localDomain) + localDomain->insert(rhsType); + } + else + { + ty = arena->addType(BlockedType{}); + localTypes[*ty].insert(rhsType); + + if (annotatedTy) + { + switch (shouldSuppressErrors(normalizer, *annotatedTy)) + { + case ErrorSuppression::DoNotSuppress: + break; + case ErrorSuppression::Suppress: + ty = simplifyUnion(builtinTypes, arena, *ty, builtinTypes->errorType).result; + break; + case ErrorSuppression::NormalizationFailed: + reportError(local->local->annotation->location, NormalizationTooComplex{}); + break; + } + } + + scope->lvalueTypes[defId] = *ty; + } + + recordInferredBinding(local->local, *ty); + + if (annotatedTy) + addConstraint(scope, local->location, SubtypeConstraint{rhsType, *annotatedTy}); + + if (TypeIds* localDomain = localTypes.find(*ty)) + localDomain->insert(rhsType); +} + +void ConstraintGenerator::visitLValue(const ScopePtr& scope, AstExprGlobal* global, TypeId rhsType) +{ + std::optional annotatedTy = scope->lookup(Symbol{global->name}); + if (annotatedTy) + { + DefId def = dfg->getDef(global); + rootScope->lvalueTypes[def] = rhsType; + + addConstraint(scope, global->location, SubtypeConstraint{rhsType, *annotatedTy}); + } +} + +void ConstraintGenerator::visitLValue(const ScopePtr& scope, AstExprIndexName* expr, TypeId rhsType) +{ + TypeId lhsTy = check(scope, expr->expr).ty; + TypeId propTy = arena->addType(BlockedType{}); + module->astTypes[expr] = propTy; + + bool incremented = recordPropertyAssignment(lhsTy); + + auto apc = + addConstraint(scope, expr->location, AssignPropConstraint{lhsTy, expr->index.value, rhsType, expr->indexLocation, propTy, incremented}); + getMutable(propTy)->setOwner(apc); +} + +void ConstraintGenerator::visitLValue(const ScopePtr& scope, AstExprIndexExpr* expr, TypeId rhsType) +{ + if (auto constantString = expr->index->as()) + { + TypeId lhsTy = check(scope, expr->expr).ty; + TypeId propTy = arena->addType(BlockedType{}); + module->astTypes[expr] = propTy; + module->astTypes[expr->index] = builtinTypes->stringType; // FIXME? Singleton strings exist. + std::string propName{constantString->value.data, constantString->value.size}; + + bool incremented = recordPropertyAssignment(lhsTy); + + auto apc = addConstraint( + scope, expr->location, AssignPropConstraint{lhsTy, std::move(propName), rhsType, expr->index->location, propTy, incremented} + ); + getMutable(propTy)->setOwner(apc); + + return; + } + + TypeId lhsTy = check(scope, expr->expr).ty; + TypeId indexTy = check(scope, expr->index).ty; + TypeId propTy = arena->addType(BlockedType{}); + module->astTypes[expr] = propTy; + auto aic = addConstraint(scope, expr->location, AssignIndexConstraint{lhsTy, indexTy, rhsType, propTy}); + getMutable(propTy)->setOwner(aic); +} + +Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprTable* expr, std::optional expectedType) +{ + TypeId ty = arena->addType(TableType{}); + TableType* ttv = getMutable(ty); + LUAU_ASSERT(ttv); + + ttv->state = TableState::Unsealed; + ttv->definitionModuleName = module->name; + ttv->scope = scope.get(); + + interiorTypes.back().push_back(ty); + + TypeIds indexKeyLowerBound; + TypeIds indexValueLowerBound; + + auto createIndexer = [&indexKeyLowerBound, &indexValueLowerBound](const Location& location, TypeId currentIndexType, TypeId currentResultType) + { + indexKeyLowerBound.insert(follow(currentIndexType)); + indexValueLowerBound.insert(follow(currentResultType)); + }; + + TypeIds valuesLowerBound; + + for (const AstExprTable::Item& item : expr->items) + { + // Expected types are threaded through table literals separately via the + // function matchLiteralType. + + TypeId itemTy = check(scope, item.value).ty; + + if (item.key) + { + // Even though we don't need to use the type of the item's key if + // it's a string constant, we still want to check it to populate + // astTypes. + TypeId keyTy = check(scope, item.key).ty; + + if (AstExprConstantString* key = item.key->as()) + { + std::string propName{key->value.data, key->value.size}; + ttv->props[propName] = {itemTy, /*deprecated*/ false, {}, key->location}; + } + else + { + createIndexer(item.key->location, keyTy, itemTy); + } + } + else + { + TypeId numberType = builtinTypes->numberType; + // FIXME? The location isn't quite right here. Not sure what is + // right. + createIndexer(item.value->location, numberType, itemTy); + } + } + + if (!indexKeyLowerBound.empty()) + { + LUAU_ASSERT(!indexValueLowerBound.empty()); + + TypeId indexKey = indexKeyLowerBound.size() == 1 + ? *indexKeyLowerBound.begin() + : arena->addType(UnionType{std::vector(indexKeyLowerBound.begin(), indexKeyLowerBound.end())}); + + TypeId indexValue = indexValueLowerBound.size() == 1 + ? *indexValueLowerBound.begin() + : arena->addType(UnionType{std::vector(indexValueLowerBound.begin(), indexValueLowerBound.end())}); + + ttv->indexer = TableIndexer{indexKey, indexValue}; + } + + if (expectedType) + { + Unifier2 unifier{arena, builtinTypes, NotNull{scope.get()}, ice}; + std::vector toBlock; + matchLiteralType( + NotNull{&module->astTypes}, NotNull{&module->astExpectedTypes}, builtinTypes, arena, NotNull{&unifier}, *expectedType, ty, expr, toBlock + ); + } + + return Inference{ty}; +} + +ConstraintGenerator::FunctionSignature ConstraintGenerator::checkFunctionSignature( + const ScopePtr& parent, + AstExprFunction* fn, + std::optional expectedType, + std::optional originalName +) +{ + ScopePtr signatureScope = nullptr; + ScopePtr bodyScope = nullptr; + TypePackId returnType = nullptr; + + std::vector genericTypes; + std::vector genericTypePacks; + + if (expectedType) + expectedType = follow(*expectedType); + + bool hasGenerics = fn->generics.size > 0 || fn->genericPacks.size > 0; + + signatureScope = childScope(fn, parent); + + // We need to assign returnType before creating bodyScope so that the + // return type gets propogated to bodyScope. + returnType = freshTypePack(signatureScope); + signatureScope->returnType = returnType; + + bodyScope = childScope(fn->body, signatureScope); + + if (hasGenerics) + { + std::vector> genericDefinitions = createGenerics(signatureScope, fn->generics); + std::vector> genericPackDefinitions = createGenericPacks(signatureScope, fn->genericPacks); + + // We do not support default values on function generics, so we only + // care about the types involved. + for (const auto& [name, g] : genericDefinitions) + { + genericTypes.push_back(g.ty); + } + + for (const auto& [name, g] : genericPackDefinitions) + { + genericTypePacks.push_back(g.tp); + } + + // Local variable works around an odd gcc 11.3 warning: may be used uninitialized + std::optional none = std::nullopt; + expectedType = none; + } + + std::vector argTypes; + std::vector> argNames; + TypePack expectedArgPack; + + const FunctionType* expectedFunction = expectedType ? get(*expectedType) : nullptr; + // This check ensures that expectedType is precisely optional and not any (since any is also an optional type) + if (expectedType && isOptional(*expectedType) && !get(*expectedType)) + { + if (auto ut = get(*expectedType)) + { + for (auto u : ut) + { + if (get(u) && !isNil(u)) + { + expectedFunction = get(u); + break; + } + } + } + } + + if (expectedFunction) + { + expectedArgPack = extendTypePack(*arena, builtinTypes, expectedFunction->argTypes, fn->args.size); + + genericTypes = expectedFunction->generics; + genericTypePacks = expectedFunction->genericPacks; + } + + if (fn->self) + { + TypeId selfType = freshType(signatureScope); + argTypes.push_back(selfType); + argNames.emplace_back(FunctionArgument{fn->self->name.value, fn->self->location}); + signatureScope->bindings[fn->self] = Binding{selfType, fn->self->location}; + + DefId def = dfg->getDef(fn->self); + signatureScope->lvalueTypes[def] = selfType; + signatureScope->rvalueRefinements[def] = selfType; + } + + for (size_t i = 0; i < fn->args.size; ++i) + { + AstLocal* local = fn->args.data[i]; + + TypeId argTy = nullptr; + if (local->annotation) + argTy = resolveType(signatureScope, local->annotation, /* inTypeArguments */ false, /* replaceErrorWithFresh*/ true); + else + { + if (i < expectedArgPack.head.size()) + argTy = expectedArgPack.head[i]; + else + argTy = freshType(signatureScope); + } + + argTypes.push_back(argTy); + argNames.emplace_back(FunctionArgument{local->name.value, local->location}); + + signatureScope->bindings[local] = Binding{argTy, local->location}; + + DefId def = dfg->getDef(local); + signatureScope->lvalueTypes[def] = argTy; + signatureScope->rvalueRefinements[def] = argTy; + } + + TypePackId varargPack = nullptr; + + if (fn->vararg) + { + if (fn->varargAnnotation) + { + TypePackId annotationType = + resolveTypePack(signatureScope, fn->varargAnnotation, /* inTypeArguments */ false, /* replaceErrorWithFresh */ true); + varargPack = annotationType; + } + else if (expectedArgPack.tail && get(*expectedArgPack.tail)) + varargPack = *expectedArgPack.tail; + else + varargPack = builtinTypes->anyTypePack; + + signatureScope->varargPack = varargPack; + bodyScope->varargPack = varargPack; + } + else + { + varargPack = arena->addTypePack(VariadicTypePack{builtinTypes->anyType, /*hidden*/ true}); + // We do not add to signatureScope->varargPack because ... is not valid + // in functions without an explicit ellipsis. + + signatureScope->varargPack = std::nullopt; + bodyScope->varargPack = std::nullopt; + } + + LUAU_ASSERT(nullptr != varargPack); + + // If there is both an annotation and an expected type, the annotation wins. + // Type checking will sort out any discrepancies later. + if (fn->returnAnnotation) + { + TypePackId annotatedRetType = + resolveTypePack(signatureScope, *fn->returnAnnotation, /* inTypeArguments */ false, /* replaceErrorWithFresh*/ true); + // We bind the annotated type directly here so that, when we need to + // generate constraints for return types, we have a guarantee that we + // know the annotated return type already, if one was provided. + LUAU_ASSERT(get(returnType)); + emplaceTypePack(asMutable(returnType), annotatedRetType); + } + else if (expectedFunction) + { + emplaceTypePack(asMutable(returnType), expectedFunction->retTypes); + } + + // TODO: Preserve argument names in the function's type. + + FunctionType actualFunction{TypeLevel{}, parent.get(), arena->addTypePack(argTypes, varargPack), returnType}; + actualFunction.generics = std::move(genericTypes); + actualFunction.genericPacks = std::move(genericTypePacks); + actualFunction.argNames = std::move(argNames); + actualFunction.hasSelf = fn->self != nullptr; + + FunctionDefinition defn; + defn.definitionModuleName = module->name; + defn.definitionLocation = fn->location; + defn.varargLocation = fn->vararg ? std::make_optional(fn->varargLocation) : std::nullopt; + defn.originalNameLocation = originalName.value_or(Location(fn->location.begin, 0)); + actualFunction.definition = defn; + + TypeId actualFunctionType = arena->addType(std::move(actualFunction)); + LUAU_ASSERT(actualFunctionType); + module->astTypes[fn] = actualFunctionType; + + if (expectedType && get(*expectedType)) + bindFreeType(*expectedType, actualFunctionType); + + return { + /* signature */ actualFunctionType, + /* signatureScope */ signatureScope, + /* bodyScope */ bodyScope, + }; +} + +void ConstraintGenerator::checkFunctionBody(const ScopePtr& scope, AstExprFunction* fn) +{ + // If it is possible for execution to reach the end of the function, the return type must be compatible with () + ControlFlow cf = visitBlockWithoutChildScope(scope, fn->body); + if (cf == ControlFlow::None) + addConstraint(scope, fn->location, PackSubtypeConstraint{builtinTypes->emptyTypePack, scope->returnType}); +} + +TypeId ConstraintGenerator::resolveType(const ScopePtr& scope, AstType* ty, bool inTypeArguments, bool replaceErrorWithFresh) +{ + TypeId result = nullptr; + + if (auto ref = ty->as()) + { + if (FFlag::DebugLuauMagicTypes) + { + if (ref->name == "_luau_ice") + ice->ice("_luau_ice encountered", ty->location); + else if (ref->name == "_luau_print") + { + if (ref->parameters.size != 1 || !ref->parameters.data[0].type) + { + reportError(ty->location, GenericError{"_luau_print requires one generic parameter"}); + module->astResolvedTypes[ty] = builtinTypes->errorRecoveryType(); + return builtinTypes->errorRecoveryType(); + } + else + return resolveType(scope, ref->parameters.data[0].type, inTypeArguments); + } + } + + std::optional alias; + + if (ref->prefix.has_value()) + { + alias = scope->lookupImportedType(ref->prefix->value, ref->name.value); + } + else + { + alias = scope->lookupType(ref->name.value); + } + + if (alias.has_value()) + { + // If the alias is not generic, we don't need to set up a blocked + // type and an instantiation constraint. + if (alias.has_value() && alias->typeParams.empty() && alias->typePackParams.empty()) + { + result = alias->type; + } + else + { + std::vector parameters; + std::vector packParameters; + + for (const AstTypeOrPack& p : ref->parameters) + { + // We do not enforce the ordering of types vs. type packs here; + // that is done in the parser. + if (p.type) + { + parameters.push_back(resolveType(scope, p.type, /* inTypeArguments */ true)); + } + else if (p.typePack) + { + packParameters.push_back(resolveTypePack(scope, p.typePack, /* inTypeArguments */ true)); + } + else + { + // This indicates a parser bug: one of these two pointers + // should be set. + LUAU_ASSERT(false); + } + } + + result = arena->addType(PendingExpansionType{ref->prefix, ref->name, parameters, packParameters}); + + // If we're not in a type argument context, we need to create a constraint that expands this. + // The dispatching of the above constraint will queue up additional constraints for nested + // type function applications. + if (!inTypeArguments) + addConstraint(scope, ty->location, TypeAliasExpansionConstraint{/* target */ result}); + } + } + else + { + result = builtinTypes->errorRecoveryType(); + if (replaceErrorWithFresh) + result = freshType(scope); + } + } + else if (auto tab = ty->as()) + { + TableType::Props props; + std::optional indexer; + + for (const AstTableProp& prop : tab->props) + { + // TODO: Recursion limit. + TypeId propTy = resolveType(scope, prop.type, inTypeArguments); + + Property& p = props[prop.name.value]; + p.typeLocation = prop.location; + + switch (prop.access) + { + case AstTableAccess::ReadWrite: + p.readTy = propTy; + p.writeTy = propTy; + break; + case AstTableAccess::Read: + p.readTy = propTy; + break; + case AstTableAccess::Write: + reportError(*prop.accessLocation, GenericError{"write keyword is illegal here"}); + p.readTy = propTy; + p.writeTy = propTy; + break; + default: + ice->ice("Unexpected property access " + std::to_string(int(prop.access))); + break; + } + } + + if (AstTableIndexer* astIndexer = tab->indexer) + { + if (astIndexer->access == AstTableAccess::Read) + reportError(astIndexer->accessLocation.value_or(Location{}), GenericError{"read keyword is illegal here"}); + else if (astIndexer->access == AstTableAccess::Write) + reportError(astIndexer->accessLocation.value_or(Location{}), GenericError{"write keyword is illegal here"}); + else if (astIndexer->access == AstTableAccess::ReadWrite) + { + // TODO: Recursion limit. + indexer = TableIndexer{ + resolveType(scope, astIndexer->indexType, inTypeArguments), + resolveType(scope, astIndexer->resultType, inTypeArguments), + }; + } + else + ice->ice("Unexpected property access " + std::to_string(int(astIndexer->access))); + } + + result = arena->addType(TableType{props, indexer, scope->level, scope.get(), TableState::Sealed}); + } + else if (auto fn = ty->as()) + { + // TODO: Recursion limit. + bool hasGenerics = fn->generics.size > 0 || fn->genericPacks.size > 0; + ScopePtr signatureScope = nullptr; + + std::vector genericTypes; + std::vector genericTypePacks; + + // If we don't have generics, we do not need to generate a child scope + // for the generic bindings to live on. + if (hasGenerics) + { + signatureScope = childScope(fn, scope); + + std::vector> genericDefinitions = createGenerics(signatureScope, fn->generics); + std::vector> genericPackDefinitions = createGenericPacks(signatureScope, fn->genericPacks); + + for (const auto& [name, g] : genericDefinitions) + { + genericTypes.push_back(g.ty); + } + + for (const auto& [name, g] : genericPackDefinitions) + { + genericTypePacks.push_back(g.tp); + } + } + else + { + // To eliminate the need to branch on hasGenerics below, we say that + // the signature scope is the parent scope if we don't have + // generics. + signatureScope = scope; + } + + TypePackId argTypes = resolveTypePack(signatureScope, fn->argTypes, inTypeArguments, replaceErrorWithFresh); + TypePackId returnTypes = resolveTypePack(signatureScope, fn->returnTypes, inTypeArguments, replaceErrorWithFresh); + + // TODO: FunctionType needs a pointer to the scope so that we know + // how to quantify/instantiate it. + FunctionType ftv{TypeLevel{}, scope.get(), {}, {}, argTypes, returnTypes}; + ftv.isCheckedFunction = fn->isCheckedFunction(); + + // This replicates the behavior of the appropriate FunctionType + // constructors. + ftv.generics = std::move(genericTypes); + ftv.genericPacks = std::move(genericTypePacks); + + ftv.argNames.reserve(fn->argNames.size); + for (const auto& el : fn->argNames) + { + if (el) + { + const auto& [name, location] = *el; + ftv.argNames.push_back(FunctionArgument{name.value, location}); + } + else + { + ftv.argNames.push_back(std::nullopt); + } + } + + result = arena->addType(std::move(ftv)); + } + else if (auto tof = ty->as()) + { + // TODO: Recursion limit. + TypeId exprType = check(scope, tof->expr).ty; + result = exprType; + } + else if (auto unionAnnotation = ty->as()) + { + std::vector parts; + for (AstType* part : unionAnnotation->types) + { + // TODO: Recursion limit. + parts.push_back(resolveType(scope, part, inTypeArguments)); + } + + result = arena->addType(UnionType{parts}); + } + else if (auto intersectionAnnotation = ty->as()) + { + std::vector parts; + for (AstType* part : intersectionAnnotation->types) + { + // TODO: Recursion limit. + parts.push_back(resolveType(scope, part, inTypeArguments)); + } + + result = arena->addType(IntersectionType{parts}); + } + else if (auto boolAnnotation = ty->as()) + { + if (boolAnnotation->value) + result = builtinTypes->trueType; + else + result = builtinTypes->falseType; + } + else if (auto stringAnnotation = ty->as()) + { + result = arena->addType(SingletonType(StringSingleton{std::string(stringAnnotation->value.data, stringAnnotation->value.size)})); + } + else if (ty->is()) + { + result = builtinTypes->errorRecoveryType(); + if (replaceErrorWithFresh) + result = freshType(scope); + } + else + { + LUAU_ASSERT(0); + result = builtinTypes->errorRecoveryType(); + } + + module->astResolvedTypes[ty] = result; + return result; +} + +TypePackId ConstraintGenerator::resolveTypePack(const ScopePtr& scope, AstTypePack* tp, bool inTypeArgument, bool replaceErrorWithFresh) +{ + TypePackId result; + if (auto expl = tp->as()) + { + result = resolveTypePack(scope, expl->typeList, inTypeArgument, replaceErrorWithFresh); + } + else if (auto var = tp->as()) + { + TypeId ty = resolveType(scope, var->variadicType, inTypeArgument, replaceErrorWithFresh); + result = arena->addTypePack(TypePackVar{VariadicTypePack{ty}}); + } + else if (auto gen = tp->as()) + { + if (std::optional lookup = scope->lookupPack(gen->genericName.value)) + { + result = *lookup; + } + else + { + reportError(tp->location, UnknownSymbol{gen->genericName.value, UnknownSymbol::Context::Type}); + result = builtinTypes->errorRecoveryTypePack(); + } + } + else + { + LUAU_ASSERT(0); + result = builtinTypes->errorRecoveryTypePack(); + } + + module->astResolvedTypePacks[tp] = result; + return result; +} + +TypePackId ConstraintGenerator::resolveTypePack(const ScopePtr& scope, const AstTypeList& list, bool inTypeArguments, bool replaceErrorWithFresh) +{ + std::vector head; + + for (AstType* headTy : list.types) + { + head.push_back(resolveType(scope, headTy, inTypeArguments, replaceErrorWithFresh)); + } + + std::optional tail = std::nullopt; + if (list.tailType) + { + tail = resolveTypePack(scope, list.tailType, inTypeArguments, replaceErrorWithFresh); + } + + return addTypePack(std::move(head), tail); +} + +std::vector> ConstraintGenerator::createGenerics( + const ScopePtr& scope, + AstArray generics, + bool useCache, + bool addTypes +) +{ + std::vector> result; + for (const auto& generic : generics) + { + TypeId genericTy = nullptr; + + if (auto it = scope->parent->typeAliasTypeParameters.find(generic.name.value); useCache && it != scope->parent->typeAliasTypeParameters.end()) + genericTy = it->second; + else + { + genericTy = arena->addType(GenericType{scope.get(), generic.name.value}); + scope->parent->typeAliasTypeParameters[generic.name.value] = genericTy; + } + + std::optional defaultTy = std::nullopt; + + if (generic.defaultValue) + defaultTy = resolveType(scope, generic.defaultValue, /* inTypeArguments */ false); + + if (addTypes) + scope->privateTypeBindings[generic.name.value] = TypeFun{genericTy}; + + result.push_back({generic.name.value, GenericTypeDefinition{genericTy, defaultTy}}); + } + + return result; +} + +std::vector> ConstraintGenerator::createGenericPacks( + const ScopePtr& scope, + AstArray generics, + bool useCache, + bool addTypes +) +{ + std::vector> result; + for (const auto& generic : generics) + { + TypePackId genericTy; + + if (auto it = scope->parent->typeAliasTypePackParameters.find(generic.name.value); + useCache && it != scope->parent->typeAliasTypePackParameters.end()) + genericTy = it->second; + else + { + genericTy = arena->addTypePack(TypePackVar{GenericTypePack{scope.get(), generic.name.value}}); + scope->parent->typeAliasTypePackParameters[generic.name.value] = genericTy; + } + + std::optional defaultTy = std::nullopt; + + if (generic.defaultValue) + defaultTy = resolveTypePack(scope, generic.defaultValue, /* inTypeArguments */ false); + + if (addTypes) + scope->privateTypePackBindings[generic.name.value] = genericTy; + + result.push_back({generic.name.value, GenericTypePackDefinition{genericTy, defaultTy}}); + } + + return result; +} + +Inference ConstraintGenerator::flattenPack(const ScopePtr& scope, Location location, InferencePack pack) +{ + const auto& [tp, refinements] = pack; + RefinementId refinement = nullptr; + if (!refinements.empty()) + refinement = refinements[0]; + + if (auto f = first(tp)) + return Inference{*f, refinement}; + + TypeId typeResult = arena->addType(BlockedType{}); + auto c = addConstraint(scope, location, UnpackConstraint{{typeResult}, tp}); + getMutable(typeResult)->setOwner(c); + + return Inference{typeResult, refinement}; +} + +void ConstraintGenerator::reportError(Location location, TypeErrorData err) +{ + errors.push_back(TypeError{location, module->name, std::move(err)}); + + if (logger) + logger->captureGenerationError(errors.back()); +} + +void ConstraintGenerator::reportCodeTooComplex(Location location) +{ + errors.push_back(TypeError{location, module->name, CodeTooComplex{}}); + + if (logger) + logger->captureGenerationError(errors.back()); +} + +TypeId ConstraintGenerator::makeUnion(const ScopePtr& scope, Location location, TypeId lhs, TypeId rhs) +{ + if (get(follow(lhs))) + return rhs; + if (get(follow(rhs))) + return lhs; + + TypeId resultType = createTypeFunctionInstance(builtinTypeFunctions().unionFunc, {lhs, rhs}, {}, scope, location); + + return resultType; +} + +TypeId ConstraintGenerator::makeIntersect(const ScopePtr& scope, Location location, TypeId lhs, TypeId rhs) +{ + TypeId resultType = createTypeFunctionInstance(builtinTypeFunctions().intersectFunc, {lhs, rhs}, {}, scope, location); + + return resultType; +} + +struct GlobalPrepopulator : AstVisitor +{ + const NotNull globalScope; + const NotNull arena; + const NotNull dfg; + + GlobalPrepopulator(NotNull globalScope, NotNull arena, NotNull dfg) + : globalScope(globalScope) + , arena(arena) + , dfg(dfg) + { + } + + bool visit(AstExprGlobal* global) override + { + if (auto ty = globalScope->lookup(global->name)) + { + DefId def = dfg->getDef(global); + globalScope->lvalueTypes[def] = *ty; + } + + return true; + } + + bool visit(AstStatFunction* function) override + { + if (AstExprGlobal* g = function->name->as()) + { + TypeId bt = arena->addType(BlockedType{}); + globalScope->bindings[g->name] = Binding{bt}; + } + + return true; + } + + bool visit(AstType*) override + { + return true; + } + + bool visit(class AstTypePack* node) override + { + return true; + } +}; + +void ConstraintGenerator::prepopulateGlobalScope(const ScopePtr& globalScope, AstStatBlock* program) +{ + GlobalPrepopulator gp{NotNull{globalScope.get()}, arena, dfg}; + + if (prepareModuleScope) + prepareModuleScope(module->name, globalScope); + + program->visit(&gp); +} + +bool ConstraintGenerator::recordPropertyAssignment(TypeId ty) +{ + DenseHashSet seen{nullptr}; + VecDeque queue; + + queue.push_back(ty); + + bool incremented = false; + + while (!queue.empty()) + { + const TypeId t = follow(queue.front()); + queue.pop_front(); + + if (seen.find(t)) + continue; + seen.insert(t); + + if (auto tt = getMutable(t); tt && tt->state == TableState::Unsealed) + { + tt->remainingProps += 1; + incremented = true; + } + else if (auto mt = get(t)) + queue.push_back(mt->table); + else if (TypeIds* localDomain = localTypes.find(t)) + { + for (TypeId domainTy : *localDomain) + queue.push_back(domainTy); + } + else if (auto ut = get(t)) + { + for (TypeId part : ut) + queue.push_back(part); + } + } + + return incremented; +} + +void ConstraintGenerator::recordInferredBinding(AstLocal* local, TypeId ty) +{ + if (InferredBinding* ib = inferredBindings.find(local)) + ib->types.insert(ty); +} + +void ConstraintGenerator::fillInInferredBindings(const ScopePtr& globalScope, AstStatBlock* block) +{ + for (const auto& [symbol, p] : inferredBindings) + { + const auto& [scope, location, types] = p; + + std::vector tys(types.begin(), types.end()); + if (tys.size() == 1) + scope->bindings[symbol] = Binding{tys.front(), location}; + else + { + TypeId ty = createTypeFunctionInstance(builtinTypeFunctions().unionFunc, std::move(tys), {}, globalScope, location); + + scope->bindings[symbol] = Binding{ty, location}; + } + } +} + +std::vector> ConstraintGenerator::getExpectedCallTypesForFunctionOverloads(const TypeId fnType) +{ + std::vector funTys; + if (auto it = get(follow(fnType))) + { + for (TypeId intersectionComponent : it) + { + funTys.push_back(intersectionComponent); + } + } + + std::vector> expectedTypes; + // For a list of functions f_0 : e_0 -> r_0, ... f_n : e_n -> r_n, + // emit a list of arguments that the function could take at each position + // by unioning the arguments at each place + auto assignOption = [this, &expectedTypes](size_t index, TypeId ty) + { + if (index == expectedTypes.size()) + expectedTypes.push_back(ty); + else if (ty) + { + auto& el = expectedTypes[index]; + + if (!el) + el = ty; + else + { + std::vector result = reduceUnion({*el, ty}); + if (result.empty()) + el = builtinTypes->neverType; + else if (result.size() == 1) + el = result[0]; + else + el = module->internalTypes.addType(UnionType{std::move(result)}); + } + } + }; + + for (const TypeId overload : funTys) + { + if (const FunctionType* ftv = get(follow(overload))) + { + auto [argsHead, argsTail] = flatten(ftv->argTypes); + size_t start = ftv->hasSelf ? 1 : 0; + size_t index = 0; + for (size_t i = start; i < argsHead.size(); ++i) + assignOption(index++, argsHead[i]); + if (argsTail) + { + argsTail = follow(*argsTail); + if (const VariadicTypePack* vtp = get(*argsTail)) + { + while (index < funTys.size()) + assignOption(index++, vtp->ty); + } + } + } + } + + // TODO vvijay Feb 24, 2023 apparently we have to demote the types here? + + return expectedTypes; +} + +TypeId ConstraintGenerator::createTypeFunctionInstance( + const TypeFunction& function, + std::vector typeArguments, + std::vector packArguments, + const ScopePtr& scope, + Location location +) +{ + TypeId result = arena->addTypeFunction(function, typeArguments, packArguments); + addConstraint(scope, location, ReduceConstraint{result}); + return result; +} + +std::vector> borrowConstraints(const std::vector& constraints) +{ + std::vector> result; + result.reserve(constraints.size()); + + for (const auto& c : constraints) + result.emplace_back(c.get()); + + return result; +} + +} // namespace Luau diff --git a/third_party/luau/Analysis/src/ConstraintGraphBuilder.cpp b/third_party/luau/Analysis/src/ConstraintGraphBuilder.cpp deleted file mode 100644 index 611f420a..00000000 --- a/third_party/luau/Analysis/src/ConstraintGraphBuilder.cpp +++ /dev/null @@ -1,2741 +0,0 @@ -// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -#include "Luau/ConstraintGraphBuilder.h" - -#include "Luau/Ast.h" -#include "Luau/Breadcrumb.h" -#include "Luau/Common.h" -#include "Luau/Constraint.h" -#include "Luau/ControlFlow.h" -#include "Luau/DcrLogger.h" -#include "Luau/ModuleResolver.h" -#include "Luau/RecursionCounter.h" -#include "Luau/Refinement.h" -#include "Luau/Scope.h" -#include "Luau/TypeUtils.h" -#include "Luau/Type.h" - -#include - -LUAU_FASTINT(LuauCheckRecursionLimit); -LUAU_FASTFLAG(DebugLuauMagicTypes); -LUAU_FASTFLAG(LuauNegatedClassTypes); - -namespace Luau -{ - -bool doesCallError(const AstExprCall* call); // TypeInfer.cpp -const AstStat* getFallthrough(const AstStat* node); // TypeInfer.cpp - -static std::optional matchRequire(const AstExprCall& call) -{ - const char* require = "require"; - - if (call.args.size != 1) - return std::nullopt; - - const AstExprGlobal* funcAsGlobal = call.func->as(); - if (!funcAsGlobal || funcAsGlobal->name != require) - return std::nullopt; - - if (call.args.size != 1) - return std::nullopt; - - return call.args.data[0]; -} - -static bool matchSetmetatable(const AstExprCall& call) -{ - const char* smt = "setmetatable"; - - if (call.args.size != 2) - return false; - - const AstExprGlobal* funcAsGlobal = call.func->as(); - if (!funcAsGlobal || funcAsGlobal->name != smt) - return false; - - return true; -} - -struct TypeGuard -{ - bool isTypeof; - AstExpr* target; - std::string type; -}; - -static std::optional matchTypeGuard(const AstExprBinary* binary) -{ - if (binary->op != AstExprBinary::CompareEq && binary->op != AstExprBinary::CompareNe) - return std::nullopt; - - AstExpr* left = binary->left; - AstExpr* right = binary->right; - if (right->is()) - std::swap(left, right); - - if (!right->is()) - return std::nullopt; - - AstExprCall* call = left->as(); - AstExprConstantString* string = right->as(); - if (!call || !string) - return std::nullopt; - - AstExprGlobal* callee = call->func->as(); - if (!callee) - return std::nullopt; - - if (callee->name != "type" && callee->name != "typeof") - return std::nullopt; - - if (call->args.size != 1) - return std::nullopt; - - return TypeGuard{ - /*isTypeof*/ callee->name == "typeof", - /*target*/ call->args.data[0], - /*type*/ std::string(string->value.data, string->value.size), - }; -} - -static bool matchAssert(const AstExprCall& call) -{ - if (call.args.size < 1) - return false; - - const AstExprGlobal* funcAsGlobal = call.func->as(); - if (!funcAsGlobal || funcAsGlobal->name != "assert") - return false; - - return true; -} - -namespace -{ - -struct Checkpoint -{ - size_t offset; -}; - -Checkpoint checkpoint(const ConstraintGraphBuilder* cgb) -{ - return Checkpoint{cgb->constraints.size()}; -} - -template -void forEachConstraint(const Checkpoint& start, const Checkpoint& end, const ConstraintGraphBuilder* cgb, F f) -{ - for (size_t i = start.offset; i < end.offset; ++i) - f(cgb->constraints[i]); -} - -} // namespace - -ConstraintGraphBuilder::ConstraintGraphBuilder(ModulePtr module, TypeArena* arena, NotNull moduleResolver, - NotNull builtinTypes, NotNull ice, const ScopePtr& globalScope, - std::function prepareModuleScope, DcrLogger* logger, NotNull dfg) - : module(module) - , builtinTypes(builtinTypes) - , arena(arena) - , rootScope(nullptr) - , dfg(dfg) - , moduleResolver(moduleResolver) - , ice(ice) - , globalScope(globalScope) - , prepareModuleScope(std::move(prepareModuleScope)) - , logger(logger) -{ - LUAU_ASSERT(module); -} - -TypeId ConstraintGraphBuilder::freshType(const ScopePtr& scope) -{ - return arena->addType(FreeType{scope.get()}); -} - -TypePackId ConstraintGraphBuilder::freshTypePack(const ScopePtr& scope) -{ - FreeTypePack f{scope.get()}; - return arena->addTypePack(TypePackVar{std::move(f)}); -} - -ScopePtr ConstraintGraphBuilder::childScope(AstNode* node, const ScopePtr& parent) -{ - auto scope = std::make_shared(parent); - scopes.emplace_back(node->location, scope); - - scope->returnType = parent->returnType; - scope->varargPack = parent->varargPack; - - parent->children.push_back(NotNull{scope.get()}); - module->astScopes[node] = scope.get(); - - return scope; -} - -NotNull ConstraintGraphBuilder::addConstraint(const ScopePtr& scope, const Location& location, ConstraintV cv) -{ - return NotNull{constraints.emplace_back(new Constraint{NotNull{scope.get()}, location, std::move(cv)}).get()}; -} - -NotNull ConstraintGraphBuilder::addConstraint(const ScopePtr& scope, std::unique_ptr c) -{ - return NotNull{constraints.emplace_back(std::move(c)).get()}; -} - -struct RefinementPartition -{ - // Types that we want to intersect against the type of the expression. - std::vector discriminantTypes; - - // Sometimes the type we're discriminating against is implicitly nil. - bool shouldAppendNilType = false; -}; - -using RefinementContext = std::unordered_map; - -static void unionRefinements(const RefinementContext& lhs, const RefinementContext& rhs, RefinementContext& dest, NotNull arena) -{ - for (auto& [def, partition] : lhs) - { - auto rhsIt = rhs.find(def); - if (rhsIt == rhs.end()) - continue; - - LUAU_ASSERT(!partition.discriminantTypes.empty()); - LUAU_ASSERT(!rhsIt->second.discriminantTypes.empty()); - - TypeId leftDiscriminantTy = - partition.discriminantTypes.size() == 1 ? partition.discriminantTypes[0] : arena->addType(IntersectionType{partition.discriminantTypes}); - - TypeId rightDiscriminantTy = rhsIt->second.discriminantTypes.size() == 1 ? rhsIt->second.discriminantTypes[0] - : arena->addType(IntersectionType{rhsIt->second.discriminantTypes}); - - dest[def].discriminantTypes.push_back(arena->addType(UnionType{{leftDiscriminantTy, rightDiscriminantTy}})); - dest[def].shouldAppendNilType |= partition.shouldAppendNilType || rhsIt->second.shouldAppendNilType; - } -} - -static void computeRefinement(const ScopePtr& scope, RefinementId refinement, RefinementContext* refis, bool sense, NotNull arena, bool eq, - std::vector* constraints) -{ - if (!refinement) - return; - else if (auto variadic = get(refinement)) - { - for (RefinementId refi : variadic->refinements) - computeRefinement(scope, refi, refis, sense, arena, eq, constraints); - } - else if (auto negation = get(refinement)) - return computeRefinement(scope, negation->refinement, refis, !sense, arena, eq, constraints); - else if (auto conjunction = get(refinement)) - { - RefinementContext lhsRefis; - RefinementContext rhsRefis; - - computeRefinement(scope, conjunction->lhs, sense ? refis : &lhsRefis, sense, arena, eq, constraints); - computeRefinement(scope, conjunction->rhs, sense ? refis : &rhsRefis, sense, arena, eq, constraints); - - if (!sense) - unionRefinements(lhsRefis, rhsRefis, *refis, arena); - } - else if (auto disjunction = get(refinement)) - { - RefinementContext lhsRefis; - RefinementContext rhsRefis; - - computeRefinement(scope, disjunction->lhs, sense ? &lhsRefis : refis, sense, arena, eq, constraints); - computeRefinement(scope, disjunction->rhs, sense ? &rhsRefis : refis, sense, arena, eq, constraints); - - if (sense) - unionRefinements(lhsRefis, rhsRefis, *refis, arena); - } - else if (auto equivalence = get(refinement)) - { - computeRefinement(scope, equivalence->lhs, refis, sense, arena, true, constraints); - computeRefinement(scope, equivalence->rhs, refis, sense, arena, true, constraints); - } - else if (auto proposition = get(refinement)) - { - TypeId discriminantTy = proposition->discriminantTy; - if (!sense && !eq) - discriminantTy = arena->addType(NegationType{proposition->discriminantTy}); - else if (eq) - { - discriminantTy = arena->addType(BlockedType{}); - constraints->push_back(SingletonOrTopTypeConstraint{discriminantTy, proposition->discriminantTy, !sense}); - } - - RefinementContext uncommittedRefis; - uncommittedRefis[proposition->breadcrumb->def].discriminantTypes.push_back(discriminantTy); - - // When the top-level expression is `t[x]`, we want to refine it into `nil`, not `never`. - if ((sense || !eq) && getMetadata(proposition->breadcrumb)) - uncommittedRefis[proposition->breadcrumb->def].shouldAppendNilType = true; - - for (NullableBreadcrumbId current = proposition->breadcrumb; current && current->previous; current = current->previous) - { - LUAU_ASSERT(get(current->def)); - - // If this current breadcrumb has no metadata, it's no-op for the purpose of building a discriminant type. - if (!current->metadata) - continue; - else if (auto field = getMetadata(current)) - { - TableType::Props props{{field->prop, Property{discriminantTy}}}; - discriminantTy = arena->addType(TableType{std::move(props), std::nullopt, TypeLevel{}, scope.get(), TableState::Sealed}); - uncommittedRefis[current->previous->def].discriminantTypes.push_back(discriminantTy); - } - } - - // And now it's time to commit it. - for (auto& [def, partition] : uncommittedRefis) - { - for (TypeId discriminantTy : partition.discriminantTypes) - (*refis)[def].discriminantTypes.push_back(discriminantTy); - - (*refis)[def].shouldAppendNilType |= partition.shouldAppendNilType; - } - } -} - -void ConstraintGraphBuilder::applyRefinements(const ScopePtr& scope, Location location, RefinementId refinement) -{ - if (!refinement) - return; - - RefinementContext refinements; - std::vector constraints; - computeRefinement(scope, refinement, &refinements, /*sense*/ true, arena, /*eq*/ false, &constraints); - - for (auto& [def, partition] : refinements) - { - if (std::optional defTy = scope->lookup(def)) - { - TypeId ty = *defTy; - if (partition.shouldAppendNilType) - ty = arena->addType(UnionType{{ty, builtinTypes->nilType}}); - - partition.discriminantTypes.push_back(ty); - scope->dcrRefinements[def] = arena->addType(IntersectionType{std::move(partition.discriminantTypes)}); - } - } - - for (auto& c : constraints) - addConstraint(scope, location, c); -} - -void ConstraintGraphBuilder::visit(AstStatBlock* block) -{ - LUAU_ASSERT(scopes.empty()); - LUAU_ASSERT(rootScope == nullptr); - ScopePtr scope = std::make_shared(globalScope); - rootScope = scope.get(); - scopes.emplace_back(block->location, scope); - module->astScopes[block] = NotNull{scope.get()}; - - rootScope->returnType = freshTypePack(scope); - - prepopulateGlobalScope(scope, block); - - visitBlockWithoutChildScope(scope, block); - - if (logger) - logger->captureGenerationModule(module); -} - -ControlFlow ConstraintGraphBuilder::visitBlockWithoutChildScope(const ScopePtr& scope, AstStatBlock* block) -{ - RecursionCounter counter{&recursionCount}; - - if (recursionCount >= FInt::LuauCheckRecursionLimit) - { - reportCodeTooComplex(block->location); - return ControlFlow::None; - } - - std::unordered_map aliasDefinitionLocations; - - // In order to enable mutually-recursive type aliases, we need to - // populate the type bindings before we actually check any of the - // alias statements. - for (AstStat* stat : block->body) - { - if (auto alias = stat->as()) - { - if (scope->exportedTypeBindings.count(alias->name.value) || scope->privateTypeBindings.count(alias->name.value)) - { - auto it = aliasDefinitionLocations.find(alias->name.value); - LUAU_ASSERT(it != aliasDefinitionLocations.end()); - reportError(alias->location, DuplicateTypeDefinition{alias->name.value, it->second}); - continue; - } - - ScopePtr defnScope = childScope(alias, scope); - - TypeId initialType = arena->addType(BlockedType{}); - TypeFun initialFun{initialType}; - - for (const auto& [name, gen] : createGenerics(defnScope, alias->generics, /* useCache */ true)) - { - initialFun.typeParams.push_back(gen); - } - - for (const auto& [name, genPack] : createGenericPacks(defnScope, alias->genericPacks, /* useCache */ true)) - { - initialFun.typePackParams.push_back(genPack); - } - - if (alias->exported) - scope->exportedTypeBindings[alias->name.value] = std::move(initialFun); - else - scope->privateTypeBindings[alias->name.value] = std::move(initialFun); - - astTypeAliasDefiningScopes[alias] = defnScope; - aliasDefinitionLocations[alias->name.value] = alias->location; - } - } - - std::optional firstControlFlow; - for (AstStat* stat : block->body) - { - ControlFlow cf = visit(scope, stat); - if (cf != ControlFlow::None && !firstControlFlow) - firstControlFlow = cf; - } - - return firstControlFlow.value_or(ControlFlow::None); -} - -ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStat* stat) -{ - RecursionLimiter limiter{&recursionCount, FInt::LuauCheckRecursionLimit}; - - if (auto s = stat->as()) - return visit(scope, s); - else if (auto i = stat->as()) - return visit(scope, i); - else if (auto s = stat->as()) - return visit(scope, s); - else if (auto s = stat->as()) - return visit(scope, s); - else if (stat->is() || stat->is()) - { - // Nothing - return ControlFlow::None; // TODO: ControlFlow::Break/Continue - } - else if (auto r = stat->as()) - return visit(scope, r); - else if (auto e = stat->as()) - { - checkPack(scope, e->expr); - - if (auto call = e->expr->as(); call && doesCallError(call)) - return ControlFlow::Throws; - - return ControlFlow::None; - } - else if (auto s = stat->as()) - return visit(scope, s); - else if (auto s = stat->as()) - return visit(scope, s); - else if (auto s = stat->as()) - return visit(scope, s); - else if (auto a = stat->as()) - return visit(scope, a); - else if (auto a = stat->as()) - return visit(scope, a); - else if (auto f = stat->as()) - return visit(scope, f); - else if (auto f = stat->as()) - return visit(scope, f); - else if (auto a = stat->as()) - return visit(scope, a); - else if (auto s = stat->as()) - return visit(scope, s); - else if (auto s = stat->as()) - return visit(scope, s); - else if (auto s = stat->as()) - return visit(scope, s); - else if (auto s = stat->as()) - return visit(scope, s); - else - { - LUAU_ASSERT(0 && "Internal error: Unknown AstStat type"); - return ControlFlow::None; - } -} - -ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocal* local) -{ - std::vector varTypes; - varTypes.reserve(local->vars.size); - - // Used to name the first value type, even if it's not placed in varTypes, - // for the purpose of synthetic name attribution. - std::optional firstValueType; - - for (AstLocal* local : local->vars) - { - TypeId ty = nullptr; - - if (local->annotation) - ty = resolveType(scope, local->annotation, /* inTypeArguments */ false); - - varTypes.push_back(ty); - } - - for (size_t i = 0; i < local->values.size; ++i) - { - AstExpr* value = local->values.data[i]; - const bool hasAnnotation = i < local->vars.size && nullptr != local->vars.data[i]->annotation; - - if (value->is()) - { - // HACK: we leave nil-initialized things floating under the - // assumption that they will later be populated. - // - // See the test TypeInfer/infer_locals_with_nil_value. Better flow - // awareness should make this obsolete. - - if (!varTypes[i]) - varTypes[i] = freshType(scope); - } - // Only function calls and vararg expressions can produce packs. All - // other expressions produce exactly one value. - else if (i != local->values.size - 1 || (!value->is() && !value->is())) - { - std::optional expectedType; - if (hasAnnotation) - expectedType = varTypes.at(i); - - TypeId exprType = check(scope, value, ValueContext::RValue, expectedType).ty; - if (i < varTypes.size()) - { - if (varTypes[i]) - addConstraint(scope, local->location, SubtypeConstraint{exprType, varTypes[i]}); - else - varTypes[i] = exprType; - } - - if (i == 0) - firstValueType = exprType; - } - else - { - std::vector> expectedTypes; - if (hasAnnotation) - expectedTypes.insert(begin(expectedTypes), begin(varTypes) + i, end(varTypes)); - - TypePackId exprPack = checkPack(scope, value, expectedTypes).tp; - - if (i < local->vars.size) - { - TypePack packTypes = extendTypePack(*arena, builtinTypes, exprPack, varTypes.size() - i); - - // fill out missing values in varTypes with values from exprPack - for (size_t j = i; j < varTypes.size(); ++j) - { - if (!varTypes[j]) - { - if (j - i < packTypes.head.size()) - varTypes[j] = packTypes.head[j - i]; - else - varTypes[j] = arena->addType(BlockedType{}); - } - } - - std::vector tailValues{varTypes.begin() + i, varTypes.end()}; - TypePackId tailPack = arena->addTypePack(std::move(tailValues)); - addConstraint(scope, local->location, UnpackConstraint{tailPack, exprPack}); - } - } - } - - if (local->vars.size == 1 && local->values.size == 1 && firstValueType && scope.get() == rootScope) - { - AstLocal* var = local->vars.data[0]; - AstExpr* value = local->values.data[0]; - - if (value->is()) - addConstraint(scope, value->location, NameConstraint{*firstValueType, var->name.value, /*synthetic*/ true}); - else if (const AstExprCall* call = value->as()) - { - if (const AstExprGlobal* global = call->func->as(); global && global->name == "setmetatable") - { - addConstraint(scope, value->location, NameConstraint{*firstValueType, var->name.value, /*synthetic*/ true}); - } - } - } - - for (size_t i = 0; i < local->vars.size; ++i) - { - AstLocal* l = local->vars.data[i]; - Location location = l->location; - - if (!varTypes[i]) - varTypes[i] = freshType(scope); - - scope->bindings[l] = Binding{varTypes[i], location}; - - // HACK: In the greedy solver, we say the type state of a variable is the type annotation itself, but - // the actual type state is the corresponding initializer expression (if it exists) or nil otherwise. - BreadcrumbId bc = dfg->getBreadcrumb(l); - scope->dcrRefinements[bc->def] = varTypes[i]; - } - - if (local->values.size > 0) - { - // To correctly handle 'require', we need to import the exported type bindings into the variable 'namespace'. - for (size_t i = 0; i < local->values.size && i < local->vars.size; ++i) - { - const AstExprCall* call = local->values.data[i]->as(); - if (!call) - continue; - - if (auto maybeRequire = matchRequire(*call)) - { - AstExpr* require = *maybeRequire; - - if (auto moduleInfo = moduleResolver->resolveModuleInfo(module->name, *require)) - { - const Name name{local->vars.data[i]->name.value}; - - if (ModulePtr module = moduleResolver->getModule(moduleInfo->name)) - { - scope->importedTypeBindings[name] = module->exportedTypeBindings; - scope->importedModules[name] = moduleInfo->name; - } - } - } - } - } - - return ControlFlow::None; -} - -ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatFor* for_) -{ - TypeId annotationTy = builtinTypes->numberType; - if (for_->var->annotation) - annotationTy = resolveType(scope, for_->var->annotation, /* inTypeArguments */ false); - - auto inferNumber = [&](AstExpr* expr) { - if (!expr) - return; - - TypeId t = check(scope, expr).ty; - addConstraint(scope, expr->location, SubtypeConstraint{t, builtinTypes->numberType}); - }; - - inferNumber(for_->from); - inferNumber(for_->to); - inferNumber(for_->step); - - ScopePtr forScope = childScope(for_, scope); - forScope->bindings[for_->var] = Binding{annotationTy, for_->var->location}; - - BreadcrumbId bc = dfg->getBreadcrumb(for_->var); - forScope->dcrRefinements[bc->def] = annotationTy; - - visit(forScope, for_->body); - - return ControlFlow::None; -} - -ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatForIn* forIn) -{ - ScopePtr loopScope = childScope(forIn, scope); - - TypePackId iterator = checkPack(scope, forIn->values).tp; - - std::vector variableTypes; - variableTypes.reserve(forIn->vars.size); - for (AstLocal* var : forIn->vars) - { - TypeId ty = freshType(loopScope); - loopScope->bindings[var] = Binding{ty, var->location}; - variableTypes.push_back(ty); - - BreadcrumbId bc = dfg->getBreadcrumb(var); - loopScope->dcrRefinements[bc->def] = ty; - } - - // It is always ok to provide too few variables, so we give this pack a free tail. - TypePackId variablePack = arena->addTypePack(std::move(variableTypes), arena->addTypePack(FreeTypePack{loopScope.get()})); - - addConstraint(loopScope, getLocation(forIn->values), IterableConstraint{iterator, variablePack}); - - visit(loopScope, forIn->body); - - return ControlFlow::None; -} - -ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatWhile* while_) -{ - check(scope, while_->condition); - - ScopePtr whileScope = childScope(while_, scope); - - visit(whileScope, while_->body); - - return ControlFlow::None; -} - -ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatRepeat* repeat) -{ - ScopePtr repeatScope = childScope(repeat, scope); - - visitBlockWithoutChildScope(repeatScope, repeat->body); - - check(repeatScope, repeat->condition); - - return ControlFlow::None; -} - -ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocalFunction* function) -{ - // Local - // Global - // Dotted path - // Self? - - TypeId functionType = nullptr; - auto ty = scope->lookup(function->name); - LUAU_ASSERT(!ty.has_value()); // The parser ensures that every local function has a distinct Symbol for its name. - - functionType = arena->addType(BlockedType{}); - scope->bindings[function->name] = Binding{functionType, function->name->location}; - - FunctionSignature sig = checkFunctionSignature(scope, function->func); - sig.bodyScope->bindings[function->name] = Binding{sig.signature, function->func->location}; - - BreadcrumbId bc = dfg->getBreadcrumb(function->name); - scope->dcrRefinements[bc->def] = functionType; - sig.bodyScope->dcrRefinements[bc->def] = sig.signature; - - Checkpoint start = checkpoint(this); - checkFunctionBody(sig.bodyScope, function->func); - Checkpoint end = checkpoint(this); - - NotNull constraintScope{sig.signatureScope ? sig.signatureScope.get() : sig.bodyScope.get()}; - std::unique_ptr c = - std::make_unique(constraintScope, function->name->location, GeneralizationConstraint{functionType, sig.signature}); - - forEachConstraint(start, end, this, [&c](const ConstraintPtr& constraint) { - c->dependencies.push_back(NotNull{constraint.get()}); - }); - - addConstraint(scope, std::move(c)); - - return ControlFlow::None; -} - -ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatFunction* function) -{ - // Name could be AstStatLocal, AstStatGlobal, AstStatIndexName. - // With or without self - - TypeId generalizedType = arena->addType(BlockedType{}); - - Checkpoint start = checkpoint(this); - FunctionSignature sig = checkFunctionSignature(scope, function->func); - - std::unordered_set excludeList; - - if (AstExprLocal* localName = function->name->as()) - { - std::optional existingFunctionTy = scope->lookup(localName->local); - if (existingFunctionTy) - { - addConstraint(scope, function->name->location, SubtypeConstraint{generalizedType, *existingFunctionTy}); - - Symbol sym{localName->local}; - scope->bindings[sym].typeId = generalizedType; - } - else - scope->bindings[localName->local] = Binding{generalizedType, localName->location}; - - sig.bodyScope->bindings[localName->local] = Binding{sig.signature, localName->location}; - } - else if (AstExprGlobal* globalName = function->name->as()) - { - std::optional existingFunctionTy = scope->lookup(globalName->name); - if (!existingFunctionTy) - ice->ice("prepopulateGlobalScope did not populate a global name", globalName->location); - - generalizedType = *existingFunctionTy; - - sig.bodyScope->bindings[globalName->name] = Binding{sig.signature, globalName->location}; - } - else if (AstExprIndexName* indexName = function->name->as()) - { - Checkpoint check1 = checkpoint(this); - TypeId lvalueType = checkLValue(scope, indexName); - Checkpoint check2 = checkpoint(this); - - forEachConstraint(check1, check2, this, [&excludeList](const ConstraintPtr& c) { - excludeList.insert(c.get()); - }); - - // TODO figure out how to populate the location field of the table Property. - - if (get(lvalueType)) - asMutable(lvalueType)->ty.emplace(generalizedType); - else - addConstraint(scope, indexName->location, SubtypeConstraint{lvalueType, generalizedType}); - } - else if (AstExprError* err = function->name->as()) - { - generalizedType = builtinTypes->errorRecoveryType(); - } - - if (generalizedType == nullptr) - ice->ice("generalizedType == nullptr", function->location); - - if (NullableBreadcrumbId bc = dfg->getBreadcrumb(function->name)) - scope->dcrRefinements[bc->def] = generalizedType; - - checkFunctionBody(sig.bodyScope, function->func); - Checkpoint end = checkpoint(this); - - NotNull constraintScope{sig.signatureScope ? sig.signatureScope.get() : sig.bodyScope.get()}; - std::unique_ptr c = - std::make_unique(constraintScope, function->name->location, GeneralizationConstraint{generalizedType, sig.signature}); - - forEachConstraint(start, end, this, [&c, &excludeList](const ConstraintPtr& constraint) { - if (!excludeList.count(constraint.get())) - c->dependencies.push_back(NotNull{constraint.get()}); - }); - - addConstraint(scope, std::move(c)); - - return ControlFlow::None; -} - -ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatReturn* ret) -{ - // At this point, the only way scope->returnType should have anything - // interesting in it is if the function has an explicit return annotation. - // If this is the case, then we can expect that the return expression - // conforms to that. - std::vector> expectedTypes; - for (TypeId ty : scope->returnType) - expectedTypes.push_back(ty); - - TypePackId exprTypes = checkPack(scope, ret->list, expectedTypes).tp; - addConstraint(scope, ret->location, PackSubtypeConstraint{exprTypes, scope->returnType}); - - return ControlFlow::Returns; -} - -ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatBlock* block) -{ - ScopePtr innerScope = childScope(block, scope); - - ControlFlow flow = visitBlockWithoutChildScope(innerScope, block); - scope->inheritRefinements(innerScope); - - return flow; -} - -static void bindFreeType(TypeId a, TypeId b) -{ - FreeType* af = getMutable(a); - FreeType* bf = getMutable(b); - - LUAU_ASSERT(af || bf); - - if (!bf) - asMutable(a)->ty.emplace(b); - else if (!af) - asMutable(b)->ty.emplace(a); - else if (subsumes(bf->scope, af->scope)) - asMutable(a)->ty.emplace(b); - else if (subsumes(af->scope, bf->scope)) - asMutable(b)->ty.emplace(a); -} - -ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatAssign* assign) -{ - std::vector varTypes = checkLValues(scope, assign->vars); - - std::vector> expectedTypes; - expectedTypes.reserve(varTypes.size()); - - for (TypeId ty : varTypes) - { - ty = follow(ty); - if (get(ty)) - expectedTypes.push_back(std::nullopt); - else - expectedTypes.push_back(ty); - } - - TypePackId exprPack = checkPack(scope, assign->values, expectedTypes).tp; - TypePackId varPack = arena->addTypePack({varTypes}); - - addConstraint(scope, assign->location, PackSubtypeConstraint{exprPack, varPack}); - - return ControlFlow::None; -} - -ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatCompoundAssign* assign) -{ - // We need to tweak the BinaryConstraint that we emit, so we cannot use the - // strategy of falsifying an AST fragment. - TypeId varTy = checkLValue(scope, assign->var); - TypeId valueTy = check(scope, assign->value).ty; - - TypeId resultType = arena->addType(BlockedType{}); - addConstraint(scope, assign->location, - BinaryConstraint{assign->op, varTy, valueTy, resultType, assign, &module->astOriginalCallTypes, &module->astOverloadResolvedTypes}); - addConstraint(scope, assign->location, SubtypeConstraint{resultType, varTy}); - - return ControlFlow::None; -} - -ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatIf* ifStatement) -{ - RefinementId refinement = check(scope, ifStatement->condition, ValueContext::RValue, std::nullopt).refinement; - - ScopePtr thenScope = childScope(ifStatement->thenbody, scope); - applyRefinements(thenScope, ifStatement->condition->location, refinement); - - ScopePtr elseScope = childScope(ifStatement->elsebody ? ifStatement->elsebody : ifStatement, scope); - applyRefinements(elseScope, ifStatement->elseLocation.value_or(ifStatement->condition->location), refinementArena.negation(refinement)); - - ControlFlow thencf = visit(thenScope, ifStatement->thenbody); - ControlFlow elsecf = ControlFlow::None; - if (ifStatement->elsebody) - elsecf = visit(elseScope, ifStatement->elsebody); - - if (matches(thencf, ControlFlow::Returns | ControlFlow::Throws) && elsecf == ControlFlow::None) - scope->inheritRefinements(elseScope); - else if (thencf == ControlFlow::None && matches(elsecf, ControlFlow::Returns | ControlFlow::Throws)) - scope->inheritRefinements(thenScope); - - if (matches(thencf, ControlFlow::Returns | ControlFlow::Throws) && matches(elsecf, ControlFlow::Returns | ControlFlow::Throws)) - return ControlFlow::Returns; - else - return ControlFlow::None; -} - -static bool occursCheck(TypeId needle, TypeId haystack) -{ - LUAU_ASSERT(get(needle)); - haystack = follow(haystack); - - auto checkHaystack = [needle](TypeId haystack) { - return occursCheck(needle, haystack); - }; - - if (needle == haystack) - return true; - else if (auto ut = get(haystack)) - return std::any_of(begin(ut), end(ut), checkHaystack); - else if (auto it = get(haystack)) - return std::any_of(begin(it), end(it), checkHaystack); - - return false; -} - -ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatTypeAlias* alias) -{ - ScopePtr* defnScope = astTypeAliasDefiningScopes.find(alias); - - std::unordered_map* typeBindings; - if (alias->exported) - typeBindings = &scope->exportedTypeBindings; - else - typeBindings = &scope->privateTypeBindings; - - // These will be undefined if the alias was a duplicate definition, in which - // case we just skip over it. - auto bindingIt = typeBindings->find(alias->name.value); - if (bindingIt == typeBindings->end() || defnScope == nullptr) - return ControlFlow::None; - - TypeId ty = resolveType(*defnScope, alias->type, /* inTypeArguments */ false); - - TypeId aliasTy = bindingIt->second.type; - LUAU_ASSERT(get(aliasTy)); - - if (occursCheck(aliasTy, ty)) - { - asMutable(aliasTy)->ty.emplace(builtinTypes->anyType); - reportError(alias->nameLocation, OccursCheckFailed{}); - } - else - asMutable(aliasTy)->ty.emplace(ty); - - std::vector typeParams; - for (auto tyParam : createGenerics(*defnScope, alias->generics, /* useCache */ true, /* addTypes */ false)) - typeParams.push_back(tyParam.second.ty); - - std::vector typePackParams; - for (auto tpParam : createGenericPacks(*defnScope, alias->genericPacks, /* useCache */ true, /* addTypes */ false)) - typePackParams.push_back(tpParam.second.tp); - - addConstraint(scope, alias->type->location, - NameConstraint{ - ty, - alias->name.value, - /*synthetic=*/false, - std::move(typeParams), - std::move(typePackParams), - }); - - return ControlFlow::None; -} - -ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatDeclareGlobal* global) -{ - LUAU_ASSERT(global->type); - - TypeId globalTy = resolveType(scope, global->type, /* inTypeArguments */ false); - Name globalName(global->name.value); - - module->declaredGlobals[globalName] = globalTy; - rootScope->bindings[global->name] = Binding{globalTy, global->location}; - - BreadcrumbId bc = dfg->getBreadcrumb(global); - rootScope->dcrRefinements[bc->def] = globalTy; - - return ControlFlow::None; -} - -static bool isMetamethod(const Name& name) -{ - return name == "__index" || name == "__newindex" || name == "__call" || name == "__concat" || name == "__unm" || name == "__add" || - name == "__sub" || name == "__mul" || name == "__div" || name == "__mod" || name == "__pow" || name == "__tostring" || - name == "__metatable" || name == "__eq" || name == "__lt" || name == "__le" || name == "__mode" || name == "__iter" || name == "__len"; -} - -ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatDeclareClass* declaredClass) -{ - std::optional superTy = FFlag::LuauNegatedClassTypes ? std::make_optional(builtinTypes->classType) : std::nullopt; - if (declaredClass->superName) - { - Name superName = Name(declaredClass->superName->value); - std::optional lookupType = scope->lookupType(superName); - - if (!lookupType) - { - reportError(declaredClass->location, UnknownSymbol{superName, UnknownSymbol::Type}); - return ControlFlow::None; - } - - // We don't have generic classes, so this assertion _should_ never be hit. - LUAU_ASSERT(lookupType->typeParams.size() == 0 && lookupType->typePackParams.size() == 0); - superTy = lookupType->type; - - if (!get(follow(*superTy))) - { - reportError(declaredClass->location, - GenericError{format("Cannot use non-class type '%s' as a superclass of class '%s'", superName.c_str(), declaredClass->name.value)}); - - return ControlFlow::None; - } - } - - Name className(declaredClass->name.value); - - TypeId classTy = arena->addType(ClassType(className, {}, superTy, std::nullopt, {}, {}, module->name)); - ClassType* ctv = getMutable(classTy); - - TypeId metaTy = arena->addType(TableType{TableState::Sealed, scope->level, scope.get()}); - TableType* metatable = getMutable(metaTy); - - ctv->metatable = metaTy; - - scope->exportedTypeBindings[className] = TypeFun{{}, classTy}; - - for (const AstDeclaredClassProp& prop : declaredClass->props) - { - Name propName(prop.name.value); - TypeId propTy = resolveType(scope, prop.ty, /* inTypeArguments */ false); - - bool assignToMetatable = isMetamethod(propName); - - // Function types always take 'self', but this isn't reflected in the - // parsed annotation. Add it here. - if (prop.isMethod) - { - if (FunctionType* ftv = getMutable(propTy)) - { - ftv->argNames.insert(ftv->argNames.begin(), FunctionArgument{"self", {}}); - ftv->argTypes = arena->addTypePack(TypePack{{classTy}, ftv->argTypes}); - - ftv->hasSelf = true; - } - } - - if (ctv->props.count(propName) == 0) - { - if (assignToMetatable) - metatable->props[propName] = {propTy}; - else - ctv->props[propName] = {propTy}; - } - else - { - TypeId currentTy = assignToMetatable ? metatable->props[propName].type() : ctv->props[propName].type(); - - // We special-case this logic to keep the intersection flat; otherwise we - // would create a ton of nested intersection types. - if (const IntersectionType* itv = get(currentTy)) - { - std::vector options = itv->parts; - options.push_back(propTy); - TypeId newItv = arena->addType(IntersectionType{std::move(options)}); - - if (assignToMetatable) - metatable->props[propName] = {newItv}; - else - ctv->props[propName] = {newItv}; - } - else if (get(currentTy)) - { - TypeId intersection = arena->addType(IntersectionType{{currentTy, propTy}}); - - if (assignToMetatable) - metatable->props[propName] = {intersection}; - else - ctv->props[propName] = {intersection}; - } - else - { - reportError(declaredClass->location, GenericError{format("Cannot overload non-function class member '%s'", propName.c_str())}); - } - } - } - - return ControlFlow::None; -} - -ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatDeclareFunction* global) -{ - std::vector> generics = createGenerics(scope, global->generics); - std::vector> genericPacks = createGenericPacks(scope, global->genericPacks); - - std::vector genericTys; - genericTys.reserve(generics.size()); - for (auto& [name, generic] : generics) - { - genericTys.push_back(generic.ty); - } - - std::vector genericTps; - genericTps.reserve(genericPacks.size()); - for (auto& [name, generic] : genericPacks) - { - genericTps.push_back(generic.tp); - } - - ScopePtr funScope = scope; - if (!generics.empty() || !genericPacks.empty()) - funScope = childScope(global, scope); - - TypePackId paramPack = resolveTypePack(funScope, global->params, /* inTypeArguments */ false); - TypePackId retPack = resolveTypePack(funScope, global->retTypes, /* inTypeArguments */ false); - TypeId fnType = arena->addType(FunctionType{TypeLevel{}, funScope.get(), std::move(genericTys), std::move(genericTps), paramPack, retPack}); - FunctionType* ftv = getMutable(fnType); - - ftv->argNames.reserve(global->paramNames.size); - for (const auto& el : global->paramNames) - ftv->argNames.push_back(FunctionArgument{el.first.value, el.second}); - - Name fnName(global->name.value); - - module->declaredGlobals[fnName] = fnType; - scope->bindings[global->name] = Binding{fnType, global->location}; - - BreadcrumbId bc = dfg->getBreadcrumb(global); - rootScope->dcrRefinements[bc->def] = fnType; - - return ControlFlow::None; -} - -ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatError* error) -{ - for (AstStat* stat : error->statements) - visit(scope, stat); - for (AstExpr* expr : error->expressions) - check(scope, expr); - - return ControlFlow::None; -} - -InferencePack ConstraintGraphBuilder::checkPack( - const ScopePtr& scope, AstArray exprs, const std::vector>& expectedTypes) -{ - std::vector head; - std::optional tail; - - for (size_t i = 0; i < exprs.size; ++i) - { - AstExpr* expr = exprs.data[i]; - if (i < exprs.size - 1) - { - std::optional expectedType; - if (i < expectedTypes.size()) - expectedType = expectedTypes[i]; - head.push_back(check(scope, expr, ValueContext::RValue, expectedType).ty); - } - else - { - std::vector> expectedTailTypes; - if (i < expectedTypes.size()) - expectedTailTypes.assign(begin(expectedTypes) + i, end(expectedTypes)); - tail = checkPack(scope, expr, expectedTailTypes).tp; - } - } - - if (head.empty() && tail) - return InferencePack{*tail}; - else - return InferencePack{arena->addTypePack(TypePack{std::move(head), tail})}; -} - -InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExpr* expr, const std::vector>& expectedTypes) -{ - RecursionCounter counter{&recursionCount}; - - if (recursionCount >= FInt::LuauCheckRecursionLimit) - { - reportCodeTooComplex(expr->location); - return InferencePack{builtinTypes->errorRecoveryTypePack()}; - } - - InferencePack result; - - if (AstExprCall* call = expr->as()) - result = checkPack(scope, call); - else if (AstExprVarargs* varargs = expr->as()) - { - if (scope->varargPack) - result = InferencePack{*scope->varargPack}; - else - result = InferencePack{builtinTypes->errorRecoveryTypePack()}; - } - else - { - std::optional expectedType; - if (!expectedTypes.empty()) - expectedType = expectedTypes[0]; - TypeId t = check(scope, expr, ValueContext::RValue, expectedType).ty; - result = InferencePack{arena->addTypePack({t})}; - } - - LUAU_ASSERT(result.tp); - module->astTypePacks[expr] = result.tp; - return result; -} - -InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExprCall* call) -{ - std::vector exprArgs; - - std::vector returnRefinements; - std::vector> discriminantTypes; - - if (call->self) - { - AstExprIndexName* indexExpr = call->func->as(); - if (!indexExpr) - ice->ice("method call expression has no 'self'"); - - exprArgs.push_back(indexExpr->expr); - - if (auto bc = dfg->getBreadcrumb(indexExpr->expr)) - { - TypeId discriminantTy = arena->addType(BlockedType{}); - returnRefinements.push_back(refinementArena.proposition(NotNull{bc}, discriminantTy)); - discriminantTypes.push_back(discriminantTy); - } - else - discriminantTypes.push_back(std::nullopt); - } - - for (AstExpr* arg : call->args) - { - exprArgs.push_back(arg); - - if (auto bc = dfg->getBreadcrumb(arg)) - { - TypeId discriminantTy = arena->addType(BlockedType{}); - returnRefinements.push_back(refinementArena.proposition(NotNull{bc}, discriminantTy)); - discriminantTypes.push_back(discriminantTy); - } - else - discriminantTypes.push_back(std::nullopt); - } - - Checkpoint startCheckpoint = checkpoint(this); - TypeId fnType = check(scope, call->func).ty; - Checkpoint fnEndCheckpoint = checkpoint(this); - - std::vector> expectedTypesForCall = getExpectedCallTypesForFunctionOverloads(fnType); - - module->astOriginalCallTypes[call->func] = fnType; - - TypePackId expectedArgPack = arena->freshTypePack(scope.get()); - TypePackId expectedRetPack = arena->freshTypePack(scope.get()); - TypeId expectedFunctionType = arena->addType(FunctionType{expectedArgPack, expectedRetPack, std::nullopt, call->self}); - - TypeId instantiatedFnType = arena->addType(BlockedType{}); - addConstraint(scope, call->location, InstantiationConstraint{instantiatedFnType, fnType}); - - NotNull extractArgsConstraint = addConstraint(scope, call->location, SubtypeConstraint{instantiatedFnType, expectedFunctionType}); - - // Fully solve fnType, then extract its argument list as expectedArgPack. - forEachConstraint(startCheckpoint, fnEndCheckpoint, this, [extractArgsConstraint](const ConstraintPtr& constraint) { - extractArgsConstraint->dependencies.emplace_back(constraint.get()); - }); - - const AstExpr* lastArg = exprArgs.size() ? exprArgs[exprArgs.size() - 1] : nullptr; - const bool needTail = lastArg && (lastArg->is() || lastArg->is()); - - TypePack expectedArgs; - - if (!needTail) - expectedArgs = extendTypePack(*arena, builtinTypes, expectedArgPack, exprArgs.size(), expectedTypesForCall); - else - expectedArgs = extendTypePack(*arena, builtinTypes, expectedArgPack, exprArgs.size() - 1, expectedTypesForCall); - - std::vector args; - std::optional argTail; - std::vector argumentRefinements; - - Checkpoint argCheckpoint = checkpoint(this); - - for (size_t i = 0; i < exprArgs.size(); ++i) - { - AstExpr* arg = exprArgs[i]; - std::optional expectedType; - if (i < expectedArgs.head.size()) - expectedType = expectedArgs.head[i]; - - if (i == 0 && call->self) - { - // The self type has already been computed as a side effect of - // computing fnType. If computing that did not cause us to exceed a - // recursion limit, we can fetch it from astTypes rather than - // recomputing it. - TypeId* selfTy = module->astTypes.find(exprArgs[0]); - if (selfTy) - args.push_back(*selfTy); - else - args.push_back(arena->freshType(scope.get())); - } - else if (i < exprArgs.size() - 1 || !(arg->is() || arg->is())) - { - auto [ty, refinement] = check(scope, arg, ValueContext::RValue, expectedType); - args.push_back(ty); - argumentRefinements.push_back(refinement); - } - else - { - auto [tp, refis] = checkPack(scope, arg, {}); - argTail = tp; - argumentRefinements.insert(argumentRefinements.end(), refis.begin(), refis.end()); - } - } - - Checkpoint argEndCheckpoint = checkpoint(this); - - // Do not solve argument constraints until after we have extracted the - // expected types from the callable. - forEachConstraint(argCheckpoint, argEndCheckpoint, this, [extractArgsConstraint](const ConstraintPtr& constraint) { - constraint->dependencies.push_back(extractArgsConstraint); - }); - - if (matchSetmetatable(*call)) - { - TypePack argTailPack; - if (argTail && args.size() < 2) - argTailPack = extendTypePack(*arena, builtinTypes, *argTail, 2 - args.size()); - - TypeId target = nullptr; - TypeId mt = nullptr; - - if (args.size() + argTailPack.head.size() == 2) - { - target = args.size() > 0 ? args[0] : argTailPack.head[0]; - mt = args.size() > 1 ? args[1] : argTailPack.head[args.size() == 0 ? 1 : 0]; - } - else - { - std::vector unpackedTypes; - if (args.size() > 0) - target = args[0]; - else - { - target = arena->addType(BlockedType{}); - unpackedTypes.emplace_back(target); - } - - mt = arena->addType(BlockedType{}); - unpackedTypes.emplace_back(mt); - TypePackId mtPack = arena->addTypePack(std::move(unpackedTypes)); - - addConstraint(scope, call->location, UnpackConstraint{mtPack, *argTail}); - } - - LUAU_ASSERT(target); - LUAU_ASSERT(mt); - - AstExpr* targetExpr = call->args.data[0]; - - MetatableType mtv{target, mt}; - TypeId resultTy = arena->addType(mtv); - - if (AstExprLocal* targetLocal = targetExpr->as()) - { - scope->bindings[targetLocal->local].typeId = resultTy; - - BreadcrumbId bc = dfg->getBreadcrumb(targetLocal); - scope->dcrRefinements[bc->def] = resultTy; // TODO: typestates: track this as an assignment - } - - return InferencePack{arena->addTypePack({resultTy}), {refinementArena.variadic(returnRefinements)}}; - } - else - { - if (matchAssert(*call) && !argumentRefinements.empty()) - applyRefinements(scope, call->args.data[0]->location, argumentRefinements[0]); - - // TODO: How do expectedTypes play into this? Do they? - TypePackId rets = arena->addTypePack(BlockedTypePack{}); - TypePackId argPack = arena->addTypePack(TypePack{args, argTail}); - FunctionType ftv(TypeLevel{}, scope.get(), argPack, rets, std::nullopt, call->self); - - NotNull fcc = addConstraint(scope, call->func->location, - FunctionCallConstraint{ - fnType, - argPack, - rets, - call, - std::move(discriminantTypes), - }); - - // We force constraints produced by checking function arguments to wait - // until after we have resolved the constraint on the function itself. - // This ensures, for instance, that we start inferring the contents of - // lambdas under the assumption that their arguments and return types - // will be compatible with the enclosing function call. - forEachConstraint(fnEndCheckpoint, argEndCheckpoint, this, [fcc](const ConstraintPtr& constraint) { - fcc->dependencies.emplace_back(constraint.get()); - }); - - return InferencePack{rets, {refinementArena.variadic(returnRefinements)}}; - } -} - -Inference ConstraintGraphBuilder::check( - const ScopePtr& scope, AstExpr* expr, ValueContext context, std::optional expectedType, bool forceSingleton) -{ - RecursionCounter counter{&recursionCount}; - - if (recursionCount >= FInt::LuauCheckRecursionLimit) - { - reportCodeTooComplex(expr->location); - return Inference{builtinTypes->errorRecoveryType()}; - } - - Inference result; - - if (auto group = expr->as()) - result = check(scope, group->expr, ValueContext::RValue, expectedType, forceSingleton); - else if (auto stringExpr = expr->as()) - result = check(scope, stringExpr, expectedType, forceSingleton); - else if (expr->is()) - result = Inference{builtinTypes->numberType}; - else if (auto boolExpr = expr->as()) - result = check(scope, boolExpr, expectedType, forceSingleton); - else if (expr->is()) - result = Inference{builtinTypes->nilType}; - else if (auto local = expr->as()) - result = check(scope, local, context); - else if (auto global = expr->as()) - result = check(scope, global); - else if (expr->is()) - result = flattenPack(scope, expr->location, checkPack(scope, expr)); - else if (auto call = expr->as()) - result = flattenPack(scope, expr->location, checkPack(scope, call)); // TODO: needs predicates too - else if (auto a = expr->as()) - { - Checkpoint startCheckpoint = checkpoint(this); - FunctionSignature sig = checkFunctionSignature(scope, a, expectedType); - checkFunctionBody(sig.bodyScope, a); - Checkpoint endCheckpoint = checkpoint(this); - - TypeId generalizedTy = arena->addType(BlockedType{}); - NotNull gc = addConstraint(scope, expr->location, GeneralizationConstraint{generalizedTy, sig.signature}); - - forEachConstraint(startCheckpoint, endCheckpoint, this, [gc](const ConstraintPtr& constraint) { - gc->dependencies.emplace_back(constraint.get()); - }); - - result = Inference{generalizedTy}; - } - else if (auto indexName = expr->as()) - result = check(scope, indexName); - else if (auto indexExpr = expr->as()) - result = check(scope, indexExpr); - else if (auto table = expr->as()) - result = check(scope, table, expectedType); - else if (auto unary = expr->as()) - result = check(scope, unary); - else if (auto binary = expr->as()) - result = check(scope, binary, expectedType); - else if (auto ifElse = expr->as()) - result = check(scope, ifElse, expectedType); - else if (auto typeAssert = expr->as()) - result = check(scope, typeAssert); - else if (auto interpString = expr->as()) - result = check(scope, interpString); - else if (auto err = expr->as()) - { - // Open question: Should we traverse into this? - for (AstExpr* subExpr : err->expressions) - check(scope, subExpr); - - result = Inference{builtinTypes->errorRecoveryType()}; - } - else - { - LUAU_ASSERT(0); - result = Inference{freshType(scope)}; - } - - LUAU_ASSERT(result.ty); - module->astTypes[expr] = result.ty; - if (expectedType) - module->astExpectedTypes[expr] = *expectedType; - return result; -} - -Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprConstantString* string, std::optional expectedType, bool forceSingleton) -{ - if (forceSingleton) - return Inference{arena->addType(SingletonType{StringSingleton{std::string{string->value.data, string->value.size}}})}; - - if (expectedType) - { - const TypeId expectedTy = follow(*expectedType); - if (get(expectedTy) || get(expectedTy)) - { - TypeId ty = arena->addType(BlockedType{}); - TypeId singletonType = arena->addType(SingletonType(StringSingleton{std::string(string->value.data, string->value.size)})); - addConstraint(scope, string->location, PrimitiveTypeConstraint{ty, expectedTy, singletonType, builtinTypes->stringType}); - return Inference{ty}; - } - else if (maybeSingleton(expectedTy)) - return Inference{arena->addType(SingletonType{StringSingleton{std::string{string->value.data, string->value.size}}})}; - - return Inference{builtinTypes->stringType}; - } - - return Inference{builtinTypes->stringType}; -} - -Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprConstantBool* boolExpr, std::optional expectedType, bool forceSingleton) -{ - const TypeId singletonType = boolExpr->value ? builtinTypes->trueType : builtinTypes->falseType; - if (forceSingleton) - return Inference{singletonType}; - - if (expectedType) - { - const TypeId expectedTy = follow(*expectedType); - - if (get(expectedTy) || get(expectedTy)) - { - TypeId ty = arena->addType(BlockedType{}); - addConstraint(scope, boolExpr->location, PrimitiveTypeConstraint{ty, expectedTy, singletonType, builtinTypes->booleanType}); - return Inference{ty}; - } - else if (maybeSingleton(expectedTy)) - return Inference{singletonType}; - - return Inference{builtinTypes->booleanType}; - } - - return Inference{builtinTypes->booleanType}; -} - -Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprLocal* local, ValueContext context) -{ - BreadcrumbId bc = dfg->getBreadcrumb(local); - - if (auto ty = scope->lookup(bc->def); ty && context == ValueContext::RValue) - return Inference{*ty, refinementArena.proposition(bc, builtinTypes->truthyType)}; - else if (auto ty = scope->lookup(local->local)) - return Inference{*ty, refinementArena.proposition(bc, builtinTypes->truthyType)}; - else - ice->ice("AstExprLocal came before its declaration?"); -} - -Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprGlobal* global) -{ - BreadcrumbId bc = dfg->getBreadcrumb(global); - - /* prepopulateGlobalScope() has already added all global functions to the environment by this point, so any - * global that is not already in-scope is definitely an unknown symbol. - */ - if (auto ty = scope->lookup(bc->def)) - return Inference{*ty, refinementArena.proposition(bc, builtinTypes->truthyType)}; - else if (auto ty = scope->lookup(global->name)) - { - rootScope->dcrRefinements[bc->def] = *ty; - return Inference{*ty, refinementArena.proposition(bc, builtinTypes->truthyType)}; - } - else - { - reportError(global->location, UnknownSymbol{global->name.value}); - return Inference{builtinTypes->errorRecoveryType()}; - } -} - -Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIndexName* indexName) -{ - TypeId obj = check(scope, indexName->expr).ty; - TypeId result = arena->addType(BlockedType{}); - - NullableBreadcrumbId bc = dfg->getBreadcrumb(indexName); - if (bc) - { - if (auto ty = scope->lookup(bc->def)) - return Inference{*ty, refinementArena.proposition(NotNull{bc}, builtinTypes->truthyType)}; - - scope->dcrRefinements[bc->def] = result; - } - - addConstraint(scope, indexName->expr->location, HasPropConstraint{result, obj, indexName->index.value}); - - if (bc) - return Inference{result, refinementArena.proposition(NotNull{bc}, builtinTypes->truthyType)}; - else - return Inference{result}; -} - -Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIndexExpr* indexExpr) -{ - TypeId obj = check(scope, indexExpr->expr).ty; - TypeId indexType = check(scope, indexExpr->index).ty; - TypeId result = freshType(scope); - - NullableBreadcrumbId bc = dfg->getBreadcrumb(indexExpr); - if (bc) - { - if (auto ty = scope->lookup(bc->def)) - return Inference{*ty, refinementArena.proposition(NotNull{bc}, builtinTypes->truthyType)}; - - scope->dcrRefinements[bc->def] = result; - } - - TableIndexer indexer{indexType, result}; - TypeId tableType = arena->addType(TableType{TableType::Props{}, TableIndexer{indexType, result}, TypeLevel{}, scope.get(), TableState::Free}); - - addConstraint(scope, indexExpr->expr->location, SubtypeConstraint{obj, tableType}); - - if (bc) - return Inference{result, refinementArena.proposition(NotNull{bc}, builtinTypes->truthyType)}; - else - return Inference{result}; -} - -Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprUnary* unary) -{ - auto [operandType, refinement] = check(scope, unary->expr); - TypeId resultType = arena->addType(BlockedType{}); - addConstraint(scope, unary->location, UnaryConstraint{unary->op, operandType, resultType}); - - if (unary->op == AstExprUnary::Not) - return Inference{resultType, refinementArena.negation(refinement)}; - else - return Inference{resultType}; -} - -Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprBinary* binary, std::optional expectedType) -{ - auto [leftType, rightType, refinement] = checkBinary(scope, binary, expectedType); - - TypeId resultType = arena->addType(BlockedType{}); - addConstraint(scope, binary->location, - BinaryConstraint{binary->op, leftType, rightType, resultType, binary, &module->astOriginalCallTypes, &module->astOverloadResolvedTypes}); - return Inference{resultType, std::move(refinement)}; -} - -Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIfElse* ifElse, std::optional expectedType) -{ - ScopePtr condScope = childScope(ifElse->condition, scope); - RefinementId refinement = check(condScope, ifElse->condition).refinement; - - ScopePtr thenScope = childScope(ifElse->trueExpr, scope); - applyRefinements(thenScope, ifElse->trueExpr->location, refinement); - TypeId thenType = check(thenScope, ifElse->trueExpr, ValueContext::RValue, expectedType).ty; - - ScopePtr elseScope = childScope(ifElse->falseExpr, scope); - applyRefinements(elseScope, ifElse->falseExpr->location, refinementArena.negation(refinement)); - TypeId elseType = check(elseScope, ifElse->falseExpr, ValueContext::RValue, expectedType).ty; - - return Inference{expectedType ? *expectedType : arena->addType(UnionType{{thenType, elseType}})}; -} - -Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprTypeAssertion* typeAssert) -{ - check(scope, typeAssert->expr, ValueContext::RValue, std::nullopt); - return Inference{resolveType(scope, typeAssert->annotation, /* inTypeArguments */ false)}; -} - -Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprInterpString* interpString) -{ - for (AstExpr* expr : interpString->expressions) - check(scope, expr); - - return Inference{builtinTypes->stringType}; -} - -std::tuple ConstraintGraphBuilder::checkBinary( - const ScopePtr& scope, AstExprBinary* binary, std::optional expectedType) -{ - if (binary->op == AstExprBinary::And) - { - std::optional relaxedExpectedLhs; - - if (expectedType) - relaxedExpectedLhs = arena->addType(UnionType{{builtinTypes->falsyType, *expectedType}}); - - auto [leftType, leftRefinement] = check(scope, binary->left, ValueContext::RValue, relaxedExpectedLhs); - - ScopePtr rightScope = childScope(binary->right, scope); - applyRefinements(rightScope, binary->right->location, leftRefinement); - auto [rightType, rightRefinement] = check(rightScope, binary->right, ValueContext::RValue, expectedType); - - return {leftType, rightType, refinementArena.conjunction(leftRefinement, rightRefinement)}; - } - else if (binary->op == AstExprBinary::Or) - { - std::optional relaxedExpectedLhs; - - if (expectedType) - relaxedExpectedLhs = arena->addType(UnionType{{builtinTypes->falsyType, *expectedType}}); - - auto [leftType, leftRefinement] = check(scope, binary->left, ValueContext::RValue, relaxedExpectedLhs); - - ScopePtr rightScope = childScope(binary->right, scope); - applyRefinements(rightScope, binary->right->location, refinementArena.negation(leftRefinement)); - auto [rightType, rightRefinement] = check(rightScope, binary->right, ValueContext::RValue, expectedType); - - return {leftType, rightType, refinementArena.disjunction(leftRefinement, rightRefinement)}; - } - else if (auto typeguard = matchTypeGuard(binary)) - { - TypeId leftType = check(scope, binary->left).ty; - TypeId rightType = check(scope, binary->right).ty; - - NullableBreadcrumbId bc = dfg->getBreadcrumb(typeguard->target); - if (!bc) - return {leftType, rightType, nullptr}; - - TypeId discriminantTy = builtinTypes->neverType; - if (typeguard->type == "nil") - discriminantTy = builtinTypes->nilType; - else if (typeguard->type == "string") - discriminantTy = builtinTypes->stringType; - else if (typeguard->type == "number") - discriminantTy = builtinTypes->numberType; - else if (typeguard->type == "boolean") - discriminantTy = builtinTypes->booleanType; - else if (typeguard->type == "thread") - discriminantTy = builtinTypes->threadType; - else if (typeguard->type == "table") - discriminantTy = builtinTypes->tableType; - else if (typeguard->type == "function") - discriminantTy = builtinTypes->functionType; - else if (typeguard->type == "userdata") - { - // For now, we don't really care about being accurate with userdata if the typeguard was using typeof. - discriminantTy = builtinTypes->classType; - } - else if (!typeguard->isTypeof && typeguard->type == "vector") - discriminantTy = builtinTypes->neverType; // TODO: figure out a way to deal with this quirky type - else if (!typeguard->isTypeof) - discriminantTy = builtinTypes->neverType; - else if (auto typeFun = globalScope->lookupType(typeguard->type); typeFun && typeFun->typeParams.empty() && typeFun->typePackParams.empty()) - { - TypeId ty = follow(typeFun->type); - - // We're only interested in the root class of any classes. - if (auto ctv = get(ty); !ctv || (FFlag::LuauNegatedClassTypes ? (ctv->parent == builtinTypes->classType) : !ctv->parent)) - discriminantTy = ty; - } - - RefinementId proposition = refinementArena.proposition(NotNull{bc}, discriminantTy); - if (binary->op == AstExprBinary::CompareEq) - return {leftType, rightType, proposition}; - else if (binary->op == AstExprBinary::CompareNe) - return {leftType, rightType, refinementArena.negation(proposition)}; - else - ice->ice("matchTypeGuard should only return a Some under `==` or `~=`!"); - } - else if (binary->op == AstExprBinary::CompareEq || binary->op == AstExprBinary::CompareNe) - { - TypeId leftType = check(scope, binary->left, ValueContext::RValue, expectedType, true).ty; - TypeId rightType = check(scope, binary->right, ValueContext::RValue, expectedType, true).ty; - - RefinementId leftRefinement = nullptr; - if (auto bc = dfg->getBreadcrumb(binary->left)) - leftRefinement = refinementArena.proposition(NotNull{bc}, rightType); - - RefinementId rightRefinement = nullptr; - if (auto bc = dfg->getBreadcrumb(binary->right)) - rightRefinement = refinementArena.proposition(NotNull{bc}, leftType); - - if (binary->op == AstExprBinary::CompareNe) - { - leftRefinement = refinementArena.negation(leftRefinement); - rightRefinement = refinementArena.negation(rightRefinement); - } - - return {leftType, rightType, refinementArena.equivalence(leftRefinement, rightRefinement)}; - } - else - { - TypeId leftType = check(scope, binary->left, ValueContext::RValue).ty; - TypeId rightType = check(scope, binary->right, ValueContext::RValue).ty; - return {leftType, rightType, nullptr}; - } -} - -std::vector ConstraintGraphBuilder::checkLValues(const ScopePtr& scope, AstArray exprs) -{ - std::vector types; - types.reserve(exprs.size); - - for (AstExpr* expr : exprs) - types.push_back(checkLValue(scope, expr)); - - return types; -} - -static bool isIndexNameEquivalent(AstExpr* expr) -{ - if (expr->is()) - return true; - - AstExprIndexExpr* e = expr->as(); - if (e == nullptr) - return false; - - if (!e->index->is()) - return false; - - return true; -} - -/** - * This function is mostly about identifying properties that are being inserted into unsealed tables. - * - * If expr has the form name.a.b.c - */ -TypeId ConstraintGraphBuilder::checkLValue(const ScopePtr& scope, AstExpr* expr) -{ - if (auto indexExpr = expr->as(); indexExpr && !indexExpr->index->is()) - { - // An indexer is only interesting in an lvalue-ey way if it is at the - // tail of an expression. - // - // If the indexer is not at the tail, then we are not interested in - // augmenting the lhs data structure with a new indexer. Constraint - // generation can treat it as an ordinary lvalue. - // - // eg - // - // a.b.c[1] = 44 -- lvalue - // a.b[4].c = 2 -- rvalue - - TypeId resultType = arena->addType(BlockedType{}); - TypeId subjectType = check(scope, indexExpr->expr).ty; - TypeId indexType = check(scope, indexExpr->index).ty; - TypeId propType = arena->addType(BlockedType{}); - addConstraint(scope, expr->location, SetIndexerConstraint{resultType, subjectType, indexType, propType}); - - module->astTypes[expr] = propType; - - return propType; - } - else if (!isIndexNameEquivalent(expr)) - return check(scope, expr, ValueContext::LValue).ty; - - Symbol sym; - std::vector segments; - std::vector exprs; - - AstExpr* e = expr; - while (e) - { - if (auto global = e->as()) - { - sym = global->name; - break; - } - else if (auto local = e->as()) - { - sym = local->local; - break; - } - else if (auto indexName = e->as()) - { - segments.push_back(indexName->index.value); - exprs.push_back(e); - e = indexName->expr; - } - else if (auto indexExpr = e->as()) - { - if (auto strIndex = indexExpr->index->as()) - { - segments.push_back(std::string(strIndex->value.data, strIndex->value.size)); - exprs.push_back(e); - e = indexExpr->expr; - } - else - { - return check(scope, expr, ValueContext::LValue).ty; - } - } - else - return check(scope, expr, ValueContext::LValue).ty; - } - - LUAU_ASSERT(!segments.empty()); - - std::reverse(begin(segments), end(segments)); - std::reverse(begin(exprs), end(exprs)); - - auto lookupResult = scope->lookupEx(sym); - if (!lookupResult) - return check(scope, expr, ValueContext::LValue).ty; - const auto [subjectBinding, symbolScope] = std::move(*lookupResult); - TypeId subjectType = subjectBinding->typeId; - - TypeId propTy = freshType(scope); - - std::vector segmentStrings(begin(segments), end(segments)); - - TypeId updatedType = arena->addType(BlockedType{}); - addConstraint(scope, expr->location, SetPropConstraint{updatedType, subjectType, std::move(segmentStrings), propTy}); - - TypeId prevSegmentTy = updatedType; - for (size_t i = 0; i < segments.size(); ++i) - { - TypeId segmentTy = arena->addType(BlockedType{}); - module->astTypes[exprs[i]] = segmentTy; - addConstraint(scope, expr->location, HasPropConstraint{segmentTy, prevSegmentTy, segments[i]}); - prevSegmentTy = segmentTy; - } - - module->astTypes[expr] = prevSegmentTy; - module->astTypes[e] = updatedType; - - if (!subjectType->persistent) - { - symbolScope->bindings[sym].typeId = updatedType; - - // This can fail if the user is erroneously trying to augment a builtin - // table like os or string. - if (auto bc = dfg->getBreadcrumb(e)) - symbolScope->dcrRefinements[bc->def] = updatedType; - } - - return propTy; -} - -Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprTable* expr, std::optional expectedType) -{ - TypeId ty = arena->addType(TableType{}); - TableType* ttv = getMutable(ty); - LUAU_ASSERT(ttv); - - ttv->state = TableState::Unsealed; - ttv->scope = scope.get(); - - auto createIndexer = [this, scope, ttv](const Location& location, TypeId currentIndexType, TypeId currentResultType) { - if (!ttv->indexer) - { - TypeId indexType = this->freshType(scope); - TypeId resultType = this->freshType(scope); - ttv->indexer = TableIndexer{indexType, resultType}; - } - - addConstraint(scope, location, SubtypeConstraint{ttv->indexer->indexType, currentIndexType}); - addConstraint(scope, location, SubtypeConstraint{ttv->indexer->indexResultType, currentResultType}); - }; - - std::optional annotatedKeyType; - std::optional annotatedIndexResultType; - - if (expectedType) - { - if (const TableType* ttv = get(follow(*expectedType))) - { - if (ttv->indexer) - { - annotatedKeyType.emplace(follow(ttv->indexer->indexType)); - annotatedIndexResultType.emplace(ttv->indexer->indexResultType); - } - } - } - - bool isIndexedResultType = false; - std::optional pinnedIndexResultType; - - - for (const AstExprTable::Item& item : expr->items) - { - std::optional expectedValueType; - if (item.kind == AstExprTable::Item::Kind::General || item.kind == AstExprTable::Item::Kind::List) - isIndexedResultType = true; - - if (item.key && expectedType) - { - if (auto stringKey = item.key->as()) - { - ErrorVec errorVec; - std::optional propTy = - findTablePropertyRespectingMeta(builtinTypes, errorVec, follow(*expectedType), stringKey->value.data, item.value->location); - if (propTy) - expectedValueType = propTy; - else - { - expectedValueType = arena->addType(BlockedType{}); - addConstraint(scope, item.value->location, HasPropConstraint{*expectedValueType, *expectedType, stringKey->value.data}); - } - } - } - - - // We'll resolve the expected index result type here with the following priority: - // 1. Record table types - in which key, value pairs must be handled on a k,v pair basis. - // In this case, the above if-statement will populate expectedValueType - // 2. Someone places an annotation on a General or List table - // Trust the annotation and have the solver inform them if they get it wrong - // 3. Someone omits the annotation on a general or List table - // Use the type of the first indexResultType as the expected type - std::optional checkExpectedIndexResultType; - if (expectedValueType) - { - checkExpectedIndexResultType = expectedValueType; - } - else if (annotatedIndexResultType) - { - checkExpectedIndexResultType = annotatedIndexResultType; - } - else if (pinnedIndexResultType) - { - checkExpectedIndexResultType = pinnedIndexResultType; - } - - TypeId itemTy = check(scope, item.value, ValueContext::RValue, checkExpectedIndexResultType).ty; - - if (isIndexedResultType && !pinnedIndexResultType) - pinnedIndexResultType = itemTy; - - if (item.key) - { - // Even though we don't need to use the type of the item's key if - // it's a string constant, we still want to check it to populate - // astTypes. - TypeId keyTy = check(scope, item.key, ValueContext::RValue, annotatedKeyType).ty; - - if (AstExprConstantString* key = item.key->as()) - { - ttv->props[key->value.begin()] = {itemTy}; - } - else - { - createIndexer(item.key->location, keyTy, itemTy); - } - } - else - { - TypeId numberType = builtinTypes->numberType; - // FIXME? The location isn't quite right here. Not sure what is - // right. - createIndexer(item.value->location, numberType, itemTy); - } - } - - return Inference{ty}; -} - -ConstraintGraphBuilder::FunctionSignature ConstraintGraphBuilder::checkFunctionSignature( - const ScopePtr& parent, AstExprFunction* fn, std::optional expectedType) -{ - ScopePtr signatureScope = nullptr; - ScopePtr bodyScope = nullptr; - TypePackId returnType = nullptr; - - std::vector genericTypes; - std::vector genericTypePacks; - - if (expectedType) - expectedType = follow(*expectedType); - - bool hasGenerics = fn->generics.size > 0 || fn->genericPacks.size > 0; - - signatureScope = childScope(fn, parent); - - // We need to assign returnType before creating bodyScope so that the - // return type gets propogated to bodyScope. - returnType = freshTypePack(signatureScope); - signatureScope->returnType = returnType; - - bodyScope = childScope(fn->body, signatureScope); - - if (hasGenerics) - { - std::vector> genericDefinitions = createGenerics(signatureScope, fn->generics); - std::vector> genericPackDefinitions = createGenericPacks(signatureScope, fn->genericPacks); - - // We do not support default values on function generics, so we only - // care about the types involved. - for (const auto& [name, g] : genericDefinitions) - { - genericTypes.push_back(g.ty); - } - - for (const auto& [name, g] : genericPackDefinitions) - { - genericTypePacks.push_back(g.tp); - } - - // Local variable works around an odd gcc 11.3 warning: may be used uninitialized - std::optional none = std::nullopt; - expectedType = none; - } - - std::vector argTypes; - std::vector> argNames; - TypePack expectedArgPack; - - const FunctionType* expectedFunction = expectedType ? get(*expectedType) : nullptr; - // This check ensures that expectedType is precisely optional and not any (since any is also an optional type) - if (expectedType && isOptional(*expectedType) && !get(*expectedType)) - { - auto ut = get(*expectedType); - for (auto u : ut) - { - if (get(u) && !isNil(u)) - { - expectedFunction = get(u); - break; - } - } - } - - if (expectedFunction) - { - expectedArgPack = extendTypePack(*arena, builtinTypes, expectedFunction->argTypes, fn->args.size); - - genericTypes = expectedFunction->generics; - genericTypePacks = expectedFunction->genericPacks; - } - - if (fn->self) - { - TypeId selfType = freshType(signatureScope); - argTypes.push_back(selfType); - argNames.emplace_back(FunctionArgument{fn->self->name.value, fn->self->location}); - signatureScope->bindings[fn->self] = Binding{selfType, fn->self->location}; - - BreadcrumbId bc = dfg->getBreadcrumb(fn->self); - signatureScope->dcrRefinements[bc->def] = selfType; - } - - for (size_t i = 0; i < fn->args.size; ++i) - { - AstLocal* local = fn->args.data[i]; - - TypeId argTy = nullptr; - if (local->annotation) - argTy = resolveType(signatureScope, local->annotation, /* inTypeArguments */ false, /* replaceErrorWithFresh*/ true); - else - { - argTy = freshType(signatureScope); - - if (i < expectedArgPack.head.size()) - addConstraint(signatureScope, local->location, SubtypeConstraint{argTy, expectedArgPack.head[i]}); - } - - argTypes.push_back(argTy); - argNames.emplace_back(FunctionArgument{local->name.value, local->location}); - signatureScope->bindings[local] = Binding{argTy, local->location}; - - BreadcrumbId bc = dfg->getBreadcrumb(local); - signatureScope->dcrRefinements[bc->def] = argTy; - } - - TypePackId varargPack = nullptr; - - if (fn->vararg) - { - if (fn->varargAnnotation) - { - TypePackId annotationType = - resolveTypePack(signatureScope, fn->varargAnnotation, /* inTypeArguments */ false, /* replaceErrorWithFresh */ true); - varargPack = annotationType; - } - else if (expectedArgPack.tail && get(*expectedArgPack.tail)) - varargPack = *expectedArgPack.tail; - else - varargPack = builtinTypes->anyTypePack; - - signatureScope->varargPack = varargPack; - bodyScope->varargPack = varargPack; - } - else - { - varargPack = arena->addTypePack(VariadicTypePack{builtinTypes->anyType, /*hidden*/ true}); - // We do not add to signatureScope->varargPack because ... is not valid - // in functions without an explicit ellipsis. - - signatureScope->varargPack = std::nullopt; - bodyScope->varargPack = std::nullopt; - } - - LUAU_ASSERT(nullptr != varargPack); - - // If there is both an annotation and an expected type, the annotation wins. - // Type checking will sort out any discrepancies later. - if (fn->returnAnnotation) - { - TypePackId annotatedRetType = - resolveTypePack(signatureScope, *fn->returnAnnotation, /* inTypeArguments */ false, /* replaceErrorWithFresh*/ true); - // We bind the annotated type directly here so that, when we need to - // generate constraints for return types, we have a guarantee that we - // know the annotated return type already, if one was provided. - LUAU_ASSERT(get(returnType)); - asMutable(returnType)->ty.emplace(annotatedRetType); - } - else if (expectedFunction) - { - asMutable(returnType)->ty.emplace(expectedFunction->retTypes); - } - - // TODO: Preserve argument names in the function's type. - - FunctionType actualFunction{TypeLevel{}, parent.get(), arena->addTypePack(argTypes, varargPack), returnType}; - actualFunction.hasNoGenerics = !hasGenerics; - actualFunction.generics = std::move(genericTypes); - actualFunction.genericPacks = std::move(genericTypePacks); - actualFunction.argNames = std::move(argNames); - actualFunction.hasSelf = fn->self != nullptr; - - TypeId actualFunctionType = arena->addType(std::move(actualFunction)); - LUAU_ASSERT(actualFunctionType); - module->astTypes[fn] = actualFunctionType; - - if (expectedType && get(*expectedType)) - bindFreeType(*expectedType, actualFunctionType); - - return { - /* signature */ actualFunctionType, - /* signatureScope */ signatureScope, - /* bodyScope */ bodyScope, - }; -} - -void ConstraintGraphBuilder::checkFunctionBody(const ScopePtr& scope, AstExprFunction* fn) -{ - visitBlockWithoutChildScope(scope, fn->body); - - // If it is possible for execution to reach the end of the function, the return type must be compatible with () - - if (nullptr != getFallthrough(fn->body)) - { - TypePackId empty = arena->addTypePack({}); // TODO we could have CSG retain one of these forever - addConstraint(scope, fn->location, PackSubtypeConstraint{scope->returnType, empty}); - } -} - -TypeId ConstraintGraphBuilder::resolveType(const ScopePtr& scope, AstType* ty, bool inTypeArguments, bool replaceErrorWithFresh) -{ - TypeId result = nullptr; - - if (auto ref = ty->as()) - { - if (FFlag::DebugLuauMagicTypes) - { - if (ref->name == "_luau_ice") - ice->ice("_luau_ice encountered", ty->location); - else if (ref->name == "_luau_print") - { - if (ref->parameters.size != 1 || !ref->parameters.data[0].type) - { - reportError(ty->location, GenericError{"_luau_print requires one generic parameter"}); - return builtinTypes->errorRecoveryType(); - } - else - return resolveType(scope, ref->parameters.data[0].type, inTypeArguments); - } - } - - std::optional alias; - - if (ref->prefix.has_value()) - { - alias = scope->lookupImportedType(ref->prefix->value, ref->name.value); - } - else - { - alias = scope->lookupType(ref->name.value); - } - - if (alias.has_value()) - { - // If the alias is not generic, we don't need to set up a blocked - // type and an instantiation constraint. - if (alias.has_value() && alias->typeParams.empty() && alias->typePackParams.empty()) - { - result = alias->type; - } - else - { - std::vector parameters; - std::vector packParameters; - - for (const AstTypeOrPack& p : ref->parameters) - { - // We do not enforce the ordering of types vs. type packs here; - // that is done in the parser. - if (p.type) - { - parameters.push_back(resolveType(scope, p.type, /* inTypeArguments */ true)); - } - else if (p.typePack) - { - packParameters.push_back(resolveTypePack(scope, p.typePack, /* inTypeArguments */ true)); - } - else - { - // This indicates a parser bug: one of these two pointers - // should be set. - LUAU_ASSERT(false); - } - } - - result = arena->addType(PendingExpansionType{ref->prefix, ref->name, parameters, packParameters}); - - // If we're not in a type argument context, we need to create a constraint that expands this. - // The dispatching of the above constraint will queue up additional constraints for nested - // type function applications. - if (!inTypeArguments) - addConstraint(scope, ty->location, TypeAliasExpansionConstraint{/* target */ result}); - } - } - else - { - result = builtinTypes->errorRecoveryType(); - if (replaceErrorWithFresh) - result = freshType(scope); - } - } - else if (auto tab = ty->as()) - { - TableType::Props props; - std::optional indexer; - - for (const AstTableProp& prop : tab->props) - { - std::string name = prop.name.value; - // TODO: Recursion limit. - TypeId propTy = resolveType(scope, prop.type, inTypeArguments); - // TODO: Fill in location. - props[name] = {propTy}; - } - - if (tab->indexer) - { - // TODO: Recursion limit. - indexer = TableIndexer{ - resolveType(scope, tab->indexer->indexType, inTypeArguments), - resolveType(scope, tab->indexer->resultType, inTypeArguments), - }; - } - - result = arena->addType(TableType{props, indexer, scope->level, scope.get(), TableState::Sealed}); - } - else if (auto fn = ty->as()) - { - // TODO: Recursion limit. - bool hasGenerics = fn->generics.size > 0 || fn->genericPacks.size > 0; - ScopePtr signatureScope = nullptr; - - std::vector genericTypes; - std::vector genericTypePacks; - - // If we don't have generics, we do not need to generate a child scope - // for the generic bindings to live on. - if (hasGenerics) - { - signatureScope = childScope(fn, scope); - - std::vector> genericDefinitions = createGenerics(signatureScope, fn->generics); - std::vector> genericPackDefinitions = createGenericPacks(signatureScope, fn->genericPacks); - - for (const auto& [name, g] : genericDefinitions) - { - genericTypes.push_back(g.ty); - } - - for (const auto& [name, g] : genericPackDefinitions) - { - genericTypePacks.push_back(g.tp); - } - } - else - { - // To eliminate the need to branch on hasGenerics below, we say that - // the signature scope is the parent scope if we don't have - // generics. - signatureScope = scope; - } - - TypePackId argTypes = resolveTypePack(signatureScope, fn->argTypes, inTypeArguments, replaceErrorWithFresh); - TypePackId returnTypes = resolveTypePack(signatureScope, fn->returnTypes, inTypeArguments, replaceErrorWithFresh); - - // TODO: FunctionType needs a pointer to the scope so that we know - // how to quantify/instantiate it. - FunctionType ftv{TypeLevel{}, scope.get(), {}, {}, argTypes, returnTypes}; - - // This replicates the behavior of the appropriate FunctionType - // constructors. - ftv.hasNoGenerics = !hasGenerics; - ftv.generics = std::move(genericTypes); - ftv.genericPacks = std::move(genericTypePacks); - - ftv.argNames.reserve(fn->argNames.size); - for (const auto& el : fn->argNames) - { - if (el) - { - const auto& [name, location] = *el; - ftv.argNames.push_back(FunctionArgument{name.value, location}); - } - else - { - ftv.argNames.push_back(std::nullopt); - } - } - - result = arena->addType(std::move(ftv)); - } - else if (auto tof = ty->as()) - { - // TODO: Recursion limit. - TypeId exprType = check(scope, tof->expr).ty; - result = exprType; - } - else if (auto unionAnnotation = ty->as()) - { - std::vector parts; - for (AstType* part : unionAnnotation->types) - { - // TODO: Recursion limit. - parts.push_back(resolveType(scope, part, inTypeArguments)); - } - - result = arena->addType(UnionType{parts}); - } - else if (auto intersectionAnnotation = ty->as()) - { - std::vector parts; - for (AstType* part : intersectionAnnotation->types) - { - // TODO: Recursion limit. - parts.push_back(resolveType(scope, part, inTypeArguments)); - } - - result = arena->addType(IntersectionType{parts}); - } - else if (auto boolAnnotation = ty->as()) - { - result = arena->addType(SingletonType(BooleanSingleton{boolAnnotation->value})); - } - else if (auto stringAnnotation = ty->as()) - { - result = arena->addType(SingletonType(StringSingleton{std::string(stringAnnotation->value.data, stringAnnotation->value.size)})); - } - else if (ty->is()) - { - result = builtinTypes->errorRecoveryType(); - if (replaceErrorWithFresh) - result = freshType(scope); - } - else - { - LUAU_ASSERT(0); - result = builtinTypes->errorRecoveryType(); - } - - module->astResolvedTypes[ty] = result; - return result; -} - -TypePackId ConstraintGraphBuilder::resolveTypePack(const ScopePtr& scope, AstTypePack* tp, bool inTypeArgument, bool replaceErrorWithFresh) -{ - TypePackId result; - if (auto expl = tp->as()) - { - result = resolveTypePack(scope, expl->typeList, inTypeArgument, replaceErrorWithFresh); - } - else if (auto var = tp->as()) - { - TypeId ty = resolveType(scope, var->variadicType, inTypeArgument, replaceErrorWithFresh); - result = arena->addTypePack(TypePackVar{VariadicTypePack{ty}}); - } - else if (auto gen = tp->as()) - { - if (std::optional lookup = scope->lookupPack(gen->genericName.value)) - { - result = *lookup; - } - else - { - reportError(tp->location, UnknownSymbol{gen->genericName.value, UnknownSymbol::Context::Type}); - result = builtinTypes->errorRecoveryTypePack(); - } - } - else - { - LUAU_ASSERT(0); - result = builtinTypes->errorRecoveryTypePack(); - } - - module->astResolvedTypePacks[tp] = result; - return result; -} - -TypePackId ConstraintGraphBuilder::resolveTypePack(const ScopePtr& scope, const AstTypeList& list, bool inTypeArguments, bool replaceErrorWithFresh) -{ - std::vector head; - - for (AstType* headTy : list.types) - { - head.push_back(resolveType(scope, headTy, inTypeArguments, replaceErrorWithFresh)); - } - - std::optional tail = std::nullopt; - if (list.tailType) - { - tail = resolveTypePack(scope, list.tailType, inTypeArguments, replaceErrorWithFresh); - } - - return arena->addTypePack(TypePack{head, tail}); -} - -std::vector> ConstraintGraphBuilder::createGenerics( - const ScopePtr& scope, AstArray generics, bool useCache, bool addTypes) -{ - std::vector> result; - for (const auto& generic : generics) - { - TypeId genericTy = nullptr; - - if (auto it = scope->parent->typeAliasTypeParameters.find(generic.name.value); useCache && it != scope->parent->typeAliasTypeParameters.end()) - genericTy = it->second; - else - { - genericTy = arena->addType(GenericType{scope.get(), generic.name.value}); - scope->parent->typeAliasTypeParameters[generic.name.value] = genericTy; - } - - std::optional defaultTy = std::nullopt; - - if (generic.defaultValue) - defaultTy = resolveType(scope, generic.defaultValue, /* inTypeArguments */ false); - - if (addTypes) - scope->privateTypeBindings[generic.name.value] = TypeFun{genericTy}; - - result.push_back({generic.name.value, GenericTypeDefinition{genericTy, defaultTy}}); - } - - return result; -} - -std::vector> ConstraintGraphBuilder::createGenericPacks( - const ScopePtr& scope, AstArray generics, bool useCache, bool addTypes) -{ - std::vector> result; - for (const auto& generic : generics) - { - TypePackId genericTy; - - if (auto it = scope->parent->typeAliasTypePackParameters.find(generic.name.value); - useCache && it != scope->parent->typeAliasTypePackParameters.end()) - genericTy = it->second; - else - { - genericTy = arena->addTypePack(TypePackVar{GenericTypePack{scope.get(), generic.name.value}}); - scope->parent->typeAliasTypePackParameters[generic.name.value] = genericTy; - } - - std::optional defaultTy = std::nullopt; - - if (generic.defaultValue) - defaultTy = resolveTypePack(scope, generic.defaultValue, /* inTypeArguments */ false); - - if (addTypes) - scope->privateTypePackBindings[generic.name.value] = genericTy; - - result.push_back({generic.name.value, GenericTypePackDefinition{genericTy, defaultTy}}); - } - - return result; -} - -Inference ConstraintGraphBuilder::flattenPack(const ScopePtr& scope, Location location, InferencePack pack) -{ - const auto& [tp, refinements] = pack; - RefinementId refinement = nullptr; - if (!refinements.empty()) - refinement = refinements[0]; - - if (auto f = first(tp)) - return Inference{*f, refinement}; - - TypeId typeResult = arena->addType(BlockedType{}); - TypePackId resultPack = arena->addTypePack({typeResult}, arena->freshTypePack(scope.get())); - addConstraint(scope, location, UnpackConstraint{resultPack, tp}); - - return Inference{typeResult, refinement}; -} - -void ConstraintGraphBuilder::reportError(Location location, TypeErrorData err) -{ - errors.push_back(TypeError{location, module->name, std::move(err)}); - - if (logger) - logger->captureGenerationError(errors.back()); -} - -void ConstraintGraphBuilder::reportCodeTooComplex(Location location) -{ - errors.push_back(TypeError{location, module->name, CodeTooComplex{}}); - - if (logger) - logger->captureGenerationError(errors.back()); -} - -struct GlobalPrepopulator : AstVisitor -{ - const NotNull globalScope; - const NotNull arena; - - GlobalPrepopulator(NotNull globalScope, NotNull arena) - : globalScope(globalScope) - , arena(arena) - { - } - - bool visit(AstStatFunction* function) override - { - if (AstExprGlobal* g = function->name->as()) - globalScope->bindings[g->name] = Binding{arena->addType(BlockedType{})}; - - return true; - } -}; - -void ConstraintGraphBuilder::prepopulateGlobalScope(const ScopePtr& globalScope, AstStatBlock* program) -{ - GlobalPrepopulator gp{NotNull{globalScope.get()}, arena}; - - if (prepareModuleScope) - prepareModuleScope(module->name, globalScope); - - program->visit(&gp); -} - -std::vector> ConstraintGraphBuilder::getExpectedCallTypesForFunctionOverloads(const TypeId fnType) -{ - std::vector funTys; - if (auto it = get(follow(fnType))) - { - for (TypeId intersectionComponent : it) - { - funTys.push_back(intersectionComponent); - } - } - - std::vector> expectedTypes; - // For a list of functions f_0 : e_0 -> r_0, ... f_n : e_n -> r_n, - // emit a list of arguments that the function could take at each position - // by unioning the arguments at each place - auto assignOption = [this, &expectedTypes](size_t index, TypeId ty) { - if (index == expectedTypes.size()) - expectedTypes.push_back(ty); - else if (ty) - { - auto& el = expectedTypes[index]; - - if (!el) - el = ty; - else - { - std::vector result = reduceUnion({*el, ty}); - if (result.empty()) - el = builtinTypes->neverType; - else if (result.size() == 1) - el = result[0]; - else - el = module->internalTypes.addType(UnionType{std::move(result)}); - } - } - }; - - for (const TypeId overload : funTys) - { - if (const FunctionType* ftv = get(follow(overload))) - { - auto [argsHead, argsTail] = flatten(ftv->argTypes); - size_t start = ftv->hasSelf ? 1 : 0; - size_t index = 0; - for (size_t i = start; i < argsHead.size(); ++i) - assignOption(index++, argsHead[i]); - if (argsTail) - { - argsTail = follow(*argsTail); - if (const VariadicTypePack* vtp = get(*argsTail)) - { - while (index < funTys.size()) - assignOption(index++, vtp->ty); - } - } - } - } - - // TODO vvijay Feb 24, 2023 apparently we have to demote the types here? - - return expectedTypes; -} - -std::vector> borrowConstraints(const std::vector& constraints) -{ - std::vector> result; - result.reserve(constraints.size()); - - for (const auto& c : constraints) - result.emplace_back(c.get()); - - return result; -} - -} // namespace Luau diff --git a/third_party/luau/Analysis/src/ConstraintSolver.cpp b/third_party/luau/Analysis/src/ConstraintSolver.cpp index ec63b25e..9cc8ef38 100644 --- a/third_party/luau/Analysis/src/ConstraintSolver.cpp +++ b/third_party/luau/Analysis/src/ConstraintSolver.cpp @@ -1,24 +1,36 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/ConstraintSolver.h" #include "Luau/Anyification.h" #include "Luau/ApplyTypeFunction.h" -#include "Luau/Clone.h" #include "Luau/Common.h" -#include "Luau/ConstraintSolver.h" #include "Luau/DcrLogger.h" +#include "Luau/Generalization.h" #include "Luau/Instantiation.h" +#include "Luau/Instantiation2.h" #include "Luau/Location.h" -#include "Luau/Metamethods.h" #include "Luau/ModuleResolver.h" +#include "Luau/OverloadResolution.h" #include "Luau/Quantify.h" +#include "Luau/RecursionCounter.h" +#include "Luau/Simplify.h" +#include "Luau/TableLiteralInference.h" +#include "Luau/TimeTrace.h" #include "Luau/ToString.h" -#include "Luau/TypeUtils.h" #include "Luau/Type.h" -#include "Luau/Unifier.h" +#include "Luau/TypeFunction.h" +#include "Luau/TypeFwd.h" +#include "Luau/TypeUtils.h" +#include "Luau/Unifier2.h" #include "Luau/VisitType.h" +#include +#include + LUAU_FASTFLAGVARIABLE(DebugLuauLogSolver, false); -LUAU_FASTFLAG(LuauRequirePathTrueModuleName) +LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverIncludeDependencies, false) +LUAU_FASTFLAGVARIABLE(DebugLuauLogBindings, false); +LUAU_FASTINTVARIABLE(LuauSolverRecursionLimit, 500); namespace Luau { @@ -51,8 +63,39 @@ size_t HashBlockedConstraintId::operator()(const BlockedConstraintId& bci) const dumpBindings(child, opts); } -static std::pair, std::vector> saturateArguments(TypeArena* arena, NotNull builtinTypes, - const TypeFun& fn, const std::vector& rawTypeArguments, const std::vector& rawPackArguments) +// used only in asserts +[[maybe_unused]] static bool canMutate(TypeId ty, NotNull constraint) +{ + if (auto blocked = get(ty)) + { + Constraint* owner = blocked->getOwner(); + LUAU_ASSERT(owner); + return owner == constraint; + } + + return true; +} + +// used only in asserts +[[maybe_unused]] static bool canMutate(TypePackId tp, NotNull constraint) +{ + if (auto blocked = get(tp)) + { + Constraint* owner = blocked->owner; + LUAU_ASSERT(owner); + return owner == constraint; + } + + return true; +} + +static std::pair, std::vector> saturateArguments( + TypeArena* arena, + NotNull builtinTypes, + const TypeFun& fn, + const std::vector& rawTypeArguments, + const std::vector& rawPackArguments +) { std::vector saturatedTypeArguments; std::vector extraTypes; @@ -72,7 +115,7 @@ static std::pair, std::vector> saturateArguments // mutually exclusive with the type pack -> type conversion we do below: // extraTypes will only have elements in it if we have more types than we // have parameter slots for them to go into. - if (!extraTypes.empty()) + if (!extraTypes.empty() && !fn.typePackParams.empty()) { saturatedPackArguments.push_back(arena->addTypePack(extraTypes)); } @@ -88,7 +131,7 @@ static std::pair, std::vector> saturateArguments { saturatedTypeArguments.push_back(*first(tp)); } - else + else if (saturatedPackArguments.size() < fn.typePackParams.size()) { saturatedPackArguments.push_back(tp); } @@ -176,6 +219,12 @@ static std::pair, std::vector> saturateArguments saturatedPackArguments.push_back(builtinTypes->errorRecoveryTypePack()); } + for (TypeId& arg : saturatedTypeArguments) + arg = follow(arg); + + for (TypePackId& pack : saturatedPackArguments) + pack = follow(pack); + // At this point, these two conditions should be true. If they aren't we // will run into access violations. LUAU_ASSERT(saturatedTypeArguments.size() == fn.typeParams.size()); @@ -223,11 +272,59 @@ void dump(ConstraintSolver* cs, ToStringOptions& opts) auto it = cs->blockedConstraints.find(c); int blockCount = it == cs->blockedConstraints.end() ? 0 : int(it->second); printf("\t%d\t%s\n", blockCount, toString(*c, opts).c_str()); + + if (FFlag::DebugLuauLogSolverIncludeDependencies) + { + for (NotNull dep : c->dependencies) + { + if (std::find(cs->unsolvedConstraints.begin(), cs->unsolvedConstraints.end(), dep) != cs->unsolvedConstraints.end()) + printf("\t\t|\t%s\n", toString(*dep, opts).c_str()); + } + } } } -ConstraintSolver::ConstraintSolver(NotNull normalizer, NotNull rootScope, std::vector> constraints, - ModuleName moduleName, NotNull moduleResolver, std::vector requireCycles, DcrLogger* logger) +struct InstantiationQueuer : TypeOnceVisitor +{ + ConstraintSolver* solver; + NotNull scope; + Location location; + + explicit InstantiationQueuer(NotNull scope, const Location& location, ConstraintSolver* solver) + : solver(solver) + , scope(scope) + , location(location) + { + } + + bool visit(TypeId ty, const PendingExpansionType& petv) override + { + solver->pushConstraint(scope, location, TypeAliasExpansionConstraint{ty}); + return false; + } + + bool visit(TypeId ty, const TypeFunctionInstanceType&) override + { + solver->pushConstraint(scope, location, ReduceConstraint{ty}); + return true; + } + + bool visit(TypeId ty, const ClassType& ctv) override + { + return false; + } +}; + +ConstraintSolver::ConstraintSolver( + NotNull normalizer, + NotNull rootScope, + std::vector> constraints, + ModuleName moduleName, + NotNull moduleResolver, + std::vector requireCycles, + DcrLogger* logger, + TypeCheckLimits limits +) : arena(normalizer->arena) , builtinTypes(normalizer->builtinTypes) , normalizer(normalizer) @@ -237,6 +334,7 @@ ConstraintSolver::ConstraintSolver(NotNull normalizer, NotNull normalizer, NotNullgetMaybeMutatedFreeTypes()) + { + // increment the reference count for `ty` + auto [refCount, _] = unresolvedConstraints.try_insert(ty, 0); + refCount += 1; + } + for (NotNull dep : c->dependencies) { block(dep, c); @@ -273,12 +379,16 @@ void ConstraintSolver::randomize(unsigned seed) void ConstraintSolver::run() { + LUAU_TIMETRACE_SCOPE("ConstraintSolver::run", "Typechecking"); + if (isDone()) return; if (FFlag::DebugLuauLogSolver) { - printf("Starting solver\n"); + printf( + "Starting solver for module %s (%s)\n", moduleResolver->getHumanReadableModuleName(currentModuleName).c_str(), currentModuleName.c_str() + ); dump(this, opts); printf("Bindings:\n"); dumpBindings(rootScope, opts); @@ -289,7 +399,8 @@ void ConstraintSolver::run() logger->captureInitialSolverState(rootScope, unsolvedConstraints); } - auto runSolverPass = [&](bool force) { + auto runSolverPass = [&](bool force) + { bool progress = false; size_t i = 0; @@ -302,6 +413,11 @@ void ConstraintSolver::run() continue; } + if (limits.finishTime && TimeTrace::getClock() > *limits.finishTime) + throwTimeLimitError(); + if (limits.cancellationToken && limits.cancellationToken->requested()) + throwUserCancelError(); + std::string saveMe = FFlag::DebugLuauLogSolver ? toString(*c, opts) : std::string{}; StepSnapshot snapshot; @@ -319,6 +435,22 @@ void ConstraintSolver::run() unblock(c); unsolvedConstraints.erase(unsolvedConstraints.begin() + i); + // decrement the referenced free types for this constraint if we dispatched successfully! + for (auto ty : c->getMaybeMutatedFreeTypes()) + { + size_t& refCount = unresolvedConstraints[ty]; + if (refCount > 0) + refCount -= 1; + + // We have two constraints that are designed to wait for the + // refCount on a free type to be equal to 1: the + // PrimitiveTypeConstraint and ReduceConstraint. We + // therefore wake any constraint waiting for a free type's + // refcount to be 1 or 0. + if (refCount <= 1) + unblock(ty, Location{}); + } + if (logger) { logger->commitStepSnapshot(snapshot); @@ -371,12 +503,17 @@ void ConstraintSolver::run() progress |= runSolverPass(true); } while (progress); - finalizeModule(); + if (!unsolvedConstraints.empty()) + reportError(ConstraintSolvingIncompleteError{}, Location{}); - if (FFlag::DebugLuauLogSolver) - { + // After we have run all the constraints, type functions should be generalized + // At this point, we can try to perform one final simplification to suss out + // whether type functions are truly uninhabited or if they can reduce + + finalizeTypeFunctions(); + + if (FFlag::DebugLuauLogSolver || FFlag::DebugLuauLogBindings) dumpBindings(rootScope, opts); - } if (logger) { @@ -384,22 +521,89 @@ void ConstraintSolver::run() } } +void ConstraintSolver::finalizeTypeFunctions() +{ + // At this point, we've generalized. Let's try to finish reducing as much as we can, we'll leave warning to the typechecker + for (auto [t, constraint] : typeFunctionsToFinalize) + { + TypeId ty = follow(t); + if (get(ty)) + { + FunctionGraphReductionResult result = + reduceTypeFunctions(t, constraint->location, TypeFunctionContext{NotNull{this}, constraint->scope, NotNull{constraint}}, true); + + for (TypeId r : result.reducedTypes) + unblock(r, constraint->location); + for (TypePackId r : result.reducedPacks) + unblock(r, constraint->location); + } + } +} + bool ConstraintSolver::isDone() { return unsolvedConstraints.empty(); } -void ConstraintSolver::finalizeModule() +namespace { - Anyification a{arena, rootScope, builtinTypes, &iceReporter, builtinTypes->anyType, builtinTypes->anyTypePack}; - std::optional returnType = a.substitute(rootScope->returnType); - if (!returnType) - { - reportError(CodeTooComplex{}, Location{}); - rootScope->returnType = builtinTypes->errorTypePack; - } - else - rootScope->returnType = *returnType; + +struct TypeAndLocation +{ + TypeId typeId; + Location location; +}; + +} // namespace + +void ConstraintSolver::bind(NotNull constraint, TypeId ty, TypeId boundTo) +{ + LUAU_ASSERT(get(ty) || get(ty) || get(ty)); + LUAU_ASSERT(canMutate(ty, constraint)); + + boundTo = follow(boundTo); + if (get(ty) && ty == boundTo) + return emplace(constraint, ty, constraint->scope, builtinTypes->neverType, builtinTypes->unknownType); + + shiftReferences(ty, boundTo); + emplaceType(asMutable(ty), boundTo); + unblock(ty, constraint->location); +} + +void ConstraintSolver::bind(NotNull constraint, TypePackId tp, TypePackId boundTo) +{ + LUAU_ASSERT(get(tp) || get(tp)); + LUAU_ASSERT(canMutate(tp, constraint)); + + boundTo = follow(boundTo); + LUAU_ASSERT(tp != boundTo); + + emplaceTypePack(asMutable(tp), boundTo); + unblock(tp, constraint->location); +} + +template +void ConstraintSolver::emplace(NotNull constraint, TypeId ty, Args&&... args) +{ + static_assert(!std::is_same_v, "cannot use `emplace`! use `bind`"); + + LUAU_ASSERT(get(ty) || get(ty) || get(ty)); + LUAU_ASSERT(canMutate(ty, constraint)); + + emplaceType(asMutable(ty), std::forward(args)...); + unblock(ty, constraint->location); +} + +template +void ConstraintSolver::emplace(NotNull constraint, TypePackId tp, Args&&... args) +{ + static_assert(!std::is_same_v, "cannot use `emplace`! use `bind`"); + + LUAU_ASSERT(get(tp) || get(tp)); + LUAU_ASSERT(canMutate(tp, constraint)); + + emplaceTypePack(asMutable(tp), std::forward(args)...); + unblock(tp, constraint->location); } bool ConstraintSolver::tryDispatch(NotNull constraint, bool force) @@ -415,12 +619,6 @@ bool ConstraintSolver::tryDispatch(NotNull constraint, bool fo success = tryDispatch(*psc, constraint, force); else if (auto gc = get(*constraint)) success = tryDispatch(*gc, constraint, force); - else if (auto ic = get(*constraint)) - success = tryDispatch(*ic, constraint, force); - else if (auto uc = get(*constraint)) - success = tryDispatch(*uc, constraint, force); - else if (auto bc = get(*constraint)) - success = tryDispatch(*bc, constraint, force); else if (auto ic = get(*constraint)) success = tryDispatch(*ic, constraint, force); else if (auto nc = get(*constraint)) @@ -429,24 +627,29 @@ bool ConstraintSolver::tryDispatch(NotNull constraint, bool fo success = tryDispatch(*taec, constraint); else if (auto fcc = get(*constraint)) success = tryDispatch(*fcc, constraint); + else if (auto fcc = get(*constraint)) + success = tryDispatch(*fcc, constraint); else if (auto fcc = get(*constraint)) success = tryDispatch(*fcc, constraint); else if (auto hpc = get(*constraint)) success = tryDispatch(*hpc, constraint); - else if (auto spc = get(*constraint)) - success = tryDispatch(*spc, constraint, force); - else if (auto spc = get(*constraint)) - success = tryDispatch(*spc, constraint, force); - else if (auto sottc = get(*constraint)) - success = tryDispatch(*sottc, constraint); + else if (auto spc = get(*constraint)) + success = tryDispatch(*spc, constraint); + else if (auto uc = get(*constraint)) + success = tryDispatch(*uc, constraint); + else if (auto uc = get(*constraint)) + success = tryDispatch(*uc, constraint); else if (auto uc = get(*constraint)) success = tryDispatch(*uc, constraint); + else if (auto rc = get(*constraint)) + success = tryDispatch(*rc, constraint, force); + else if (auto rpc = get(*constraint)) + success = tryDispatch(*rpc, constraint, force); + else if (auto eqc = get(*constraint)) + success = tryDispatch(*eqc, constraint, force); else LUAU_ASSERT(false); - if (success) - unblock(constraint); - return success; } @@ -457,7 +660,9 @@ bool ConstraintSolver::tryDispatch(const SubtypeConstraint& c, NotNull constraint, bool force) @@ -467,7 +672,9 @@ bool ConstraintSolver::tryDispatch(const PackSubtypeConstraint& c, NotNull constraint, bool force) @@ -479,319 +686,35 @@ bool ConstraintSolver::tryDispatch(const GeneralizationConstraint& c, NotNull(generalizedType)) return block(generalizedType, constraint); - std::optional generalized = quantify(arena, c.sourceType, constraint->scope); - if (generalized) - { - if (get(generalizedType)) - asMutable(generalizedType)->ty.emplace(*generalized); - else - unify(generalizedType, *generalized, constraint->scope); - } + std::optional generalized; + + std::optional generalizedTy = generalize(NotNull{arena}, builtinTypes, constraint->scope, generalizedTypes, c.sourceType); + if (generalizedTy) + generalized = QuantifierResult{*generalizedTy}; // FIXME insertedGenerics and insertedGenericPacks else - { reportError(CodeTooComplex{}, constraint->location); - asMutable(c.generalizedType)->ty.emplace(builtinTypes->errorRecoveryType()); - } - - unblock(c.generalizedType); - unblock(c.sourceType); - - return true; -} - -bool ConstraintSolver::tryDispatch(const InstantiationConstraint& c, NotNull constraint, bool force) -{ - if (isBlocked(c.superType)) - return block(c.superType, constraint); - Instantiation inst(TxnLog::empty(), arena, TypeLevel{}, constraint->scope); - - std::optional instantiated = inst.substitute(c.superType); - LUAU_ASSERT(instantiated); // TODO FIXME HANDLE THIS - - LUAU_ASSERT(get(c.subType)); - asMutable(c.subType)->ty.emplace(*instantiated); - - unblock(c.subType); - - return true; -} - -bool ConstraintSolver::tryDispatch(const UnaryConstraint& c, NotNull constraint, bool force) -{ - TypeId operandType = follow(c.operandType); - - if (isBlocked(operandType)) - return block(operandType, constraint); - - if (get(operandType)) - return block(operandType, constraint); - - LUAU_ASSERT(get(c.resultType)); - - switch (c.op) - { - case AstExprUnary::Not: - { - asMutable(c.resultType)->ty.emplace(builtinTypes->booleanType); - - unblock(c.resultType); - return true; - } - case AstExprUnary::Len: - { - // __len must return a number. - asMutable(c.resultType)->ty.emplace(builtinTypes->numberType); - - unblock(c.resultType); - return true; - } - case AstExprUnary::Minus: + if (generalized) { - if (isNumber(operandType) || get(operandType) || get(operandType) || get(operandType)) - { - asMutable(c.resultType)->ty.emplace(c.operandType); - } - else if (std::optional mm = findMetatableEntry(builtinTypes, errors, operandType, "__unm", constraint->location)) - { - TypeId mmTy = follow(*mm); - - if (get(mmTy) && !force) - return block(mmTy, constraint); - - TypePackId argPack = arena->addTypePack(TypePack{{operandType}, {}}); - TypePackId retPack = arena->addTypePack(BlockedTypePack{}); - - asMutable(c.resultType)->ty.emplace(constraint->scope); - - pushConstraint(constraint->scope, constraint->location, PackSubtypeConstraint{retPack, arena->addTypePack(TypePack{{c.resultType}})}); - - pushConstraint(constraint->scope, constraint->location, FunctionCallConstraint{mmTy, argPack, retPack, nullptr}); - } + if (get(generalizedType)) + bind(constraint, generalizedType, generalized->result); else - { - asMutable(c.resultType)->ty.emplace(builtinTypes->errorRecoveryType()); - } - - unblock(c.resultType); - return true; - } - } - - LUAU_ASSERT(false); - return false; -} - -bool ConstraintSolver::tryDispatch(const BinaryConstraint& c, NotNull constraint, bool force) -{ - TypeId leftType = follow(c.leftType); - TypeId rightType = follow(c.rightType); - TypeId resultType = follow(c.resultType); - - LUAU_ASSERT(get(resultType)); - - bool isLogical = c.op == AstExprBinary::Op::And || c.op == AstExprBinary::Op::Or; - - /* Compound assignments create constraints of the form - * - * A <: Binary - * - * This constraint is the one that is meant to unblock A, so it doesn't - * make any sense to stop and wait for someone else to do it. - */ - - // If any is present, the expression must evaluate to any as well. - bool leftAny = get(leftType) || get(leftType); - bool rightAny = get(rightType) || get(rightType); - bool anyPresent = leftAny || rightAny; - - if (isBlocked(leftType) && leftType != resultType) - return block(c.leftType, constraint); - - if (isBlocked(rightType) && rightType != resultType) - return block(c.rightType, constraint); - - if (!force) - { - // Logical expressions may proceed if the LHS is free. - if (hasTypeInIntersection(leftType) && !isLogical) - return block(leftType, constraint); - } - - // Logical expressions may proceed if the LHS is free. - if (isBlocked(leftType) || (hasTypeInIntersection(leftType) && !isLogical)) - { - asMutable(resultType)->ty.emplace(errorRecoveryType()); - unblock(resultType); - return true; - } - - // Metatables go first, even if there is primitive behavior. - if (auto it = kBinaryOpMetamethods.find(c.op); it != kBinaryOpMetamethods.end()) - { - // Metatables are not the same. The metamethod will not be invoked. - if ((c.op == AstExprBinary::Op::CompareEq || c.op == AstExprBinary::Op::CompareNe) && - getMetatable(leftType, builtinTypes) != getMetatable(rightType, builtinTypes)) - { - // TODO: Boolean singleton false? The result is _always_ boolean false. - asMutable(resultType)->ty.emplace(builtinTypes->booleanType); - unblock(resultType); - return true; - } - - std::optional mm; - - // The LHS metatable takes priority over the RHS metatable, where - // present. - if (std::optional leftMm = findMetatableEntry(builtinTypes, errors, leftType, it->second, constraint->location)) - mm = leftMm; - else if (std::optional rightMm = findMetatableEntry(builtinTypes, errors, rightType, it->second, constraint->location)) - mm = rightMm; - - if (mm) - { - Instantiation instantiation{TxnLog::empty(), arena, TypeLevel{}, constraint->scope}; - std::optional instantiatedMm = instantiation.substitute(*mm); - if (!instantiatedMm) - { - reportError(CodeTooComplex{}, constraint->location); - return true; - } - - // TODO: Is a table with __call legal here? - // TODO: Overloads - if (const FunctionType* ftv = get(follow(*instantiatedMm))) - { - TypePackId inferredArgs; - // For >= and > we invoke __lt and __le respectively with - // swapped argument ordering. - if (c.op == AstExprBinary::Op::CompareGe || c.op == AstExprBinary::Op::CompareGt) - { - inferredArgs = arena->addTypePack({rightType, leftType}); - } - else - { - inferredArgs = arena->addTypePack({leftType, rightType}); - } - - unify(inferredArgs, ftv->argTypes, constraint->scope); - - TypeId mmResult; - - // Comparison operations always evaluate to a boolean, - // regardless of what the metamethod returns. - switch (c.op) - { - case AstExprBinary::Op::CompareEq: - case AstExprBinary::Op::CompareNe: - case AstExprBinary::Op::CompareGe: - case AstExprBinary::Op::CompareGt: - case AstExprBinary::Op::CompareLe: - case AstExprBinary::Op::CompareLt: - mmResult = builtinTypes->booleanType; - break; - default: - mmResult = first(ftv->retTypes).value_or(errorRecoveryType()); - } - - asMutable(resultType)->ty.emplace(mmResult); - unblock(resultType); - - (*c.astOriginalCallTypes)[c.astFragment] = *mm; - (*c.astOverloadResolvedTypes)[c.astFragment] = *instantiatedMm; - return true; - } - } - - // If there's no metamethod available, fall back to primitive behavior. - } - - switch (c.op) - { - // For arithmetic operators, if the LHS is a number, the RHS must be a - // number as well. The result will also be a number. - case AstExprBinary::Op::Add: - case AstExprBinary::Op::Sub: - case AstExprBinary::Op::Mul: - case AstExprBinary::Op::Div: - case AstExprBinary::Op::Pow: - case AstExprBinary::Op::Mod: - if (hasTypeInIntersection(leftType) && force) - asMutable(leftType)->ty.emplace(anyPresent ? builtinTypes->anyType : builtinTypes->numberType); - if (isNumber(leftType)) - { - unify(leftType, rightType, constraint->scope); - asMutable(resultType)->ty.emplace(anyPresent ? builtinTypes->anyType : leftType); - unblock(resultType); - return true; - } + unify(constraint, generalizedType, generalized->result); - break; - // For concatenation, if the LHS is a string, the RHS must be a string as - // well. The result will also be a string. - case AstExprBinary::Op::Concat: - if (hasTypeInIntersection(leftType) && force) - asMutable(leftType)->ty.emplace(anyPresent ? builtinTypes->anyType : builtinTypes->stringType); - if (isString(leftType)) - { - unify(leftType, rightType, constraint->scope); - asMutable(resultType)->ty.emplace(anyPresent ? builtinTypes->anyType : leftType); - unblock(resultType); - return true; - } - - break; - // Inexact comparisons require that the types be both numbers or both - // strings, and evaluate to a boolean. - case AstExprBinary::Op::CompareGe: - case AstExprBinary::Op::CompareGt: - case AstExprBinary::Op::CompareLe: - case AstExprBinary::Op::CompareLt: - if ((isNumber(leftType) && isNumber(rightType)) || (isString(leftType) && isString(rightType))) - { - asMutable(resultType)->ty.emplace(builtinTypes->booleanType); - unblock(resultType); - return true; - } - - break; - // == and ~= always evaluate to a boolean, and impose no other constraints - // on their parameters. - case AstExprBinary::Op::CompareEq: - case AstExprBinary::Op::CompareNe: - asMutable(resultType)->ty.emplace(builtinTypes->booleanType); - unblock(resultType); - return true; - // And evalutes to a boolean if the LHS is falsey, and the RHS type if LHS is - // truthy. - case AstExprBinary::Op::And: - { - TypeId leftFilteredTy = arena->addType(IntersectionType{{builtinTypes->falsyType, leftType}}); + for (auto [free, gen] : generalized->insertedGenerics.pairings) + unify(constraint, free, gen); - asMutable(resultType)->ty.emplace(arena->addType(UnionType{{leftFilteredTy, rightType}})); - unblock(resultType); - return true; + for (auto [free, gen] : generalized->insertedGenericPacks.pairings) + unify(constraint, free, gen); } - // Or evaluates to the LHS type if the LHS is truthy, and the RHS type if - // LHS is falsey. - case AstExprBinary::Op::Or: + else { - TypeId leftFilteredTy = arena->addType(IntersectionType{{builtinTypes->truthyType, leftType}}); - - asMutable(resultType)->ty.emplace(arena->addType(UnionType{{leftFilteredTy, rightType}})); - unblock(resultType); - return true; - } - default: - iceReporter.ice("Unhandled AstExprBinary::Op for binary operation", constraint->location); - break; + reportError(CodeTooComplex{}, constraint->location); + bind(constraint, c.generalizedType, builtinTypes->errorRecoveryType()); } - // We failed to either evaluate a metamethod or invoke primitive behavior. - unify(leftType, errorRecoveryType(), constraint->scope); - unify(rightType, errorRecoveryType(), constraint->scope); - asMutable(resultType)->ty.emplace(errorRecoveryType()); - unblock(resultType); + for (TypeId ty : c.interiorTypes) + generalize(NotNull{arena}, builtinTypes, constraint->scope, generalizedTypes, ty, /* avoidSealingTables */ false); return true; } @@ -815,14 +738,15 @@ bool ConstraintSolver::tryDispatch(const IterableConstraint& c, NotNullscope, builtinTypes, &iceReporter, errorRecoveryType(), errorRecoveryTypePack()}; - std::optional anyified = anyify.substitute(c.variables); - LUAU_ASSERT(anyified); - unify(*anyified, c.variables, constraint->scope); + for (TypeId ty : c.variables) + unify(constraint, builtinTypes->errorRecoveryType(), ty); return true; } - TypeId nextTy = follow(iteratorTypes[0]); + TypeId nextTy = follow(iterator.head[0]); if (get(nextTy)) - return block_(nextTy); - - if (get(nextTy)) { - TypeId tableTy = builtinTypes->nilType; - if (iteratorTypes.size() >= 2) - tableTy = iteratorTypes[1]; + TypeId keyTy = freshType(arena, builtinTypes, constraint->scope); + TypeId valueTy = freshType(arena, builtinTypes, constraint->scope); + TypeId tableTy = + arena->addType(TableType{TableType::Props{}, TableIndexer{keyTy, valueTy}, TypeLevel{}, constraint->scope, TableState::Free}); - TypeId firstIndexTy = builtinTypes->nilType; - if (iteratorTypes.size() >= 3) - firstIndexTy = iteratorTypes[2]; + unify(constraint, nextTy, tableTy); - return tryDispatchIterableFunction(nextTy, tableTy, firstIndexTy, c, constraint, force); - } - - else - return tryDispatchIterableTable(iteratorTypes[0], c, constraint, force); + auto it = begin(c.variables); + auto endIt = end(c.variables); - return true; -} + if (it != endIt) + { + bind(constraint, *it, keyTy); + ++it; + } + if (it != endIt) + { + bind(constraint, *it, valueTy); + ++it; + } -bool ConstraintSolver::tryDispatch(const NameConstraint& c, NotNull constraint) + while (it != endIt) + { + bind(constraint, *it, builtinTypes->nilType); + ++it; + } + + return true; + } + + if (get(nextTy)) + { + TypeId tableTy = builtinTypes->nilType; + if (iterator.head.size() >= 2) + tableTy = iterator.head[1]; + + return tryDispatchIterableFunction(nextTy, tableTy, c, constraint, force); + } + + else + return tryDispatchIterableTable(iterator.head[0], c, constraint, force); + + return true; +} + +bool ConstraintSolver::tryDispatch(const NameConstraint& c, NotNull constraint) { if (isBlocked(c.namedType)) return block(c.namedType, constraint); @@ -913,8 +860,6 @@ bool ConstraintSolver::tryDispatch(const NameConstraint& c, NotNull scope; - Location location; - - explicit InstantiationQueuer(NotNull scope, const Location& location, ConstraintSolver* solver) - : solver(solver) - , scope(scope) - , location(location) - { - } - - bool visit(TypeId ty, const PendingExpansionType& petv) override - { - solver->pushConstraint(scope, location, TypeAliasExpansionConstraint{ty}); - return false; - } -}; - bool ConstraintSolver::tryDispatch(const TypeAliasExpansionConstraint& c, NotNull constraint) { const PendingExpansionType* petv = get(follow(c.target)); if (!petv) { - unblock(c.target); + unblock(c.target, constraint->location); // TODO: do we need this? any re-entrancy? return true; } - auto bindResult = [this, &c](TypeId result) { + auto bindResult = [this, &c, constraint](TypeId result) + { LUAU_ASSERT(get(c.target)); - asMutable(c.target)->ty.emplace(result); - unblock(c.target); + shiftReferences(c.target, result); + bind(constraint, c.target, result); }; std::optional tf = (petv->prefix) ? constraint->scope->lookupImportedType(petv->prefix->value, petv->name.value) @@ -1006,16 +932,41 @@ bool ConstraintSolver::tryDispatch(const TypeAliasExpansionConstraint& c, NotNul return true; } - auto [typeArguments, packArguments] = saturateArguments(arena, builtinTypes, *tf, petv->typeArguments, petv->packArguments); + // Due to how pending expansion types and TypeFun's are created + // If this check passes, we have created a cyclic / corecursive type alias + // of size 0 + TypeId lhs = c.target; + TypeId rhs = tf->type; + if (occursCheck(lhs, rhs)) + { + reportError(OccursCheckFailed{}, constraint->location); + bindResult(errorRecoveryType()); + return true; + } - bool sameTypes = std::equal(typeArguments.begin(), typeArguments.end(), tf->typeParams.begin(), tf->typeParams.end(), [](auto&& itp, auto&& p) { - return itp == p.ty; - }); + auto [typeArguments, packArguments] = saturateArguments(arena, builtinTypes, *tf, petv->typeArguments, petv->packArguments); - bool samePacks = - std::equal(packArguments.begin(), packArguments.end(), tf->typePackParams.begin(), tf->typePackParams.end(), [](auto&& itp, auto&& p) { + bool sameTypes = std::equal( + typeArguments.begin(), + typeArguments.end(), + tf->typeParams.begin(), + tf->typeParams.end(), + [](auto&& itp, auto&& p) + { + return itp == p.ty; + } + ); + + bool samePacks = std::equal( + packArguments.begin(), + packArguments.end(), + tf->typePackParams.begin(), + tf->typePackParams.end(), + [](auto&& itp, auto&& p) + { return itp == p.tp; - }); + } + ); // If we're instantiating the type with its generic saturatedTypeArguments we are // performing the identity substitution. We can just short-circuit and bind @@ -1044,9 +995,9 @@ bool ConstraintSolver::tryDispatch(const TypeAliasExpansionConstraint& c, NotNul // In order to prevent infinite types from being expanded and causing us to // cycle infinitely, we need to scan the type function for cases where we // expand the same alias with different type saturatedTypeArguments. See - // https://github.com/Roblox/luau/pull/68 for the RFC responsible for this. - // This is a little nicer than using a recursion limit because we can catch - // the infinite expansion before actually trying to expand it. + // https://github.com/luau-lang/luau/pull/68 for the RFC responsible for + // this. This is a little nicer than using a recursion limit because we can + // catch the infinite expansion before actually trying to expand it. InfiniteTypeFinder itf{this, signature, constraint->scope}; itf.traverse(tf->type); @@ -1098,7 +1049,20 @@ bool ConstraintSolver::tryDispatch(const TypeAliasExpansionConstraint& c, NotNul // Type function application will happily give us the exact same type if // there are e.g. generic saturatedTypeArguments that go unused. - bool needsClone = follow(tf->type) == target; + const TableType* tfTable = getTableType(tf->type); + + //clang-format off + bool needsClone = follow(tf->type) == target || (tfTable != nullptr && tfTable == getTableType(target)) || + std::any_of( + typeArguments.begin(), + typeArguments.end(), + [&](const auto& other) + { + return other == target; + } + ); + //clang-format on + // Only tables have the properties we're trying to set. TableType* ttv = getMutableTableType(target); @@ -1145,12 +1109,54 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNull std::optional { + if (get(fn)) + { + emplaceTypePack(asMutable(c.result), builtinTypes->anyTypePack); + unblock(c.result, constraint->location); + return true; + } + + // if we're calling an error type, the result is an error type, and that's that. + if (get(fn)) + { + bind(constraint, c.result, builtinTypes->errorRecoveryTypePack()); + return true; + } + + if (get(fn)) + { + bind(constraint, c.result, builtinTypes->neverTypePack); + return true; + } + + auto [argsHead, argsTail] = flatten(argsPack); + + bool blocked = false; + for (TypeId t : argsHead) + { + if (isBlocked(t)) + { + block(t, constraint); + blocked = true; + } + } + + if (argsTail && isBlocked(*argsTail)) + { + block(*argsTail, constraint); + blocked = true; + } + + if (blocked) + return false; + + auto collapse = [](const auto* t) -> std::optional + { auto it = begin(t); auto endIt = end(t); @@ -1175,14 +1181,17 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNull callMm = findMetatableEntry(builtinTypes, errors, fn, "__call", constraint->location)) { - std::vector args{fn}; + if (isBlocked(*callMm)) + return block(*callMm, constraint); - for (TypeId arg : c.argsPack) - args.push_back(arg); + argsHead.insert(argsHead.begin(), fn); - argsPack = arena->addTypePack(TypePack{args, {}}); - fn = *callMm; - asMutable(c.result)->ty.emplace(constraint->scope); + if (argsTail && isBlocked(*argsTail)) + return block(*argsTail, constraint); + + argsPack = arena->addTypePack(TypePack{std::move(argsHead), argsTail}); + fn = follow(*callMm); + emplace(constraint, c.result, constraint->scope); } else { @@ -1192,21 +1201,28 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNulldcrMagicFunction) - usedMagic = ftv->dcrMagicFunction(MagicFunctionCallContext{NotNull(this), c.callSite, c.argsPack, result}); + usedMagic = ftv->dcrMagicFunction(MagicFunctionCallContext{NotNull{this}, constraint, c.callSite, c.argsPack, result}); if (ftv->dcrMagicRefinement) ftv->dcrMagicRefinement(MagicRefinementContext{constraint->scope, c.callSite, c.discriminantTypes}); } if (!usedMagic) - asMutable(c.result)->ty.emplace(constraint->scope); + emplace(constraint, c.result, constraint->scope); } for (std::optional ty : c.discriminantTypes) { - if (!ty || !isBlocked(*ty)) + if (!ty) continue; + // If the discriminant type has been transmuted, we need to unblock them. + if (!isBlocked(*ty)) + { + unblock(*ty, constraint->location); + continue; + } + // We use `any` here because the discriminant type may be pointed at by both branches, // where the discriminant type is not negated, and the other where it is negated, i.e. // `unknown ~ unknown` and `~unknown ~ never`, so `T & unknown ~ T` and `T & ~unknown ~ never` @@ -1214,105 +1230,223 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNullanyType}; + emplaceType(asMutable(follow(*ty)), builtinTypes->anyType); } + OverloadResolver resolver{ + builtinTypes, NotNull{arena}, normalizer, constraint->scope, NotNull{&iceReporter}, NotNull{&limits}, constraint->location + }; + auto [status, overload] = resolver.selectOverload(fn, argsPack); + TypeId overloadToUse = fn; + if (status == OverloadResolver::Analysis::Ok) + overloadToUse = overload; + TypeId inferredTy = arena->addType(FunctionType{TypeLevel{}, constraint->scope.get(), argsPack, c.result}); + Unifier2 u2{NotNull{arena}, builtinTypes, constraint->scope, NotNull{&iceReporter}}; - std::vector overloads = flattenIntersection(fn); + const bool occursCheckPassed = u2.unify(overloadToUse, inferredTy); - Instantiation inst(TxnLog::empty(), arena, TypeLevel{}, constraint->scope); + if (!u2.genericSubstitutions.empty() || !u2.genericPackSubstitutions.empty()) + { + std::optional subst = instantiate2(arena, std::move(u2.genericSubstitutions), std::move(u2.genericPackSubstitutions), result); + if (!subst) + { + reportError(CodeTooComplex{}, constraint->location); + result = builtinTypes->errorTypePack; + } + else + result = *subst; + + if (c.result != result) + emplaceTypePack(asMutable(c.result), result); + } - for (TypeId overload : overloads) + for (const auto& [expanded, additions] : u2.expandedFreeTypes) { - overload = follow(overload); + for (TypeId addition : additions) + upperBoundContributors[expanded].push_back(std::make_pair(constraint->location, addition)); + } + + if (occursCheckPassed && c.callSite) + (*c.astOverloadResolvedTypes)[c.callSite] = inferredTy; - std::optional instantiated = inst.substitute(overload); - LUAU_ASSERT(instantiated); // TODO FIXME HANDLE THIS + InstantiationQueuer queuer{constraint->scope, constraint->location, this}; + queuer.traverse(overloadToUse); + queuer.traverse(inferredTy); - Unifier u{normalizer, Mode::Strict, constraint->scope, Location{}, Covariant}; - u.useScopes = true; + unblock(c.result, constraint->location); - u.tryUnify(*instantiated, inferredTy, /* isFunctionCall */ true); + return true; +} - if (!u.blockedTypes.empty() || !u.blockedTypePacks.empty()) - { - for (TypeId bt : u.blockedTypes) - block(bt, constraint); - for (TypePackId btp : u.blockedTypePacks) - block(btp, constraint); - return false; - } +bool ConstraintSolver::tryDispatch(const FunctionCheckConstraint& c, NotNull constraint) +{ + TypeId fn = follow(c.fn); + const TypePackId argsPack = follow(c.argsPack); + + if (isBlocked(fn)) + return block(fn, constraint); + + if (isBlocked(argsPack)) + return true; + + // We know the type of the function and the arguments it expects to receive. + // We also know the TypeIds of the actual arguments that will be passed. + // + // Bidirectional type checking: Force those TypeIds to be the expected + // arguments. If something is incoherent, we'll spot it in type checking. + // + // Most important detail: If a function argument is a lambda, we also want + // to force unannotated argument types of that lambda to be the expected + // types. + + // FIXME: Bidirectional type checking of overloaded functions is not yet supported. + const FunctionType* ftv = get(fn); + if (!ftv) + return true; - if (const auto& e = hasUnificationTooComplex(u.errors)) - reportError(*e); + DenseHashMap replacements{nullptr}; + DenseHashMap replacementPacks{nullptr}; - if (u.errors.empty()) + for (auto generic : ftv->generics) + replacements[generic] = builtinTypes->unknownType; + + for (auto genericPack : ftv->genericPacks) + replacementPacks[genericPack] = builtinTypes->unknownTypePack; + + // If the type of the function has generics, we don't actually want to push any of the generics themselves + // into the argument types as expected types because this creates an unnecessary loop. Instead, we want to + // replace these types with `unknown` (and `...unknown`) to keep any structure but not create the cycle. + if (!replacements.empty() || !replacementPacks.empty()) + { + Replacer replacer{arena, std::move(replacements), std::move(replacementPacks)}; + + std::optional res = replacer.substitute(fn); + if (res) { - // We found a matching overload. - const auto [changedTypes, changedPacks] = u.log.getChanges(); - u.log.commit(); - unblock(changedTypes); - unblock(changedPacks); + if (*res != fn) + { + FunctionType* ftvMut = getMutable(*res); + LUAU_ASSERT(ftvMut); + ftvMut->generics.clear(); + ftvMut->genericPacks.clear(); + } - unblock(c.result); - return true; + fn = *res; + ftv = get(*res); + LUAU_ASSERT(ftv); + + // we've potentially copied type functions here, so we need to reproduce their reduce constraint. + reproduceConstraints(constraint->scope, constraint->location, replacer); } } - // We found no matching overloads. - Unifier u{normalizer, Mode::Strict, constraint->scope, Location{}, Covariant}; - u.useScopes = true; + const std::vector expectedArgs = flatten(ftv->argTypes).first; + const std::vector argPackHead = flatten(argsPack).first; - u.tryUnify(inferredTy, builtinTypes->anyType); - u.tryUnify(fn, builtinTypes->anyType); + // If this is a self call, the types will have more elements than the AST call. + // We don't attempt to perform bidirectional inference on the self type. + const size_t typeOffset = c.callSite->self ? 1 : 0; - LUAU_ASSERT(u.errors.empty()); // unifying with any should never fail + for (size_t i = 0; i < c.callSite->args.size && i + typeOffset < expectedArgs.size() && i + typeOffset < argPackHead.size(); ++i) + { + const TypeId expectedArgTy = follow(expectedArgs[i + typeOffset]); + const TypeId actualArgTy = follow(argPackHead[i + typeOffset]); + const AstExpr* expr = c.callSite->args.data[i]; + + (*c.astExpectedTypes)[expr] = expectedArgTy; - const auto [changedTypes, changedPacks] = u.log.getChanges(); - u.log.commit(); + const FunctionType* expectedLambdaTy = get(expectedArgTy); + const FunctionType* lambdaTy = get(actualArgTy); + const AstExprFunction* lambdaExpr = expr->as(); - unblock(changedTypes); - unblock(changedPacks); + if (expectedLambdaTy && lambdaTy && lambdaExpr) + { + const std::vector expectedLambdaArgTys = flatten(expectedLambdaTy->argTypes).first; + const std::vector lambdaArgTys = flatten(lambdaTy->argTypes).first; + + for (size_t j = 0; j < expectedLambdaArgTys.size() && j < lambdaArgTys.size() && j < lambdaExpr->args.size; ++j) + { + if (!lambdaExpr->args.data[j]->annotation && get(follow(lambdaArgTys[j]))) + { + shiftReferences(lambdaArgTys[j], expectedLambdaArgTys[j]); + bind(constraint, lambdaArgTys[j], expectedLambdaArgTys[j]); + } + } + } + else if (expr->is() || expr->is() || expr->is() || expr->is()) + { + Unifier2 u2{arena, builtinTypes, constraint->scope, NotNull{&iceReporter}}; + u2.unify(actualArgTy, expectedArgTy); + } + else if (expr->is()) + { + Unifier2 u2{arena, builtinTypes, constraint->scope, NotNull{&iceReporter}}; + std::vector toBlock; + (void)matchLiteralType(c.astTypes, c.astExpectedTypes, builtinTypes, arena, NotNull{&u2}, expectedArgTy, actualArgTy, expr, toBlock); + for (auto t : toBlock) + block(t, constraint); + if (!toBlock.empty()) + return false; + } + } - unblock(c.result); return true; } bool ConstraintSolver::tryDispatch(const PrimitiveTypeConstraint& c, NotNull constraint) { - TypeId expectedType = follow(c.expectedType); - if (isBlocked(expectedType) || get(expectedType)) - return block(expectedType, constraint); + std::optional expectedType = c.expectedType ? std::make_optional(follow(*c.expectedType)) : std::nullopt; + if (expectedType && (isBlocked(*expectedType) || get(*expectedType))) + return block(*expectedType, constraint); + + const FreeType* freeType = get(follow(c.freeType)); + + // if this is no longer a free type, then we're done. + if (!freeType) + return true; + + // We will wait if there are any other references to the free type mentioned here. + // This is probably the only thing that makes this not insane to do. + if (auto refCount = unresolvedConstraints.find(c.freeType); refCount && *refCount > 1) + { + block(c.freeType, constraint); + return false; + } - LUAU_ASSERT(get(c.resultType)); + TypeId bindTo = c.primitiveType; - TypeId bindTo = maybeSingleton(expectedType) ? c.singletonType : c.multitonType; - asMutable(c.resultType)->ty.emplace(bindTo); - unblock(c.resultType); + if (freeType->upperBound != c.primitiveType && maybeSingleton(freeType->upperBound)) + bindTo = freeType->lowerBound; + else if (expectedType && maybeSingleton(*expectedType)) + bindTo = freeType->lowerBound; + + shiftReferences(c.freeType, bindTo); + bind(constraint, c.freeType, bindTo); return true; } bool ConstraintSolver::tryDispatch(const HasPropConstraint& c, NotNull constraint) { - TypeId subjectType = follow(c.subjectType); + const TypeId subjectType = follow(c.subjectType); + const TypeId resultType = follow(c.resultType); - LUAU_ASSERT(get(c.resultType)); + LUAU_ASSERT(get(resultType)); + LUAU_ASSERT(canMutate(resultType, constraint)); - if (isBlocked(subjectType) || get(subjectType)) + if (isBlocked(subjectType) || get(subjectType) || get(subjectType)) return block(subjectType, constraint); - if (get(subjectType)) + if (const TableType* subjectTable = getTableType(subjectType)) { - TableType& ttv = asMutable(subjectType)->ty.emplace(TableState::Free, TypeLevel{}, constraint->scope); - ttv.props[c.prop] = Property{c.resultType}; - asMutable(c.resultType)->ty.emplace(constraint->scope); - unblock(c.resultType); - return true; + if (subjectTable->state == TableState::Unsealed && subjectTable->remainingProps > 0 && subjectTable->props.count(c.prop) == 0) + { + return block(subjectType, constraint); + } } - auto [blocked, result] = lookupTableProp(subjectType, c.prop); + auto [blocked, result] = lookupTableProp(constraint, subjectType, c.prop, c.context, c.inConditional, c.suppressSimplification); if (!blocked.empty()) { for (TypeId blocked : blocked) @@ -1321,253 +1455,487 @@ bool ConstraintSolver::tryDispatch(const HasPropConstraint& c, NotNullty.emplace(result.value_or(builtinTypes->anyType)); - unblock(c.resultType); + bind(constraint, resultType, result.value_or(builtinTypes->anyType)); return true; } -static bool isUnsealedTable(TypeId ty) +bool ConstraintSolver::tryDispatchHasIndexer( + int& recursionDepth, + NotNull constraint, + TypeId subjectType, + TypeId indexType, + TypeId resultType, + Set& seen +) { - ty = follow(ty); - const TableType* ttv = get(ty); - return ttv && ttv->state == TableState::Unsealed; -} + RecursionLimiter _rl{&recursionDepth, FInt::LuauSolverRecursionLimit}; -/** - * Given a path into a set of nested unsealed tables `ty`, insert a new property `replaceTy` as the leaf-most property. - * - * Fails and does nothing if every table along the way is not unsealed. - * - * Mutates the innermost table type in-place. - */ -static void updateTheTableType( - NotNull builtinTypes, NotNull arena, TypeId ty, const std::vector& path, TypeId replaceTy) -{ - if (path.empty()) - return; + subjectType = follow(subjectType); + indexType = follow(indexType); - // First walk the path and ensure that it's unsealed tables all the way - // to the end. + if (seen.contains(subjectType)) + return false; + seen.insert(subjectType); + + LUAU_ASSERT(get(resultType)); + LUAU_ASSERT(canMutate(resultType, constraint)); + + if (get(subjectType)) { - TypeId t = ty; - for (size_t i = 0; i < path.size() - 1; ++i) + bind(constraint, resultType, builtinTypes->anyType); + return true; + } + + if (auto ft = get(subjectType)) + { + if (auto tbl = get(follow(ft->upperBound)); tbl && tbl->indexer) { - if (!isUnsealedTable(t)) - return; + unify(constraint, indexType, tbl->indexer->indexType); + bind(constraint, resultType, tbl->indexer->indexResultType); + return true; + } + else if (auto mt = get(follow(ft->upperBound))) + return tryDispatchHasIndexer(recursionDepth, constraint, mt->table, indexType, resultType, seen); - const TableType* tbl = get(t); - auto it = tbl->props.find(path[i]); - if (it == tbl->props.end()) - return; + FreeType freeResult{ft->scope, builtinTypes->neverType, builtinTypes->unknownType}; + emplace(constraint, resultType, freeResult); - t = follow(it->second.type()); - } + TypeId upperBound = + arena->addType(TableType{/* props */ {}, TableIndexer{indexType, resultType}, TypeLevel{}, ft->scope, TableState::Unsealed}); - // The last path segment should not be a property of the table at all. - // We are not changing property types. We are only admitting this one - // new property to be appended. - if (!isUnsealedTable(t)) - return; - const TableType* tbl = get(t); - if (0 != tbl->props.count(path.back())) - return; + unify(constraint, subjectType, upperBound); + + return true; } + else if (auto tt = getMutable(subjectType)) + { + if (auto indexer = tt->indexer) + { + unify(constraint, indexType, indexer->indexType); + bind(constraint, resultType, indexer->indexResultType); + return true; + } + + if (tt->state == TableState::Unsealed) + { + // FIXME this is greedy. - TypeId t = ty; - ErrorVec dummy; + FreeType freeResult{tt->scope, builtinTypes->neverType, builtinTypes->unknownType}; + emplace(constraint, resultType, freeResult); - for (size_t i = 0; i < path.size() - 1; ++i) + tt->indexer = TableIndexer{indexType, resultType}; + return true; + } + } + else if (auto mt = get(subjectType)) + return tryDispatchHasIndexer(recursionDepth, constraint, mt->table, indexType, resultType, seen); + else if (auto ct = get(subjectType)) + { + if (auto indexer = ct->indexer) + { + unify(constraint, indexType, indexer->indexType); + bind(constraint, resultType, indexer->indexResultType); + return true; + } + else if (isString(indexType)) + { + bind(constraint, resultType, builtinTypes->unknownType); + return true; + } + } + else if (auto it = get(subjectType)) { - auto propTy = findTablePropertyRespectingMeta(builtinTypes, dummy, t, path[i], Location{}); - dummy.clear(); + // subjectType <: {[indexType]: resultType} + // + // 'a & ~(false | nil) <: {[indexType]: resultType} + // + // 'a <: {[indexType]: resultType} + // ~(false | nil) <: {[indexType]: resultType} + + Set parts{nullptr}; + for (TypeId part : it) + parts.insert(follow(part)); + + Set results{nullptr}; + + for (TypeId part : parts) + { + TypeId r = arena->addType(BlockedType{}); + getMutable(r)->setOwner(const_cast(constraint.get())); - if (!propTy) - return; + bool ok = tryDispatchHasIndexer(recursionDepth, constraint, part, indexType, r, seen); + // If we've cut a recursive loop short, skip it. + if (!ok) + continue; + + r = follow(r); + if (!get(r)) + results.insert(r); + } + + if (0 == results.size()) + bind(constraint, resultType, builtinTypes->errorType); + else if (1 == results.size()) + bind(constraint, resultType, *results.begin()); + else + emplace(constraint, resultType, std::vector(results.begin(), results.end())); - t = *propTy; + return true; } + else if (auto ut = get(subjectType)) + { + Set parts{nullptr}; + for (TypeId part : ut) + parts.insert(follow(part)); + + Set results{nullptr}; - const std::string& lastSegment = path.back(); + for (TypeId part : parts) + { + TypeId r = arena->addType(BlockedType{}); + getMutable(r)->setOwner(const_cast(constraint.get())); - t = follow(t); - TableType* tt = getMutable(t); - if (auto mt = get(t)) - tt = getMutable(mt->table); + bool ok = tryDispatchHasIndexer(recursionDepth, constraint, part, indexType, r, seen); + // If we've cut a recursive loop short, skip it. + if (!ok) + continue; - if (!tt) - return; + r = follow(r); + if (!get(r)) + results.insert(r); + } + + if (0 == results.size()) + bind(constraint, resultType, builtinTypes->errorType); + else if (1 == results.size()) + { + TypeId firstResult = *results.begin(); + shiftReferences(resultType, firstResult); + bind(constraint, resultType, firstResult); + } + else + emplace(constraint, resultType, std::vector(results.begin(), results.end())); - tt->props[lastSegment].setType(replaceTy); + return true; + } + + bind(constraint, resultType, builtinTypes->errorType); + + return true; } -bool ConstraintSolver::tryDispatch(const SetPropConstraint& c, NotNull constraint, bool force) +namespace { - TypeId subjectType = follow(c.subjectType); + +struct BlockedTypeFinder : TypeOnceVisitor +{ + std::optional blocked; + + bool visit(TypeId ty) override + { + // If we've already found one, stop traversing. + return !blocked.has_value(); + } + + bool visit(TypeId ty, const BlockedType&) override + { + blocked = ty; + return false; + } +}; + +} // namespace + +bool ConstraintSolver::tryDispatch(const HasIndexerConstraint& c, NotNull constraint) +{ + const TypeId subjectType = follow(c.subjectType); + const TypeId indexType = follow(c.indexType); if (isBlocked(subjectType)) return block(subjectType, constraint); - if (!force && get(subjectType)) - return block(subjectType, constraint); + if (isBlocked(indexType)) + return block(indexType, constraint); - std::optional existingPropType = subjectType; - for (const std::string& segment : c.path) - { - if (!existingPropType) - break; + BlockedTypeFinder btf; - auto [blocked, result] = lookupTableProp(*existingPropType, segment); - if (!blocked.empty()) - { - for (TypeId blocked : blocked) - block(blocked, constraint); - return false; - } + btf.visit(subjectType); - existingPropType = result; - } + if (btf.blocked) + return block(*btf.blocked, constraint); + int recursionDepth = 0; - auto bind = [](TypeId a, TypeId b) { - asMutable(a)->ty.emplace(b); - }; + Set seen{nullptr}; + + return tryDispatchHasIndexer(recursionDepth, constraint, subjectType, indexType, c.resultType, seen); +} + +bool ConstraintSolver::tryDispatch(const AssignPropConstraint& c, NotNull constraint) +{ + TypeId lhsType = follow(c.lhsType); + const std::string& propName = c.propName; + const TypeId rhsType = follow(c.rhsType); + + if (isBlocked(lhsType)) + return block(lhsType, constraint); + + // 1. lhsType is a class that already has the prop + // 2. lhsType is a table that already has the prop (or a union or + // intersection that has the prop in aggregate) + // 3. lhsType has a metatable that already has the prop + // 4. lhsType is an unsealed table that does not have the prop, but has a + // string indexer + // 5. lhsType is an unsealed table that does not have the prop or a string + // indexer - if (existingPropType) + // Important: In every codepath through this function, the type `c.propType` + // must be bound to something, even if it's just the errorType. + + if (auto lhsClass = get(lhsType)) { - if (!isBlocked(c.propType)) - unify(c.propType, *existingPropType, constraint->scope); - bind(c.resultType, c.subjectType); - unblock(c.resultType); + const Property* prop = lookupClassProp(lhsClass, propName); + if (!prop || !prop->writeTy.has_value()) + return true; + + bind(constraint, c.propType, *prop->writeTy); + unify(constraint, rhsType, *prop->writeTy); return true; } - if (auto mt = get(subjectType)) - subjectType = follow(mt->table); - - if (get(subjectType)) + if (auto lhsFree = getMutable(lhsType)) { - TypeId ty = arena->freshType(constraint->scope); - - // Mint a chain of free tables per c.path - for (auto it = rbegin(c.path); it != rend(c.path); ++it) + if (get(lhsFree->upperBound) || get(lhsFree->upperBound)) + lhsType = lhsFree->upperBound; + else { - TableType t{TableState::Free, TypeLevel{}, constraint->scope}; - t.props[*it] = {ty}; + TypeId newUpperBound = arena->addType(TableType{TableState::Free, TypeLevel{}, constraint->scope}); + TableType* upperTable = getMutable(newUpperBound); + LUAU_ASSERT(upperTable); - ty = arena->addType(std::move(t)); + upperTable->props[c.propName] = rhsType; + + // Food for thought: Could we block if simplification encounters a blocked type? + lhsFree->upperBound = simplifyIntersection(builtinTypes, arena, lhsFree->upperBound, newUpperBound).result; + + bind(constraint, c.propType, rhsType); + return true; } + } - LUAU_ASSERT(ty); + // Handle the case that lhsType is a table that already has the property or + // a matching indexer. This also handles unions and intersections. + const auto [blocked, maybeTy] = lookupTableProp(constraint, lhsType, propName, ValueContext::LValue); + if (!blocked.empty()) + { + for (TypeId t : blocked) + block(t, constraint); + return false; + } - bind(subjectType, ty); - if (follow(c.resultType) != follow(ty)) - bind(c.resultType, ty); - unblock(subjectType); - unblock(c.resultType); + if (maybeTy) + { + const TypeId propTy = *maybeTy; + bind(constraint, c.propType, propTy); + unify(constraint, rhsType, propTy); return true; } - else if (auto ttv = getMutable(subjectType)) + + if (auto lhsMeta = get(lhsType)) + lhsType = follow(lhsMeta->table); + + // Handle the case where the lhs type is a table that does not have the + // named property. It could be a table with a string indexer, or an unsealed + // or free table that can grow. + if (auto lhsTable = getMutable(lhsType)) { - if (ttv->state == TableState::Free) + if (auto it = lhsTable->props.find(propName); it != lhsTable->props.end()) { - LUAU_ASSERT(!subjectType->persistent); + Property& prop = it->second; - ttv->props[c.path[0]] = Property{c.propType}; - bind(c.resultType, c.subjectType); - unblock(c.resultType); - return true; + if (prop.writeTy.has_value()) + { + bind(constraint, c.propType, *prop.writeTy); + unify(constraint, rhsType, *prop.writeTy); + return true; + } + else + { + LUAU_ASSERT(prop.isReadOnly()); + if (lhsTable->state == TableState::Unsealed || lhsTable->state == TableState::Free) + { + prop.writeTy = prop.readTy; + bind(constraint, c.propType, *prop.writeTy); + unify(constraint, rhsType, *prop.writeTy); + return true; + } + else + { + bind(constraint, c.propType, builtinTypes->errorType); + return true; + } + } } - else if (ttv->state == TableState::Unsealed) - { - LUAU_ASSERT(!subjectType->persistent); - updateTheTableType(builtinTypes, NotNull{arena}, subjectType, c.path, c.propType); - bind(c.resultType, c.subjectType); - unblock(subjectType); - unblock(c.resultType); + if (lhsTable->indexer && maybeString(lhsTable->indexer->indexType)) + { + bind(constraint, c.propType, rhsType); + unify(constraint, rhsType, lhsTable->indexer->indexResultType); return true; } - else + + if (lhsTable->state == TableState::Unsealed || lhsTable->state == TableState::Free) { - bind(c.resultType, subjectType); - unblock(c.resultType); + bind(constraint, c.propType, rhsType); + Property& newProp = lhsTable->props[propName]; + newProp.readTy = rhsType; + newProp.writeTy = rhsType; + newProp.location = c.propLocation; + + if (lhsTable->state == TableState::Unsealed && c.decrementPropCount) + { + LUAU_ASSERT(lhsTable->remainingProps > 0); + lhsTable->remainingProps -= 1; + } + return true; } } - else - { - // Other kinds of types don't change shape when properties are assigned - // to them. (if they allow properties at all!) - bind(c.resultType, subjectType); - unblock(c.resultType); - return true; - } + + bind(constraint, c.propType, builtinTypes->errorType); + + return true; } -bool ConstraintSolver::tryDispatch(const SetIndexerConstraint& c, NotNull constraint, bool force) +bool ConstraintSolver::tryDispatch(const AssignIndexConstraint& c, NotNull constraint) { - TypeId subjectType = follow(c.subjectType); - if (isBlocked(subjectType)) - return block(subjectType, constraint); + const TypeId lhsType = follow(c.lhsType); + const TypeId indexType = follow(c.indexType); + const TypeId rhsType = follow(c.rhsType); - if (auto ft = get(subjectType)) - { - Scope* scope = ft->scope; - TableType* tt = &asMutable(subjectType)->ty.emplace(TableState::Free, TypeLevel{}, scope); - tt->indexer = TableIndexer{c.indexType, c.propType}; + if (isBlocked(lhsType)) + return block(lhsType, constraint); - asMutable(c.resultType)->ty.emplace(subjectType); - asMutable(c.propType)->ty.emplace(scope); - unblock(c.propType); - unblock(c.resultType); + // 0. lhsType could be an intersection or union. + // 1. lhsType is a class with an indexer + // 2. lhsType is a table with an indexer, or it has a metatable that has an indexer + // 3. lhsType is a free or unsealed table and can grow an indexer - return true; - } - else if (auto tt = get(subjectType)) + // Important: In every codepath through this function, the type `c.propType` + // must be bound to something, even if it's just the errorType. + + auto tableStuff = [&](TableType* lhsTable) -> std::optional { - if (tt->indexer) + if (lhsTable->indexer) { - // TODO This probably has to be invariant. - unify(c.indexType, tt->indexer->indexType, constraint->scope); - asMutable(c.propType)->ty.emplace(tt->indexer->indexResultType); - asMutable(c.resultType)->ty.emplace(subjectType); - unblock(c.propType); - unblock(c.resultType); + unify(constraint, indexType, lhsTable->indexer->indexType); + unify(constraint, rhsType, lhsTable->indexer->indexResultType); + bind(constraint, c.propType, lhsTable->indexer->indexResultType); return true; } - else if (tt->state == TableState::Free || tt->state == TableState::Unsealed) + + if (lhsTable->state == TableState::Unsealed || lhsTable->state == TableState::Free) { - auto mtt = getMutable(subjectType); - mtt->indexer = TableIndexer{c.indexType, c.propType}; - asMutable(c.propType)->ty.emplace(tt->scope); - asMutable(c.resultType)->ty.emplace(subjectType); - unblock(c.propType); - unblock(c.resultType); + lhsTable->indexer = TableIndexer{indexType, rhsType}; + bind(constraint, c.propType, rhsType); return true; } - // Do not augment sealed or generic tables that lack indexers + + return {}; + }; + + if (auto lhsFree = getMutable(lhsType)) + { + if (auto lhsTable = getMutable(lhsFree->upperBound)) + { + if (auto res = tableStuff(lhsTable)) + return *res; + } + + TypeId newUpperBound = + arena->addType(TableType{/*props*/ {}, TableIndexer{indexType, rhsType}, TypeLevel{}, constraint->scope, TableState::Free}); + const TableType* newTable = get(newUpperBound); + LUAU_ASSERT(newTable); + + unify(constraint, lhsType, newUpperBound); + + LUAU_ASSERT(newTable->indexer); + bind(constraint, c.propType, newTable->indexer->indexResultType); + return true; } - asMutable(c.propType)->ty.emplace(builtinTypes->errorRecoveryType()); - asMutable(c.resultType)->ty.emplace(builtinTypes->errorRecoveryType()); - unblock(c.propType); - unblock(c.resultType); - return true; -} + if (auto lhsTable = getMutable(lhsType)) + { + std::optional res = tableStuff(lhsTable); + if (res.has_value()) + return *res; + } -bool ConstraintSolver::tryDispatch(const SingletonOrTopTypeConstraint& c, NotNull constraint) -{ - if (isBlocked(c.discriminantType)) - return false; + if (auto lhsClass = get(lhsType)) + { + while (true) + { + if (lhsClass->indexer) + { + unify(constraint, indexType, lhsClass->indexer->indexType); + unify(constraint, rhsType, lhsClass->indexer->indexResultType); + bind(constraint, c.propType, lhsClass->indexer->indexResultType); + return true; + } - TypeId followed = follow(c.discriminantType); + if (lhsClass->parent) + lhsClass = get(lhsClass->parent); + else + break; + } + return true; + } - // `nil` is a singleton type too! There's only one value of type `nil`. - if (c.negated && (get(followed) || isNil(followed))) - *asMutable(c.resultType) = NegationType{c.discriminantType}; - else if (!c.negated && get(followed)) - *asMutable(c.resultType) = BoundType{c.discriminantType}; - else - *asMutable(c.resultType) = BoundType{builtinTypes->unknownType}; + if (auto lhsIntersection = getMutable(lhsType)) + { + std::set parts; + + for (TypeId t : lhsIntersection) + { + if (auto tbl = getMutable(follow(t))) + { + if (tbl->indexer) + { + unify(constraint, indexType, tbl->indexer->indexType); + parts.insert(tbl->indexer->indexResultType); + } + + if (tbl->state == TableState::Unsealed || tbl->state == TableState::Free) + { + tbl->indexer = TableIndexer{indexType, rhsType}; + parts.insert(rhsType); + } + } + else if (auto cls = get(follow(t))) + { + while (true) + { + if (cls->indexer) + { + unify(constraint, indexType, cls->indexer->indexType); + parts.insert(cls->indexer->indexResultType); + break; + } + + if (cls->parent) + cls = get(cls->parent); + else + break; + } + } + } + + TypeId res = simplifyIntersection(builtinTypes, arena, std::move(parts)).result; + + unify(constraint, rhsType, res); + } + + // Other types do not support index assignment. + bind(constraint, c.propType, builtinTypes->errorType); return true; } @@ -1575,114 +1943,209 @@ bool ConstraintSolver::tryDispatch(const SingletonOrTopTypeConstraint& c, NotNul bool ConstraintSolver::tryDispatch(const UnpackConstraint& c, NotNull constraint) { TypePackId sourcePack = follow(c.sourcePack); - TypePackId resultPack = follow(c.resultPack); if (isBlocked(sourcePack)) return block(sourcePack, constraint); - if (isBlocked(resultPack)) - { - asMutable(resultPack)->ty.emplace(sourcePack); - unblock(resultPack); - return true; - } - - TypePack srcPack = extendTypePack(*arena, builtinTypes, sourcePack, size(resultPack)); + TypePack srcPack = extendTypePack(*arena, builtinTypes, sourcePack, c.resultPack.size()); - auto destIter = begin(resultPack); - auto destEnd = end(resultPack); + auto resultIter = begin(c.resultPack); + auto resultEnd = end(c.resultPack); size_t i = 0; - while (destIter != destEnd) + while (resultIter != resultEnd) { if (i >= srcPack.head.size()) break; + TypeId srcTy = follow(srcPack.head[i]); + TypeId resultTy = follow(*resultIter); + + LUAU_ASSERT(get(resultTy)); + LUAU_ASSERT(canMutate(resultTy, constraint)); - if (isBlocked(*destIter)) + if (get(resultTy)) { - if (follow(srcTy) == *destIter) + if (follow(srcTy) == resultTy) { - // Cyclic type dependency. (????) - asMutable(*destIter)->ty.emplace(constraint->scope); + // It is sometimes the case that we find that a blocked type + // is only blocked on itself. This doesn't actually + // constitute any meaningful constraint, so we replace it + // with a free type. + TypeId f = freshType(arena, builtinTypes, constraint->scope); + shiftReferences(resultTy, f); + emplaceType(asMutable(resultTy), f); } else - asMutable(*destIter)->ty.emplace(srcTy); - unblock(*destIter); + bind(constraint, resultTy, srcTy); } else - unify(*destIter, srcTy, constraint->scope); + unify(constraint, srcTy, resultTy); - ++destIter; + unblock(resultTy, constraint->location); + + ++resultIter; ++i; } // We know that resultPack does not have a tail, but we don't know if // sourcePack is long enough to fill every value. Replace every remaining - // result TypeId with the error recovery type. + // result TypeId with `nil`. - while (destIter != destEnd) + while (resultIter != resultEnd) { - if (isBlocked(*destIter)) + TypeId resultTy = follow(*resultIter); + LUAU_ASSERT(canMutate(resultTy, constraint)); + if (get(resultTy) || get(resultTy)) { - asMutable(*destIter)->ty.emplace(builtinTypes->errorRecoveryType()); - unblock(*destIter); + bind(constraint, resultTy, builtinTypes->nilType); } - ++destIter; + ++resultIter; } return true; } -bool ConstraintSolver::tryDispatchIterableTable(TypeId iteratorTy, const IterableConstraint& c, NotNull constraint, bool force) +bool ConstraintSolver::tryDispatch(const ReduceConstraint& c, NotNull constraint, bool force) { - auto block_ = [&](auto&& t) { - if (force) + TypeId ty = follow(c.ty); + FunctionGraphReductionResult result = + reduceTypeFunctions(ty, constraint->location, TypeFunctionContext{NotNull{this}, constraint->scope, constraint}, force); + + for (TypeId r : result.reducedTypes) + unblock(r, constraint->location); + + for (TypePackId r : result.reducedPacks) + unblock(r, constraint->location); + + bool reductionFinished = result.blockedTypes.empty() && result.blockedPacks.empty(); + + ty = follow(ty); + // If we couldn't reduce this type function, stick it in the set! + if (get(ty)) + typeFunctionsToFinalize[ty] = constraint; + + if (force || reductionFinished) + { + // if we're completely dispatching this constraint, we want to record any uninhabited type functions to unblock. + for (auto error : result.errors) { - // TODO: I believe it is the case that, if we are asked to force - // this constraint, then we can do nothing but fail. I'd like to - // find a code sample that gets here. - LUAU_ASSERT(false); + if (auto utf = get(error)) + uninhabitedTypeFunctions.insert(utf->ty); + else if (auto utpf = get(error)) + uninhabitedTypeFunctions.insert(utpf->tp); } - else - block(t, constraint); - return false; - }; + } + + if (force) + return true; + + for (TypeId b : result.blockedTypes) + block(b, constraint); + + for (TypePackId b : result.blockedPacks) + block(b, constraint); + + return reductionFinished; +} + +bool ConstraintSolver::tryDispatch(const ReducePackConstraint& c, NotNull constraint, bool force) +{ + TypePackId tp = follow(c.tp); + FunctionGraphReductionResult result = + reduceTypeFunctions(tp, constraint->location, TypeFunctionContext{NotNull{this}, constraint->scope, constraint}, force); + + for (TypeId r : result.reducedTypes) + unblock(r, constraint->location); + + for (TypePackId r : result.reducedPacks) + unblock(r, constraint->location); + + bool reductionFinished = result.blockedTypes.empty() && result.blockedPacks.empty(); - // We may have to block here if we don't know what the iteratee type is, - // if it's a free table, if we don't know it has a metatable, and so on. + if (force || reductionFinished) + { + // if we're completely dispatching this constraint, we want to record any uninhabited type functions to unblock. + for (auto error : result.errors) + { + if (auto utf = get(error)) + uninhabitedTypeFunctions.insert(utf->ty); + else if (auto utpf = get(error)) + uninhabitedTypeFunctions.insert(utpf->tp); + } + } + + if (force) + return true; + + for (TypeId b : result.blockedTypes) + block(b, constraint); + + for (TypePackId b : result.blockedPacks) + block(b, constraint); + + return reductionFinished; +} + +bool ConstraintSolver::tryDispatch(const EqualityConstraint& c, NotNull constraint, bool force) +{ + unify(constraint, c.resultType, c.assignmentType); + unify(constraint, c.assignmentType, c.resultType); + return true; +} + +bool ConstraintSolver::tryDispatchIterableTable(TypeId iteratorTy, const IterableConstraint& c, NotNull constraint, bool force) +{ iteratorTy = follow(iteratorTy); + if (get(iteratorTy)) - return block_(iteratorTy); + { + TypeId keyTy = freshType(arena, builtinTypes, constraint->scope); + TypeId valueTy = freshType(arena, builtinTypes, constraint->scope); + TypeId tableTy = arena->addType(TableType{TableState::Sealed, {}, constraint->scope}); + getMutable(tableTy)->indexer = TableIndexer{keyTy, valueTy}; - auto anyify = [&](auto ty) { - Anyification anyify{arena, constraint->scope, builtinTypes, &iceReporter, builtinTypes->anyType, builtinTypes->anyTypePack}; - std::optional anyified = anyify.substitute(ty); - if (!anyified) - reportError(CodeTooComplex{}, constraint->location); - else - unify(*anyified, ty, constraint->scope); - }; + pushConstraint(constraint->scope, constraint->location, SubtypeConstraint{iteratorTy, tableTy}); - auto errorify = [&](auto ty) { - Anyification anyify{arena, constraint->scope, builtinTypes, &iceReporter, errorRecoveryType(), errorRecoveryTypePack()}; - std::optional errorified = anyify.substitute(ty); - if (!errorified) - reportError(CodeTooComplex{}, constraint->location); - else - unify(*errorified, ty, constraint->scope); + auto it = begin(c.variables); + auto endIt = end(c.variables); + if (it != endIt) + { + bind(constraint, *it, keyTy); + ++it; + } + if (it != endIt) + bind(constraint, *it, valueTy); + + return true; + } + + auto unpack = [&](TypeId ty) + { + for (TypeId varTy : c.variables) + { + LUAU_ASSERT(get(varTy)); + LUAU_ASSERT(varTy != ty); + bind(constraint, varTy, ty); + } }; if (get(iteratorTy)) { - anyify(c.variables); + unpack(builtinTypes->anyType); return true; } if (get(iteratorTy)) { - errorify(c.variables); + unpack(builtinTypes->errorType); + return true; + } + + if (get(iteratorTy)) + { + unpack(builtinTypes->neverType); return true; } @@ -1696,18 +2159,28 @@ bool ConstraintSolver::tryDispatchIterableTable(TypeId iteratorTy, const Iterabl * it's possible that there are other constraints on the table that will * clarify what we should do. * - * We should eventually introduce a type family to talk about iteration. + * We should eventually introduce a type function to talk about iteration. */ if (iteratorTable->state == TableState::Free && !force) return block(iteratorTy, constraint); if (iteratorTable->indexer) { - TypePackId expectedVariablePack = arena->addTypePack({iteratorTable->indexer->indexType, iteratorTable->indexer->indexResultType}); - unify(c.variables, expectedVariablePack, constraint->scope); + std::vector expectedVariables{iteratorTable->indexer->indexType, iteratorTable->indexer->indexResultType}; + while (c.variables.size() >= expectedVariables.size()) + expectedVariables.push_back(builtinTypes->errorRecoveryType()); + + for (size_t i = 0; i < c.variables.size(); ++i) + { + LUAU_ASSERT(c.variables[i] != expectedVariables[i]); + + unify(constraint, c.variables[i], expectedVariables[i]); + + bind(constraint, c.variables[i], expectedVariables[i]); + } } else - errorify(c.variables); + unpack(builtinTypes->errorType); } else if (std::optional iterFn = findMetatableEntry(builtinTypes, errors, iteratorTy, "__iter", Location{})) { @@ -1716,14 +2189,12 @@ bool ConstraintSolver::tryDispatchIterableTable(TypeId iteratorTy, const Iterabl return block(*iterFn, constraint); } - Instantiation instantiation(TxnLog::empty(), arena, TypeLevel{}, constraint->scope); - - if (std::optional instantiatedIterFn = instantiation.substitute(*iterFn)) + if (std::optional instantiatedIterFn = instantiate(builtinTypes, arena, NotNull{&limits}, constraint->scope, *iterFn)) { if (auto iterFtv = get(*instantiatedIterFn)) { TypePackId expectedIterArgs = arena->addTypePack({iteratorTy}); - unify(iterFtv->argTypes, expectedIterArgs, constraint->scope); + unify(constraint, iterFtv->argTypes, expectedIterArgs); TypePack iterRets = extendTypePack(*arena, builtinTypes, iterFtv->retTypes, 2); @@ -1735,21 +2206,16 @@ bool ConstraintSolver::tryDispatchIterableTable(TypeId iteratorTy, const Iterabl } TypeId nextFn = iterRets.head[0]; - TypeId table = iterRets.head.size() == 2 ? iterRets.head[1] : arena->freshType(constraint->scope); - if (std::optional instantiatedNextFn = instantiation.substitute(nextFn)) + if (std::optional instantiatedNextFn = instantiate(builtinTypes, arena, NotNull{&limits}, constraint->scope, nextFn)) { - const TypeId firstIndex = arena->freshType(constraint->scope); + const FunctionType* nextFn = get(*instantiatedNextFn); - // nextTy : (iteratorTy, indexTy?) -> (indexTy, valueTailTy...) - const TypePackId nextArgPack = arena->addTypePack({table, arena->addType(UnionType{{firstIndex, builtinTypes->nilType}})}); - const TypePackId valueTailTy = arena->addTypePack(FreeTypePack{constraint->scope}); - const TypePackId nextRetPack = arena->addTypePack(TypePack{{firstIndex}, valueTailTy}); + // If nextFn is nullptr, then the iterator function has an improper signature. + if (nextFn) + unpackAndAssign(c.variables, nextFn->retTypes, constraint); - const TypeId expectedNextTy = arena->addType(FunctionType{nextArgPack, nextRetPack}); - unify(*instantiatedNextFn, expectedNextTy, constraint->scope); - - pushConstraint(constraint->scope, constraint->location, PackSubtypeConstraint{c.variables, nextRetPack}); + return true; } else { @@ -1768,53 +2234,34 @@ bool ConstraintSolver::tryDispatchIterableTable(TypeId iteratorTy, const Iterabl } else if (auto iteratorMetatable = get(iteratorTy)) { - TypeId metaTy = follow(iteratorMetatable->metatable); - if (get(metaTy)) - return block_(metaTy); - - LUAU_ASSERT(false); + // If the metatable does not contain a `__iter` metamethod, then we iterate over the table part of the metatable. + return tryDispatchIterableTable(iteratorMetatable->table, c, constraint, force); } + else if (auto primitiveTy = get(iteratorTy); primitiveTy && primitiveTy->type == PrimitiveType::Type::Table) + unpack(builtinTypes->unknownType); else - errorify(c.variables); + { + unpack(builtinTypes->errorType); + } return true; } bool ConstraintSolver::tryDispatchIterableFunction( - TypeId nextTy, TypeId tableTy, TypeId firstIndexTy, const IterableConstraint& c, NotNull constraint, bool force) + TypeId nextTy, + TypeId tableTy, + const IterableConstraint& c, + NotNull constraint, + bool force +) { - // We need to know whether or not this type is nil or not. - // If we don't know, block and reschedule ourselves. - firstIndexTy = follow(firstIndexTy); - if (get(firstIndexTy)) - { - if (force) - LUAU_ASSERT(false); - else - block(firstIndexTy, constraint); - return false; - } - - TypeId firstIndex; - TypeId retIndex; - if (isNil(firstIndexTy) || isOptional(firstIndexTy)) - { - firstIndex = arena->addType(UnionType{{arena->freshType(constraint->scope), builtinTypes->nilType}}); - retIndex = firstIndex; - } - else - { - firstIndex = firstIndexTy; - retIndex = arena->addType(UnionType{{firstIndexTy, builtinTypes->nilType}}); - } - - // nextTy : (tableTy, indexTy?) -> (indexTy?, valueTailTy...) - const TypePackId nextArgPack = arena->addTypePack({tableTy, firstIndex}); - const TypePackId valueTailTy = arena->addTypePack(FreeTypePack{constraint->scope}); - const TypePackId nextRetPack = arena->addTypePack(TypePack{{retIndex}, valueTailTy}); + const FunctionType* nextFn = get(nextTy); + // If this does not hold, we should've never called `tryDispatchIterableFunction` in the first place. + LUAU_ASSERT(nextFn); + const TypePackId nextRetPack = nextFn->retTypes; - const TypeId expectedNextTy = arena->addType(FunctionType{TypeLevel{}, constraint->scope, nextArgPack, nextRetPack}); - unify(nextTy, expectedNextTy, constraint->scope); + // the type of the `nextAstFragment` is the `nextTy`. + (*c.astForInNextTypes)[c.nextAstFragment] = nextTy; auto it = begin(nextRetPack); std::vector modifiedNextRetHead; @@ -1834,22 +2281,58 @@ bool ConstraintSolver::tryDispatchIterableFunction( modifiedNextRetHead.push_back(*it); TypePackId modifiedNextRetPack = arena->addTypePack(std::move(modifiedNextRetHead), it.tail()); - auto psc = pushConstraint(constraint->scope, constraint->location, PackSubtypeConstraint{c.variables, modifiedNextRetPack}); - inheritBlocks(constraint, psc); + + auto unpackConstraint = unpackAndAssign(c.variables, modifiedNextRetPack, constraint); + + inheritBlocks(constraint, unpackConstraint); return true; } -std::pair, std::optional> ConstraintSolver::lookupTableProp(TypeId subjectType, const std::string& propName) +NotNull ConstraintSolver::unpackAndAssign( + const std::vector destTypes, + TypePackId srcTypes, + NotNull constraint +) { - std::unordered_set seen; - return lookupTableProp(subjectType, propName, seen); + auto c = pushConstraint(constraint->scope, constraint->location, UnpackConstraint{destTypes, srcTypes}); + + for (TypeId t : destTypes) + { + BlockedType* bt = getMutable(t); + LUAU_ASSERT(bt); + bt->replaceOwner(c); + } + + return c; } -std::pair, std::optional> ConstraintSolver::lookupTableProp(TypeId subjectType, const std::string& propName, std::unordered_set& seen) -{ - if (!seen.insert(subjectType).second) +std::pair, std::optional> ConstraintSolver::lookupTableProp( + NotNull constraint, + TypeId subjectType, + const std::string& propName, + ValueContext context, + bool inConditional, + bool suppressSimplification +) +{ + DenseHashSet seen{nullptr}; + return lookupTableProp(constraint, subjectType, propName, context, inConditional, suppressSimplification, seen); +} + +std::pair, std::optional> ConstraintSolver::lookupTableProp( + NotNull constraint, + TypeId subjectType, + const std::string& propName, + ValueContext context, + bool inConditional, + bool suppressSimplification, + DenseHashSet& seen +) +{ + if (seen.contains(subjectType)) return {}; + seen.insert(subjectType); subjectType = follow(subjectType); @@ -1862,19 +2345,64 @@ std::pair, std::optional> ConstraintSolver::lookupTa else if (auto ttv = getMutable(subjectType)) { if (auto prop = ttv->props.find(propName); prop != ttv->props.end()) - return {{}, prop->second.type()}; - else if (ttv->indexer && maybeString(ttv->indexer->indexType)) + { + switch (context) + { + case ValueContext::RValue: + if (auto rt = prop->second.readTy) + return {{}, rt}; + break; + case ValueContext::LValue: + if (auto wt = prop->second.writeTy) + return {{}, wt}; + break; + } + } + + if (ttv->indexer && maybeString(ttv->indexer->indexType)) return {{}, ttv->indexer->indexResultType}; - else if (ttv->state == TableState::Free) + + if (ttv->state == TableState::Free) { - TypeId result = arena->freshType(ttv->scope); - ttv->props[propName] = Property{result}; + TypeId result = freshType(arena, builtinTypes, ttv->scope); + switch (context) + { + case ValueContext::RValue: + ttv->props[propName].readTy = result; + break; + case ValueContext::LValue: + if (auto it = ttv->props.find(propName); it != ttv->props.end() && it->second.isReadOnly()) + { + // We do infer read-only properties, but we do not infer + // separate read and write types. + // + // If we encounter a case where a free table has a read-only + // property that we subsequently sense a write to, we make + // the judgement that the property is read-write and that + // both the read and write types are the same. + + Property& prop = it->second; + + prop.writeTy = prop.readTy; + return {{}, *prop.readTy}; + } + else + ttv->props[propName] = Property::rw(result); + + break; + } return {{}, result}; } + + // if we are in a conditional context, we treat the property as present and `unknown` because + // we may be _refining_ a table to include that property. we will want to revisit this a bit + // in the future once luau has support for exact tables since this only applies when inexact. + if (inConditional) + return {{}, builtinTypes->unknownType}; } - else if (auto mt = get(subjectType)) + else if (auto mt = get(subjectType); mt && context == ValueContext::RValue) { - auto [blocked, result] = lookupTableProp(mt->table, propName, seen); + auto [blocked, result] = lookupTableProp(constraint, mt->table, propName, context, inConditional, suppressSimplification, seen); if (!blocked.empty() || result) return {blocked, result}; @@ -1905,13 +2433,19 @@ std::pair, std::optional> ConstraintSolver::lookupTa } } else - return lookupTableProp(indexType, propName, seen); + return lookupTableProp(constraint, indexType, propName, context, inConditional, suppressSimplification, seen); } + else if (get(mtt)) + return lookupTableProp(constraint, mtt, propName, context, inConditional, suppressSimplification, seen); } else if (auto ct = get(subjectType)) { if (auto p = lookupClassProp(ct, propName)) - return {{}, p->type()}; + return {{}, context == ValueContext::RValue ? p->readTy : p->writeTy}; + if (ct->indexer) + { + return {{}, ct->indexer->indexResultType}; + } } else if (auto pt = get(subjectType); pt && pt->metatable) { @@ -1922,17 +2456,35 @@ std::pair, std::optional> ConstraintSolver::lookupTa if (indexProp == metatable->props.end()) return {{}, std::nullopt}; - return lookupTableProp(indexProp->second.type(), propName, seen); + return lookupTableProp(constraint, indexProp->second.type(), propName, context, inConditional, suppressSimplification, seen); } else if (auto ft = get(subjectType)) { - Scope* scope = ft->scope; + const TypeId upperBound = follow(ft->upperBound); + + if (get(upperBound) || get(upperBound)) + return lookupTableProp(constraint, upperBound, propName, context, inConditional, suppressSimplification, seen); + + // TODO: The upper bound could be an intersection that contains suitable tables or classes. + + NotNull scope{ft->scope}; + + const TypeId newUpperBound = arena->addType(TableType{TableState::Free, TypeLevel{}, scope}); + TableType* tt = getMutable(newUpperBound); + LUAU_ASSERT(tt); + TypeId propType = freshType(arena, builtinTypes, scope); - TableType* tt = &asMutable(subjectType)->ty.emplace(); - tt->state = TableState::Free; - tt->scope = scope; - TypeId propType = arena->freshType(scope); - tt->props[propName] = Property{propType}; + switch (context) + { + case ValueContext::RValue: + tt->props[propName] = Property::readonly(propType); + break; + case ValueContext::LValue: + tt->props[propName] = Property::rw(propType); + break; + } + + unify(constraint, subjectType, newUpperBound); return {{}, propType}; } @@ -1943,7 +2495,7 @@ std::pair, std::optional> ConstraintSolver::lookupTa for (TypeId ty : utv) { - auto [innerBlocked, innerResult] = lookupTableProp(ty, propName, seen); + auto [innerBlocked, innerResult] = lookupTableProp(constraint, ty, propName, context, inConditional, suppressSimplification, seen); blocked.insert(blocked.end(), innerBlocked.begin(), innerBlocked.end()); if (innerResult) options.insert(*innerResult); @@ -1956,6 +2508,20 @@ std::pair, std::optional> ConstraintSolver::lookupTa return {{}, std::nullopt}; else if (options.size() == 1) return {{}, *begin(options)}; + else if (options.size() == 2 && !suppressSimplification) + { + TypeId one = *begin(options); + TypeId two = *(++begin(options)); + + // if we're in an lvalue context, we need the _common_ type here. + if (context == ValueContext::LValue) + return {{}, simplifyIntersection(builtinTypes, arena, one, two).result}; + + return {{}, simplifyUnion(builtinTypes, arena, one, two).result}; + } + // if we're in an lvalue context, we need the _common_ type here. + else if (context == ValueContext::LValue) + return {{}, arena->addType(IntersectionType{std::vector(begin(options), end(options))})}; else return {{}, arena->addType(UnionType{std::vector(begin(options), end(options))})}; } @@ -1966,7 +2532,7 @@ std::pair, std::optional> ConstraintSolver::lookupTa for (TypeId ty : itv) { - auto [innerBlocked, innerResult] = lookupTableProp(ty, propName, seen); + auto [innerBlocked, innerResult] = lookupTableProp(constraint, ty, propName, context, inConditional, suppressSimplification, seen); blocked.insert(blocked.end(), innerBlocked.begin(), innerBlocked.end()); if (innerResult) options.insert(*innerResult); @@ -1979,100 +2545,115 @@ std::pair, std::optional> ConstraintSolver::lookupTa return {{}, std::nullopt}; else if (options.size() == 1) return {{}, *begin(options)}; + else if (options.size() == 2 && !suppressSimplification) + { + TypeId one = *begin(options); + TypeId two = *(++begin(options)); + return {{}, simplifyIntersection(builtinTypes, arena, one, two).result}; + } else return {{}, arena->addType(IntersectionType{std::vector(begin(options), end(options))})}; } + else if (auto pt = get(subjectType)) + { + // if we are in a conditional context, we treat the property as present and `unknown` because + // we may be _refining_ a table to include that property. we will want to revisit this a bit + // in the future once luau has support for exact tables since this only applies when inexact. + if (inConditional && pt->type == PrimitiveType::Table) + return {{}, builtinTypes->unknownType}; + } return {{}, std::nullopt}; } -static TypeId getErrorType(NotNull builtinTypes, TypeId) -{ - return builtinTypes->errorRecoveryType(); -} - -static TypePackId getErrorType(NotNull builtinTypes, TypePackId) -{ - return builtinTypes->errorRecoveryTypePack(); -} - -template -bool ConstraintSolver::tryUnify(NotNull constraint, TID subTy, TID superTy) +template +bool ConstraintSolver::unify(NotNull constraint, TID subTy, TID superTy) { - Unifier u{normalizer, Mode::Strict, constraint->scope, constraint->location, Covariant}; - u.useScopes = true; + Unifier2 u2{NotNull{arena}, builtinTypes, constraint->scope, NotNull{&iceReporter}, &uninhabitedTypeFunctions}; - u.tryUnify(subTy, superTy); + const bool ok = u2.unify(subTy, superTy); - if (!u.blockedTypes.empty() || !u.blockedTypePacks.empty()) + for (ConstraintV& c : u2.incompleteSubtypes) { - for (TypeId bt : u.blockedTypes) - block(bt, constraint); - for (TypePackId btp : u.blockedTypePacks) - block(btp, constraint); - return false; + NotNull addition = pushConstraint(constraint->scope, constraint->location, std::move(c)); + inheritBlocks(constraint, addition); } - if (const auto& e = hasUnificationTooComplex(u.errors)) - reportError(*e); - - if (!u.errors.empty()) + if (ok) { - TID errorType = getErrorType(builtinTypes, TID{}); - u.tryUnify(subTy, errorType); - u.tryUnify(superTy, errorType); + for (const auto& [expanded, additions] : u2.expandedFreeTypes) + { + for (TypeId addition : additions) + upperBoundContributors[expanded].push_back(std::make_pair(constraint->location, addition)); + } + } + else + { + reportError(OccursCheckFailed{}, constraint->location); + return false; } - - const auto [changedTypes, changedPacks] = u.log.getChanges(); - - u.log.commit(); - - unblock(changedTypes); - unblock(changedPacks); return true; } -void ConstraintSolver::block_(BlockedConstraintId target, NotNull constraint) +bool ConstraintSolver::block_(BlockedConstraintId target, NotNull constraint) { - blocked[target].push_back(constraint); + // If a set is not present for the target, construct a new DenseHashSet for it, + // else grab the address of the existing set. + auto [iter, inserted] = blocked.try_emplace(target, nullptr); + auto& [key, blockVec] = *iter; + + if (blockVec.find(constraint)) + return false; + + blockVec.insert(constraint); - auto& count = blockedConstraints[constraint]; + size_t& count = blockedConstraints[constraint]; count += 1; + + return true; } void ConstraintSolver::block(NotNull target, NotNull constraint) { - if (logger) - logger->pushBlock(constraint, target); - - if (FFlag::DebugLuauLogSolver) - printf("block Constraint %s on\t%s\n", toString(*target, opts).c_str(), toString(*constraint, opts).c_str()); + const bool newBlock = block_(target.get(), constraint); + if (newBlock) + { + if (logger) + logger->pushBlock(constraint, target); - block_(target.get(), constraint); + if (FFlag::DebugLuauLogSolver) + printf("%s depends on constraint %s\n", toString(*constraint, opts).c_str(), toString(*target, opts).c_str()); + } } bool ConstraintSolver::block(TypeId target, NotNull constraint) { - if (logger) - logger->pushBlock(constraint, target); + const bool newBlock = block_(follow(target), constraint); + if (newBlock) + { + if (logger) + logger->pushBlock(constraint, target); - if (FFlag::DebugLuauLogSolver) - printf("block TypeId %s on\t%s\n", toString(target, opts).c_str(), toString(*constraint, opts).c_str()); + if (FFlag::DebugLuauLogSolver) + printf("%s depends on TypeId %s\n", toString(*constraint, opts).c_str(), toString(target, opts).c_str()); + } - block_(follow(target), constraint); return false; } bool ConstraintSolver::block(TypePackId target, NotNull constraint) { - if (logger) - logger->pushBlock(constraint, target); + const bool newBlock = block_(target, constraint); + if (newBlock) + { + if (logger) + logger->pushBlock(constraint, target); - if (FFlag::DebugLuauLogSolver) - printf("block TypeId %s on\t%s\n", toString(target, opts).c_str(), toString(*constraint, opts).c_str()); + if (FFlag::DebugLuauLogSolver) + printf("%s depends on TypePackId %s\n", toString(*constraint, opts).c_str(), toString(target, opts).c_str()); + } - block_(target, constraint); return false; } @@ -2083,9 +2664,9 @@ void ConstraintSolver::inheritBlocks(NotNull source, NotNullsecond) + for (const Constraint* blockedConstraint : blockedIt->second) { - block(addition, blockedConstraint); + block(addition, NotNull{blockedConstraint}); } } } @@ -2103,29 +2684,27 @@ struct Blocker : TypeOnceVisitor { } - bool visit(TypeId ty, const BlockedType&) + bool visit(TypeId ty, const PendingExpansionType&) override { blocked = true; solver->block(ty, constraint); return false; } - bool visit(TypeId ty, const PendingExpansionType&) + bool visit(TypeId ty, const ClassType&) override { - blocked = true; - solver->block(ty, constraint); return false; } }; -bool ConstraintSolver::recursiveBlock(TypeId target, NotNull constraint) +bool ConstraintSolver::blockOnPendingTypes(TypeId target, NotNull constraint) { Blocker blocker{NotNull{this}, constraint}; blocker.traverse(target); return !blocker.blocked; } -bool ConstraintSolver::recursiveBlock(TypePackId pack, NotNull constraint) +bool ConstraintSolver::blockOnPendingTypes(TypePackId pack, NotNull constraint) { Blocker blocker{NotNull{this}, constraint}; blocker.traverse(pack); @@ -2139,9 +2718,9 @@ void ConstraintSolver::unblock_(BlockedConstraintId progressed) return; // unblocked should contain a value always, because of the above check - for (NotNull unblockedConstraint : it->second) + for (const Constraint* unblockedConstraint : it->second) { - auto& count = blockedConstraints[unblockedConstraint]; + auto& count = blockedConstraints[NotNull{unblockedConstraint}]; if (FFlag::DebugLuauLogSolver) printf("Unblocking count=%d\t%s\n", int(count), toString(*unblockedConstraint, opts).c_str()); @@ -2164,18 +2743,30 @@ void ConstraintSolver::unblock(NotNull progressed) return unblock_(progressed.get()); } -void ConstraintSolver::unblock(TypeId progressed) +void ConstraintSolver::unblock(TypeId ty, Location location) { - if (logger) - logger->popBlock(progressed); + DenseHashSet seen{nullptr}; + + TypeId progressed = ty; + while (true) + { + if (seen.find(progressed)) + iceReporter.ice("ConstraintSolver::unblock encountered a self-bound type!", location); + seen.insert(progressed); - unblock_(progressed); + if (logger) + logger->popBlock(progressed); - if (auto bt = get(progressed)) - unblock(bt->boundTo); + unblock_(progressed); + + if (auto bt = get(progressed)) + progressed = bt->boundTo; + else + break; + } } -void ConstraintSolver::unblock(TypePackId progressed) +void ConstraintSolver::unblock(TypePackId progressed, Location) { if (logger) logger->popBlock(progressed); @@ -2183,70 +2774,57 @@ void ConstraintSolver::unblock(TypePackId progressed) return unblock_(progressed); } -void ConstraintSolver::unblock(const std::vector& types) +void ConstraintSolver::unblock(const std::vector& types, Location location) { for (TypeId t : types) - unblock(t); + unblock(t, location); } -void ConstraintSolver::unblock(const std::vector& packs) +void ConstraintSolver::unblock(const std::vector& packs, Location location) { for (TypePackId t : packs) - unblock(t); + unblock(t, location); } -bool ConstraintSolver::isBlocked(TypeId ty) -{ - return nullptr != get(follow(ty)) || nullptr != get(follow(ty)); -} - -bool ConstraintSolver::isBlocked(TypePackId tp) -{ - return nullptr != get(follow(tp)); -} - -bool ConstraintSolver::isBlocked(NotNull constraint) -{ - auto blockedIt = blockedConstraints.find(constraint); - return blockedIt != blockedConstraints.end() && blockedIt->second > 0; -} - -void ConstraintSolver::unify(TypeId subType, TypeId superType, NotNull scope) +void ConstraintSolver::reproduceConstraints(NotNull scope, const Location& location, const Substitution& subst) { - Unifier u{normalizer, Mode::Strict, scope, Location{}, Covariant}; - u.useScopes = true; - - u.tryUnify(subType, superType); + for (auto [_, newTy] : subst.newTypes) + { + if (get(newTy)) + pushConstraint(scope, location, ReduceConstraint{newTy}); + } - if (!u.errors.empty()) + for (auto [_, newPack] : subst.newPacks) { - TypeId errorType = errorRecoveryType(); - u.tryUnify(subType, errorType); - u.tryUnify(superType, errorType); + if (get(newPack)) + pushConstraint(scope, location, ReducePackConstraint{newPack}); } +} - const auto [changedTypes, changedPacks] = u.log.getChanges(); +bool ConstraintSolver::isBlocked(TypeId ty) +{ + ty = follow(ty); - u.log.commit(); + if (auto tfit = get(ty)) + return uninhabitedTypeFunctions.contains(ty) == false; - unblock(changedTypes); - unblock(changedPacks); + return nullptr != get(ty) || nullptr != get(ty); } -void ConstraintSolver::unify(TypePackId subPack, TypePackId superPack, NotNull scope) +bool ConstraintSolver::isBlocked(TypePackId tp) { - UnifierSharedState sharedState{&iceReporter}; - Unifier u{normalizer, Mode::Strict, scope, Location{}, Covariant}; - u.useScopes = true; + tp = follow(tp); - u.tryUnify(subPack, superPack); + if (auto tfitp = get(tp)) + return uninhabitedTypeFunctions.contains(tp) == false; - const auto [changedTypes, changedPacks] = u.log.getChanges(); - - u.log.commit(); + return nullptr != get(tp); +} - unblock(changedTypes); - unblock(changedPacks); +bool ConstraintSolver::isBlocked(NotNull constraint) +{ + auto blockedIt = blockedConstraints.find(constraint); + return blockedIt != blockedConstraints.end() && blockedIt->second > 0; } NotNull ConstraintSolver::pushConstraint(NotNull scope, const Location& location, ConstraintV cv) @@ -2269,7 +2847,7 @@ TypeId ConstraintSolver::resolveModule(const ModuleInfo& info, const Location& l for (const auto& [location, path] : requireCycles) { - if (!path.empty() && path.front() == (FFlag::LuauRequirePathTrueModuleName ? info.name : moduleResolver->getHumanReadableModuleName(info.name))) + if (!path.empty() && path.front() == info.name) return builtinTypes->anyType; } @@ -2314,6 +2892,52 @@ void ConstraintSolver::reportError(TypeError e) errors.back().moduleName = currentModuleName; } +void ConstraintSolver::shiftReferences(TypeId source, TypeId target) +{ + target = follow(target); + + // if the target isn't a reference counted type, there's nothing to do. + // this stops us from keeping unnecessary counts for e.g. primitive types. + if (!isReferenceCountedType(target)) + return; + + auto sourceRefs = unresolvedConstraints.find(source); + if (!sourceRefs) + return; + + // we read out the count before proceeding to avoid hash invalidation issues. + size_t count = *sourceRefs; + + auto [targetRefs, _] = unresolvedConstraints.try_insert(target, 0); + targetRefs += count; +} + +std::optional ConstraintSolver::generalizeFreeType(NotNull scope, TypeId type, bool avoidSealingTables) +{ + TypeId t = follow(type); + if (get(t)) + { + auto refCount = unresolvedConstraints.find(t); + if (refCount && *refCount > 0) + return {}; + + // if no reference count is present, then that means the only constraints referring to + // this free type need only for it to be generalized. in principle, this means we could + // have actually never generated the free type in the first place, but we couldn't know + // that until all constraint generation is complete. + } + + return generalize(NotNull{arena}, builtinTypes, scope, generalizedTypes, type, avoidSealingTables); +} + +bool ConstraintSolver::hasUnresolvedConstraints(TypeId ty) +{ + if (auto refCount = unresolvedConstraints.find(ty)) + return *refCount > 0; + + return false; +} + TypeId ConstraintSolver::errorRecoveryType() const { return builtinTypes->errorRecoveryType(); @@ -2324,39 +2948,48 @@ TypePackId ConstraintSolver::errorRecoveryTypePack() const return builtinTypes->errorRecoveryTypePack(); } -TypeId ConstraintSolver::unionOfTypes(TypeId a, TypeId b, NotNull scope, bool unifyFreeTypes) +TypePackId ConstraintSolver::anyifyModuleReturnTypePackGenerics(TypePackId tp) { - a = follow(a); - b = follow(b); + tp = follow(tp); - if (unifyFreeTypes && (get(a) || get(b))) + if (const VariadicTypePack* vtp = get(tp)) { - Unifier u{normalizer, Mode::Strict, scope, Location{}, Covariant}; - u.useScopes = true; - u.tryUnify(b, a); + TypeId ty = follow(vtp->ty); + return get(ty) ? builtinTypes->anyTypePack : tp; + } - if (u.errors.empty()) - { - u.log.commit(); - return a; - } - else - { - return builtinTypes->errorRecoveryType(builtinTypes->anyType); - } + if (!get(follow(tp))) + return tp; + + std::vector resultTypes; + std::optional resultTail; + + TypePackIterator it = begin(tp); + + for (TypePackIterator e = end(tp); it != e; ++it) + { + TypeId ty = follow(*it); + resultTypes.push_back(get(ty) ? builtinTypes->anyType : ty); } - if (*a == *b) - return a; + if (std::optional tail = it.tail()) + resultTail = anyifyModuleReturnTypePackGenerics(*tail); - std::vector types = reduceUnion({a, b}); - if (types.empty()) - return builtinTypes->neverType; + return arena->addTypePack(resultTypes, resultTail); +} - if (types.size() == 1) - return types[0]; +LUAU_NOINLINE void ConstraintSolver::throwTimeLimitError() +{ + throw TimeLimitError(currentModuleName); +} - return arena->addType(UnionType{types}); +LUAU_NOINLINE void ConstraintSolver::throwUserCancelError() +{ + throw UserCancelError(currentModuleName); } +// Instantiate private template implementations for external callers +template bool ConstraintSolver::unify(NotNull constraint, TypeId subTy, TypeId superTy); +template bool ConstraintSolver::unify(NotNull constraint, TypePackId subTy, TypePackId superTy); + } // namespace Luau diff --git a/third_party/luau/Analysis/src/DataFlowGraph.cpp b/third_party/luau/Analysis/src/DataFlowGraph.cpp index e73c7e8c..57e45c3e 100644 --- a/third_party/luau/Analysis/src/DataFlowGraph.cpp +++ b/third_party/luau/Analysis/src/DataFlowGraph.cpp @@ -1,9 +1,13 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/DataFlowGraph.h" -#include "Luau/Breadcrumb.h" +#include "Luau/Ast.h" +#include "Luau/Def.h" +#include "Luau/Common.h" #include "Luau/Error.h" -#include "Luau/Refinement.h" +#include "Luau/TimeTrace.h" + +#include LUAU_FASTFLAG(DebugLuauFreezeArena) LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) @@ -11,112 +15,332 @@ LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) namespace Luau { -NullableBreadcrumbId DataFlowGraph::getBreadcrumb(const AstExpr* expr) const +bool doesCallError(const AstExprCall* call); // TypeInfer.cpp + +const RefinementKey* RefinementKeyArena::leaf(DefId def) { - // We need to skip through AstExprGroup because DFG doesn't try its best to transitively - while (auto group = expr->as()) - expr = group->expr; - if (auto bc = astBreadcrumbs.find(expr)) - return *bc; - return nullptr; + return allocator.allocate(RefinementKey{nullptr, def, std::nullopt}); +} + +const RefinementKey* RefinementKeyArena::node(const RefinementKey* parent, DefId def, const std::string& propName) +{ + return allocator.allocate(RefinementKey{parent, def, propName}); } -BreadcrumbId DataFlowGraph::getBreadcrumb(const AstLocal* local) const +DefId DataFlowGraph::getDef(const AstExpr* expr) const { - auto bc = localBreadcrumbs.find(local); - LUAU_ASSERT(bc); - return NotNull{*bc}; + auto def = astDefs.find(expr); + LUAU_ASSERT(def); + return NotNull{*def}; } -BreadcrumbId DataFlowGraph::getBreadcrumb(const AstExprLocal* local) const +std::optional DataFlowGraph::getRValueDefForCompoundAssign(const AstExpr* expr) const { - auto bc = astBreadcrumbs.find(local); - LUAU_ASSERT(bc); - return NotNull{*bc}; + auto def = compoundAssignDefs.find(expr); + return def ? std::optional(*def) : std::nullopt; } -BreadcrumbId DataFlowGraph::getBreadcrumb(const AstExprGlobal* global) const +DefId DataFlowGraph::getDef(const AstLocal* local) const { - auto bc = astBreadcrumbs.find(global); - LUAU_ASSERT(bc); - return NotNull{*bc}; + auto def = localDefs.find(local); + LUAU_ASSERT(def); + return NotNull{*def}; } -BreadcrumbId DataFlowGraph::getBreadcrumb(const AstStatDeclareGlobal* global) const +DefId DataFlowGraph::getDef(const AstStatDeclareGlobal* global) const { - auto bc = declaredBreadcrumbs.find(global); - LUAU_ASSERT(bc); - return NotNull{*bc}; + auto def = declaredDefs.find(global); + LUAU_ASSERT(def); + return NotNull{*def}; } -BreadcrumbId DataFlowGraph::getBreadcrumb(const AstStatDeclareFunction* func) const +DefId DataFlowGraph::getDef(const AstStatDeclareFunction* func) const { - auto bc = declaredBreadcrumbs.find(func); - LUAU_ASSERT(bc); - return NotNull{*bc}; + auto def = declaredDefs.find(func); + LUAU_ASSERT(def); + return NotNull{*def}; +} + +const RefinementKey* DataFlowGraph::getRefinementKey(const AstExpr* expr) const +{ + if (auto key = astRefinementKeys.find(expr)) + return *key; + + return nullptr; } -NullableBreadcrumbId DfgScope::lookup(Symbol symbol) const +std::optional DfgScope::lookup(Symbol symbol) const { for (const DfgScope* current = this; current; current = current->parent) { - if (auto breadcrumb = current->bindings.find(symbol)) - return *breadcrumb; + if (auto def = current->bindings.find(symbol)) + return NotNull{*def}; } - return nullptr; + return std::nullopt; } -NullableBreadcrumbId DfgScope::lookup(DefId def, const std::string& key) const +std::optional DfgScope::lookup(DefId def, const std::string& key) const { for (const DfgScope* current = this; current; current = current->parent) { - if (auto map = props.find(def)) + if (auto props = current->props.find(def)) { - if (auto it = map->find(key); it != map->end()) - return it->second; + if (auto it = props->find(key); it != props->end()) + return NotNull{it->second}; } } - return nullptr; + return std::nullopt; +} + +void DfgScope::inherit(const DfgScope* childScope) +{ + for (const auto& [k, a] : childScope->bindings) + { + if (lookup(k)) + bindings[k] = a; + } + + for (const auto& [k1, a1] : childScope->props) + { + for (const auto& [k2, a2] : a1) + props[k1][k2] = a2; + } +} + +bool DfgScope::canUpdateDefinition(Symbol symbol) const +{ + for (const DfgScope* current = this; current; current = current->parent) + { + if (current->bindings.find(symbol)) + return true; + else if (current->scopeType == DfgScope::Loop) + return false; + } + + return true; +} + +bool DfgScope::canUpdateDefinition(DefId def, const std::string& key) const +{ + for (const DfgScope* current = this; current; current = current->parent) + { + if (auto props = current->props.find(def)) + return true; + else if (current->scopeType == DfgScope::Loop) + return false; + } + + return true; } DataFlowGraph DataFlowGraphBuilder::build(AstStatBlock* block, NotNull handle) { + LUAU_TIMETRACE_SCOPE("DataFlowGraphBuilder::build", "Typechecking"); + LUAU_ASSERT(FFlag::DebugLuauDeferredConstraintResolution); DataFlowGraphBuilder builder; builder.handle = handle; builder.moduleScope = builder.childScope(nullptr); // nullptr is the root DFG scope. builder.visitBlockWithoutChildScope(builder.moduleScope, block); + builder.resolveCaptures(); if (FFlag::DebugLuauFreezeArena) { - builder.defs->allocator.freeze(); - builder.breadcrumbs->allocator.freeze(); + builder.defArena->allocator.freeze(); + builder.keyArena->allocator.freeze(); } return std::move(builder.graph); } -DfgScope* DataFlowGraphBuilder::childScope(DfgScope* scope) +void DataFlowGraphBuilder::resolveCaptures() +{ + for (const auto& [_, capture] : captures) + { + std::vector operands; + for (size_t i = capture.versionOffset; i < capture.allVersions.size(); ++i) + collectOperands(capture.allVersions[i], &operands); + + for (DefId captureDef : capture.captureDefs) + { + Phi* phi = const_cast(get(captureDef)); + LUAU_ASSERT(phi); + LUAU_ASSERT(phi->operands.empty()); + phi->operands = operands; + } + } +} + +DfgScope* DataFlowGraphBuilder::childScope(DfgScope* scope, DfgScope::ScopeType scopeType) { - return scopes.emplace_back(new DfgScope{scope}).get(); + return scopes.emplace_back(new DfgScope{scope, scopeType}).get(); +} + +void DataFlowGraphBuilder::join(DfgScope* p, DfgScope* a, DfgScope* b) +{ + joinBindings(p, *a, *b); + joinProps(p, *a, *b); +} + +void DataFlowGraphBuilder::joinBindings(DfgScope* p, const DfgScope& a, const DfgScope& b) +{ + for (const auto& [sym, def1] : a.bindings) + { + if (auto def2 = b.bindings.find(sym)) + p->bindings[sym] = defArena->phi(NotNull{def1}, NotNull{*def2}); + else if (auto def2 = p->lookup(sym)) + p->bindings[sym] = defArena->phi(NotNull{def1}, NotNull{*def2}); + } + + for (const auto& [sym, def1] : b.bindings) + { + if (auto def2 = p->lookup(sym)) + p->bindings[sym] = defArena->phi(NotNull{def1}, NotNull{*def2}); + } +} + +void DataFlowGraphBuilder::joinProps(DfgScope* result, const DfgScope& a, const DfgScope& b) +{ + auto phinodify = [this](DfgScope* scope, const auto& a, const auto& b, DefId parent) mutable + { + auto& p = scope->props[parent]; + for (const auto& [k, defA] : a) + { + if (auto it = b.find(k); it != b.end()) + p[k] = defArena->phi(NotNull{it->second}, NotNull{defA}); + else if (auto it = p.find(k); it != p.end()) + p[k] = defArena->phi(NotNull{it->second}, NotNull{defA}); + else if (auto def2 = scope->lookup(parent, k)) + p[k] = defArena->phi(*def2, NotNull{defA}); + else + p[k] = defA; + } + + for (const auto& [k, defB] : b) + { + if (auto it = a.find(k); it != a.end()) + continue; + else if (auto it = p.find(k); it != p.end()) + p[k] = defArena->phi(NotNull{it->second}, NotNull{defB}); + else if (auto def2 = scope->lookup(parent, k)) + p[k] = defArena->phi(*def2, NotNull{defB}); + else + p[k] = defB; + } + }; + + for (const auto& [def, a1] : a.props) + { + result->props.try_insert(def, {}); + if (auto a2 = b.props.find(def)) + phinodify(result, a1, *a2, NotNull{def}); + else if (auto a2 = result->props.find(def)) + phinodify(result, a1, *a2, NotNull{def}); + } + + for (const auto& [def, a1] : b.props) + { + result->props.try_insert(def, {}); + if (a.props.find(def)) + continue; + else if (auto a2 = result->props.find(def)) + phinodify(result, a1, *a2, NotNull{def}); + } +} + +DefId DataFlowGraphBuilder::lookup(DfgScope* scope, Symbol symbol) +{ + // true if any of the considered scopes are a loop. + bool outsideLoopScope = false; + for (DfgScope* current = scope; current; current = current->parent) + { + outsideLoopScope = outsideLoopScope || current->scopeType == DfgScope::Loop; + + if (auto found = current->bindings.find(symbol)) + return NotNull{*found}; + else if (current->scopeType == DfgScope::Function) + { + FunctionCapture& capture = captures[symbol]; + DefId captureDef = defArena->phi({}); + capture.captureDefs.push_back(captureDef); + + // If we are outside of a loop scope, then we don't want to actually bind + // uses of `symbol` to this new phi node since it will not get populated. + if (!outsideLoopScope) + scope->bindings[symbol] = captureDef; + + return NotNull{captureDef}; + } + } + + DefId result = defArena->freshCell(); + scope->bindings[symbol] = result; + captures[symbol].allVersions.push_back(result); + return result; +} + +DefId DataFlowGraphBuilder::lookup(DfgScope* scope, DefId def, const std::string& key) +{ + for (DfgScope* current = scope; current; current = current->parent) + { + if (auto props = current->props.find(def)) + { + if (auto it = props->find(key); it != props->end()) + return NotNull{it->second}; + } + else if (auto phi = get(def); phi && phi->operands.empty()) // Unresolved phi nodes + { + DefId result = defArena->freshCell(); + scope->props[def][key] = result; + return result; + } + } + + if (auto phi = get(def)) + { + std::vector defs; + for (DefId operand : phi->operands) + defs.push_back(lookup(scope, operand, key)); + + DefId result = defArena->phi(defs); + scope->props[def][key] = result; + return result; + } + else if (get(def)) + { + DefId result = defArena->freshCell(); + scope->props[def][key] = result; + return result; + } + else + handle->ice("Inexhaustive lookup cases in DataFlowGraphBuilder::lookup"); } -void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatBlock* b) +ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatBlock* b) { DfgScope* child = childScope(scope); - return visitBlockWithoutChildScope(child, b); + ControlFlow cf = visitBlockWithoutChildScope(child, b); + scope->inherit(child); + return cf; } -void DataFlowGraphBuilder::visitBlockWithoutChildScope(DfgScope* scope, AstStatBlock* b) +ControlFlow DataFlowGraphBuilder::visitBlockWithoutChildScope(DfgScope* scope, AstStatBlock* b) { - for (AstStat* s : b->body) - visit(scope, s); + std::optional firstControlFlow; + for (AstStat* stat : b->body) + { + ControlFlow cf = visit(scope, stat); + if (cf != ControlFlow::None && !firstControlFlow) + firstControlFlow = cf; + } + + return firstControlFlow.value_or(ControlFlow::None); } -void DataFlowGraphBuilder::visit(DfgScope* scope, AstStat* s) +ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStat* s) { if (auto b = s->as()) return visit(scope, b); @@ -150,6 +374,8 @@ void DataFlowGraphBuilder::visit(DfgScope* scope, AstStat* s) return visit(scope, l); else if (auto t = s->as()) return visit(scope, t); + else if (auto f = s->as()) + return visit(scope, f); else if (auto d = s->as()) return visit(scope, d); else if (auto d = s->as()) @@ -162,62 +388,91 @@ void DataFlowGraphBuilder::visit(DfgScope* scope, AstStat* s) handle->ice("Unknown AstStat in DataFlowGraphBuilder::visit"); } -void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatIf* i) +ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatIf* i) { - // TODO: type states and control flow analysis visitExpr(scope, i->condition); - visit(scope, i->thenbody); + + DfgScope* thenScope = childScope(scope); + DfgScope* elseScope = childScope(scope); + + ControlFlow thencf = visit(thenScope, i->thenbody); + ControlFlow elsecf = ControlFlow::None; if (i->elsebody) - visit(scope, i->elsebody); + elsecf = visit(elseScope, i->elsebody); + + if (thencf != ControlFlow::None && elsecf == ControlFlow::None) + join(scope, scope, elseScope); + else if (thencf == ControlFlow::None && elsecf != ControlFlow::None) + join(scope, thenScope, scope); + else if ((thencf | elsecf) == ControlFlow::None) + join(scope, thenScope, elseScope); + + if (thencf == elsecf) + return thencf; + else if (matches(thencf, ControlFlow::Returns | ControlFlow::Throws) && matches(elsecf, ControlFlow::Returns | ControlFlow::Throws)) + return ControlFlow::Returns; + else + return ControlFlow::None; } -void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatWhile* w) +ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatWhile* w) { // TODO(controlflow): entry point has a back edge from exit point - DfgScope* whileScope = childScope(scope); + DfgScope* whileScope = childScope(scope, DfgScope::Loop); visitExpr(whileScope, w->condition); visit(whileScope, w->body); + + scope->inherit(whileScope); + + return ControlFlow::None; } -void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatRepeat* r) +ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatRepeat* r) { // TODO(controlflow): entry point has a back edge from exit point - DfgScope* repeatScope = childScope(scope); // TODO: loop scope. + DfgScope* repeatScope = childScope(scope, DfgScope::Loop); visitBlockWithoutChildScope(repeatScope, r->body); visitExpr(repeatScope, r->condition); + + scope->inherit(repeatScope); + + return ControlFlow::None; } -void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatBreak* b) +ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatBreak* b) { - // TODO: Control flow analysis - return; // ok + return ControlFlow::Breaks; } -void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatContinue* c) +ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatContinue* c) { - // TODO: Control flow analysis - return; // ok + return ControlFlow::Continues; } -void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatReturn* r) +ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatReturn* r) { - // TODO: Control flow analysis for (AstExpr* e : r->list) visitExpr(scope, e); + + return ControlFlow::Returns; } -void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatExpr* e) +ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatExpr* e) { visitExpr(scope, e->expr); + if (auto call = e->expr->as(); call && doesCallError(call)) + return ControlFlow::Throws; + else + return ControlFlow::None; } -void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatLocal* l) +ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatLocal* l) { // We're gonna need a `visitExprList` and `visitVariadicExpr` (function calls and `...`) - std::vector bcs; - bcs.reserve(l->values.size); + std::vector defs; + defs.reserve(l->values.size); for (AstExpr* e : l->values) - bcs.push_back(visitExpr(scope, e)); + defs.push_back(visitExpr(scope, e).def); for (size_t i = 0; i < l->vars.size; ++i) { @@ -225,16 +480,29 @@ void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatLocal* l) if (local->annotation) visitType(scope, local->annotation); - // We need to create a new breadcrumb with new defs to intentionally avoid alias tracking. - BreadcrumbId bc = breadcrumbs->add(nullptr, defs->freshCell(), i < bcs.size() ? bcs[i]->metadata : std::nullopt); - graph.localBreadcrumbs[local] = bc; - scope->bindings[local] = bc; + // We need to create a new def to intentionally avoid alias tracking, but we'd like to + // make sure that the non-aliased defs are also marked as a subscript for refinements. + bool subscripted = i < defs.size() && containsSubscriptedDefinition(defs[i]); + DefId def = defArena->freshCell(subscripted); + if (i < l->values.size) + { + AstExpr* e = l->values.data[i]; + if (const AstExprTable* tbl = e->as()) + { + def = defs[i]; + } + } + graph.localDefs[local] = def; + scope->bindings[local] = def; + captures[local].allVersions.push_back(def); } + + return ControlFlow::None; } -void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatFor* f) +ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatFor* f) { - DfgScope* forScope = childScope(scope); // TODO: loop scope. + DfgScope* forScope = childScope(scope, DfgScope::Loop); visitExpr(scope, f->from); visitExpr(scope, f->to); @@ -244,28 +512,32 @@ void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatFor* f) if (f->var->annotation) visitType(forScope, f->var->annotation); - // TODO: RangeMetadata. - BreadcrumbId bc = breadcrumbs->add(nullptr, defs->freshCell()); - graph.localBreadcrumbs[f->var] = bc; - scope->bindings[f->var] = bc; + DefId def = defArena->freshCell(); + graph.localDefs[f->var] = def; + scope->bindings[f->var] = def; + captures[f->var].allVersions.push_back(def); // TODO(controlflow): entry point has a back edge from exit point visit(forScope, f->body); + + scope->inherit(forScope); + + return ControlFlow::None; } -void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatForIn* f) +ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatForIn* f) { - DfgScope* forScope = childScope(scope); // TODO: loop scope. + DfgScope* forScope = childScope(scope, DfgScope::Loop); for (AstLocal* local : f->vars) { if (local->annotation) visitType(forScope, local->annotation); - // TODO: IterMetadata (different from RangeMetadata) - BreadcrumbId bc = breadcrumbs->add(nullptr, defs->freshCell()); - graph.localBreadcrumbs[local] = bc; - forScope->bindings[local] = bc; + DefId def = defArena->freshCell(); + graph.localDefs[local] = def; + forScope->bindings[local] = def; + captures[local].allVersions.push_back(def); } // TODO(controlflow): entry point has a back edge from exit point @@ -274,32 +546,44 @@ void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatForIn* f) visitExpr(forScope, e); visit(forScope, f->body); + + scope->inherit(forScope); + + return ControlFlow::None; } -void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatAssign* a) +ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatAssign* a) { - for (AstExpr* r : a->values) - visitExpr(scope, r); + std::vector defs; + defs.reserve(a->values.size); + for (AstExpr* e : a->values) + defs.push_back(visitExpr(scope, e).def); - for (AstExpr* l : a->vars) - visitLValue(scope, l); + for (size_t i = 0; i < a->vars.size; ++i) + { + AstExpr* v = a->vars.data[i]; + visitLValue(scope, v, i < defs.size() ? defs[i] : defArena->freshCell()); + } + + return ControlFlow::None; } -void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatCompoundAssign* c) +ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatCompoundAssign* c) { // TODO: This needs revisiting because this is incorrect. The `c->var` part is both being read and written to, // but the `c->var` only has one pointer address, so we need to come up with a way to store both. // For now, it's not important because we don't have type states, but it is going to be important, e.g. // - // local a = 5 -- a[1] - // a += 5 -- a[2] = a[1] + 5 - // + // local a = 5 -- a-1 + // a += 5 -- a-2 = a-1 + 5 // We can't just visit `c->var` as a rvalue and then separately traverse `c->var` as an lvalue, since that's O(n^2). - visitLValue(scope, c->var); - visitExpr(scope, c->value); + DefId def = visitExpr(scope, c->value).def; + visitLValue(scope, c->var, def, /* isCompoundAssignment */ true); + + return ControlFlow::None; } -void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatFunction* f) +ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatFunction* f) { // In the old solver, we assumed that the name of the function is always a function in the body // but this isn't true, e.g. the following example will print `5`, not a function address. @@ -311,52 +595,82 @@ void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatFunction* f) // // which is evidence that references to variables must be a phi node of all possible definitions, // but for bug compatibility, we'll assume the same thing here. - visitLValue(scope, f->name); + visitLValue(scope, f->name, defArena->freshCell()); visitExpr(scope, f->func); + + if (auto local = f->name->as()) + { + // local f + // function f() + // if cond() then + // f() -- should reference only the function version and other future version, and nothing prior + // end + // end + FunctionCapture& capture = captures[local->local]; + capture.versionOffset = capture.allVersions.size() - 1; + } + + return ControlFlow::None; } -void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatLocalFunction* l) +ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatLocalFunction* l) { - BreadcrumbId bc = breadcrumbs->add(nullptr, defs->freshCell()); - graph.localBreadcrumbs[l->name] = bc; - scope->bindings[l->name] = bc; - + DefId def = defArena->freshCell(); + graph.localDefs[l->name] = def; + scope->bindings[l->name] = def; + captures[l->name].allVersions.push_back(def); visitExpr(scope, l->func); + + return ControlFlow::None; } -void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatTypeAlias* t) +ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatTypeAlias* t) { DfgScope* unreachable = childScope(scope); visitGenerics(unreachable, t->generics); visitGenericPacks(unreachable, t->genericPacks); visitType(unreachable, t->type); + + return ControlFlow::None; +} + +ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatTypeFunction* f) +{ + DfgScope* unreachable = childScope(scope); + visitExpr(unreachable, f->body); + + return ControlFlow::None; } -void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatDeclareGlobal* d) +ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatDeclareGlobal* d) { - // TODO: AmbientDeclarationMetadata. - BreadcrumbId bc = breadcrumbs->add(nullptr, defs->freshCell()); - graph.declaredBreadcrumbs[d] = bc; - scope->bindings[d->name] = bc; + DefId def = defArena->freshCell(); + graph.declaredDefs[d] = def; + scope->bindings[d->name] = def; + captures[d->name].allVersions.push_back(def); visitType(scope, d->type); + + return ControlFlow::None; } -void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatDeclareFunction* d) +ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatDeclareFunction* d) { - // TODO: AmbientDeclarationMetadata. - BreadcrumbId bc = breadcrumbs->add(nullptr, defs->freshCell()); - graph.declaredBreadcrumbs[d] = bc; - scope->bindings[d->name] = bc; + DefId def = defArena->freshCell(); + graph.declaredDefs[d] = def; + scope->bindings[d->name] = def; + captures[d->name].allVersions.push_back(def); DfgScope* unreachable = childScope(scope); visitGenerics(unreachable, d->generics); visitGenericPacks(unreachable, d->genericPacks); visitTypeList(unreachable, d->params); visitTypeList(unreachable, d->retTypes); + + return ControlFlow::None; } -void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatDeclareClass* d) +ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatDeclareClass* d) { // This declaration does not "introduce" any bindings in value namespace, // so there's no symbolic value to begin with. We'll traverse the properties @@ -364,139 +678,149 @@ void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatDeclareClass* d) DfgScope* unreachable = childScope(scope); for (AstDeclaredClassProp prop : d->props) visitType(unreachable, prop.ty); + + return ControlFlow::None; } -void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatError* error) +ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatError* error) { DfgScope* unreachable = childScope(scope); for (AstStat* s : error->statements) visit(unreachable, s); for (AstExpr* e : error->expressions) visitExpr(unreachable, e); -} -BreadcrumbId DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExpr* e) -{ - if (auto g = e->as()) - return visitExpr(scope, g->expr); - else if (auto c = e->as()) - return breadcrumbs->add(nullptr, defs->freshCell()); // ok - else if (auto c = e->as()) - return breadcrumbs->add(nullptr, defs->freshCell()); // ok - else if (auto c = e->as()) - return breadcrumbs->add(nullptr, defs->freshCell()); // ok - else if (auto c = e->as()) - return breadcrumbs->add(nullptr, defs->freshCell()); // ok - else if (auto l = e->as()) - return visitExpr(scope, l); - else if (auto g = e->as()) - return visitExpr(scope, g); - else if (auto v = e->as()) - return breadcrumbs->add(nullptr, defs->freshCell()); // ok - else if (auto c = e->as()) - return visitExpr(scope, c); - else if (auto i = e->as()) - return visitExpr(scope, i); - else if (auto i = e->as()) - return visitExpr(scope, i); - else if (auto f = e->as()) - return visitExpr(scope, f); - else if (auto t = e->as()) - return visitExpr(scope, t); - else if (auto u = e->as()) - return visitExpr(scope, u); - else if (auto b = e->as()) - return visitExpr(scope, b); - else if (auto t = e->as()) - return visitExpr(scope, t); - else if (auto i = e->as()) - return visitExpr(scope, i); - else if (auto i = e->as()) - return visitExpr(scope, i); - else if (auto error = e->as()) - return visitExpr(scope, error); - else - handle->ice("Unknown AstExpr in DataFlowGraphBuilder::visitExpr"); + return ControlFlow::None; } -BreadcrumbId DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprLocal* l) +DataFlowResult DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExpr* e) { - NullableBreadcrumbId breadcrumb = scope->lookup(l->local); - if (!breadcrumb) - handle->ice("AstExprLocal came before its declaration?"); + // Some subexpressions could be visited two times. If we've already seen it, just extract it. + if (auto def = graph.astDefs.find(e)) + { + auto key = graph.astRefinementKeys.find(e); + return {NotNull{*def}, key ? *key : nullptr}; + } + + auto go = [&]() -> DataFlowResult + { + if (auto g = e->as()) + return visitExpr(scope, g); + else if (auto c = e->as()) + return {defArena->freshCell(), nullptr}; // ok + else if (auto c = e->as()) + return {defArena->freshCell(), nullptr}; // ok + else if (auto c = e->as()) + return {defArena->freshCell(), nullptr}; // ok + else if (auto c = e->as()) + return {defArena->freshCell(), nullptr}; // ok + else if (auto l = e->as()) + return visitExpr(scope, l); + else if (auto g = e->as()) + return visitExpr(scope, g); + else if (auto v = e->as()) + return {defArena->freshCell(), nullptr}; // ok + else if (auto c = e->as()) + return visitExpr(scope, c); + else if (auto i = e->as()) + return visitExpr(scope, i); + else if (auto i = e->as()) + return visitExpr(scope, i); + else if (auto f = e->as()) + return visitExpr(scope, f); + else if (auto t = e->as()) + return visitExpr(scope, t); + else if (auto u = e->as()) + return visitExpr(scope, u); + else if (auto b = e->as()) + return visitExpr(scope, b); + else if (auto t = e->as()) + return visitExpr(scope, t); + else if (auto i = e->as()) + return visitExpr(scope, i); + else if (auto i = e->as()) + return visitExpr(scope, i); + else if (auto error = e->as()) + return visitExpr(scope, error); + else + handle->ice("Unknown AstExpr in DataFlowGraphBuilder::visitExpr"); + }; - graph.astBreadcrumbs[l] = breadcrumb; - return NotNull{breadcrumb}; + auto [def, key] = go(); + graph.astDefs[e] = def; + if (key) + graph.astRefinementKeys[e] = key; + return {def, key}; } -BreadcrumbId DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprGlobal* g) +DataFlowResult DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprGroup* group) { - NullableBreadcrumbId bc = scope->lookup(g->name); - if (!bc) - { - bc = breadcrumbs->add(nullptr, defs->freshCell()); - moduleScope->bindings[g->name] = bc; - } + return visitExpr(scope, group->expr); +} - graph.astBreadcrumbs[g] = bc; - return NotNull{bc}; +DataFlowResult DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprLocal* l) +{ + DefId def = lookup(scope, l->local); + const RefinementKey* key = keyArena->leaf(def); + return {def, key}; } -BreadcrumbId DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprCall* c) +DataFlowResult DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprGlobal* g) +{ + DefId def = lookup(scope, g->name); + return {def, keyArena->leaf(def)}; +} + +DataFlowResult DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprCall* c) { visitExpr(scope, c->func); for (AstExpr* arg : c->args) visitExpr(scope, arg); - return breadcrumbs->add(nullptr, defs->freshCell()); + // calls should be treated as subscripted. + return {defArena->freshCell(/* subscripted */ true), nullptr}; } -BreadcrumbId DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprIndexName* i) +DataFlowResult DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprIndexName* i) { - BreadcrumbId parentBreadcrumb = visitExpr(scope, i->expr); + auto [parentDef, parentKey] = visitExpr(scope, i->expr); - std::string key = i->index.value; - NullableBreadcrumbId& propBreadcrumb = moduleScope->props[parentBreadcrumb->def][key]; - if (!propBreadcrumb) - propBreadcrumb = breadcrumbs->emplace(parentBreadcrumb, defs->freshCell(), key); + std::string index = i->index.value; - graph.astBreadcrumbs[i] = propBreadcrumb; - return NotNull{propBreadcrumb}; + DefId def = lookup(scope, parentDef, index); + return {def, keyArena->node(parentKey, def, index)}; } -BreadcrumbId DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprIndexExpr* i) +DataFlowResult DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprIndexExpr* i) { - BreadcrumbId parentBreadcrumb = visitExpr(scope, i->expr); - BreadcrumbId key = visitExpr(scope, i->index); + auto [parentDef, parentKey] = visitExpr(scope, i->expr); + visitExpr(scope, i->index); if (auto string = i->index->as()) { - std::string key{string->value.data, string->value.size}; - NullableBreadcrumbId& propBreadcrumb = moduleScope->props[parentBreadcrumb->def][key]; - if (!propBreadcrumb) - propBreadcrumb = breadcrumbs->emplace(parentBreadcrumb, defs->freshCell(), key); + std::string index{string->value.data, string->value.size}; - graph.astBreadcrumbs[i] = NotNull{propBreadcrumb}; - return NotNull{propBreadcrumb}; + DefId def = lookup(scope, parentDef, index); + return {def, keyArena->node(parentKey, def, index)}; } - return breadcrumbs->emplace(nullptr, defs->freshCell(), key); + return {defArena->freshCell(/* subscripted= */ true), nullptr}; } -BreadcrumbId DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprFunction* f) +DataFlowResult DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprFunction* f) { - DfgScope* signatureScope = childScope(scope); + DfgScope* signatureScope = childScope(scope, DfgScope::Function); if (AstLocal* self = f->self) { // There's no syntax for `self` to have an annotation if using `function t:m()` LUAU_ASSERT(!self->annotation); - // TODO: ParameterMetadata. - BreadcrumbId bc = breadcrumbs->add(nullptr, defs->freshCell()); - graph.localBreadcrumbs[self] = bc; - signatureScope->bindings[self] = bc; + DefId def = defArena->freshCell(); + graph.localDefs[self] = def; + signatureScope->bindings[self] = def; + captures[self].allVersions.push_back(def); } for (AstLocal* param : f->args) @@ -504,10 +828,10 @@ BreadcrumbId DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprFunction* f if (param->annotation) visitType(signatureScope, param->annotation); - // TODO: ParameterMetadata. - BreadcrumbId bc = breadcrumbs->add(nullptr, defs->freshCell()); - graph.localBreadcrumbs[param] = bc; - signatureScope->bindings[param] = bc; + DefId def = defArena->freshCell(); + graph.localDefs[param] = def; + signatureScope->bindings[param] = def; + captures[param].allVersions.push_back(def); } if (f->varargAnnotation) @@ -526,144 +850,176 @@ BreadcrumbId DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprFunction* f // g() --> 5 visit(signatureScope, f->body); - return breadcrumbs->add(nullptr, defs->freshCell()); + return {defArena->freshCell(), nullptr}; } -BreadcrumbId DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprTable* t) +DataFlowResult DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprTable* t) { + DefId tableCell = defArena->freshCell(); + scope->props[tableCell] = {}; for (AstExprTable::Item item : t->items) { + DataFlowResult result = visitExpr(scope, item.value); if (item.key) + { visitExpr(scope, item.key); - visitExpr(scope, item.value); + if (auto string = item.key->as()) + scope->props[tableCell][string->value.data] = result.def; + } } - return breadcrumbs->add(nullptr, defs->freshCell()); + return {tableCell, nullptr}; } -BreadcrumbId DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprUnary* u) +DataFlowResult DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprUnary* u) { visitExpr(scope, u->expr); - return breadcrumbs->add(nullptr, defs->freshCell()); + return {defArena->freshCell(), nullptr}; } -BreadcrumbId DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprBinary* b) +DataFlowResult DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprBinary* b) { visitExpr(scope, b->left); visitExpr(scope, b->right); - return breadcrumbs->add(nullptr, defs->freshCell()); + return {defArena->freshCell(), nullptr}; } -BreadcrumbId DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprTypeAssertion* t) +DataFlowResult DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprTypeAssertion* t) { - // TODO: TypeAssertionMetadata? - BreadcrumbId bc = visitExpr(scope, t->expr); + auto [def, key] = visitExpr(scope, t->expr); visitType(scope, t->annotation); - return bc; + return {def, key}; } -BreadcrumbId DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprIfElse* i) +DataFlowResult DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprIfElse* i) { visitExpr(scope, i->condition); visitExpr(scope, i->trueExpr); visitExpr(scope, i->falseExpr); - return breadcrumbs->add(nullptr, defs->freshCell()); + return {defArena->freshCell(), nullptr}; } -BreadcrumbId DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprInterpString* i) +DataFlowResult DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprInterpString* i) { for (AstExpr* e : i->expressions) visitExpr(scope, e); - return breadcrumbs->add(nullptr, defs->freshCell()); + return {defArena->freshCell(), nullptr}; } -BreadcrumbId DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprError* error) +DataFlowResult DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprError* error) { DfgScope* unreachable = childScope(scope); for (AstExpr* e : error->expressions) visitExpr(unreachable, e); - return breadcrumbs->add(nullptr, defs->freshCell()); + return {defArena->freshCell(), nullptr}; } -void DataFlowGraphBuilder::visitLValue(DfgScope* scope, AstExpr* e) +void DataFlowGraphBuilder::visitLValue(DfgScope* scope, AstExpr* e, DefId incomingDef, bool isCompoundAssignment) { - if (auto l = e->as()) - return visitLValue(scope, l); - else if (auto g = e->as()) - return visitLValue(scope, g); - else if (auto i = e->as()) - return visitLValue(scope, i); - else if (auto i = e->as()) - return visitLValue(scope, i); - else if (auto error = e->as()) + auto go = [&]() { - visitExpr(scope, error); // TODO: is this right? - return; - } - else - handle->ice("Unknown AstExpr in DataFlowGraphBuilder::visitLValue"); + if (auto l = e->as()) + return visitLValue(scope, l, incomingDef, isCompoundAssignment); + else if (auto g = e->as()) + return visitLValue(scope, g, incomingDef, isCompoundAssignment); + else if (auto i = e->as()) + return visitLValue(scope, i, incomingDef); + else if (auto i = e->as()) + return visitLValue(scope, i, incomingDef); + else if (auto error = e->as()) + return visitLValue(scope, error, incomingDef); + else + handle->ice("Unknown AstExpr in DataFlowGraphBuilder::visitLValue"); + }; + + graph.astDefs[e] = go(); } -void DataFlowGraphBuilder::visitLValue(DfgScope* scope, AstExprLocal* l) +DefId DataFlowGraphBuilder::visitLValue(DfgScope* scope, AstExprLocal* l, DefId incomingDef, bool isCompoundAssignment) { - // Bug compatibility: we don't support type states yet, so we need to do this. - NullableBreadcrumbId bc = scope->lookup(l->local); - LUAU_ASSERT(bc); + // We need to keep the previous def around for a compound assignment. + if (isCompoundAssignment) + { + DefId def = lookup(scope, l->local); + graph.compoundAssignDefs[l] = def; + } - graph.astBreadcrumbs[l] = bc; - scope->bindings[l->local] = bc; + // In order to avoid alias tracking, we need to clip the reference to the parent def. + if (scope->canUpdateDefinition(l->local)) + { + DefId updated = defArena->freshCell(containsSubscriptedDefinition(incomingDef)); + scope->bindings[l->local] = updated; + captures[l->local].allVersions.push_back(updated); + return updated; + } + else + return visitExpr(scope, static_cast(l)).def; } -void DataFlowGraphBuilder::visitLValue(DfgScope* scope, AstExprGlobal* g) +DefId DataFlowGraphBuilder::visitLValue(DfgScope* scope, AstExprGlobal* g, DefId incomingDef, bool isCompoundAssignment) { - // Bug compatibility: we don't support type states yet, so we need to do this. - NullableBreadcrumbId bc = scope->lookup(g->name); - if (!bc) - bc = breadcrumbs->add(nullptr, defs->freshCell()); + // We need to keep the previous def around for a compound assignment. + if (isCompoundAssignment) + { + DefId def = lookup(scope, g->name); + graph.compoundAssignDefs[g] = def; + } - graph.astBreadcrumbs[g] = bc; - scope->bindings[g->name] = bc; + // In order to avoid alias tracking, we need to clip the reference to the parent def. + if (scope->canUpdateDefinition(g->name)) + { + DefId updated = defArena->freshCell(containsSubscriptedDefinition(incomingDef)); + scope->bindings[g->name] = updated; + captures[g->name].allVersions.push_back(updated); + return updated; + } + else + return visitExpr(scope, static_cast(g)).def; } -void DataFlowGraphBuilder::visitLValue(DfgScope* scope, AstExprIndexName* i) +DefId DataFlowGraphBuilder::visitLValue(DfgScope* scope, AstExprIndexName* i, DefId incomingDef) { - // Bug compatibility: we don't support type states yet, so we need to do this. - BreadcrumbId parentBreadcrumb = visitExpr(scope, i->expr); + DefId parentDef = visitExpr(scope, i->expr).def; - std::string key = i->index.value; - NullableBreadcrumbId propBreadcrumb = scope->lookup(parentBreadcrumb->def, key); - if (!propBreadcrumb) + if (scope->canUpdateDefinition(parentDef, i->index.value)) { - propBreadcrumb = breadcrumbs->emplace(parentBreadcrumb, defs->freshCell(), key); - moduleScope->props[parentBreadcrumb->def][key] = propBreadcrumb; + DefId updated = defArena->freshCell(containsSubscriptedDefinition(incomingDef)); + scope->props[parentDef][i->index.value] = updated; + return updated; } - - graph.astBreadcrumbs[i] = propBreadcrumb; + else + return visitExpr(scope, static_cast(i)).def; } -void DataFlowGraphBuilder::visitLValue(DfgScope* scope, AstExprIndexExpr* i) +DefId DataFlowGraphBuilder::visitLValue(DfgScope* scope, AstExprIndexExpr* i, DefId incomingDef) { - BreadcrumbId parentBreadcrumb = visitExpr(scope, i->expr); + DefId parentDef = visitExpr(scope, i->expr).def; visitExpr(scope, i->index); if (auto string = i->index->as()) { - std::string key{string->value.data, string->value.size}; - NullableBreadcrumbId propBreadcrumb = scope->lookup(parentBreadcrumb->def, key); - if (!propBreadcrumb) + if (scope->canUpdateDefinition(parentDef, string->value.data)) { - propBreadcrumb = breadcrumbs->add(parentBreadcrumb, parentBreadcrumb->def); - moduleScope->props[parentBreadcrumb->def][key] = propBreadcrumb; + DefId updated = defArena->freshCell(containsSubscriptedDefinition(incomingDef)); + scope->props[parentDef][string->value.data] = updated; + return updated; } - - graph.astBreadcrumbs[i] = propBreadcrumb; + else + return visitExpr(scope, static_cast(i)).def; } + else + return defArena->freshCell(/*subscripted=*/true); +} + +DefId DataFlowGraphBuilder::visitLValue(DfgScope* scope, AstExprError* error, DefId incomingDef) +{ + return visitExpr(scope, error).def; } void DataFlowGraphBuilder::visitType(DfgScope* scope, AstType* t) diff --git a/third_party/luau/Analysis/src/DcrLogger.cpp b/third_party/luau/Analysis/src/DcrLogger.cpp index 9f66b022..f013b985 100644 --- a/third_party/luau/Analysis/src/DcrLogger.cpp +++ b/third_party/luau/Analysis/src/DcrLogger.cpp @@ -124,7 +124,8 @@ void write(JsonEmitter& emitter, const ConstraintBlock& block) ObjectEmitter o = emitter.writeObject(); o.writePair("stringification", block.stringification); - auto go = [&o](auto&& t) { + auto go = [&o](auto&& t) + { using T = std::decay_t; o.writePair("id", toPointerId(t)); @@ -350,8 +351,12 @@ void DcrLogger::popBlock(NotNull block) } } -static void snapshotTypeStrings(const std::vector& interestedExprs, - const std::vector& interestedAnnots, DenseHashMap& map, ToStringOptions& opts) +static void snapshotTypeStrings( + const std::vector& interestedExprs, + const std::vector& interestedAnnots, + DenseHashMap& map, + ToStringOptions& opts +) { for (const ExprTypesAtLocation& tys : interestedExprs) { @@ -368,7 +373,10 @@ static void snapshotTypeStrings(const std::vector& interest } void DcrLogger::captureBoundaryState( - BoundarySnapshot& target, const Scope* rootScope, const std::vector>& unsolvedConstraints) + BoundarySnapshot& target, + const Scope* rootScope, + const std::vector>& unsolvedConstraints +) { target.rootScope = snapshotScope(rootScope, opts); target.unsolvedConstraints.clear(); @@ -391,7 +399,11 @@ void DcrLogger::captureInitialSolverState(const Scope* rootScope, const std::vec } StepSnapshot DcrLogger::prepareStepSnapshot( - const Scope* rootScope, NotNull current, bool force, const std::vector>& unsolvedConstraints) + const Scope* rootScope, + NotNull current, + bool force, + const std::vector>& unsolvedConstraints +) { ScopeSnapshot scopeSnapshot = snapshotScope(rootScope, opts); DenseHashMap constraints{nullptr}; diff --git a/third_party/luau/Analysis/src/Def.cpp b/third_party/luau/Analysis/src/Def.cpp index 7be075c2..6d58b28f 100644 --- a/third_party/luau/Analysis/src/Def.cpp +++ b/third_party/luau/Analysis/src/Def.cpp @@ -1,12 +1,62 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Def.h" +#include "Luau/Common.h" + +#include + namespace Luau { -DefId DefArena::freshCell() +bool containsSubscriptedDefinition(DefId def) +{ + if (auto cell = get(def)) + return cell->subscripted; + else if (auto phi = get(def)) + return std::any_of(phi->operands.begin(), phi->operands.end(), containsSubscriptedDefinition); + else + return false; +} + +void collectOperands(DefId def, std::vector* operands) +{ + LUAU_ASSERT(operands); + if (std::find(operands->begin(), operands->end(), def) != operands->end()) + return; + else if (get(def)) + operands->push_back(def); + else if (auto phi = get(def)) + { + // A trivial phi node has no operands to populate, so we push this definition in directly. + if (phi->operands.empty()) + return operands->push_back(def); + + for (const Def* operand : phi->operands) + collectOperands(NotNull{operand}, operands); + } +} + +DefId DefArena::freshCell(bool subscripted) +{ + return NotNull{allocator.allocate(Def{Cell{subscripted}})}; +} + +DefId DefArena::phi(DefId a, DefId b) +{ + return phi({a, b}); +} + +DefId DefArena::phi(const std::vector& defs) { - return NotNull{allocator.allocate(Def{Cell{}})}; + std::vector operands; + for (DefId operand : defs) + collectOperands(operand, &operands); + + // There's no need to allocate a Phi node for a singleton set. + if (operands.size() == 1) + return operands[0]; + else + return NotNull{allocator.allocate(Def{Phi{std::move(operands)}})}; } } // namespace Luau diff --git a/third_party/luau/Analysis/src/Differ.cpp b/third_party/luau/Analysis/src/Differ.cpp new file mode 100644 index 00000000..25687e11 --- /dev/null +++ b/third_party/luau/Analysis/src/Differ.cpp @@ -0,0 +1,967 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/Differ.h" +#include "Luau/Common.h" +#include "Luau/Error.h" +#include "Luau/ToString.h" +#include "Luau/Type.h" +#include "Luau/TypePack.h" +#include "Luau/Unifiable.h" +#include +#include +#include +#include + +namespace Luau +{ +std::string DiffPathNode::toString() const +{ + switch (kind) + { + case DiffPathNode::Kind::TableProperty: + { + if (!tableProperty.has_value()) + throw InternalCompilerError{"DiffPathNode has kind TableProperty but tableProperty is nullopt"}; + return *tableProperty; + break; + } + case DiffPathNode::Kind::FunctionArgument: + { + if (!index.has_value()) + return "Arg[Variadic]"; + // Add 1 because Lua is 1-indexed + return "Arg[" + std::to_string(*index + 1) + "]"; + } + case DiffPathNode::Kind::FunctionReturn: + { + if (!index.has_value()) + return "Ret[Variadic]"; + // Add 1 because Lua is 1-indexed + return "Ret[" + std::to_string(*index + 1) + "]"; + } + case DiffPathNode::Kind::Negation: + { + return "Negation"; + } + default: + { + throw InternalCompilerError{"DiffPathNode::toString is not exhaustive"}; + } + } +} + +DiffPathNode DiffPathNode::constructWithTableProperty(Name tableProperty) +{ + return DiffPathNode{DiffPathNode::Kind::TableProperty, tableProperty, std::nullopt}; +} + +DiffPathNode DiffPathNode::constructWithKindAndIndex(Kind kind, size_t index) +{ + return DiffPathNode{kind, std::nullopt, index}; +} + +DiffPathNode DiffPathNode::constructWithKind(Kind kind) +{ + return DiffPathNode{kind, std::nullopt, std::nullopt}; +} + +DiffPathNodeLeaf DiffPathNodeLeaf::detailsNormal(TypeId ty) +{ + return DiffPathNodeLeaf{ty, std::nullopt, std::nullopt, false, std::nullopt}; +} + +DiffPathNodeLeaf DiffPathNodeLeaf::detailsTableProperty(TypeId ty, Name tableProperty) +{ + return DiffPathNodeLeaf{ty, tableProperty, std::nullopt, false, std::nullopt}; +} + +DiffPathNodeLeaf DiffPathNodeLeaf::detailsUnionIndex(TypeId ty, size_t index) +{ + return DiffPathNodeLeaf{ty, std::nullopt, std::nullopt, false, index}; +} + +DiffPathNodeLeaf DiffPathNodeLeaf::detailsLength(int minLength, bool isVariadic) +{ + return DiffPathNodeLeaf{std::nullopt, std::nullopt, minLength, isVariadic, std::nullopt}; +} + +DiffPathNodeLeaf DiffPathNodeLeaf::nullopts() +{ + return DiffPathNodeLeaf{std::nullopt, std::nullopt, std::nullopt, false, std::nullopt}; +} + +std::string DiffPath::toString(bool prependDot) const +{ + std::string pathStr; + bool isFirstInForLoop = !prependDot; + for (auto node = path.rbegin(); node != path.rend(); node++) + { + if (isFirstInForLoop) + { + isFirstInForLoop = false; + } + else + { + pathStr += "."; + } + pathStr += node->toString(); + } + return pathStr; +} +std::string DiffError::toStringALeaf(std::string rootName, const DiffPathNodeLeaf& leaf, const DiffPathNodeLeaf& otherLeaf, bool multiLine) const +{ + std::string conditionalNewline = multiLine ? "\n" : " "; + std::string conditionalIndent = multiLine ? " " : ""; + std::string pathStr{rootName + diffPath.toString(true)}; + switch (kind) + { + case DiffError::Kind::Normal: + { + checkNonMissingPropertyLeavesHaveNulloptTableProperty(); + return pathStr + conditionalNewline + "has type" + conditionalNewline + conditionalIndent + Luau::toString(*leaf.ty); + } + case DiffError::Kind::MissingTableProperty: + { + if (leaf.ty.has_value()) + { + if (!leaf.tableProperty.has_value()) + throw InternalCompilerError{"leaf.tableProperty is nullopt"}; + return pathStr + "." + *leaf.tableProperty + conditionalNewline + "has type" + conditionalNewline + conditionalIndent + + Luau::toString(*leaf.ty); + } + else if (otherLeaf.ty.has_value()) + { + if (!otherLeaf.tableProperty.has_value()) + throw InternalCompilerError{"otherLeaf.tableProperty is nullopt"}; + return pathStr + conditionalNewline + "is missing the property" + conditionalNewline + conditionalIndent + *otherLeaf.tableProperty; + } + throw InternalCompilerError{"Both leaf.ty and otherLeaf.ty is nullopt"}; + } + case DiffError::Kind::MissingUnionMember: + { + // TODO: do normal case + if (leaf.ty.has_value()) + { + if (!leaf.unionIndex.has_value()) + throw InternalCompilerError{"leaf.unionIndex is nullopt"}; + return pathStr + conditionalNewline + "is a union containing type" + conditionalNewline + conditionalIndent + Luau::toString(*leaf.ty); + } + else if (otherLeaf.ty.has_value()) + { + return pathStr + conditionalNewline + "is a union missing type" + conditionalNewline + conditionalIndent + Luau::toString(*otherLeaf.ty); + } + throw InternalCompilerError{"Both leaf.ty and otherLeaf.ty is nullopt"}; + } + case DiffError::Kind::MissingIntersectionMember: + { + // TODO: better message for intersections + // An intersection of just functions is always an "overloaded function" + // An intersection of just tables is always a "joined table" + if (leaf.ty.has_value()) + { + if (!leaf.unionIndex.has_value()) + throw InternalCompilerError{"leaf.unionIndex is nullopt"}; + return pathStr + conditionalNewline + "is an intersection containing type" + conditionalNewline + conditionalIndent + + Luau::toString(*leaf.ty); + } + else if (otherLeaf.ty.has_value()) + { + return pathStr + conditionalNewline + "is an intersection missing type" + conditionalNewline + conditionalIndent + + Luau::toString(*otherLeaf.ty); + } + throw InternalCompilerError{"Both leaf.ty and otherLeaf.ty is nullopt"}; + } + case DiffError::Kind::LengthMismatchInFnArgs: + { + if (!leaf.minLength.has_value()) + throw InternalCompilerError{"leaf.minLength is nullopt"}; + return pathStr + conditionalNewline + "takes " + std::to_string(*leaf.minLength) + (leaf.isVariadic ? " or more" : "") + " arguments"; + } + case DiffError::Kind::LengthMismatchInFnRets: + { + if (!leaf.minLength.has_value()) + throw InternalCompilerError{"leaf.minLength is nullopt"}; + return pathStr + conditionalNewline + "returns " + std::to_string(*leaf.minLength) + (leaf.isVariadic ? " or more" : "") + " values"; + } + default: + { + throw InternalCompilerError{"DiffPath::toStringALeaf is not exhaustive"}; + } + } +} + +void DiffError::checkNonMissingPropertyLeavesHaveNulloptTableProperty() const +{ + if (left.tableProperty.has_value() || right.tableProperty.has_value()) + throw InternalCompilerError{"Non-MissingProperty DiffError should have nullopt tableProperty in both leaves"}; +} + +std::string getDevFixFriendlyName(const std::optional& maybeSymbol, TypeId ty) +{ + if (maybeSymbol.has_value()) + return *maybeSymbol; + + if (auto table = get(ty)) + { + if (table->name.has_value()) + return *table->name; + else if (table->syntheticName.has_value()) + return *table->syntheticName; + } + if (auto metatable = get(ty)) + { + if (metatable->syntheticName.has_value()) + { + return *metatable->syntheticName; + } + } + return ""; +} + +std::string DifferEnvironment::getDevFixFriendlyNameLeft() const +{ + return getDevFixFriendlyName(externalSymbolLeft, rootLeft); +} + +std::string DifferEnvironment::getDevFixFriendlyNameRight() const +{ + return getDevFixFriendlyName(externalSymbolRight, rootRight); +} + +std::string DiffError::toString(bool multiLine) const +{ + std::string conditionalNewline = multiLine ? "\n" : " "; + std::string conditionalIndent = multiLine ? " " : ""; + switch (kind) + { + case DiffError::Kind::IncompatibleGeneric: + { + std::string diffPathStr{diffPath.toString(true)}; + return "DiffError: these two types are not equal because the left generic at" + conditionalNewline + conditionalIndent + leftRootName + + diffPathStr + conditionalNewline + "cannot be the same type parameter as the right generic at" + conditionalNewline + + conditionalIndent + rightRootName + diffPathStr; + } + default: + { + return "DiffError: these two types are not equal because the left type at" + conditionalNewline + conditionalIndent + + toStringALeaf(leftRootName, left, right, multiLine) + "," + conditionalNewline + "while the right type at" + conditionalNewline + + conditionalIndent + toStringALeaf(rightRootName, right, left, multiLine); + } + } +} + +void DiffError::checkValidInitialization(const DiffPathNodeLeaf& left, const DiffPathNodeLeaf& right) +{ + if (!left.ty.has_value() || !right.ty.has_value()) + { + // TODO: think about whether this should be always thrown! + // For example, Kind::Primitive doesn't make too much sense to have a TypeId + // throw InternalCompilerError{"Left and Right fields are leaf nodes and must have a TypeId"}; + } +} + +void DifferResult::wrapDiffPath(DiffPathNode node) +{ + if (!diffError.has_value()) + { + throw InternalCompilerError{"Cannot wrap diffPath because there is no diffError"}; + } + + diffError->diffPath.path.push_back(node); +} + +static DifferResult diffUsingEnv(DifferEnvironment& env, TypeId left, TypeId right); +static DifferResult diffTable(DifferEnvironment& env, TypeId left, TypeId right); +static DifferResult diffMetatable(DifferEnvironment& env, TypeId left, TypeId right); +static DifferResult diffPrimitive(DifferEnvironment& env, TypeId left, TypeId right); +static DifferResult diffSingleton(DifferEnvironment& env, TypeId left, TypeId right); +static DifferResult diffFunction(DifferEnvironment& env, TypeId left, TypeId right); +static DifferResult diffGeneric(DifferEnvironment& env, TypeId left, TypeId right); +static DifferResult diffNegation(DifferEnvironment& env, TypeId left, TypeId right); +static DifferResult diffClass(DifferEnvironment& env, TypeId left, TypeId right); +struct FindSeteqCounterexampleResult +{ + // nullopt if no counterexample found + std::optional mismatchIdx; + // true if counterexample is in the left, false if cex is in the right + bool inLeft; +}; +static FindSeteqCounterexampleResult findSeteqCounterexample( + DifferEnvironment& env, + const std::vector& left, + const std::vector& right +); +static DifferResult diffUnion(DifferEnvironment& env, TypeId left, TypeId right); +static DifferResult diffIntersection(DifferEnvironment& env, TypeId left, TypeId right); +/** + * The last argument gives context info on which complex type contained the TypePack. + */ +static DifferResult diffTpi(DifferEnvironment& env, DiffError::Kind possibleNonNormalErrorKind, TypePackId left, TypePackId right); +static DifferResult diffCanonicalTpShape( + DifferEnvironment& env, + DiffError::Kind possibleNonNormalErrorKind, + const std::pair, std::optional>& left, + const std::pair, std::optional>& right +); +static DifferResult diffHandleFlattenedTail(DifferEnvironment& env, DiffError::Kind possibleNonNormalErrorKind, TypePackId left, TypePackId right); +static DifferResult diffGenericTp(DifferEnvironment& env, TypePackId left, TypePackId right); + +static DifferResult diffTable(DifferEnvironment& env, TypeId left, TypeId right) +{ + const TableType* leftTable = get(left); + const TableType* rightTable = get(right); + LUAU_ASSERT(leftTable); + LUAU_ASSERT(rightTable); + + for (auto const& [field, value] : leftTable->props) + { + if (rightTable->props.find(field) == rightTable->props.end()) + { + // left has a field the right doesn't + return DifferResult{DiffError{ + DiffError::Kind::MissingTableProperty, + DiffPathNodeLeaf::detailsTableProperty(value.type(), field), + DiffPathNodeLeaf::nullopts(), + env.getDevFixFriendlyNameLeft(), + env.getDevFixFriendlyNameRight(), + }}; + } + } + for (auto const& [field, value] : rightTable->props) + { + if (leftTable->props.find(field) == leftTable->props.end()) + { + // right has a field the left doesn't + return DifferResult{DiffError{ + DiffError::Kind::MissingTableProperty, + DiffPathNodeLeaf::nullopts(), + DiffPathNodeLeaf::detailsTableProperty(value.type(), field), + env.getDevFixFriendlyNameLeft(), + env.getDevFixFriendlyNameRight() + }}; + } + } + // left and right have the same set of keys + for (auto const& [field, leftValue] : leftTable->props) + { + auto const& rightValue = rightTable->props.at(field); + DifferResult differResult = diffUsingEnv(env, leftValue.type(), rightValue.type()); + if (differResult.diffError.has_value()) + { + differResult.wrapDiffPath(DiffPathNode::constructWithTableProperty(field)); + return differResult; + } + } + return DifferResult{}; +} + +static DifferResult diffMetatable(DifferEnvironment& env, TypeId left, TypeId right) +{ + const MetatableType* leftMetatable = get(left); + const MetatableType* rightMetatable = get(right); + LUAU_ASSERT(leftMetatable); + LUAU_ASSERT(rightMetatable); + + DifferResult diffRes = diffUsingEnv(env, leftMetatable->table, rightMetatable->table); + if (diffRes.diffError.has_value()) + { + return diffRes; + } + + diffRes = diffUsingEnv(env, leftMetatable->metatable, rightMetatable->metatable); + if (diffRes.diffError.has_value()) + { + diffRes.wrapDiffPath(DiffPathNode::constructWithTableProperty("__metatable")); + return diffRes; + } + return DifferResult{}; +} + +static DifferResult diffPrimitive(DifferEnvironment& env, TypeId left, TypeId right) +{ + const PrimitiveType* leftPrimitive = get(left); + const PrimitiveType* rightPrimitive = get(right); + LUAU_ASSERT(leftPrimitive); + LUAU_ASSERT(rightPrimitive); + + if (leftPrimitive->type != rightPrimitive->type) + { + return DifferResult{DiffError{ + DiffError::Kind::Normal, + DiffPathNodeLeaf::detailsNormal(left), + DiffPathNodeLeaf::detailsNormal(right), + env.getDevFixFriendlyNameLeft(), + env.getDevFixFriendlyNameRight(), + }}; + } + return DifferResult{}; +} + +static DifferResult diffSingleton(DifferEnvironment& env, TypeId left, TypeId right) +{ + const SingletonType* leftSingleton = get(left); + const SingletonType* rightSingleton = get(right); + LUAU_ASSERT(leftSingleton); + LUAU_ASSERT(rightSingleton); + + if (*leftSingleton != *rightSingleton) + { + return DifferResult{DiffError{ + DiffError::Kind::Normal, + DiffPathNodeLeaf::detailsNormal(left), + DiffPathNodeLeaf::detailsNormal(right), + env.getDevFixFriendlyNameLeft(), + env.getDevFixFriendlyNameRight(), + }}; + } + return DifferResult{}; +} + +static DifferResult diffFunction(DifferEnvironment& env, TypeId left, TypeId right) +{ + const FunctionType* leftFunction = get(left); + const FunctionType* rightFunction = get(right); + LUAU_ASSERT(leftFunction); + LUAU_ASSERT(rightFunction); + + DifferResult differResult = diffTpi(env, DiffError::Kind::LengthMismatchInFnArgs, leftFunction->argTypes, rightFunction->argTypes); + if (differResult.diffError.has_value()) + return differResult; + return diffTpi(env, DiffError::Kind::LengthMismatchInFnRets, leftFunction->retTypes, rightFunction->retTypes); +} + +static DifferResult diffGeneric(DifferEnvironment& env, TypeId left, TypeId right) +{ + LUAU_ASSERT(get(left)); + LUAU_ASSERT(get(right)); + // Try to pair up the generics + bool isLeftFree = !env.genericMatchedPairs.contains(left); + bool isRightFree = !env.genericMatchedPairs.contains(right); + if (isLeftFree && isRightFree) + { + env.genericMatchedPairs[left] = right; + env.genericMatchedPairs[right] = left; + return DifferResult{}; + } + else if (isLeftFree || isRightFree) + { + return DifferResult{DiffError{ + DiffError::Kind::IncompatibleGeneric, + DiffPathNodeLeaf::nullopts(), + DiffPathNodeLeaf::nullopts(), + env.getDevFixFriendlyNameLeft(), + env.getDevFixFriendlyNameRight(), + }}; + } + + // Both generics are already paired up + if (*env.genericMatchedPairs.find(left) == right) + return DifferResult{}; + + return DifferResult{DiffError{ + DiffError::Kind::IncompatibleGeneric, + DiffPathNodeLeaf::nullopts(), + DiffPathNodeLeaf::nullopts(), + env.getDevFixFriendlyNameLeft(), + env.getDevFixFriendlyNameRight(), + }}; +} + +static DifferResult diffNegation(DifferEnvironment& env, TypeId left, TypeId right) +{ + const NegationType* leftNegation = get(left); + const NegationType* rightNegation = get(right); + LUAU_ASSERT(leftNegation); + LUAU_ASSERT(rightNegation); + + DifferResult differResult = diffUsingEnv(env, leftNegation->ty, rightNegation->ty); + if (!differResult.diffError.has_value()) + return DifferResult{}; + + differResult.wrapDiffPath(DiffPathNode::constructWithKind(DiffPathNode::Kind::Negation)); + return differResult; +} + +static DifferResult diffClass(DifferEnvironment& env, TypeId left, TypeId right) +{ + const ClassType* leftClass = get(left); + const ClassType* rightClass = get(right); + LUAU_ASSERT(leftClass); + LUAU_ASSERT(rightClass); + + if (leftClass == rightClass) + { + return DifferResult{}; + } + + return DifferResult{DiffError{ + DiffError::Kind::Normal, + DiffPathNodeLeaf::detailsNormal(left), + DiffPathNodeLeaf::detailsNormal(right), + env.getDevFixFriendlyNameLeft(), + env.getDevFixFriendlyNameRight(), + }}; +} + +static FindSeteqCounterexampleResult findSeteqCounterexample( + DifferEnvironment& env, + const std::vector& left, + const std::vector& right +) +{ + std::unordered_set unmatchedRightIdxes; + for (size_t i = 0; i < right.size(); i++) + unmatchedRightIdxes.insert(i); + for (size_t leftIdx = 0; leftIdx < left.size(); leftIdx++) + { + bool leftIdxIsMatched = false; + auto unmatchedRightIdxIt = unmatchedRightIdxes.begin(); + while (unmatchedRightIdxIt != unmatchedRightIdxes.end()) + { + DifferResult differResult = diffUsingEnv(env, left[leftIdx], right[*unmatchedRightIdxIt]); + if (differResult.diffError.has_value()) + { + unmatchedRightIdxIt++; + continue; + } + // unmatchedRightIdxIt is matched with current leftIdx + env.recordProvenEqual(left[leftIdx], right[*unmatchedRightIdxIt]); + leftIdxIsMatched = true; + unmatchedRightIdxIt = unmatchedRightIdxes.erase(unmatchedRightIdxIt); + } + if (!leftIdxIsMatched) + { + return FindSeteqCounterexampleResult{leftIdx, true}; + } + } + if (unmatchedRightIdxes.empty()) + return FindSeteqCounterexampleResult{std::nullopt, false}; + return FindSeteqCounterexampleResult{*unmatchedRightIdxes.begin(), false}; +} + +static DifferResult diffUnion(DifferEnvironment& env, TypeId left, TypeId right) +{ + const UnionType* leftUnion = get(left); + const UnionType* rightUnion = get(right); + LUAU_ASSERT(leftUnion); + LUAU_ASSERT(rightUnion); + + FindSeteqCounterexampleResult findSeteqCexResult = findSeteqCounterexample(env, leftUnion->options, rightUnion->options); + if (findSeteqCexResult.mismatchIdx.has_value()) + { + if (findSeteqCexResult.inLeft) + return DifferResult{DiffError{ + DiffError::Kind::MissingUnionMember, + DiffPathNodeLeaf::detailsUnionIndex(leftUnion->options[*findSeteqCexResult.mismatchIdx], *findSeteqCexResult.mismatchIdx), + DiffPathNodeLeaf::nullopts(), + env.getDevFixFriendlyNameLeft(), + env.getDevFixFriendlyNameRight(), + }}; + else + return DifferResult{DiffError{ + DiffError::Kind::MissingUnionMember, + DiffPathNodeLeaf::nullopts(), + DiffPathNodeLeaf::detailsUnionIndex(rightUnion->options[*findSeteqCexResult.mismatchIdx], *findSeteqCexResult.mismatchIdx), + env.getDevFixFriendlyNameLeft(), + env.getDevFixFriendlyNameRight(), + }}; + } + + // TODO: somehow detect mismatch index, likely using heuristics + + return DifferResult{}; +} + +static DifferResult diffIntersection(DifferEnvironment& env, TypeId left, TypeId right) +{ + const IntersectionType* leftIntersection = get(left); + const IntersectionType* rightIntersection = get(right); + LUAU_ASSERT(leftIntersection); + LUAU_ASSERT(rightIntersection); + + FindSeteqCounterexampleResult findSeteqCexResult = findSeteqCounterexample(env, leftIntersection->parts, rightIntersection->parts); + if (findSeteqCexResult.mismatchIdx.has_value()) + { + if (findSeteqCexResult.inLeft) + return DifferResult{DiffError{ + DiffError::Kind::MissingIntersectionMember, + DiffPathNodeLeaf::detailsUnionIndex(leftIntersection->parts[*findSeteqCexResult.mismatchIdx], *findSeteqCexResult.mismatchIdx), + DiffPathNodeLeaf::nullopts(), + env.getDevFixFriendlyNameLeft(), + env.getDevFixFriendlyNameRight(), + }}; + else + return DifferResult{DiffError{ + DiffError::Kind::MissingIntersectionMember, + DiffPathNodeLeaf::nullopts(), + DiffPathNodeLeaf::detailsUnionIndex(rightIntersection->parts[*findSeteqCexResult.mismatchIdx], *findSeteqCexResult.mismatchIdx), + env.getDevFixFriendlyNameLeft(), + env.getDevFixFriendlyNameRight(), + }}; + } + + // TODO: somehow detect mismatch index, likely using heuristics + + return DifferResult{}; +} + +static DifferResult diffUsingEnv(DifferEnvironment& env, TypeId left, TypeId right) +{ + left = follow(left); + right = follow(right); + + if (left->ty.index() != right->ty.index()) + { + return DifferResult{DiffError{ + DiffError::Kind::Normal, + DiffPathNodeLeaf::detailsNormal(left), + DiffPathNodeLeaf::detailsNormal(right), + env.getDevFixFriendlyNameLeft(), + env.getDevFixFriendlyNameRight(), + }}; + } + + // Both left and right are the same variant + + // Check cycles & caches + if (env.isAssumedEqual(left, right) || env.isProvenEqual(left, right)) + return DifferResult{}; + + if (isSimple(left)) + { + if (auto lp = get(left)) + return diffPrimitive(env, left, right); + else if (auto ls = get(left)) + { + return diffSingleton(env, left, right); + } + else if (auto la = get(left)) + { + // Both left and right must be Any if either is Any for them to be equal! + return DifferResult{}; + } + else if (auto lu = get(left)) + { + return DifferResult{}; + } + else if (auto ln = get(left)) + { + return DifferResult{}; + } + else if (auto ln = get(left)) + { + return diffNegation(env, left, right); + } + else if (auto lc = get(left)) + { + return diffClass(env, left, right); + } + + throw InternalCompilerError{"Unimplemented Simple TypeId variant for diffing"}; + } + + // Both left and right are the same non-Simple + // Non-simple types must record visits in the DifferEnvironment + env.pushVisiting(left, right); + + if (auto lt = get(left)) + { + DifferResult diffRes = diffTable(env, left, right); + if (!diffRes.diffError.has_value()) + { + env.recordProvenEqual(left, right); + } + env.popVisiting(); + return diffRes; + } + if (auto lm = get(left)) + { + env.popVisiting(); + return diffMetatable(env, left, right); + } + if (auto lf = get(left)) + { + DifferResult diffRes = diffFunction(env, left, right); + if (!diffRes.diffError.has_value()) + { + env.recordProvenEqual(left, right); + } + env.popVisiting(); + return diffRes; + } + if (auto lg = get(left)) + { + DifferResult diffRes = diffGeneric(env, left, right); + if (!diffRes.diffError.has_value()) + { + env.recordProvenEqual(left, right); + } + env.popVisiting(); + return diffRes; + } + if (auto lu = get(left)) + { + DifferResult diffRes = diffUnion(env, left, right); + if (!diffRes.diffError.has_value()) + { + env.recordProvenEqual(left, right); + } + env.popVisiting(); + return diffRes; + } + if (auto li = get(left)) + { + DifferResult diffRes = diffIntersection(env, left, right); + if (!diffRes.diffError.has_value()) + { + env.recordProvenEqual(left, right); + } + env.popVisiting(); + return diffRes; + } + if (auto le = get(left)) + { + // TODO: return debug-friendly result state + env.popVisiting(); + return DifferResult{}; + } + + throw InternalCompilerError{"Unimplemented non-simple TypeId variant for diffing"}; +} + +static DifferResult diffTpi(DifferEnvironment& env, DiffError::Kind possibleNonNormalErrorKind, TypePackId left, TypePackId right) +{ + left = follow(left); + right = follow(right); + + // Canonicalize + std::pair, std::optional> leftFlatTpi = flatten(left); + std::pair, std::optional> rightFlatTpi = flatten(right); + + // Check for shape equality + DifferResult diffResult = diffCanonicalTpShape(env, possibleNonNormalErrorKind, leftFlatTpi, rightFlatTpi); + if (diffResult.diffError.has_value()) + { + return diffResult; + } + + // Left and Right have the same shape + for (size_t i = 0; i < leftFlatTpi.first.size(); i++) + { + DifferResult differResult = diffUsingEnv(env, leftFlatTpi.first[i], rightFlatTpi.first[i]); + if (!differResult.diffError.has_value()) + continue; + + switch (possibleNonNormalErrorKind) + { + case DiffError::Kind::LengthMismatchInFnArgs: + { + differResult.wrapDiffPath(DiffPathNode::constructWithKindAndIndex(DiffPathNode::Kind::FunctionArgument, i)); + return differResult; + } + case DiffError::Kind::LengthMismatchInFnRets: + { + differResult.wrapDiffPath(DiffPathNode::constructWithKindAndIndex(DiffPathNode::Kind::FunctionReturn, i)); + return differResult; + } + default: + { + throw InternalCompilerError{"Unhandled Tpi diffing case with same shape"}; + } + } + } + if (!leftFlatTpi.second.has_value()) + return DifferResult{}; + + return diffHandleFlattenedTail(env, possibleNonNormalErrorKind, *leftFlatTpi.second, *rightFlatTpi.second); +} + +static DifferResult diffCanonicalTpShape( + DifferEnvironment& env, + DiffError::Kind possibleNonNormalErrorKind, + const std::pair, std::optional>& left, + const std::pair, std::optional>& right +) +{ + if (left.first.size() == right.first.size() && left.second.has_value() == right.second.has_value()) + return DifferResult{}; + + return DifferResult{DiffError{ + possibleNonNormalErrorKind, + DiffPathNodeLeaf::detailsLength(int(left.first.size()), left.second.has_value()), + DiffPathNodeLeaf::detailsLength(int(right.first.size()), right.second.has_value()), + env.getDevFixFriendlyNameLeft(), + env.getDevFixFriendlyNameRight(), + }}; +} + +static DifferResult diffHandleFlattenedTail(DifferEnvironment& env, DiffError::Kind possibleNonNormalErrorKind, TypePackId left, TypePackId right) +{ + left = follow(left); + right = follow(right); + + if (left->ty.index() != right->ty.index()) + { + return DifferResult{DiffError{ + DiffError::Kind::Normal, + DiffPathNodeLeaf::detailsNormal(env.visitingBegin()->first), + DiffPathNodeLeaf::detailsNormal(env.visitingBegin()->second), + env.getDevFixFriendlyNameLeft(), + env.getDevFixFriendlyNameRight(), + }}; + } + + // Both left and right are the same variant + + if (auto lv = get(left)) + { + auto rv = get(right); + DifferResult differResult = diffUsingEnv(env, lv->ty, rv->ty); + if (!differResult.diffError.has_value()) + return DifferResult{}; + + switch (possibleNonNormalErrorKind) + { + case DiffError::Kind::LengthMismatchInFnArgs: + { + differResult.wrapDiffPath(DiffPathNode::constructWithKind(DiffPathNode::Kind::FunctionArgument)); + return differResult; + } + case DiffError::Kind::LengthMismatchInFnRets: + { + differResult.wrapDiffPath(DiffPathNode::constructWithKind(DiffPathNode::Kind::FunctionReturn)); + return differResult; + } + default: + { + throw InternalCompilerError{"Unhandled flattened tail case for VariadicTypePack"}; + } + } + } + if (auto lg = get(left)) + { + DifferResult diffRes = diffGenericTp(env, left, right); + if (!diffRes.diffError.has_value()) + return DifferResult{}; + switch (possibleNonNormalErrorKind) + { + case DiffError::Kind::LengthMismatchInFnArgs: + { + diffRes.wrapDiffPath(DiffPathNode::constructWithKind(DiffPathNode::Kind::FunctionArgument)); + return diffRes; + } + case DiffError::Kind::LengthMismatchInFnRets: + { + diffRes.wrapDiffPath(DiffPathNode::constructWithKind(DiffPathNode::Kind::FunctionReturn)); + return diffRes; + } + default: + { + throw InternalCompilerError{"Unhandled flattened tail case for GenericTypePack"}; + } + } + } + + throw InternalCompilerError{"Unhandled tail type pack variant for flattened tails"}; +} + +static DifferResult diffGenericTp(DifferEnvironment& env, TypePackId left, TypePackId right) +{ + LUAU_ASSERT(get(left)); + LUAU_ASSERT(get(right)); + // Try to pair up the generics + bool isLeftFree = !env.genericTpMatchedPairs.contains(left); + bool isRightFree = !env.genericTpMatchedPairs.contains(right); + if (isLeftFree && isRightFree) + { + env.genericTpMatchedPairs[left] = right; + env.genericTpMatchedPairs[right] = left; + return DifferResult{}; + } + else if (isLeftFree || isRightFree) + { + return DifferResult{DiffError{ + DiffError::Kind::IncompatibleGeneric, + DiffPathNodeLeaf::nullopts(), + DiffPathNodeLeaf::nullopts(), + env.getDevFixFriendlyNameLeft(), + env.getDevFixFriendlyNameRight(), + }}; + } + + // Both generics are already paired up + if (*env.genericTpMatchedPairs.find(left) == right) + return DifferResult{}; + + return DifferResult{DiffError{ + DiffError::Kind::IncompatibleGeneric, + DiffPathNodeLeaf::nullopts(), + DiffPathNodeLeaf::nullopts(), + env.getDevFixFriendlyNameLeft(), + env.getDevFixFriendlyNameRight(), + }}; +} + +bool DifferEnvironment::isProvenEqual(TypeId left, TypeId right) const +{ + return provenEqual.find({left, right}) != provenEqual.end(); +} + +bool DifferEnvironment::isAssumedEqual(TypeId left, TypeId right) const +{ + return visiting.find({left, right}) != visiting.end(); +} + +void DifferEnvironment::recordProvenEqual(TypeId left, TypeId right) +{ + provenEqual.insert({left, right}); + provenEqual.insert({right, left}); +} + +void DifferEnvironment::pushVisiting(TypeId left, TypeId right) +{ + LUAU_ASSERT(visiting.find({left, right}) == visiting.end()); + LUAU_ASSERT(visiting.find({right, left}) == visiting.end()); + visitingStack.push_back({left, right}); + visiting.insert({left, right}); + visiting.insert({right, left}); +} + +void DifferEnvironment::popVisiting() +{ + auto tyPair = visitingStack.back(); + visiting.erase({tyPair.first, tyPair.second}); + visiting.erase({tyPair.second, tyPair.first}); + visitingStack.pop_back(); +} + +std::vector>::const_reverse_iterator DifferEnvironment::visitingBegin() const +{ + return visitingStack.crbegin(); +} + +std::vector>::const_reverse_iterator DifferEnvironment::visitingEnd() const +{ + return visitingStack.crend(); +} + +DifferResult diff(TypeId ty1, TypeId ty2) +{ + DifferEnvironment differEnv{ty1, ty2, std::nullopt, std::nullopt}; + return diffUsingEnv(differEnv, ty1, ty2); +} + +DifferResult diffWithSymbols(TypeId ty1, TypeId ty2, std::optional symbol1, std::optional symbol2) +{ + DifferEnvironment differEnv{ty1, ty2, symbol1, symbol2}; + return diffUsingEnv(differEnv, ty1, ty2); +} + +bool isSimple(TypeId ty) +{ + ty = follow(ty); + // TODO: think about GenericType, etc. + return get(ty) || get(ty) || get(ty) || get(ty) || get(ty) || + get(ty) || get(ty); +} + +} // namespace Luau diff --git a/third_party/luau/Analysis/src/EmbeddedBuiltinDefinitions.cpp b/third_party/luau/Analysis/src/EmbeddedBuiltinDefinitions.cpp index 364244ad..e539661a 100644 --- a/third_party/luau/Analysis/src/EmbeddedBuiltinDefinitions.cpp +++ b/third_party/luau/Analysis/src/EmbeddedBuiltinDefinitions.cpp @@ -4,68 +4,69 @@ namespace Luau { -static const std::string kBuiltinDefinitionLuaSrc = R"BUILTIN_SRC( +static const std::string kBuiltinDefinitionLuaSrcChecked = R"BUILTIN_SRC( declare bit32: { - band: (...number) -> number, - bor: (...number) -> number, - bxor: (...number) -> number, - btest: (number, ...number) -> boolean, - rrotate: (x: number, disp: number) -> number, - lrotate: (x: number, disp: number) -> number, - lshift: (x: number, disp: number) -> number, - arshift: (x: number, disp: number) -> number, - rshift: (x: number, disp: number) -> number, - bnot: (x: number) -> number, - extract: (n: number, field: number, width: number?) -> number, - replace: (n: number, v: number, field: number, width: number?) -> number, - countlz: (n: number) -> number, - countrz: (n: number) -> number, + band: @checked (...number) -> number, + bor: @checked (...number) -> number, + bxor: @checked (...number) -> number, + btest: @checked (number, ...number) -> boolean, + rrotate: @checked (x: number, disp: number) -> number, + lrotate: @checked (x: number, disp: number) -> number, + lshift: @checked (x: number, disp: number) -> number, + arshift: @checked (x: number, disp: number) -> number, + rshift: @checked (x: number, disp: number) -> number, + bnot: @checked (x: number) -> number, + extract: @checked (n: number, field: number, width: number?) -> number, + replace: @checked (n: number, v: number, field: number, width: number?) -> number, + countlz: @checked (n: number) -> number, + countrz: @checked (n: number) -> number, + byteswap: @checked (n: number) -> number, } declare math: { - frexp: (n: number) -> (number, number), - ldexp: (s: number, e: number) -> number, - fmod: (x: number, y: number) -> number, - modf: (n: number) -> (number, number), - pow: (x: number, y: number) -> number, - exp: (n: number) -> number, - - ceil: (n: number) -> number, - floor: (n: number) -> number, - abs: (n: number) -> number, - sqrt: (n: number) -> number, - - log: (n: number, base: number?) -> number, - log10: (n: number) -> number, - - rad: (n: number) -> number, - deg: (n: number) -> number, - - sin: (n: number) -> number, - cos: (n: number) -> number, - tan: (n: number) -> number, - sinh: (n: number) -> number, - cosh: (n: number) -> number, - tanh: (n: number) -> number, - atan: (n: number) -> number, - acos: (n: number) -> number, - asin: (n: number) -> number, - atan2: (y: number, x: number) -> number, - - min: (number, ...number) -> number, - max: (number, ...number) -> number, + frexp: @checked (n: number) -> (number, number), + ldexp: @checked (s: number, e: number) -> number, + fmod: @checked (x: number, y: number) -> number, + modf: @checked (n: number) -> (number, number), + pow: @checked (x: number, y: number) -> number, + exp: @checked (n: number) -> number, + + ceil: @checked (n: number) -> number, + floor: @checked (n: number) -> number, + abs: @checked (n: number) -> number, + sqrt: @checked (n: number) -> number, + + log: @checked (n: number, base: number?) -> number, + log10: @checked (n: number) -> number, + + rad: @checked (n: number) -> number, + deg: @checked (n: number) -> number, + + sin: @checked (n: number) -> number, + cos: @checked (n: number) -> number, + tan: @checked (n: number) -> number, + sinh: @checked (n: number) -> number, + cosh: @checked (n: number) -> number, + tanh: @checked (n: number) -> number, + atan: @checked (n: number) -> number, + acos: @checked (n: number) -> number, + asin: @checked (n: number) -> number, + atan2: @checked (y: number, x: number) -> number, + + min: @checked (number, ...number) -> number, + max: @checked (number, ...number) -> number, pi: number, huge: number, - randomseed: (seed: number) -> (), - random: (number?, number?) -> number, + randomseed: @checked (seed: number) -> (), + random: @checked (number?, number?) -> number, - sign: (n: number) -> number, - clamp: (n: number, min: number, max: number) -> number, - noise: (x: number, y: number?, z: number?) -> number, - round: (n: number) -> number, + sign: @checked (n: number) -> number, + clamp: @checked (n: number, min: number, max: number) -> number, + noise: @checked (x: number, y: number?, z: number?) -> number, + round: @checked (n: number) -> number, } type DateTypeArg = { @@ -92,14 +93,14 @@ type DateTypeResult = { declare os: { time: (time: DateTypeArg?) -> number, - date: (formatString: string?, time: number?) -> DateTypeResult | string, + date: ((formatString: "*t" | "!*t", time: number?) -> DateTypeResult) & ((formatString: string?, time: number?) -> string), difftime: (t2: DateTypeResult | number, t1: DateTypeResult | number) -> number, clock: () -> number, } -declare function require(target: any): any +@checked declare function require(target: any): any -declare function getfenv(target: any): { [string]: any } +@checked declare function getfenv(target: any): { [string]: any } declare _G: any declare _VERSION: string @@ -141,18 +142,17 @@ declare function select(i: string | number, ...: A...): ...any -- (nil, string). declare function loadstring(src: string, chunkname: string?): (((A...) -> any)?, string?) -declare function newproxy(mt: boolean?): any +@checked declare function newproxy(mt: boolean?): any declare coroutine: { create: (f: (A...) -> R...) -> thread, resume: (co: thread, A...) -> (boolean, R...), running: () -> thread, - status: (co: thread) -> "dead" | "running" | "normal" | "suspended", - -- FIXME: This technically returns a function, but we can't represent this yet. - wrap: (f: (A...) -> R...) -> any, + status: @checked (co: thread) -> "dead" | "running" | "normal" | "suspended", + wrap: (f: (A...) -> R...) -> ((A...) -> R...), yield: (A...) -> R..., isyieldable: () -> boolean, - close: (co: thread) -> (boolean, any) + close: @checked (co: thread) -> (boolean, any) } declare table: { @@ -183,22 +183,51 @@ declare debug: { } declare utf8: { - char: (...number) -> string, + char: @checked (...number) -> string, charpattern: string, - codes: (str: string) -> ((string, number) -> (number, number), string, number), - codepoint: (str: string, i: number?, j: number?) -> ...number, - len: (s: string, i: number?, j: number?) -> (number?, number?), - offset: (s: string, n: number?, i: number?) -> number, + codes: @checked (str: string) -> ((string, number) -> (number, number), string, number), + codepoint: @checked (str: string, i: number?, j: number?) -> ...number, + len: @checked (s: string, i: number?, j: number?) -> (number?, number?), + offset: @checked (s: string, n: number?, i: number?) -> number, } -- Cannot use `typeof` here because it will produce a polytype when we expect a monotype. declare function unpack(tab: {V}, i: number?, j: number?): ...V + +--- Buffer API +declare buffer: { + create: @checked (size: number) -> buffer, + fromstring: @checked (str: string) -> buffer, + tostring: @checked (b: buffer) -> string, + len: @checked (b: buffer) -> number, + copy: @checked (target: buffer, targetOffset: number, source: buffer, sourceOffset: number?, count: number?) -> (), + fill: @checked (b: buffer, offset: number, value: number, count: number?) -> (), + readi8: @checked (b: buffer, offset: number) -> number, + readu8: @checked (b: buffer, offset: number) -> number, + readi16: @checked (b: buffer, offset: number) -> number, + readu16: @checked (b: buffer, offset: number) -> number, + readi32: @checked (b: buffer, offset: number) -> number, + readu32: @checked (b: buffer, offset: number) -> number, + readf32: @checked (b: buffer, offset: number) -> number, + readf64: @checked (b: buffer, offset: number) -> number, + writei8: @checked (b: buffer, offset: number, value: number) -> (), + writeu8: @checked (b: buffer, offset: number, value: number) -> (), + writei16: @checked (b: buffer, offset: number, value: number) -> (), + writeu16: @checked (b: buffer, offset: number, value: number) -> (), + writei32: @checked (b: buffer, offset: number, value: number) -> (), + writeu32: @checked (b: buffer, offset: number, value: number) -> (), + writef32: @checked (b: buffer, offset: number, value: number) -> (), + writef64: @checked (b: buffer, offset: number, value: number) -> (), + readstring: @checked (b: buffer, offset: number, count: number) -> string, + writestring: @checked (b: buffer, offset: number, value: string, count: number?) -> (), +} + )BUILTIN_SRC"; std::string getBuiltinDefinitionSource() { - std::string result = kBuiltinDefinitionLuaSrc; + std::string result = kBuiltinDefinitionLuaSrcChecked; return result; } diff --git a/third_party/luau/Analysis/src/Error.cpp b/third_party/luau/Analysis/src/Error.cpp index 1e037972..60058d99 100644 --- a/third_party/luau/Analysis/src/Error.cpp +++ b/third_party/luau/Analysis/src/Error.cpp @@ -4,17 +4,29 @@ #include "Luau/Clone.h" #include "Luau/Common.h" #include "Luau/FileResolver.h" +#include "Luau/NotNull.h" #include "Luau/StringUtils.h" #include "Luau/ToString.h" +#include "Luau/Type.h" +#include "Luau/TypeFunction.h" +#include #include +#include #include +#include -LUAU_FASTFLAGVARIABLE(LuauTypeMismatchInvarianceInError, false) -LUAU_FASTFLAGVARIABLE(LuauRequirePathTrueModuleName, false) +LUAU_FASTINTVARIABLE(LuauIndentTypeMismatchMaxTypeLength, 10) + +LUAU_DYNAMIC_FASTFLAGVARIABLE(LuauImproveNonFunctionCallError, false) static std::string wrongNumberOfArgsString( - size_t expectedCount, std::optional maximumCount, size_t actualCount, const char* argPrefix = nullptr, bool isVariadic = false) + size_t expectedCount, + std::optional maximumCount, + size_t actualCount, + const char* argPrefix = nullptr, + bool isVariadic = false +) { std::string s = "expects "; @@ -57,6 +69,30 @@ static std::string wrongNumberOfArgsString( namespace Luau { +// this list of binary operator type functions is used for better stringification of type functions errors +static const std::unordered_map kBinaryOps{ + {"add", "+"}, + {"sub", "-"}, + {"mul", "*"}, + {"div", "/"}, + {"idiv", "//"}, + {"pow", "^"}, + {"mod", "%"}, + {"concat", ".."}, + {"and", "and"}, + {"or", "or"}, + {"lt", "< or >="}, + {"le", "<= or >"}, + {"eq", "== or ~="} +}; + +// this list of unary operator type functions is used for better stringification of type functions errors +static const std::unordered_map kUnaryOps{{"unm", "-"}, {"len", "#"}, {"not", "not"}}; + +// this list of type functions will receive a special error indicating that the user should file a bug on the GitHub repository +// putting a type function in this list indicates that it is expected to _always_ reduce +static const std::unordered_set kUnreachableTypeFunctions{"refine", "singleton", "union", "intersect"}; + struct ErrorConverter { FileResolver* fileResolver = nullptr; @@ -68,6 +104,23 @@ struct ErrorConverter std::string result; + auto quote = [&](std::string s) + { + return "'" + s + "'"; + }; + + auto constructErrorMessage = + [&](std::string givenType, std::string wantedType, std::optional givenModule, std::optional wantedModule + ) -> std::string + { + std::string given = givenModule ? quote(givenType) + " from " + quote(*givenModule) : quote(givenType); + std::string wanted = wantedModule ? quote(wantedType) + " from " + quote(*wantedModule) : quote(wantedType); + size_t luauIndentTypeMismatchMaxTypeLength = size_t(FInt::LuauIndentTypeMismatchMaxTypeLength); + if (givenType.length() <= luauIndentTypeMismatchMaxTypeLength || wantedType.length() <= luauIndentTypeMismatchMaxTypeLength) + return "Type " + given + " could not be converted into " + wanted; + return "Type\n " + given + "\ncould not be converted into\n " + wanted; + }; + if (givenTypeName == wantedTypeName) { if (auto givenDefinitionModule = getDefinitionModuleName(tm.givenType)) @@ -78,20 +131,18 @@ struct ErrorConverter { std::string givenModuleName = fileResolver->getHumanReadableModuleName(*givenDefinitionModule); std::string wantedModuleName = fileResolver->getHumanReadableModuleName(*wantedDefinitionModule); - result = "Type '" + givenTypeName + "' from '" + givenModuleName + "' could not be converted into '" + wantedTypeName + - "' from '" + wantedModuleName + "'"; + result = constructErrorMessage(givenTypeName, wantedTypeName, givenModuleName, wantedModuleName); } else { - result = "Type '" + givenTypeName + "' from '" + *givenDefinitionModule + "' could not be converted into '" + wantedTypeName + - "' from '" + *wantedDefinitionModule + "'"; + result = constructErrorMessage(givenTypeName, wantedTypeName, *givenDefinitionModule, *wantedDefinitionModule); } } } } if (result.empty()) - result = "Type '" + givenTypeName + "' could not be converted into '" + wantedTypeName + "'"; + result = constructErrorMessage(givenTypeName, wantedTypeName, std::nullopt, std::nullopt); if (tm.error) @@ -99,7 +150,7 @@ struct ErrorConverter result += "\ncaused by:\n "; if (!tm.reason.empty()) - result += tm.reason + " "; + result += tm.reason + "\n"; result += Luau::toString(*tm.error, TypeErrorToStringOptions{fileResolver}); } @@ -107,7 +158,7 @@ struct ErrorConverter { result += "; " + tm.reason; } - else if (FFlag::LuauTypeMismatchInvarianceInError && tm.context == TypeMismatch::InvariantContext) + else if (tm.context == TypeMismatch::InvariantContext) { result += " in an invariant context"; } @@ -321,8 +372,70 @@ struct ErrorConverter return e.message; } + std::string operator()(const Luau::ConstraintSolvingIncompleteError& e) const + { + return "Type inference failed to complete, you may see some confusing types and type errors."; + } + + std::optional findCallMetamethod(TypeId type) const + { + type = follow(type); + + std::optional metatable; + if (const MetatableType* mtType = get(type)) + metatable = mtType->metatable; + else if (const ClassType* classType = get(type)) + metatable = classType->metatable; + + if (!metatable) + return std::nullopt; + + TypeId unwrapped = follow(*metatable); + + if (get(unwrapped)) + return unwrapped; + + const TableType* mtt = getTableType(unwrapped); + if (!mtt) + return std::nullopt; + + auto it = mtt->props.find("__call"); + if (it != mtt->props.end()) + return it->second.type(); + else + return std::nullopt; + } + std::string operator()(const Luau::CannotCallNonFunction& e) const { + if (DFFlag::LuauImproveNonFunctionCallError) + { + if (auto unionTy = get(follow(e.ty))) + { + std::string err = "Cannot call a value of the union type:"; + + for (auto option : unionTy) + { + option = follow(option); + + if (get(option) || findCallMetamethod(option)) + { + err += "\n | " + toString(option); + continue; + } + + // early-exit if we find something that isn't callable in the union. + return "Cannot call a value of type " + toString(option) + " in union:\n " + toString(e.ty); + } + + err += "\nWe are unable to determine the appropriate result type for such a call."; + + return err; + } + + return "Cannot call a value of type " + toString(e.ty); + } + return "Cannot call non-function " + toString(e.ty); } std::string operator()(const Luau::ExtraInformation& e) const @@ -350,7 +463,7 @@ struct ErrorConverter else s += " -> "; - if (FFlag::LuauRequirePathTrueModuleName && fileResolver != nullptr) + if (fileResolver != nullptr) s += fileResolver->getHumanReadableModuleName(name); else s += name; @@ -477,13 +590,227 @@ struct ErrorConverter std::string operator()(const TypePackMismatch& e) const { - return "Type pack '" + toString(e.givenTp) + "' could not be converted into '" + toString(e.wantedTp) + "'"; + std::string ss = "Type pack '" + toString(e.givenTp) + "' could not be converted into '" + toString(e.wantedTp) + "'"; + + if (!e.reason.empty()) + ss += "; " + e.reason; + + return ss; } std::string operator()(const DynamicPropertyLookupOnClassesUnsafe& e) const { return "Attempting a dynamic property access on type '" + Luau::toString(e.ty) + "' is unsafe and may cause exceptions at runtime"; } + + std::string operator()(const UninhabitedTypeFunction& e) const + { + auto tfit = get(e.ty); + LUAU_ASSERT(tfit); // Luau analysis has actually done something wrong if this type is not a type function. + if (!tfit) + return "Unexpected type " + Luau::toString(e.ty) + " flagged as an uninhabited type function."; + + // unary operators + if (auto unaryString = kUnaryOps.find(tfit->function->name); unaryString != kUnaryOps.end()) + { + std::string result = "Operator '" + std::string(unaryString->second) + "' could not be applied to "; + + if (tfit->typeArguments.size() == 1 && tfit->packArguments.empty()) + { + result += "operand of type " + Luau::toString(tfit->typeArguments[0]); + + if (tfit->function->name != "not") + result += "; there is no corresponding overload for __" + tfit->function->name; + } + else + { + // if it's not the expected case, we ought to add a specialization later, but this is a sane default. + result += "operands of types "; + + bool isFirst = true; + for (auto arg : tfit->typeArguments) + { + if (!isFirst) + result += ", "; + + result += Luau::toString(arg); + isFirst = false; + } + + for (auto packArg : tfit->packArguments) + result += ", " + Luau::toString(packArg); + } + + return result; + } + + // binary operators + if (auto binaryString = kBinaryOps.find(tfit->function->name); binaryString != kBinaryOps.end()) + { + std::string result = "Operator '" + std::string(binaryString->second) + "' could not be applied to operands of types "; + + if (tfit->typeArguments.size() == 2 && tfit->packArguments.empty()) + { + // this is the expected case. + result += Luau::toString(tfit->typeArguments[0]) + " and " + Luau::toString(tfit->typeArguments[1]); + } + else + { + // if it's not the expected case, we ought to add a specialization later, but this is a sane default. + + bool isFirst = true; + for (auto arg : tfit->typeArguments) + { + if (!isFirst) + result += ", "; + + result += Luau::toString(arg); + isFirst = false; + } + + for (auto packArg : tfit->packArguments) + result += ", " + Luau::toString(packArg); + } + + result += "; there is no corresponding overload for __" + tfit->function->name; + + return result; + } + + // miscellaneous + + if ("keyof" == tfit->function->name || "rawkeyof" == tfit->function->name) + { + if (tfit->typeArguments.size() == 1 && tfit->packArguments.empty()) + return "Type '" + toString(tfit->typeArguments[0]) + "' does not have keys, so '" + Luau::toString(e.ty) + "' is invalid"; + else + return "Type function instance " + Luau::toString(e.ty) + " is ill-formed, and thus invalid"; + } + + if ("index" == tfit->function->name || "rawget" == tfit->function->name) + { + if (tfit->typeArguments.size() != 2) + return "Type function instance " + Luau::toString(e.ty) + " is ill-formed, and thus invalid"; + + if (auto errType = get(tfit->typeArguments[1])) // Second argument to (index | rawget)<_,_> is not a type + return "Second argument to " + tfit->function->name + "<" + Luau::toString(tfit->typeArguments[0]) + ", _> is not a valid index type"; + else // Property `indexer` does not exist on type `indexee` + return "Property '" + Luau::toString(tfit->typeArguments[1]) + "' does not exist on type '" + Luau::toString(tfit->typeArguments[0]) + + "'"; + } + + if (kUnreachableTypeFunctions.count(tfit->function->name)) + { + return "Type function instance " + Luau::toString(e.ty) + " is uninhabited\n" + + "This is likely to be a bug, please report it at https://github.com/luau-lang/luau/issues"; + } + + // Everything should be specialized above to report a more descriptive error that hopefully does not mention "type functions" explicitly. + // If we produce this message, it's an indication that we've missed a specialization and it should be fixed! + return "Type function instance " + Luau::toString(e.ty) + " is uninhabited"; + } + + std::string operator()(const ExplicitFunctionAnnotationRecommended& r) const + { + std::string toReturn = toString(r.recommendedReturn); + std::string argAnnotations; + for (auto [arg, type] : r.recommendedArgs) + { + argAnnotations += arg + ": " + toString(type) + ", "; + } + if (argAnnotations.length() >= 2) + { + argAnnotations.pop_back(); + argAnnotations.pop_back(); + } + + if (argAnnotations.empty()) + return "Consider annotating the return with " + toReturn; + + return "Consider placing the following annotations on the arguments: " + argAnnotations + " or instead annotating the return as " + toReturn; + } + + std::string operator()(const UninhabitedTypePackFunction& e) const + { + return "Type pack function instance " + Luau::toString(e.tp) + " is uninhabited"; + } + + std::string operator()(const WhereClauseNeeded& e) const + { + return "Type function instance " + Luau::toString(e.ty) + + " depends on generic function parameters but does not appear in the function signature; this construct cannot be type-checked at this " + "time"; + } + + std::string operator()(const PackWhereClauseNeeded& e) const + { + return "Type pack function instance " + Luau::toString(e.tp) + + " depends on generic function parameters but does not appear in the function signature; this construct cannot be type-checked at this " + "time"; + } + + std::string operator()(const CheckedFunctionCallError& e) const + { + // TODO: What happens if checkedFunctionName cannot be found?? + return "Function '" + e.checkedFunctionName + "' expects '" + toString(e.expected) + "' at argument #" + std::to_string(e.argumentIndex) + + ", but got '" + Luau::toString(e.passed) + "'"; + } + + std::string operator()(const NonStrictFunctionDefinitionError& e) const + { + return "Argument " + e.argument + " with type '" + toString(e.argumentType) + "' in function '" + e.functionName + + "' is used in a way that will run time error"; + } + + std::string operator()(const PropertyAccessViolation& e) const + { + const std::string stringKey = isIdentifier(e.key) ? e.key : "\"" + e.key + "\""; + switch (e.context) + { + case PropertyAccessViolation::CannotRead: + return "Property " + stringKey + " of table '" + toString(e.table) + "' is write-only"; + case PropertyAccessViolation::CannotWrite: + return "Property " + stringKey + " of table '" + toString(e.table) + "' is read-only"; + } + + LUAU_UNREACHABLE(); + return ""; + } + + std::string operator()(const CheckedFunctionIncorrectArgs& e) const + { + return "Checked Function " + e.functionName + " expects " + std::to_string(e.expected) + " arguments, but received " + + std::to_string(e.actual); + } + + std::string operator()(const UnexpectedTypeInSubtyping& e) const + { + return "Encountered an unexpected type in subtyping: " + toString(e.ty); + } + + std::string operator()(const UnexpectedTypePackInSubtyping& e) const + { + return "Encountered an unexpected type pack in subtyping: " + toString(e.tp); + } + + std::string operator()(const CannotAssignToNever& e) const + { + std::string result = "Cannot assign a value of type " + toString(e.rhsType) + " to a field of type never"; + + switch (e.reason) + { + case CannotAssignToNever::Reason::PropertyNarrowed: + if (!e.cause.empty()) + { + result += "\ncaused by the property being given the following incompatible types:\n"; + for (auto ty : e.cause) + result += " " + toString(ty) + "\n"; + result += "There are no values that could safely satisfy all of these types at once."; + } + } + + return result; + } }; struct InvalidNameChecker @@ -576,6 +903,11 @@ bool UnknownProperty::operator==(const UnknownProperty& rhs) const return *table == *rhs.table && key == rhs.key; } +bool PropertyAccessViolation::operator==(const PropertyAccessViolation& rhs) const +{ + return *table == *rhs.table && key == rhs.key && context == rhs.context; +} + bool NotATable::operator==(const NotATable& rhs) const { return ty == rhs.ty; @@ -681,6 +1013,11 @@ bool InternalError::operator==(const InternalError& rhs) const return message == rhs.message; } +bool ConstraintSolvingIncompleteError::operator==(const ConstraintSolvingIncompleteError& rhs) const +{ + return true; +} + bool CannotCallNonFunction::operator==(const CannotCallNonFunction& rhs) const { return ty == rhs.ty; @@ -786,6 +1123,72 @@ bool DynamicPropertyLookupOnClassesUnsafe::operator==(const DynamicPropertyLooku return ty == rhs.ty; } +bool UninhabitedTypeFunction::operator==(const UninhabitedTypeFunction& rhs) const +{ + return ty == rhs.ty; +} + + +bool ExplicitFunctionAnnotationRecommended::operator==(const ExplicitFunctionAnnotationRecommended& rhs) const +{ + return recommendedReturn == rhs.recommendedReturn && recommendedArgs == rhs.recommendedArgs; +} + +bool UninhabitedTypePackFunction::operator==(const UninhabitedTypePackFunction& rhs) const +{ + return tp == rhs.tp; +} + +bool WhereClauseNeeded::operator==(const WhereClauseNeeded& rhs) const +{ + return ty == rhs.ty; +} + +bool PackWhereClauseNeeded::operator==(const PackWhereClauseNeeded& rhs) const +{ + return tp == rhs.tp; +} + +bool CheckedFunctionCallError::operator==(const CheckedFunctionCallError& rhs) const +{ + return *expected == *rhs.expected && *passed == *rhs.passed && checkedFunctionName == rhs.checkedFunctionName && + argumentIndex == rhs.argumentIndex; +} + +bool NonStrictFunctionDefinitionError::operator==(const NonStrictFunctionDefinitionError& rhs) const +{ + return functionName == rhs.functionName && argument == rhs.argument && argumentType == rhs.argumentType; +} + +bool CheckedFunctionIncorrectArgs::operator==(const CheckedFunctionIncorrectArgs& rhs) const +{ + return functionName == rhs.functionName && expected == rhs.expected && actual == rhs.actual; +} + +bool UnexpectedTypeInSubtyping::operator==(const UnexpectedTypeInSubtyping& rhs) const +{ + return ty == rhs.ty; +} + +bool UnexpectedTypePackInSubtyping::operator==(const UnexpectedTypePackInSubtyping& rhs) const +{ + return tp == rhs.tp; +} + +bool CannotAssignToNever::operator==(const CannotAssignToNever& rhs) const +{ + if (cause.size() != rhs.cause.size()) + return false; + + for (size_t i = 0; i < cause.size(); ++i) + { + if (*cause[i] != *rhs.cause[i]) + return false; + } + + return *rhsType == *rhs.rhsType && reason == rhs.reason; +} + std::string toString(const TypeError& error) { return toString(error, TypeErrorToStringOptions{}); @@ -803,13 +1206,15 @@ bool containsParseErrorName(const TypeError& error) } template -void copyError(T& e, TypeArena& destArena, CloneState cloneState) +void copyError(T& e, TypeArena& destArena, CloneState& cloneState) { - auto clone = [&](auto&& ty) { + auto clone = [&](auto&& ty) + { return ::Luau::clone(ty, destArena, cloneState); }; - auto visitErrorData = [&](auto&& e) { + auto visitErrorData = [&](auto&& e) + { copyError(e, destArena, cloneState); }; @@ -884,6 +1289,9 @@ void copyError(T& e, TypeArena& destArena, CloneState cloneState) else if constexpr (std::is_same_v) { } + else if constexpr (std::is_same_v) + { + } else if constexpr (std::is_same_v) { e.ty = clone(e.ty); @@ -944,15 +1352,55 @@ void copyError(T& e, TypeArena& destArena, CloneState cloneState) } else if constexpr (std::is_same_v) e.ty = clone(e.ty); + else if constexpr (std::is_same_v) + e.ty = clone(e.ty); + else if constexpr (std::is_same_v) + { + e.recommendedReturn = clone(e.recommendedReturn); + for (auto& [_, t] : e.recommendedArgs) + t = clone(t); + } + else if constexpr (std::is_same_v) + e.tp = clone(e.tp); + else if constexpr (std::is_same_v) + e.ty = clone(e.ty); + else if constexpr (std::is_same_v) + e.tp = clone(e.tp); + else if constexpr (std::is_same_v) + { + e.expected = clone(e.expected); + e.passed = clone(e.passed); + } + else if constexpr (std::is_same_v) + { + e.argumentType = clone(e.argumentType); + } + else if constexpr (std::is_same_v) + e.table = clone(e.table); + else if constexpr (std::is_same_v) + { + } + else if constexpr (std::is_same_v) + e.ty = clone(e.ty); + else if constexpr (std::is_same_v) + e.tp = clone(e.tp); + else if constexpr (std::is_same_v) + { + e.rhsType = clone(e.rhsType); + + for (auto& ty : e.cause) + ty = clone(ty); + } else static_assert(always_false_v, "Non-exhaustive type switch"); } -void copyErrors(ErrorVec& errors, TypeArena& destArena) +void copyErrors(ErrorVec& errors, TypeArena& destArena, NotNull builtinTypes) { - CloneState cloneState; + CloneState cloneState{builtinTypes}; - auto visitErrorData = [&](auto&& e) { + auto visitErrorData = [&](auto&& e) + { copyError(e, destArena, cloneState); }; diff --git a/third_party/luau/Analysis/src/Frontend.cpp b/third_party/luau/Analysis/src/Frontend.cpp index 486ef696..a8ae99d5 100644 --- a/third_party/luau/Analysis/src/Frontend.cpp +++ b/third_party/luau/Analysis/src/Frontend.cpp @@ -1,26 +1,34 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Frontend.h" +#include "Luau/AnyTypeSummary.h" #include "Luau/BuiltinDefinitions.h" #include "Luau/Clone.h" #include "Luau/Common.h" #include "Luau/Config.h" -#include "Luau/ConstraintGraphBuilder.h" +#include "Luau/ConstraintGenerator.h" #include "Luau/ConstraintSolver.h" #include "Luau/DataFlowGraph.h" #include "Luau/DcrLogger.h" #include "Luau/FileResolver.h" +#include "Luau/NonStrictTypeChecker.h" #include "Luau/Parser.h" #include "Luau/Scope.h" #include "Luau/StringUtils.h" #include "Luau/TimeTrace.h" +#include "Luau/ToString.h" +#include "Luau/Transpiler.h" +#include "Luau/TypeArena.h" #include "Luau/TypeChecker2.h" #include "Luau/TypeInfer.h" -#include "Luau/TypeReduction.h" #include "Luau/Variant.h" +#include "Luau/VisitType.h" #include #include +#include +#include +#include #include #include @@ -29,15 +37,46 @@ LUAU_FASTINT(LuauTypeInferRecursionLimit) LUAU_FASTINT(LuauTarjanChildLimit) LUAU_FASTFLAG(LuauInferInNoCheckMode) LUAU_FASTFLAGVARIABLE(LuauKnowsTheDataModel3, false) -LUAU_FASTINTVARIABLE(LuauAutocompleteCheckTimeoutMs, 100) -LUAU_FASTFLAGVARIABLE(DebugLuauDeferredConstraintResolution, false) +LUAU_FASTFLAGVARIABLE(LuauCancelFromProgress, false) +LUAU_FASTFLAGVARIABLE(LuauStoreCommentsForDefinitionFiles, false) +LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverToJson, false) -LUAU_FASTFLAG(LuauRequirePathTrueModuleName) -LUAU_FASTFLAGVARIABLE(DebugLuauReadWriteProperties, false) +LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverToJsonFile, false) +LUAU_FASTFLAGVARIABLE(DebugLuauForbidInternalTypes, false) +LUAU_FASTFLAGVARIABLE(DebugLuauForceStrictMode, false) +LUAU_FASTFLAGVARIABLE(DebugLuauForceNonStrictMode, false) +LUAU_FASTFLAGVARIABLE(LuauSourceModuleUpdatedWithSelectedMode, false) + +LUAU_FASTFLAG(StudioReportLuauAny) namespace Luau { +struct BuildQueueItem +{ + ModuleName name; + ModuleName humanReadableName; + + // Parameters + std::shared_ptr sourceNode; + std::shared_ptr sourceModule; + Config config; + ScopePtr environmentScope; + std::vector requireCycles; + FrontendOptions options; + bool recordJsonLog = false; + + // Queue state + std::vector reverseDeps; + int dirtyDependencies = 0; + bool processing = false; + + // Result + std::exception_ptr exception; + ModulePtr module; + Frontend::Stats stats; +}; + std::optional parseMode(const std::vector& hotcomments) { for (const HotComment& hc : hotcomments) @@ -94,12 +133,19 @@ static ParseResult parseSourceForModule(std::string_view source, Luau::SourceMod Luau::ParseResult parseResult = Luau::Parser::parse(source.data(), source.size(), *sourceModule.names, *sourceModule.allocator, options); sourceModule.root = parseResult.root; sourceModule.mode = Mode::Definition; + + if (FFlag::LuauStoreCommentsForDefinitionFiles && options.captureComments) + { + sourceModule.hotcomments = parseResult.hotcomments; + sourceModule.commentLocations = parseResult.commentLocations; + } + return parseResult; } static void persistCheckedTypes(ModulePtr checkedModule, GlobalTypes& globals, ScopePtr targetScope, const std::string& packageName) { - CloneState cloneState; + CloneState cloneState{globals.builtinTypes}; std::vector typesToPersist; typesToPersist.reserve(checkedModule->declaredGlobals.size() + checkedModule->exportedTypeBindings.size()); @@ -130,12 +176,21 @@ static void persistCheckedTypes(ModulePtr checkedModule, GlobalTypes& globals, S } } -LoadDefinitionFileResult Frontend::loadDefinitionFile(GlobalTypes& globals, ScopePtr targetScope, std::string_view source, - const std::string& packageName, bool captureComments, bool typeCheckForAutocomplete) +LoadDefinitionFileResult Frontend::loadDefinitionFile( + GlobalTypes& globals, + ScopePtr targetScope, + std::string_view source, + const std::string& packageName, + bool captureComments, + bool typeCheckForAutocomplete +) { LUAU_TIMETRACE_SCOPE("loadDefinitionFile", "Frontend"); Luau::SourceModule sourceModule; + sourceModule.name = packageName; + sourceModule.humanReadableName = packageName; + Luau::ParseResult parseResult = parseSourceForModule(source, sourceModule, captureComments); if (parseResult.errors.size() > 0) return LoadDefinitionFileResult{false, parseResult, sourceModule, nullptr}; @@ -220,9 +275,12 @@ namespace { static ErrorVec accumulateErrors( - const std::unordered_map& sourceNodes, ModuleResolver& moduleResolver, const ModuleName& name) + const std::unordered_map>& sourceNodes, + ModuleResolver& moduleResolver, + const ModuleName& name +) { - std::unordered_set seen; + DenseHashSet seen{{}}; std::vector queue{name}; ErrorVec result; @@ -232,7 +290,7 @@ static ErrorVec accumulateErrors( ModuleName next = std::move(queue.back()); queue.pop_back(); - if (seen.count(next)) + if (seen.contains(next)) continue; seen.insert(next); @@ -240,7 +298,7 @@ static ErrorVec accumulateErrors( if (it == sourceNodes.end()) continue; - const SourceNode& sourceNode = it->second; + const SourceNode& sourceNode = *it->second; queue.insert(queue.end(), sourceNode.requireSet.begin(), sourceNode.requireSet.end()); // FIXME: If a module has a syntax error, we won't be able to re-report it here. @@ -252,9 +310,14 @@ static ErrorVec accumulateErrors( Module& module = *modulePtr; - std::sort(module.errors.begin(), module.errors.end(), [](const TypeError& e1, const TypeError& e2) -> bool { - return e1.location.begin > e2.location.begin; - }); + std::sort( + module.errors.begin(), + module.errors.end(), + [](const TypeError& e1, const TypeError& e2) -> bool + { + return e1.location.begin > e2.location.begin; + } + ); result.insert(result.end(), module.errors.begin(), module.errors.end()); } @@ -286,7 +349,11 @@ static void filterLintOptions(LintOptions& lintOptions, const std::vector getRequireCycles( - const FileResolver* resolver, const std::unordered_map& sourceNodes, const SourceNode* start, bool stopAtFirst = false) + const FileResolver* resolver, + const std::unordered_map>& sourceNodes, + const SourceNode* start, + bool stopAtFirst = false +) { std::vector result; @@ -302,7 +369,7 @@ std::vector getRequireCycles( if (dit == sourceNodes.end()) continue; - stack.push_back(&dit->second); + stack.push_back(dit->second.get()); while (!stack.empty()) { @@ -320,9 +387,9 @@ std::vector getRequireCycles( if (top == start) { for (const SourceNode* node : path) - cycle.push_back(FFlag::LuauRequirePathTrueModuleName ? node->name : node->humanReadableName); + cycle.push_back(node->name); - cycle.push_back(FFlag::LuauRequirePathTrueModuleName ? top->name : top->humanReadableName); + cycle.push_back(top->name); break; } } @@ -343,7 +410,7 @@ std::vector getRequireCycles( auto rit = sourceNodes.find(reqName); if (rit != sourceNodes.end()) - stack.push_back(&rit->second); + stack.push_back(rit->second.get()); } } } @@ -388,202 +455,376 @@ Frontend::Frontend(FileResolver* fileResolver, ConfigResolver* configResolver, c { } +void Frontend::parse(const ModuleName& name) +{ + LUAU_TIMETRACE_SCOPE("Frontend::parse", "Frontend"); + LUAU_TIMETRACE_ARGUMENT("name", name.c_str()); + + if (getCheckResult(name, false, false)) + return; + + std::vector buildQueue; + parseGraph(buildQueue, name, false); +} + CheckResult Frontend::check(const ModuleName& name, std::optional optionOverride) { LUAU_TIMETRACE_SCOPE("Frontend::check", "Frontend"); LUAU_TIMETRACE_ARGUMENT("name", name.c_str()); FrontendOptions frontendOptions = optionOverride.value_or(options); - CheckResult checkResult; + if (FFlag::DebugLuauDeferredConstraintResolution) + frontendOptions.forAutocomplete = false; - FrontendModuleResolver& resolver = frontendOptions.forAutocomplete ? moduleResolverForAutocomplete : moduleResolver; + if (std::optional result = getCheckResult(name, true, frontendOptions.forAutocomplete)) + return std::move(*result); - auto it = sourceNodes.find(name); - if (it != sourceNodes.end() && !it->second.hasDirtyModule(frontendOptions.forAutocomplete)) + std::vector buildQueue; + bool cycleDetected = parseGraph(buildQueue, name, frontendOptions.forAutocomplete); + + DenseHashSet seen{{}}; + std::vector buildQueueItems; + addBuildQueueItems(buildQueueItems, buildQueue, cycleDetected, seen, frontendOptions); + LUAU_ASSERT(!buildQueueItems.empty()); + + if (FFlag::DebugLuauLogSolverToJson) { - // No recheck required. - ModulePtr module = resolver.getModule(name); + LUAU_ASSERT(buildQueueItems.back().name == name); + buildQueueItems.back().recordJsonLog = true; + } - if (!module) - throw InternalCompilerError("Frontend::modules does not have data for " + name, name); + checkBuildQueueItems(buildQueueItems); - checkResult.errors = accumulateErrors(sourceNodes, resolver, name); + // Collect results only for checked modules, 'getCheckResult' produces a different result + CheckResult checkResult; + + for (const BuildQueueItem& item : buildQueueItems) + { + if (item.module->timeout) + checkResult.timeoutHits.push_back(item.name); + + // If check was manually cancelled, do not return partial results + if (item.module->cancelled) + return {}; + + checkResult.errors.insert(checkResult.errors.end(), item.module->errors.begin(), item.module->errors.end()); - // Get lint result only for top checked module - checkResult.lintResult = module->lintResult; + if (item.name == name) + checkResult.lintResult = item.module->lintResult; - return checkResult; + if (FFlag::StudioReportLuauAny && item.options.retainFullTypeGraphs) + { + if (item.module) + { + const SourceModule& sourceModule = *item.sourceModule; + if (sourceModule.mode == Luau::Mode::Strict) + { + item.module->ats.root = toString(sourceModule.root); + } + item.module->ats.rootSrc = sourceModule.root; + item.module->ats.traverse(item.module.get(), sourceModule.root, NotNull{&builtinTypes_}); + } + } } - std::vector buildQueue; - bool cycleDetected = parseGraph(buildQueue, name, frontendOptions.forAutocomplete); + return checkResult; +} - for (const ModuleName& moduleName : buildQueue) - { - LUAU_ASSERT(sourceNodes.count(moduleName)); - SourceNode& sourceNode = sourceNodes[moduleName]; +void Frontend::queueModuleCheck(const std::vector& names) +{ + moduleQueue.insert(moduleQueue.end(), names.begin(), names.end()); +} - if (!sourceNode.hasDirtyModule(frontendOptions.forAutocomplete)) - continue; +void Frontend::queueModuleCheck(const ModuleName& name) +{ + moduleQueue.push_back(name); +} - LUAU_ASSERT(sourceModules.count(moduleName)); - SourceModule& sourceModule = sourceModules[moduleName]; +std::vector Frontend::checkQueuedModules( + std::optional optionOverride, + std::function task)> executeTask, + std::function progress +) +{ + FrontendOptions frontendOptions = optionOverride.value_or(options); + if (FFlag::DebugLuauDeferredConstraintResolution) + frontendOptions.forAutocomplete = false; - const Config& config = configResolver->getConfig(moduleName); + // By taking data into locals, we make sure queue is cleared at the end, even if an ICE or a different exception is thrown + std::vector currModuleQueue; + std::swap(currModuleQueue, moduleQueue); - Mode mode = sourceModule.mode.value_or(config.mode); + DenseHashSet seen{{}}; + std::vector buildQueueItems; - ScopePtr environmentScope = getModuleEnvironment(sourceModule, config, frontendOptions.forAutocomplete); + for (const ModuleName& name : currModuleQueue) + { + if (seen.contains(name)) + continue; - double timestamp = getTimestamp(); + if (!isDirty(name, frontendOptions.forAutocomplete)) + { + seen.insert(name); + continue; + } - std::vector requireCycles; + std::vector queue; + bool cycleDetected = parseGraph( + queue, + name, + frontendOptions.forAutocomplete, + [&seen](const ModuleName& name) + { + return seen.contains(name); + } + ); - // in NoCheck mode we only need to compute the value of .cyclic for typeck - // in the future we could replace toposort with an algorithm that can flag cyclic nodes by itself - // however, for now getRequireCycles isn't expensive in practice on the cases we care about, and long term - // all correct programs must be acyclic so this code triggers rarely - if (cycleDetected) - requireCycles = getRequireCycles(fileResolver, sourceNodes, &sourceNode, mode == Mode::NoCheck); + addBuildQueueItems(buildQueueItems, queue, cycleDetected, seen, frontendOptions); + } - // This is used by the type checker to replace the resulting type of cyclic modules with any - sourceModule.cyclic = !requireCycles.empty(); + if (buildQueueItems.empty()) + return {}; + + // We need a mapping from modules to build queue slots + std::unordered_map moduleNameToQueue; - if (frontendOptions.forAutocomplete) + for (size_t i = 0; i < buildQueueItems.size(); i++) + { + BuildQueueItem& item = buildQueueItems[i]; + moduleNameToQueue[item.name] = i; + } + + // Default task execution is single-threaded and immediate + if (!executeTask) + { + executeTask = [](std::function task) { - double autocompleteTimeLimit = FInt::LuauAutocompleteCheckTimeoutMs / 1000.0; + task(); + }; + } - // The autocomplete typecheck is always in strict mode with DM awareness - // to provide better type information for IDE features - TypeCheckLimits typeCheckLimits; + std::mutex mtx; + std::condition_variable cv; + std::vector readyQueueItems; - if (autocompleteTimeLimit != 0.0) - typeCheckLimits.finishTime = TimeTrace::getClock() + autocompleteTimeLimit; - else - typeCheckLimits.finishTime = std::nullopt; + size_t processing = 0; + size_t remaining = buildQueueItems.size(); - // TODO: This is a dirty ad hoc solution for autocomplete timeouts - // We are trying to dynamically adjust our existing limits to lower total typechecking time under the limit - // so that we'll have type information for the whole file at lower quality instead of a full abort in the middle - if (FInt::LuauTarjanChildLimit > 0) - typeCheckLimits.instantiationChildLimit = std::max(1, int(FInt::LuauTarjanChildLimit * sourceNode.autocompleteLimitsMult)); - else - typeCheckLimits.instantiationChildLimit = std::nullopt; + auto itemTask = [&](size_t i) + { + BuildQueueItem& item = buildQueueItems[i]; - if (FInt::LuauTypeInferIterationLimit > 0) - typeCheckLimits.unifierIterationLimit = std::max(1, int(FInt::LuauTypeInferIterationLimit * sourceNode.autocompleteLimitsMult)); - else - typeCheckLimits.unifierIterationLimit = std::nullopt; + try + { + checkBuildQueueItem(item); + } + catch (...) + { + item.exception = std::current_exception(); + } + + { + std::unique_lock guard(mtx); + readyQueueItems.push_back(i); + } - ModulePtr moduleForAutocomplete = check(sourceModule, Mode::Strict, requireCycles, environmentScope, /*forAutocomplete*/ true, - /*recordJsonLog*/ false, typeCheckLimits); + cv.notify_one(); + }; - resolver.setModule(moduleName, moduleForAutocomplete); + auto sendItemTask = [&](size_t i) + { + BuildQueueItem& item = buildQueueItems[i]; - double duration = getTimestamp() - timestamp; + item.processing = true; + processing++; - if (moduleForAutocomplete->timeout) + executeTask( + [&itemTask, i]() { - checkResult.timeoutHits.push_back(moduleName); - - sourceNode.autocompleteLimitsMult = sourceNode.autocompleteLimitsMult / 2.0; + itemTask(i); } - else if (duration < autocompleteTimeLimit / 2.0) + ); + }; + + auto sendCycleItemTask = [&] + { + for (size_t i = 0; i < buildQueueItems.size(); i++) + { + BuildQueueItem& item = buildQueueItems[i]; + + if (!item.processing) { - sourceNode.autocompleteLimitsMult = std::min(sourceNode.autocompleteLimitsMult * 2.0, 1.0); + sendItemTask(i); + break; } + } + }; - stats.timeCheck += duration; - stats.filesStrict += 1; + // In a first pass, check modules that have no dependencies and record info of those modules that wait + for (size_t i = 0; i < buildQueueItems.size(); i++) + { + BuildQueueItem& item = buildQueueItems[i]; - sourceNode.dirtyModuleForAutocomplete = false; - continue; - } + for (const ModuleName& dep : item.sourceNode->requireSet) + { + if (auto it = sourceNodes.find(dep); it != sourceNodes.end()) + { + if (it->second->hasDirtyModule(frontendOptions.forAutocomplete)) + { + item.dirtyDependencies++; - const bool recordJsonLog = FFlag::DebugLuauLogSolverToJson && moduleName == name; - ModulePtr module = check(sourceModule, mode, requireCycles, environmentScope, /*forAutocomplete*/ false, recordJsonLog, {}); + buildQueueItems[moduleNameToQueue[dep]].reverseDeps.push_back(i); + } + } + } - stats.timeCheck += getTimestamp() - timestamp; - stats.filesStrict += mode == Mode::Strict; - stats.filesNonstrict += mode == Mode::Nonstrict; + if (item.dirtyDependencies == 0) + sendItemTask(i); + } - if (module == nullptr) - throw InternalCompilerError("Frontend::check produced a nullptr module for " + moduleName, moduleName); + // Not a single item was found, a cycle in the graph was hit + if (processing == 0) + sendCycleItemTask(); - if (FFlag::DebugLuauDeferredConstraintResolution && mode == Mode::NoCheck) - module->errors.clear(); + std::vector nextItems; + std::optional itemWithException; + bool cancelled = false; - if (frontendOptions.runLintChecks) + while (remaining != 0) + { { - LUAU_TIMETRACE_SCOPE("lint", "Frontend"); + std::unique_lock guard(mtx); + + // If nothing is ready yet, wait + cv.wait( + guard, + [&readyQueueItems] + { + return !readyQueueItems.empty(); + } + ); - LintOptions lintOptions = frontendOptions.enabledLintWarnings.value_or(config.enabledLint); - filterLintOptions(lintOptions, sourceModule.hotcomments, mode); + // Handle checked items + for (size_t i : readyQueueItems) + { + const BuildQueueItem& item = buildQueueItems[i]; - double timestamp = getTimestamp(); + // If exception was thrown, stop adding new items and wait for processing items to complete + if (item.exception) + itemWithException = i; - std::vector warnings = - Luau::lint(sourceModule.root, *sourceModule.names, environmentScope, module.get(), sourceModule.hotcomments, lintOptions); + if (item.module && item.module->cancelled) + cancelled = true; - stats.timeLint += getTimestamp() - timestamp; + if (itemWithException || cancelled) + break; - module->lintResult = classifyLints(warnings, config); + recordItemResult(item); + + // Notify items that were waiting for this dependency + for (size_t reverseDep : item.reverseDeps) + { + BuildQueueItem& reverseDepItem = buildQueueItems[reverseDep]; + + LUAU_ASSERT(reverseDepItem.dirtyDependencies != 0); + reverseDepItem.dirtyDependencies--; + + // In case of a module cycle earlier, check if unlocked an item that was already processed + if (!reverseDepItem.processing && reverseDepItem.dirtyDependencies == 0) + nextItems.push_back(reverseDep); + } + } + + LUAU_ASSERT(processing >= readyQueueItems.size()); + processing -= readyQueueItems.size(); + + LUAU_ASSERT(remaining >= readyQueueItems.size()); + remaining -= readyQueueItems.size(); + readyQueueItems.clear(); } - if (!frontendOptions.retainFullTypeGraphs) + if (progress) { - // copyErrors needs to allocate into interfaceTypes as it copies - // types out of internalTypes, so we unfreeze it here. - unfreeze(module->interfaceTypes); - copyErrors(module->errors, module->interfaceTypes); - freeze(module->interfaceTypes); - - module->internalTypes.clear(); - - module->astTypes.clear(); - module->astTypePacks.clear(); - module->astExpectedTypes.clear(); - module->astOriginalCallTypes.clear(); - module->astOverloadResolvedTypes.clear(); - module->astResolvedTypes.clear(); - module->astOriginalResolvedTypes.clear(); - module->astResolvedTypePacks.clear(); - module->astScopes.clear(); - - module->scopes.clear(); + if (FFlag::LuauCancelFromProgress) + { + if (!progress(buildQueueItems.size() - remaining, buildQueueItems.size())) + cancelled = true; + } + else + { + progress(buildQueueItems.size() - remaining, buildQueueItems.size()); + } } - if (mode != Mode::NoCheck) + // Items cannot be submitted while holding the lock + for (size_t i : nextItems) + sendItemTask(i); + nextItems.clear(); + + if (processing == 0) { - for (const RequireCycle& cyc : requireCycles) - { - TypeError te{cyc.location, moduleName, ModuleHasCyclicDependency{cyc.path}}; + // Typechecking might have been cancelled by user, don't return partial results + if (cancelled) + return {}; - module->errors.push_back(te); - } + // We might have stopped because of a pending exception + if (itemWithException) + recordItemResult(buildQueueItems[*itemWithException]); } - ErrorVec parseErrors; + // If we aren't done, but don't have anything processing, we hit a cycle + if (remaining != 0 && processing == 0) + sendCycleItemTask(); + } - for (const ParseError& pe : sourceModule.parseErrors) - parseErrors.push_back(TypeError{pe.getLocation(), moduleName, SyntaxError{pe.what()}}); + std::vector checkedModules; + checkedModules.reserve(buildQueueItems.size()); - module->errors.insert(module->errors.begin(), parseErrors.begin(), parseErrors.end()); + for (size_t i = 0; i < buildQueueItems.size(); i++) + checkedModules.push_back(std::move(buildQueueItems[i].name)); - checkResult.errors.insert(checkResult.errors.end(), module->errors.begin(), module->errors.end()); + return checkedModules; +} - resolver.setModule(moduleName, std::move(module)); - sourceNode.dirtyModule = false; - } +std::optional Frontend::getCheckResult(const ModuleName& name, bool accumulateNested, bool forAutocomplete) +{ + if (FFlag::DebugLuauDeferredConstraintResolution) + forAutocomplete = false; + + auto it = sourceNodes.find(name); + + if (it == sourceNodes.end() || it->second->hasDirtyModule(forAutocomplete)) + return std::nullopt; + + auto& resolver = forAutocomplete ? moduleResolverForAutocomplete : moduleResolver; + + ModulePtr module = resolver.getModule(name); + + if (module == nullptr) + throw InternalCompilerError("Frontend does not have module: " + name, name); + + CheckResult checkResult; + + if (module->timeout) + checkResult.timeoutHits.push_back(name); + + if (accumulateNested) + checkResult.errors = accumulateErrors(sourceNodes, resolver, name); + else + checkResult.errors.insert(checkResult.errors.end(), module->errors.begin(), module->errors.end()); // Get lint result only for top checked module - if (ModulePtr module = resolver.getModule(name)) - checkResult.lintResult = module->lintResult; + checkResult.lintResult = module->lintResult; return checkResult; } -bool Frontend::parseGraph(std::vector& buildQueue, const ModuleName& root, bool forAutocomplete) +bool Frontend::parseGraph( + std::vector& buildQueue, + const ModuleName& root, + bool forAutocomplete, + std::function canSkip +) { LUAU_TIMETRACE_SCOPE("Frontend::parseGraph", "Frontend"); LUAU_TIMETRACE_ARGUMENT("root", root.c_str()); @@ -654,14 +895,18 @@ bool Frontend::parseGraph(std::vector& buildQueue, const ModuleName& // this relies on the fact that markDirty marks reverse-dependencies dirty as well // thus if a node is not dirty, all its transitive deps aren't dirty, which means that they won't ever need // to be built, *and* can't form a cycle with any nodes we did process. - if (!it->second.hasDirtyModule(forAutocomplete)) + if (!it->second->hasDirtyModule(forAutocomplete)) + continue; + + // This module might already be in the outside build queue + if (canSkip && canSkip(dep)) continue; // note: this check is technically redundant *except* that getSourceNode has somewhat broken memoization // calling getSourceNode twice in succession will reparse the file, since getSourceNode leaves dirty flag set - if (seen.contains(&it->second)) + if (seen.contains(it->second.get())) { - stack.push_back(&it->second); + stack.push_back(it->second.get()); continue; } } @@ -681,6 +926,248 @@ bool Frontend::parseGraph(std::vector& buildQueue, const ModuleName& return cyclic; } +void Frontend::addBuildQueueItems( + std::vector& items, + std::vector& buildQueue, + bool cycleDetected, + DenseHashSet& seen, + const FrontendOptions& frontendOptions +) +{ + for (const ModuleName& moduleName : buildQueue) + { + if (seen.contains(moduleName)) + continue; + seen.insert(moduleName); + + LUAU_ASSERT(sourceNodes.count(moduleName)); + std::shared_ptr& sourceNode = sourceNodes[moduleName]; + + if (!sourceNode->hasDirtyModule(frontendOptions.forAutocomplete)) + continue; + + LUAU_ASSERT(sourceModules.count(moduleName)); + std::shared_ptr& sourceModule = sourceModules[moduleName]; + + BuildQueueItem data{moduleName, fileResolver->getHumanReadableModuleName(moduleName), sourceNode, sourceModule}; + + data.config = configResolver->getConfig(moduleName); + data.environmentScope = getModuleEnvironment(*sourceModule, data.config, frontendOptions.forAutocomplete); + data.recordJsonLog = FFlag::DebugLuauLogSolverToJson; + + Mode mode = sourceModule->mode.value_or(data.config.mode); + + // in NoCheck mode we only need to compute the value of .cyclic for typeck + // in the future we could replace toposort with an algorithm that can flag cyclic nodes by itself + // however, for now getRequireCycles isn't expensive in practice on the cases we care about, and long term + // all correct programs must be acyclic so this code triggers rarely + if (cycleDetected) + data.requireCycles = getRequireCycles(fileResolver, sourceNodes, sourceNode.get(), mode == Mode::NoCheck); + + data.options = frontendOptions; + + // This is used by the type checker to replace the resulting type of cyclic modules with any + sourceModule->cyclic = !data.requireCycles.empty(); + + items.push_back(std::move(data)); + } +} + +static void applyInternalLimitScaling(SourceNode& sourceNode, const ModulePtr module, double limit) +{ + if (module->timeout) + sourceNode.autocompleteLimitsMult = sourceNode.autocompleteLimitsMult / 2.0; + else if (module->checkDurationSec < limit / 2.0) + sourceNode.autocompleteLimitsMult = std::min(sourceNode.autocompleteLimitsMult * 2.0, 1.0); +} + +void Frontend::checkBuildQueueItem(BuildQueueItem& item) +{ + SourceNode& sourceNode = *item.sourceNode; + const SourceModule& sourceModule = *item.sourceModule; + const Config& config = item.config; + Mode mode; + if (FFlag::DebugLuauForceStrictMode) + mode = Mode::Strict; + else if (FFlag::DebugLuauForceNonStrictMode) + mode = Mode::Nonstrict; + else + mode = sourceModule.mode.value_or(config.mode); + + if (FFlag::LuauSourceModuleUpdatedWithSelectedMode) + item.sourceModule->mode = {mode}; + ScopePtr environmentScope = item.environmentScope; + double timestamp = getTimestamp(); + const std::vector& requireCycles = item.requireCycles; + + TypeCheckLimits typeCheckLimits; + + if (item.options.moduleTimeLimitSec) + typeCheckLimits.finishTime = TimeTrace::getClock() + *item.options.moduleTimeLimitSec; + else + typeCheckLimits.finishTime = std::nullopt; + + // TODO: This is a dirty ad hoc solution for autocomplete timeouts + // We are trying to dynamically adjust our existing limits to lower total typechecking time under the limit + // so that we'll have type information for the whole file at lower quality instead of a full abort in the middle + if (item.options.applyInternalLimitScaling) + { + if (FInt::LuauTarjanChildLimit > 0) + typeCheckLimits.instantiationChildLimit = std::max(1, int(FInt::LuauTarjanChildLimit * sourceNode.autocompleteLimitsMult)); + else + typeCheckLimits.instantiationChildLimit = std::nullopt; + + if (FInt::LuauTypeInferIterationLimit > 0) + typeCheckLimits.unifierIterationLimit = std::max(1, int(FInt::LuauTypeInferIterationLimit * sourceNode.autocompleteLimitsMult)); + else + typeCheckLimits.unifierIterationLimit = std::nullopt; + } + + typeCheckLimits.cancellationToken = item.options.cancellationToken; + + if (item.options.forAutocomplete) + { + // The autocomplete typecheck is always in strict mode with DM awareness to provide better type information for IDE features + ModulePtr moduleForAutocomplete = check( + sourceModule, + Mode::Strict, + requireCycles, + environmentScope, + /*forAutocomplete*/ true, + /*recordJsonLog*/ false, + typeCheckLimits + ); + + double duration = getTimestamp() - timestamp; + + moduleForAutocomplete->checkDurationSec = duration; + + if (item.options.moduleTimeLimitSec && item.options.applyInternalLimitScaling) + applyInternalLimitScaling(sourceNode, moduleForAutocomplete, *item.options.moduleTimeLimitSec); + + item.stats.timeCheck += duration; + item.stats.filesStrict += 1; + + item.module = moduleForAutocomplete; + return; + } + + ModulePtr module = check(sourceModule, mode, requireCycles, environmentScope, /*forAutocomplete*/ false, item.recordJsonLog, typeCheckLimits); + + double duration = getTimestamp() - timestamp; + + module->checkDurationSec = duration; + + if (item.options.moduleTimeLimitSec && item.options.applyInternalLimitScaling) + applyInternalLimitScaling(sourceNode, module, *item.options.moduleTimeLimitSec); + + item.stats.timeCheck += duration; + item.stats.filesStrict += mode == Mode::Strict; + item.stats.filesNonstrict += mode == Mode::Nonstrict; + + if (module == nullptr) + throw InternalCompilerError("Frontend::check produced a nullptr module for " + item.name, item.name); + + if (FFlag::DebugLuauDeferredConstraintResolution && mode == Mode::NoCheck) + module->errors.clear(); + + if (item.options.runLintChecks) + { + LUAU_TIMETRACE_SCOPE("lint", "Frontend"); + + LintOptions lintOptions = item.options.enabledLintWarnings.value_or(config.enabledLint); + filterLintOptions(lintOptions, sourceModule.hotcomments, mode); + + double timestamp = getTimestamp(); + + std::vector warnings = + Luau::lint(sourceModule.root, *sourceModule.names, environmentScope, module.get(), sourceModule.hotcomments, lintOptions); + + item.stats.timeLint += getTimestamp() - timestamp; + + module->lintResult = classifyLints(warnings, config); + } + + if (!item.options.retainFullTypeGraphs) + { + // copyErrors needs to allocate into interfaceTypes as it copies + // types out of internalTypes, so we unfreeze it here. + unfreeze(module->interfaceTypes); + copyErrors(module->errors, module->interfaceTypes, builtinTypes); + freeze(module->interfaceTypes); + + module->internalTypes.clear(); + + module->astTypes.clear(); + module->astTypePacks.clear(); + module->astExpectedTypes.clear(); + module->astOriginalCallTypes.clear(); + module->astOverloadResolvedTypes.clear(); + module->astForInNextTypes.clear(); + module->astResolvedTypes.clear(); + module->astResolvedTypePacks.clear(); + module->astCompoundAssignResultTypes.clear(); + module->astScopes.clear(); + module->upperBoundContributors.clear(); + module->scopes.clear(); + } + + if (mode != Mode::NoCheck) + { + for (const RequireCycle& cyc : requireCycles) + { + TypeError te{cyc.location, item.name, ModuleHasCyclicDependency{cyc.path}}; + + module->errors.push_back(te); + } + } + + ErrorVec parseErrors; + + for (const ParseError& pe : sourceModule.parseErrors) + parseErrors.push_back(TypeError{pe.getLocation(), item.name, SyntaxError{pe.what()}}); + + module->errors.insert(module->errors.begin(), parseErrors.begin(), parseErrors.end()); + + item.module = module; +} + +void Frontend::checkBuildQueueItems(std::vector& items) +{ + for (BuildQueueItem& item : items) + { + checkBuildQueueItem(item); + + if (item.module && item.module->cancelled) + break; + + recordItemResult(item); + } +} + +void Frontend::recordItemResult(const BuildQueueItem& item) +{ + if (item.exception) + std::rethrow_exception(item.exception); + + if (item.options.forAutocomplete) + { + moduleResolverForAutocomplete.setModule(item.name, item.module); + item.sourceNode->dirtyModuleForAutocomplete = false; + } + else + { + moduleResolver.setModule(item.name, item.module); + item.sourceNode->dirtyModule = false; + } + + stats.timeCheck += item.stats.timeCheck; + stats.timeLint += item.stats.timeLint; + + stats.filesStrict += item.stats.filesStrict; + stats.filesNonstrict += item.stats.filesNonstrict; +} + ScopePtr Frontend::getModuleEnvironment(const SourceModule& module, const Config& config, bool forAutocomplete) const { ScopePtr result; @@ -711,7 +1198,7 @@ ScopePtr Frontend::getModuleEnvironment(const SourceModule& module, const Config bool Frontend::isDirty(const ModuleName& name, bool forAutocomplete) const { auto it = sourceNodes.find(name); - return it == sourceNodes.end() || it->second.hasDirtyModule(forAutocomplete); + return it == sourceNodes.end() || it->second->hasDirtyModule(forAutocomplete); } /* @@ -722,13 +1209,13 @@ bool Frontend::isDirty(const ModuleName& name, bool forAutocomplete) const */ void Frontend::markDirty(const ModuleName& name, std::vector* markedDirty) { - if (!moduleResolver.getModule(name) && !moduleResolverForAutocomplete.getModule(name)) + if (sourceNodes.count(name) == 0) return; std::unordered_map> reverseDeps; for (const auto& module : sourceNodes) { - for (const auto& dep : module.second.requireSet) + for (const auto& dep : module.second->requireSet) reverseDeps[dep].push_back(module.first); } @@ -740,7 +1227,7 @@ void Frontend::markDirty(const ModuleName& name, std::vector* marked queue.pop_back(); LUAU_ASSERT(sourceNodes.count(next) > 0); - SourceNode& sourceNode = sourceNodes[next]; + SourceNode& sourceNode = *sourceNodes[next]; if (markedDirty) markedDirty->push_back(next); @@ -766,7 +1253,7 @@ SourceModule* Frontend::getSourceModule(const ModuleName& moduleName) { auto it = sourceModules.find(moduleName); if (it != sourceModules.end()) - return &it->second; + return it->second.get(); else return nullptr; } @@ -776,24 +1263,111 @@ const SourceModule* Frontend::getSourceModule(const ModuleName& moduleName) cons return const_cast(this)->getSourceModule(moduleName); } -ModulePtr check(const SourceModule& sourceModule, const std::vector& requireCycles, NotNull builtinTypes, - NotNull iceHandler, NotNull moduleResolver, NotNull fileResolver, - const ScopePtr& parentScope, std::function prepareModuleScope, FrontendOptions options) +ModulePtr check( + const SourceModule& sourceModule, + Mode mode, + const std::vector& requireCycles, + NotNull builtinTypes, + NotNull iceHandler, + NotNull moduleResolver, + NotNull fileResolver, + const ScopePtr& parentScope, + std::function prepareModuleScope, + FrontendOptions options, + TypeCheckLimits limits, + std::function writeJsonLog +) { const bool recordJsonLog = FFlag::DebugLuauLogSolverToJson; - return check(sourceModule, requireCycles, builtinTypes, iceHandler, moduleResolver, fileResolver, parentScope, std::move(prepareModuleScope), - options, recordJsonLog); + return check( + sourceModule, + mode, + requireCycles, + builtinTypes, + iceHandler, + moduleResolver, + fileResolver, + parentScope, + std::move(prepareModuleScope), + options, + limits, + recordJsonLog, + writeJsonLog + ); } -ModulePtr check(const SourceModule& sourceModule, const std::vector& requireCycles, NotNull builtinTypes, - NotNull iceHandler, NotNull moduleResolver, NotNull fileResolver, - const ScopePtr& parentScope, std::function prepareModuleScope, FrontendOptions options, - bool recordJsonLog) +struct InternalTypeFinder : TypeOnceVisitor +{ + bool visit(TypeId, const ClassType&) override + { + return false; + } + + bool visit(TypeId, const BlockedType&) override + { + LUAU_ASSERT(false); + return false; + } + + bool visit(TypeId, const FreeType&) override + { + LUAU_ASSERT(false); + return false; + } + + bool visit(TypeId, const PendingExpansionType&) override + { + LUAU_ASSERT(false); + return false; + } + + bool visit(TypePackId, const BlockedTypePack&) override + { + LUAU_ASSERT(false); + return false; + } + + bool visit(TypePackId, const FreeTypePack&) override + { + LUAU_ASSERT(false); + return false; + } + + bool visit(TypePackId, const TypeFunctionInstanceTypePack&) override + { + LUAU_ASSERT(false); + return false; + } +}; + +ModulePtr check( + const SourceModule& sourceModule, + Mode mode, + const std::vector& requireCycles, + NotNull builtinTypes, + NotNull iceHandler, + NotNull moduleResolver, + NotNull fileResolver, + const ScopePtr& parentScope, + std::function prepareModuleScope, + FrontendOptions options, + TypeCheckLimits limits, + bool recordJsonLog, + std::function writeJsonLog +) { + LUAU_TIMETRACE_SCOPE("Frontend::check", "Typechecking"); + LUAU_TIMETRACE_ARGUMENT("module", sourceModule.name.c_str()); + LUAU_TIMETRACE_ARGUMENT("name", sourceModule.humanReadableName.c_str()); + ModulePtr result = std::make_shared(); result->name = sourceModule.name; result->humanReadableName = sourceModule.humanReadableName; - result->reduction = std::make_unique(NotNull{&result->internalTypes}, builtinTypes, iceHandler); + result->mode = mode; + result->internalTypes.owningModule = result.get(); + result->interfaceTypes.owningModule = result.get(); + + iceHandler->moduleName = sourceModule.name; std::unique_ptr logger; if (recordJsonLog) @@ -810,13 +1384,13 @@ ModulePtr check(const SourceModule& sourceModule, const std::vectorinternalTypes, builtinTypes, NotNull{&unifierState}}; - ConstraintGraphBuilder cgb{ + ConstraintGenerator cg{ result, - &result->internalTypes, + NotNull{&normalizer}, moduleResolver, builtinTypes, iceHandler, @@ -824,67 +1398,189 @@ ModulePtr check(const SourceModule& sourceModule, const std::vectorerrors = std::move(cgb.errors); + cg.visitModuleRoot(sourceModule.root); + result->errors = std::move(cg.errors); ConstraintSolver cs{ - NotNull{&normalizer}, NotNull(cgb.rootScope), borrowConstraints(cgb.constraints), result->name, moduleResolver, requireCycles, logger.get()}; + NotNull{&normalizer}, + NotNull(cg.rootScope), + borrowConstraints(cg.constraints), + result->name, + moduleResolver, + requireCycles, + logger.get(), + limits + }; if (options.randomizeConstraintResolutionSeed) cs.randomize(*options.randomizeConstraintResolutionSeed); - cs.run(); + try + { + cs.run(); + } + catch (const TimeLimitError&) + { + result->timeout = true; + } + catch (const UserCancelError&) + { + result->cancelled = true; + } + + if (recordJsonLog) + { + std::string output = logger->compileOutput(); + if (FFlag::DebugLuauLogSolverToJsonFile && writeJsonLog) + writeJsonLog(sourceModule.name, std::move(output)); + else + printf("%s\n", output.c_str()); + } for (TypeError& e : cs.errors) result->errors.emplace_back(std::move(e)); - result->scopes = std::move(cgb.scopes); + result->scopes = std::move(cg.scopes); result->type = sourceModule.type; + result->upperBoundContributors = std::move(cs.upperBoundContributors); + + if (result->timeout || result->cancelled) + { + // If solver was interrupted, skip typechecking and replace all module results with error-supressing types to avoid leaking blocked/pending + // types + ScopePtr moduleScope = result->getModuleScope(); + moduleScope->returnType = builtinTypes->errorRecoveryTypePack(); + for (auto& [name, ty] : result->declaredGlobals) + ty = builtinTypes->errorRecoveryType(); + + for (auto& [name, tf] : result->exportedTypeBindings) + tf.type = builtinTypes->errorRecoveryType(); + } + else + { + switch (mode) + { + case Mode::Nonstrict: + Luau::checkNonStrict(builtinTypes, iceHandler, NotNull{&unifierState}, NotNull{&dfg}, NotNull{&limits}, sourceModule, result.get()); + break; + case Mode::Definition: + // fallthrough intentional + case Mode::Strict: + Luau::check(builtinTypes, NotNull{&unifierState}, NotNull{&limits}, logger.get(), sourceModule, result.get()); + break; + case Mode::NoCheck: + break; + }; + } + + unfreeze(result->interfaceTypes); result->clonePublicInterface(builtinTypes, *iceHandler); - Luau::check(builtinTypes, NotNull{&unifierState}, logger.get(), sourceModule, result.get()); + if (FFlag::DebugLuauForbidInternalTypes) + { + InternalTypeFinder finder; + + finder.traverse(result->returnType); - // Ideally we freeze the arenas before the call into Luau::check, but TypeReduction - // needs to allocate new types while Luau::check is in progress, so here we are. + for (const auto& [_, binding] : result->exportedTypeBindings) + finder.traverse(binding.type); + + for (const auto& [_, ty] : result->astTypes) + finder.traverse(ty); + + for (const auto& [_, ty] : result->astExpectedTypes) + finder.traverse(ty); + + for (const auto& [_, tp] : result->astTypePacks) + finder.traverse(tp); + + for (const auto& [_, ty] : result->astResolvedTypes) + finder.traverse(ty); + + for (const auto& [_, ty] : result->astOverloadResolvedTypes) + finder.traverse(ty); + + for (const auto& [_, tp] : result->astResolvedTypePacks) + finder.traverse(tp); + } + + // It would be nice if we could freeze the arenas before doing type + // checking, but we'll have to do some work to get there. // - // It does mean that mutations to the type graph can happen after the constraints - // have been solved, which will cause hard-to-debug problems. We should revisit this. + // TypeChecker2 sometimes needs to allocate TypePacks via extendTypePack() + // in order to do its thing. We can rework that code to instead allocate + // into a temporary arena as long as we can prove that the allocated types + // and packs can never find their way into an error. + // + // Notably, we would first need to get to a place where TypeChecker2 is + // never in the position of dealing with a FreeType. They should all be + // bound to something by the time constraints are solved. freeze(result->internalTypes); freeze(result->interfaceTypes); - if (recordJsonLog) - { - std::string output = logger->compileOutput(); - printf("%s\n", output.c_str()); - } - return result; } -ModulePtr Frontend::check(const SourceModule& sourceModule, Mode mode, std::vector requireCycles, - std::optional environmentScope, bool forAutocomplete, bool recordJsonLog, TypeCheckLimits typeCheckLimits) +ModulePtr Frontend::check( + const SourceModule& sourceModule, + Mode mode, + std::vector requireCycles, + std::optional environmentScope, + bool forAutocomplete, + bool recordJsonLog, + TypeCheckLimits typeCheckLimits +) { - if (FFlag::DebugLuauDeferredConstraintResolution && mode == Mode::Strict) + if (FFlag::DebugLuauDeferredConstraintResolution) { - auto prepareModuleScopeWrap = [this, forAutocomplete](const ModuleName& name, const ScopePtr& scope) { + auto prepareModuleScopeWrap = [this, forAutocomplete](const ModuleName& name, const ScopePtr& scope) + { if (prepareModuleScope) prepareModuleScope(name, scope, forAutocomplete); }; - return Luau::check(sourceModule, requireCycles, builtinTypes, NotNull{&iceHandler}, - NotNull{forAutocomplete ? &moduleResolverForAutocomplete : &moduleResolver}, NotNull{fileResolver}, - environmentScope ? *environmentScope : globals.globalScope, prepareModuleScopeWrap, options, recordJsonLog); + try + { + return Luau::check( + sourceModule, + mode, + requireCycles, + builtinTypes, + NotNull{&iceHandler}, + NotNull{forAutocomplete ? &moduleResolverForAutocomplete : &moduleResolver}, + NotNull{fileResolver}, + environmentScope ? *environmentScope : globals.globalScope, + prepareModuleScopeWrap, + options, + typeCheckLimits, + recordJsonLog, + writeJsonLog + ); + } + catch (const InternalCompilerError& err) + { + InternalCompilerError augmented = err.location.has_value() ? InternalCompilerError{err.message, sourceModule.name, *err.location} + : InternalCompilerError{err.message, sourceModule.name}; + throw augmented; + } } else { - TypeChecker typeChecker(globals.globalScope, forAutocomplete ? &moduleResolverForAutocomplete : &moduleResolver, builtinTypes, &iceHandler); + TypeChecker typeChecker( + forAutocomplete ? globalsForAutocomplete.globalScope : globals.globalScope, + forAutocomplete ? &moduleResolverForAutocomplete : &moduleResolver, + builtinTypes, + &iceHandler + ); if (prepareModuleScope) { - typeChecker.prepareModuleScope = [this, forAutocomplete](const ModuleName& name, const ScopePtr& scope) { + typeChecker.prepareModuleScope = [this, forAutocomplete](const ModuleName& name, const ScopePtr& scope) + { prepareModuleScope(name, scope, forAutocomplete); }; } @@ -893,6 +1589,7 @@ ModulePtr Frontend::check(const SourceModule& sourceModule, Mode mode, std::vect typeChecker.finishTime = typeCheckLimits.finishTime; typeChecker.instantiationChildLimit = typeCheckLimits.instantiationChildLimit; typeChecker.unifierIterationLimit = typeCheckLimits.unifierIterationLimit; + typeChecker.cancellationToken = typeCheckLimits.cancellationToken; return typeChecker.check(sourceModule, mode, environmentScope); } @@ -901,22 +1598,22 @@ ModulePtr Frontend::check(const SourceModule& sourceModule, Mode mode, std::vect // Read AST into sourceModules if necessary. Trace require()s. Report parse errors. std::pair Frontend::getSourceNode(const ModuleName& name) { - LUAU_TIMETRACE_SCOPE("Frontend::getSourceNode", "Frontend"); - LUAU_TIMETRACE_ARGUMENT("name", name.c_str()); - auto it = sourceNodes.find(name); - if (it != sourceNodes.end() && !it->second.hasDirtySourceModule()) + if (it != sourceNodes.end() && !it->second->hasDirtySourceModule()) { auto moduleIt = sourceModules.find(name); if (moduleIt != sourceModules.end()) - return {&it->second, &moduleIt->second}; + return {it->second.get(), moduleIt->second.get()}; else { LUAU_ASSERT(!"Everything in sourceNodes should also be in sourceModules"); - return {&it->second, nullptr}; + return {it->second.get(), nullptr}; } } + LUAU_TIMETRACE_SCOPE("Frontend::getSourceNode", "Frontend"); + LUAU_TIMETRACE_ARGUMENT("name", name.c_str()); + double timestamp = getTimestamp(); std::optional source = fileResolver->readSource(name); @@ -939,30 +1636,37 @@ std::pair Frontend::getSourceNode(const ModuleName& RequireTraceResult& require = requireTrace[name]; require = traceRequires(fileResolver, result.root, name); - SourceNode& sourceNode = sourceNodes[name]; - SourceModule& sourceModule = sourceModules[name]; + std::shared_ptr& sourceNode = sourceNodes[name]; - sourceModule = std::move(result); - sourceModule.environmentName = environmentName; + if (!sourceNode) + sourceNode = std::make_shared(); - sourceNode.name = sourceModule.name; - sourceNode.humanReadableName = sourceModule.humanReadableName; - sourceNode.requireSet.clear(); - sourceNode.requireLocations.clear(); - sourceNode.dirtySourceModule = false; + std::shared_ptr& sourceModule = sourceModules[name]; + + if (!sourceModule) + sourceModule = std::make_shared(); + + *sourceModule = std::move(result); + sourceModule->environmentName = environmentName; + + sourceNode->name = sourceModule->name; + sourceNode->humanReadableName = sourceModule->humanReadableName; + sourceNode->requireSet.clear(); + sourceNode->requireLocations.clear(); + sourceNode->dirtySourceModule = false; if (it == sourceNodes.end()) { - sourceNode.dirtyModule = true; - sourceNode.dirtyModuleForAutocomplete = true; + sourceNode->dirtyModule = true; + sourceNode->dirtyModuleForAutocomplete = true; } for (const auto& [moduleName, location] : require.requireList) - sourceNode.requireSet.insert(moduleName); + sourceNode->requireSet.insert(moduleName); - sourceNode.requireLocations = require.requireList; + sourceNode->requireLocations = require.requireList; - return {&sourceNode, &sourceModule}; + return {sourceNode.get(), sourceModule.get()}; } /** Try to parse a source file into a SourceModule. diff --git a/third_party/luau/Analysis/src/Generalization.cpp b/third_party/luau/Analysis/src/Generalization.cpp new file mode 100644 index 00000000..ea736642 --- /dev/null +++ b/third_party/luau/Analysis/src/Generalization.cpp @@ -0,0 +1,922 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/Generalization.h" + +#include "Luau/Scope.h" +#include "Luau/Type.h" +#include "Luau/ToString.h" +#include "Luau/TypeArena.h" +#include "Luau/TypePack.h" +#include "Luau/VisitType.h" + +namespace Luau +{ + +struct MutatingGeneralizer : TypeOnceVisitor +{ + NotNull builtinTypes; + + NotNull scope; + NotNull> cachedTypes; + DenseHashMap positiveTypes; + DenseHashMap negativeTypes; + std::vector generics; + std::vector genericPacks; + + bool isWithinFunction = false; + bool avoidSealingTables = false; + + MutatingGeneralizer( + NotNull builtinTypes, + NotNull scope, + NotNull> cachedTypes, + DenseHashMap positiveTypes, + DenseHashMap negativeTypes, + bool avoidSealingTables + ) + : TypeOnceVisitor(/* skipBoundTypes */ true) + , builtinTypes(builtinTypes) + , scope(scope) + , cachedTypes(cachedTypes) + , positiveTypes(std::move(positiveTypes)) + , negativeTypes(std::move(negativeTypes)) + , avoidSealingTables(avoidSealingTables) + { + } + + static void replace(DenseHashSet& seen, TypeId haystack, TypeId needle, TypeId replacement) + { + haystack = follow(haystack); + + if (seen.find(haystack)) + return; + seen.insert(haystack); + + if (UnionType* ut = getMutable(haystack)) + { + for (auto iter = ut->options.begin(); iter != ut->options.end();) + { + // FIXME: I bet this function has reentrancy problems + TypeId option = follow(*iter); + + if (option == needle && get(replacement)) + { + iter = ut->options.erase(iter); + continue; + } + + if (option == needle) + { + *iter = replacement; + iter++; + continue; + } + + // advance the iterator, nothing after this can use it. + iter++; + + if (seen.find(option)) + continue; + seen.insert(option); + + if (get(option)) + replace(seen, option, needle, haystack); + else if (get(option)) + replace(seen, option, needle, haystack); + } + + if (ut->options.size() == 1) + { + TypeId onlyType = ut->options[0]; + LUAU_ASSERT(onlyType != haystack); + emplaceType(asMutable(haystack), onlyType); + } + + return; + } + + if (IntersectionType* it = getMutable(needle)) + { + for (auto iter = it->parts.begin(); iter != it->parts.end();) + { + // FIXME: I bet this function has reentrancy problems + TypeId part = follow(*iter); + + if (part == needle && get(replacement)) + { + iter = it->parts.erase(iter); + continue; + } + + if (part == needle) + { + *iter = replacement; + iter++; + continue; + } + + // advance the iterator, nothing after this can use it. + iter++; + + if (seen.find(part)) + continue; + seen.insert(part); + + if (get(part)) + replace(seen, part, needle, haystack); + else if (get(part)) + replace(seen, part, needle, haystack); + } + + if (it->parts.size() == 1) + { + TypeId onlyType = it->parts[0]; + LUAU_ASSERT(onlyType != needle); + emplaceType(asMutable(needle), onlyType); + } + + return; + } + } + + bool visit(TypeId ty, const FunctionType& ft) override + { + if (cachedTypes->contains(ty)) + return false; + + const bool oldValue = isWithinFunction; + + isWithinFunction = true; + + traverse(ft.argTypes); + traverse(ft.retTypes); + + isWithinFunction = oldValue; + + return false; + } + + bool visit(TypeId ty, const FreeType&) override + { + LUAU_ASSERT(!cachedTypes->contains(ty)); + + const FreeType* ft = get(ty); + LUAU_ASSERT(ft); + + traverse(ft->lowerBound); + traverse(ft->upperBound); + + // It is possible for the above traverse() calls to cause ty to be + // transmuted. We must reacquire ft if this happens. + ty = follow(ty); + ft = get(ty); + if (!ft) + return false; + + const size_t positiveCount = getCount(positiveTypes, ty); + const size_t negativeCount = getCount(negativeTypes, ty); + + if (!positiveCount && !negativeCount) + return false; + + const bool hasLowerBound = !get(follow(ft->lowerBound)); + const bool hasUpperBound = !get(follow(ft->upperBound)); + + DenseHashSet seen{nullptr}; + seen.insert(ty); + + if (!hasLowerBound && !hasUpperBound) + { + if (!isWithinFunction || (positiveCount + negativeCount == 1)) + emplaceType(asMutable(ty), builtinTypes->unknownType); + else + { + emplaceType(asMutable(ty), scope); + generics.push_back(ty); + } + } + + // It is possible that this free type has other free types in its upper + // or lower bounds. If this is the case, we must replace those + // references with never (for the lower bound) or unknown (for the upper + // bound). + // + // If we do not do this, we get tautological bounds like a <: a <: unknown. + else if (positiveCount && !hasUpperBound) + { + TypeId lb = follow(ft->lowerBound); + if (FreeType* lowerFree = getMutable(lb); lowerFree && lowerFree->upperBound == ty) + lowerFree->upperBound = builtinTypes->unknownType; + else + { + DenseHashSet replaceSeen{nullptr}; + replace(replaceSeen, lb, ty, builtinTypes->unknownType); + } + + if (lb != ty) + emplaceType(asMutable(ty), lb); + else if (!isWithinFunction || (positiveCount + negativeCount == 1)) + emplaceType(asMutable(ty), builtinTypes->unknownType); + else + { + // if the lower bound is the type in question, we don't actually have a lower bound. + emplaceType(asMutable(ty), scope); + generics.push_back(ty); + } + } + else + { + TypeId ub = follow(ft->upperBound); + if (FreeType* upperFree = getMutable(ub); upperFree && upperFree->lowerBound == ty) + upperFree->lowerBound = builtinTypes->neverType; + else + { + DenseHashSet replaceSeen{nullptr}; + replace(replaceSeen, ub, ty, builtinTypes->neverType); + } + + if (ub != ty) + emplaceType(asMutable(ty), ub); + else if (!isWithinFunction || (positiveCount + negativeCount == 1)) + emplaceType(asMutable(ty), builtinTypes->unknownType); + else + { + // if the upper bound is the type in question, we don't actually have an upper bound. + emplaceType(asMutable(ty), scope); + generics.push_back(ty); + } + } + + return false; + } + + size_t getCount(const DenseHashMap& map, const void* ty) + { + if (const size_t* count = map.find(ty)) + return *count; + else + return 0; + } + + bool visit(TypeId ty, const TableType&) override + { + if (cachedTypes->contains(ty)) + return false; + + const size_t positiveCount = getCount(positiveTypes, ty); + const size_t negativeCount = getCount(negativeTypes, ty); + + // FIXME: Free tables should probably just be replaced by upper bounds on free types. + // + // eg never <: 'a <: {x: number} & {z: boolean} + + if (!positiveCount && !negativeCount) + return true; + + TableType* tt = getMutable(ty); + LUAU_ASSERT(tt); + + if (!avoidSealingTables) + tt->state = TableState::Sealed; + + return true; + } + + bool visit(TypePackId tp, const FreeTypePack& ftp) override + { + if (!subsumes(scope, ftp.scope)) + return true; + + tp = follow(tp); + + const size_t positiveCount = getCount(positiveTypes, tp); + const size_t negativeCount = getCount(negativeTypes, tp); + + if (1 == positiveCount + negativeCount) + emplaceTypePack(asMutable(tp), builtinTypes->unknownTypePack); + else + { + emplaceTypePack(asMutable(tp), scope); + genericPacks.push_back(tp); + } + + return true; + } +}; + +struct FreeTypeSearcher : TypeVisitor +{ + NotNull scope; + NotNull> cachedTypes; + + explicit FreeTypeSearcher(NotNull scope, NotNull> cachedTypes) + : TypeVisitor(/*skipBoundTypes*/ true) + , scope(scope) + , cachedTypes(cachedTypes) + { + } + + enum Polarity + { + Positive, + Negative, + Both, + }; + + Polarity polarity = Positive; + + void flip() + { + switch (polarity) + { + case Positive: + polarity = Negative; + break; + case Negative: + polarity = Positive; + break; + case Both: + break; + } + } + + DenseHashSet seenPositive{nullptr}; + DenseHashSet seenNegative{nullptr}; + + bool seenWithPolarity(const void* ty) + { + switch (polarity) + { + case Positive: + { + if (seenPositive.contains(ty)) + return true; + + seenPositive.insert(ty); + return false; + } + case Negative: + { + if (seenNegative.contains(ty)) + return true; + + seenNegative.insert(ty); + return false; + } + case Both: + { + if (seenPositive.contains(ty) && seenNegative.contains(ty)) + return true; + + seenPositive.insert(ty); + seenNegative.insert(ty); + return false; + } + } + + return false; + } + + // The keys in these maps are either TypeIds or TypePackIds. It's safe to + // mix them because we only use these pointers as unique keys. We never + // indirect them. + DenseHashMap negativeTypes{0}; + DenseHashMap positiveTypes{0}; + + bool visit(TypeId ty) override + { + if (cachedTypes->contains(ty) || seenWithPolarity(ty)) + return false; + + LUAU_ASSERT(ty); + return true; + } + + bool visit(TypeId ty, const FreeType& ft) override + { + if (cachedTypes->contains(ty) || seenWithPolarity(ty)) + return false; + + if (!subsumes(scope, ft.scope)) + return true; + + switch (polarity) + { + case Positive: + positiveTypes[ty]++; + break; + case Negative: + negativeTypes[ty]++; + break; + case Both: + positiveTypes[ty]++; + negativeTypes[ty]++; + break; + } + + return true; + } + + bool visit(TypeId ty, const TableType& tt) override + { + if (cachedTypes->contains(ty) || seenWithPolarity(ty)) + return false; + + if ((tt.state == TableState::Free || tt.state == TableState::Unsealed) && subsumes(scope, tt.scope)) + { + switch (polarity) + { + case Positive: + positiveTypes[ty]++; + break; + case Negative: + negativeTypes[ty]++; + break; + case Both: + positiveTypes[ty]++; + negativeTypes[ty]++; + break; + } + } + + for (const auto& [_name, prop] : tt.props) + { + if (prop.isReadOnly()) + traverse(*prop.readTy); + else + { + LUAU_ASSERT(prop.isShared()); + + Polarity p = polarity; + polarity = Both; + traverse(prop.type()); + polarity = p; + } + } + + if (tt.indexer) + { + traverse(tt.indexer->indexType); + traverse(tt.indexer->indexResultType); + } + + return false; + } + + bool visit(TypeId ty, const FunctionType& ft) override + { + if (cachedTypes->contains(ty) || seenWithPolarity(ty)) + return false; + + flip(); + traverse(ft.argTypes); + flip(); + + traverse(ft.retTypes); + + return false; + } + + bool visit(TypeId, const ClassType&) override + { + return false; + } + + bool visit(TypePackId tp, const FreeTypePack& ftp) override + { + if (seenWithPolarity(tp)) + return false; + + if (!subsumes(scope, ftp.scope)) + return true; + + switch (polarity) + { + case Positive: + positiveTypes[tp]++; + break; + case Negative: + negativeTypes[tp]++; + break; + case Both: + positiveTypes[tp]++; + negativeTypes[tp]++; + break; + } + + return true; + } +}; + +// We keep a running set of types that will not change under generalization and +// only have outgoing references to types that are the same. We use this to +// short circuit generalization. It improves performance quite a lot. +// +// We do this by tracing through the type and searching for types that are +// uncacheable. If a type has a reference to an uncacheable type, it is itself +// uncacheable. +// +// If a type has no outbound references to uncacheable types, we add it to the +// cache. +struct TypeCacher : TypeOnceVisitor +{ + NotNull> cachedTypes; + + DenseHashSet uncacheable{nullptr}; + DenseHashSet uncacheablePacks{nullptr}; + + explicit TypeCacher(NotNull> cachedTypes) + : TypeOnceVisitor(/* skipBoundTypes */ true) + , cachedTypes(cachedTypes) + { + } + + void cache(TypeId ty) + { + cachedTypes->insert(ty); + } + + bool isCached(TypeId ty) const + { + return cachedTypes->contains(ty); + } + + void markUncacheable(TypeId ty) + { + uncacheable.insert(ty); + } + + void markUncacheable(TypePackId tp) + { + uncacheablePacks.insert(tp); + } + + bool isUncacheable(TypeId ty) const + { + return uncacheable.contains(ty); + } + + bool isUncacheable(TypePackId tp) const + { + return uncacheablePacks.contains(tp); + } + + bool visit(TypeId ty) override + { + if (isUncacheable(ty) || isCached(ty)) + return false; + return true; + } + + bool visit(TypeId ty, const FreeType& ft) override + { + // Free types are never cacheable. + LUAU_ASSERT(!isCached(ty)); + + if (!isUncacheable(ty)) + { + traverse(ft.lowerBound); + traverse(ft.upperBound); + + markUncacheable(ty); + } + + return false; + } + + bool visit(TypeId ty, const GenericType&) override + { + cache(ty); + return false; + } + + bool visit(TypeId ty, const PrimitiveType&) override + { + cache(ty); + return false; + } + + bool visit(TypeId ty, const SingletonType&) override + { + cache(ty); + return false; + } + + bool visit(TypeId ty, const BlockedType&) override + { + markUncacheable(ty); + return false; + } + + bool visit(TypeId ty, const PendingExpansionType&) override + { + markUncacheable(ty); + return false; + } + + bool visit(TypeId ty, const FunctionType& ft) override + { + if (isCached(ty) || isUncacheable(ty)) + return false; + + traverse(ft.argTypes); + traverse(ft.retTypes); + for (TypeId gen : ft.generics) + traverse(gen); + + bool uncacheable = false; + + if (isUncacheable(ft.argTypes)) + uncacheable = true; + + else if (isUncacheable(ft.retTypes)) + uncacheable = true; + + for (TypeId argTy : ft.argTypes) + { + if (isUncacheable(argTy)) + { + uncacheable = true; + break; + } + } + + for (TypeId retTy : ft.retTypes) + { + if (isUncacheable(retTy)) + { + uncacheable = true; + break; + } + } + + for (TypeId g : ft.generics) + { + if (isUncacheable(g)) + { + uncacheable = true; + break; + } + } + + if (uncacheable) + markUncacheable(ty); + else + cache(ty); + + return false; + } + + bool visit(TypeId ty, const TableType& tt) override + { + if (isCached(ty) || isUncacheable(ty)) + return false; + + if (tt.boundTo) + { + traverse(*tt.boundTo); + if (isUncacheable(*tt.boundTo)) + { + markUncacheable(ty); + return false; + } + } + + bool uncacheable = false; + + // This logic runs immediately after generalization, so any remaining + // unsealed tables are assuredly not cacheable. They may yet have + // properties added to them. + if (tt.state == TableState::Free || tt.state == TableState::Unsealed) + uncacheable = true; + + for (const auto& [_name, prop] : tt.props) + { + if (prop.readTy) + { + traverse(*prop.readTy); + + if (isUncacheable(*prop.readTy)) + uncacheable = true; + } + if (prop.writeTy && prop.writeTy != prop.readTy) + { + traverse(*prop.writeTy); + + if (isUncacheable(*prop.writeTy)) + uncacheable = true; + } + } + + if (tt.indexer) + { + traverse(tt.indexer->indexType); + if (isUncacheable(tt.indexer->indexType)) + uncacheable = true; + + traverse(tt.indexer->indexResultType); + if (isUncacheable(tt.indexer->indexResultType)) + uncacheable = true; + } + + if (uncacheable) + markUncacheable(ty); + else + cache(ty); + + return false; + } + + bool visit(TypeId ty, const ClassType&) override + { + cache(ty); + return false; + } + + bool visit(TypeId ty, const AnyType&) override + { + cache(ty); + return false; + } + + bool visit(TypeId ty, const UnionType& ut) override + { + if (isUncacheable(ty) || isCached(ty)) + return false; + + bool uncacheable = false; + + for (TypeId partTy : ut.options) + { + traverse(partTy); + + uncacheable |= isUncacheable(partTy); + } + + if (uncacheable) + markUncacheable(ty); + else + cache(ty); + + return false; + } + + bool visit(TypeId ty, const IntersectionType& it) override + { + if (isUncacheable(ty) || isCached(ty)) + return false; + + bool uncacheable = false; + + for (TypeId partTy : it.parts) + { + traverse(partTy); + + uncacheable |= isUncacheable(partTy); + } + + if (uncacheable) + markUncacheable(ty); + else + cache(ty); + + return false; + } + + bool visit(TypeId ty, const UnknownType&) override + { + cache(ty); + return false; + } + + bool visit(TypeId ty, const NeverType&) override + { + cache(ty); + return false; + } + + bool visit(TypeId ty, const NegationType& nt) override + { + if (!isCached(ty) && !isUncacheable(ty)) + { + traverse(nt.ty); + + if (isUncacheable(nt.ty)) + markUncacheable(ty); + else + cache(ty); + } + + return false; + } + + bool visit(TypeId ty, const TypeFunctionInstanceType& tfit) override + { + if (isCached(ty) || isUncacheable(ty)) + return false; + + bool uncacheable = false; + + for (TypeId argTy : tfit.typeArguments) + { + traverse(argTy); + + if (isUncacheable(argTy)) + uncacheable = true; + } + + for (TypePackId argPack : tfit.packArguments) + { + traverse(argPack); + + if (isUncacheable(argPack)) + uncacheable = true; + } + + if (uncacheable) + markUncacheable(ty); + else + cache(ty); + + return false; + } + + bool visit(TypePackId tp, const FreeTypePack&) override + { + markUncacheable(tp); + return false; + } + + bool visit(TypePackId tp, const VariadicTypePack& vtp) override + { + if (isUncacheable(tp)) + return false; + + traverse(vtp.ty); + + if (isUncacheable(vtp.ty)) + markUncacheable(tp); + + return false; + } + + bool visit(TypePackId tp, const BlockedTypePack&) override + { + markUncacheable(tp); + return false; + } + + bool visit(TypePackId tp, const TypeFunctionInstanceTypePack&) override + { + markUncacheable(tp); + return false; + } +}; + +std::optional generalize( + NotNull arena, + NotNull builtinTypes, + NotNull scope, + NotNull> cachedTypes, + TypeId ty, + bool avoidSealingTables +) +{ + ty = follow(ty); + + if (ty->owningArena != arena || ty->persistent) + return ty; + + if (const FunctionType* ft = get(ty); ft && (!ft->generics.empty() || !ft->genericPacks.empty())) + return ty; + + FreeTypeSearcher fts{scope, cachedTypes}; + fts.traverse(ty); + + MutatingGeneralizer gen{builtinTypes, scope, cachedTypes, std::move(fts.positiveTypes), std::move(fts.negativeTypes), avoidSealingTables}; + + gen.traverse(ty); + + /* MutatingGeneralizer mutates types in place, so it is possible that ty has + * been transmuted to a BoundType. We must follow it again and verify that + * we are allowed to mutate it before we attach generics to it. + */ + ty = follow(ty); + + if (ty->owningArena != arena || ty->persistent) + return ty; + + TypeCacher cacher{cachedTypes}; + cacher.traverse(ty); + + FunctionType* ftv = getMutable(ty); + if (ftv) + { + ftv->generics = std::move(gen.generics); + ftv->genericPacks = std::move(gen.genericPacks); + } + + return ty; +} + +} // namespace Luau diff --git a/third_party/luau/Analysis/src/GlobalTypes.cpp b/third_party/luau/Analysis/src/GlobalTypes.cpp new file mode 100644 index 00000000..9dd60caa --- /dev/null +++ b/third_party/luau/Analysis/src/GlobalTypes.cpp @@ -0,0 +1,30 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/GlobalTypes.h" + +namespace Luau +{ + +GlobalTypes::GlobalTypes(NotNull builtinTypes) + : builtinTypes(builtinTypes) +{ + globalScope = std::make_shared(globalTypes.addTypePack(TypePackVar{FreeTypePack{TypeLevel{}}})); + + globalScope->addBuiltinTypeBinding("any", TypeFun{{}, builtinTypes->anyType}); + globalScope->addBuiltinTypeBinding("nil", TypeFun{{}, builtinTypes->nilType}); + globalScope->addBuiltinTypeBinding("number", TypeFun{{}, builtinTypes->numberType}); + globalScope->addBuiltinTypeBinding("string", TypeFun{{}, builtinTypes->stringType}); + globalScope->addBuiltinTypeBinding("boolean", TypeFun{{}, builtinTypes->booleanType}); + globalScope->addBuiltinTypeBinding("thread", TypeFun{{}, builtinTypes->threadType}); + globalScope->addBuiltinTypeBinding("buffer", TypeFun{{}, builtinTypes->bufferType}); + globalScope->addBuiltinTypeBinding("unknown", TypeFun{{}, builtinTypes->unknownType}); + globalScope->addBuiltinTypeBinding("never", TypeFun{{}, builtinTypes->neverType}); + + unfreeze(*builtinTypes->arena); + TypeId stringMetatableTy = makeStringMetatable(builtinTypes); + asMutable(builtinTypes->stringType)->ty.emplace(PrimitiveType::String, stringMetatableTy); + persist(stringMetatableTy); + freeze(*builtinTypes->arena); +} + +} // namespace Luau diff --git a/third_party/luau/Analysis/src/Instantiation.cpp b/third_party/luau/Analysis/src/Instantiation.cpp index 7d0f0f72..8422d8c4 100644 --- a/third_party/luau/Analysis/src/Instantiation.cpp +++ b/third_party/luau/Analysis/src/Instantiation.cpp @@ -1,19 +1,38 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -#include "Luau/Common.h" #include "Luau/Instantiation.h" + +#include "Luau/Common.h" +#include "Luau/Instantiation2.h" // including for `Replacer` which was stolen since it will be kept in the new solver +#include "Luau/ToString.h" #include "Luau/TxnLog.h" #include "Luau/TypeArena.h" +#include "Luau/TypeCheckLimits.h" -LUAU_FASTFLAG(LuauClassTypeVarsInSubstitution) +#include + +LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) +LUAU_FASTFLAG(LuauReusableSubstitutions) namespace Luau { +void Instantiation::resetState(const TxnLog* log, TypeArena* arena, NotNull builtinTypes, TypeLevel level, Scope* scope) +{ + LUAU_ASSERT(FFlag::LuauReusableSubstitutions); + + Substitution::resetState(log, arena); + + this->builtinTypes = builtinTypes; + + this->level = level; + this->scope = scope; +} + bool Instantiation::isDirty(TypeId ty) { if (const FunctionType* ftv = log->getMutable(ty)) { - if (ftv->hasNoGenerics) + if (ftv->hasNoFreeOrGenericTypes) return false; return true; @@ -33,7 +52,7 @@ bool Instantiation::ignoreChildren(TypeId ty) { if (log->getMutable(ty)) return true; - else if (FFlag::LuauClassTypeVarsInSubstitution && get(ty)) + else if (get(ty)) return true; else return false; @@ -52,13 +71,26 @@ TypeId Instantiation::clean(TypeId ty) clone.argNames = ftv->argNames; TypeId result = addType(std::move(clone)); - // Annoyingly, we have to do this even if there are no generics, - // to replace any generic tables. - ReplaceGenerics replaceGenerics{log, arena, level, scope, ftv->generics, ftv->genericPacks}; + if (FFlag::LuauReusableSubstitutions) + { + // Annoyingly, we have to do this even if there are no generics, + // to replace any generic tables. + reusableReplaceGenerics.resetState(log, arena, builtinTypes, level, scope, ftv->generics, ftv->genericPacks); - // TODO: What to do if this returns nullopt? - // We don't have access to the error-reporting machinery - result = replaceGenerics.substitute(result).value_or(result); + // TODO: What to do if this returns nullopt? + // We don't have access to the error-reporting machinery + result = reusableReplaceGenerics.substitute(result).value_or(result); + } + else + { + // Annoyingly, we have to do this even if there are no generics, + // to replace any generic tables. + ReplaceGenerics replaceGenerics{log, arena, builtinTypes, level, scope, ftv->generics, ftv->genericPacks}; + + // TODO: What to do if this returns nullopt? + // We don't have access to the error-reporting machinery + result = replaceGenerics.substitute(result).value_or(result); + } asMutable(result)->documentationSymbol = ty->documentationSymbol; return result; @@ -70,11 +102,34 @@ TypePackId Instantiation::clean(TypePackId tp) return tp; } +void ReplaceGenerics::resetState( + const TxnLog* log, + TypeArena* arena, + NotNull builtinTypes, + TypeLevel level, + Scope* scope, + const std::vector& generics, + const std::vector& genericPacks +) +{ + LUAU_ASSERT(FFlag::LuauReusableSubstitutions); + + Substitution::resetState(log, arena); + + this->builtinTypes = builtinTypes; + + this->level = level; + this->scope = scope; + + this->generics = generics; + this->genericPacks = genericPacks; +} + bool ReplaceGenerics::ignoreChildren(TypeId ty) { if (const FunctionType* ftv = log->getMutable(ty)) { - if (ftv->hasNoGenerics) + if (ftv->hasNoFreeOrGenericTypes) return true; // We aren't recursing in the case of a generic function which @@ -84,7 +139,7 @@ bool ReplaceGenerics::ignoreChildren(TypeId ty) // whenever we quantify, so the vectors overlap if and only if they are equal. return (!generics.empty() || !genericPacks.empty()) && (ftv->generics == generics) && (ftv->genericPacks == genericPacks); } - else if (FFlag::LuauClassTypeVarsInSubstitution && get(ty)) + else if (get(ty)) return true; else { @@ -120,8 +175,16 @@ TypeId ReplaceGenerics::clean(TypeId ty) clone.definitionLocation = ttv->definitionLocation; return addType(std::move(clone)); } + else if (FFlag::DebugLuauDeferredConstraintResolution) + { + TypeId res = freshType(NotNull{arena}, builtinTypes, scope); + getMutable(res)->level = level; + return res; + } else + { return addType(FreeType{scope, level}); + } } TypePackId ReplaceGenerics::clean(TypePackId tp) @@ -130,4 +193,48 @@ TypePackId ReplaceGenerics::clean(TypePackId tp) return addTypePack(TypePackVar(FreeTypePack{scope, level})); } +std::optional instantiate( + NotNull builtinTypes, + NotNull arena, + NotNull limits, + NotNull scope, + TypeId ty +) +{ + ty = follow(ty); + + const FunctionType* ft = get(ty); + if (!ft) + return ty; + + if (ft->generics.empty() && ft->genericPacks.empty()) + return ty; + + DenseHashMap replacements{nullptr}; + DenseHashMap replacementPacks{nullptr}; + + for (TypeId g : ft->generics) + replacements[g] = freshType(arena, builtinTypes, scope); + + for (TypePackId g : ft->genericPacks) + replacementPacks[g] = arena->freshTypePack(scope); + + Replacer r{arena, std::move(replacements), std::move(replacementPacks)}; + + if (limits->instantiationChildLimit) + r.childLimit = *limits->instantiationChildLimit; + + std::optional res = r.substitute(ty); + if (!res) + return res; + + FunctionType* ft2 = getMutable(*res); + LUAU_ASSERT(ft != ft2); + + ft2->generics.clear(); + ft2->genericPacks.clear(); + + return res; +} + } // namespace Luau diff --git a/third_party/luau/Analysis/src/Instantiation2.cpp b/third_party/luau/Analysis/src/Instantiation2.cpp new file mode 100644 index 00000000..106ad870 --- /dev/null +++ b/third_party/luau/Analysis/src/Instantiation2.cpp @@ -0,0 +1,88 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/Instantiation2.h" + +namespace Luau +{ + +bool Instantiation2::ignoreChildren(TypeId ty) +{ + if (get(ty)) + return true; + + if (auto ftv = get(ty)) + { + if (ftv->hasNoFreeOrGenericTypes) + return false; + + // If this function type quantifies over these generics, we don't want substitution to + // go any further into them because it's being shadowed in this case. + for (auto generic : ftv->generics) + if (genericSubstitutions.contains(generic)) + return true; + + for (auto generic : ftv->genericPacks) + if (genericPackSubstitutions.contains(generic)) + return true; + } + + return false; +} + +bool Instantiation2::isDirty(TypeId ty) +{ + return get(ty) && genericSubstitutions.contains(ty); +} + +bool Instantiation2::isDirty(TypePackId tp) +{ + return get(tp) && genericPackSubstitutions.contains(tp); +} + +TypeId Instantiation2::clean(TypeId ty) +{ + TypeId substTy = follow(genericSubstitutions[ty]); + const FreeType* ft = get(substTy); + + // violation of the substitution invariant if this is not a free type. + LUAU_ASSERT(ft); + + // if we didn't learn anything about the lower bound, we pick the upper bound instead. + // we default to the lower bound which represents the most specific type for the free type. + TypeId res = get(ft->lowerBound) ? ft->upperBound : ft->lowerBound; + + // Instantiation should not traverse into the type that we are substituting for. + dontTraverseInto(res); + + return res; +} + +TypePackId Instantiation2::clean(TypePackId tp) +{ + TypePackId res = genericPackSubstitutions[tp]; + dontTraverseInto(res); + return res; +} + +std::optional instantiate2( + TypeArena* arena, + DenseHashMap genericSubstitutions, + DenseHashMap genericPackSubstitutions, + TypeId ty +) +{ + Instantiation2 instantiation{arena, std::move(genericSubstitutions), std::move(genericPackSubstitutions)}; + return instantiation.substitute(ty); +} + +std::optional instantiate2( + TypeArena* arena, + DenseHashMap genericSubstitutions, + DenseHashMap genericPackSubstitutions, + TypePackId tp +) +{ + Instantiation2 instantiation{arena, std::move(genericSubstitutions), std::move(genericPackSubstitutions)}; + return instantiation.substitute(tp); +} + +} // namespace Luau diff --git a/third_party/luau/Analysis/src/IostreamHelpers.cpp b/third_party/luau/Analysis/src/IostreamHelpers.cpp index 43580da4..a3d8b4e3 100644 --- a/third_party/luau/Analysis/src/IostreamHelpers.cpp +++ b/third_party/luau/Analysis/src/IostreamHelpers.cpp @@ -1,6 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/IostreamHelpers.h" #include "Luau/ToString.h" +#include "Luau/TypePath.h" namespace Luau { @@ -113,6 +114,8 @@ static void errorToString(std::ostream& stream, const T& err) stream << "GenericError { " << err.message << " }"; else if constexpr (std::is_same_v) stream << "InternalError { " << err.message << " }"; + else if constexpr (std::is_same_v) + stream << "ConstraintSolvingIncompleteError {}"; else if constexpr (std::is_same_v) stream << "CannotCallNonFunction { " << toString(err.ty) << " }"; else if constexpr (std::is_same_v) @@ -192,13 +195,74 @@ static void errorToString(std::ostream& stream, const T& err) stream << "TypePackMismatch { wanted = '" + toString(err.wantedTp) + "', given = '" + toString(err.givenTp) + "' }"; else if constexpr (std::is_same_v) stream << "DynamicPropertyLookupOnClassesUnsafe { " << toString(err.ty) << " }"; + else if constexpr (std::is_same_v) + stream << "UninhabitedTypeFunction { " << toString(err.ty) << " }"; + else if constexpr (std::is_same_v) + { + std::string recArgs = "["; + for (auto [s, t] : err.recommendedArgs) + recArgs += " " + s + ": " + toString(t); + recArgs += " ]"; + stream << "ExplicitFunctionAnnotationRecommended { recommmendedReturn = '" + toString(err.recommendedReturn) + + "', recommmendedArgs = " + recArgs + "}"; + } + else if constexpr (std::is_same_v) + stream << "UninhabitedTypePackFunction { " << toString(err.tp) << " }"; + else if constexpr (std::is_same_v) + stream << "WhereClauseNeeded { " << toString(err.ty) << " }"; + else if constexpr (std::is_same_v) + stream << "PackWhereClauseNeeded { " << toString(err.tp) << " }"; + else if constexpr (std::is_same_v) + stream << "CheckedFunctionCallError { expected = '" << toString(err.expected) << "', passed = '" << toString(err.passed) + << "', checkedFunctionName = " << err.checkedFunctionName << ", argumentIndex = " << std::to_string(err.argumentIndex) << " }"; + else if constexpr (std::is_same_v) + stream << "NonStrictFunctionDefinitionError { functionName = '" + err.functionName + "', argument = '" + err.argument + + "', argumentType = '" + toString(err.argumentType) + "' }"; + else if constexpr (std::is_same_v) + stream << "PropertyAccessViolation { table = " << toString(err.table) << ", prop = '" << err.key << "', context = " << err.context << " }"; + else if constexpr (std::is_same_v) + stream << "CheckedFunction { functionName = '" + err.functionName + ", expected = " + std::to_string(err.expected) + + ", actual = " + std::to_string(err.actual) + "}"; + else if constexpr (std::is_same_v) + stream << "UnexpectedTypeInSubtyping { ty = '" + toString(err.ty) + "' }"; + else if constexpr (std::is_same_v) + stream << "UnexpectedTypePackInSubtyping { tp = '" + toString(err.tp) + "' }"; + else if constexpr (std::is_same_v) + { + stream << "CannotAssignToNever { rvalueType = '" << toString(err.rhsType) << "', reason = '" << err.reason << "', cause = { "; + + bool first = true; + for (TypeId ty : err.cause) + { + if (first) + first = false; + else + stream << ", "; + + stream << "'" << toString(ty) << "'"; + } + + stream << " } } "; + } else static_assert(always_false_v, "Non-exhaustive type switch"); } +std::ostream& operator<<(std::ostream& stream, const CannotAssignToNever::Reason& reason) +{ + switch (reason) + { + case CannotAssignToNever::Reason::PropertyNarrowed: + return stream << "PropertyNarrowed"; + default: + return stream << "UnknownReason"; + } +} + std::ostream& operator<<(std::ostream& stream, const TypeErrorData& data) { - auto cb = [&](const auto& e) { + auto cb = [&](const auto& e) + { return errorToString(stream, e); }; visit(cb, data); @@ -225,4 +289,34 @@ std::ostream& operator<<(std::ostream& stream, const TypePackVar& tv) return stream << toString(tv); } +std::ostream& operator<<(std::ostream& stream, TypeId ty) +{ + // we commonly use a null pointer when a type may not be present; we need to + // account for that here. + if (!ty) + return stream << ""; + + return stream << toString(ty); +} + +std::ostream& operator<<(std::ostream& stream, TypePackId tp) +{ + // we commonly use a null pointer when a type may not be present; we need to + // account for that here. + if (!tp) + return stream << ""; + + return stream << toString(tp); +} + +namespace TypePath +{ + +std::ostream& operator<<(std::ostream& stream, const Path& path) +{ + return stream << toString(path); +} + +} // namespace TypePath + } // namespace Luau diff --git a/third_party/luau/Analysis/src/Linter.cpp b/third_party/luau/Analysis/src/Linter.cpp index d6aafda6..23457f4c 100644 --- a/third_party/luau/Analysis/src/Linter.cpp +++ b/third_party/luau/Analysis/src/Linter.cpp @@ -14,45 +14,14 @@ LUAU_FASTINTVARIABLE(LuauSuggestionDistance, 4) -namespace Luau -{ +LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) -// clang-format off -static const char* kWarningNames[] = { - "Unknown", - - "UnknownGlobal", - "DeprecatedGlobal", - "GlobalUsedAsLocal", - "LocalShadow", - "SameLineStatement", - "MultiLineStatement", - "LocalUnused", - "FunctionUnused", - "ImportUnused", - "BuiltinGlobalWrite", - "PlaceholderRead", - "UnreachableCode", - "UnknownType", - "ForRange", - "UnbalancedAssignment", - "ImplicitReturn", - "DuplicateLocal", - "FormatString", - "TableLiteral", - "UninitializedLocal", - "DuplicateFunction", - "DeprecatedApi", - "TableOperations", - "DuplicateCondition", - "MisleadingAndOr", - "CommentDirective", - "IntegerParsing", - "ComparisonPrecedence", -}; -// clang-format on +LUAU_FASTFLAG(LuauAttribute) +LUAU_FASTFLAG(LuauNativeAttribute) +LUAU_FASTFLAGVARIABLE(LintRedundantNativeAttribute, false) -static_assert(std::size(kWarningNames) == unsigned(LintWarning::Code__Count), "did you forget to add warning to the list?"); +namespace Luau +{ struct LintContext { @@ -306,8 +275,14 @@ class LintGlobalLocal : AstVisitor else if (g->deprecated) { if (const char* replacement = *g->deprecated; replacement && strlen(replacement)) - emitWarning(*context, LintWarning::Code_DeprecatedGlobal, gv->location, "Global '%s' is deprecated, use '%s' instead", - gv->name.value, replacement); + emitWarning( + *context, + LintWarning::Code_DeprecatedGlobal, + gv->location, + "Global '%s' is deprecated, use '%s' instead", + gv->name.value, + replacement + ); else emitWarning(*context, LintWarning::Code_DeprecatedGlobal, gv->location, "Global '%s' is deprecated", gv->name.value); } @@ -322,18 +297,33 @@ class LintGlobalLocal : AstVisitor AstExprFunction* top = g.functionRef.back(); if (top->debugname.value) - emitWarning(*context, LintWarning::Code_GlobalUsedAsLocal, g.firstRef->location, - "Global '%s' is only used in the enclosing function '%s'; consider changing it to local", g.firstRef->name.value, - top->debugname.value); + emitWarning( + *context, + LintWarning::Code_GlobalUsedAsLocal, + g.firstRef->location, + "Global '%s' is only used in the enclosing function '%s'; consider changing it to local", + g.firstRef->name.value, + top->debugname.value + ); else - emitWarning(*context, LintWarning::Code_GlobalUsedAsLocal, g.firstRef->location, + emitWarning( + *context, + LintWarning::Code_GlobalUsedAsLocal, + g.firstRef->location, "Global '%s' is only used in the enclosing function defined at line %d; consider changing it to local", - g.firstRef->name.value, top->location.begin.line + 1); + g.firstRef->name.value, + top->location.begin.line + 1 + ); } else if (g.assigned && !g.readBeforeWritten && !g.definedInModuleScope && g.firstRef->name != context->placeholder) { - emitWarning(*context, LintWarning::Code_GlobalUsedAsLocal, g.firstRef->location, - "Global '%s' is never read before being written. Consider changing it to local", g.firstRef->name.value); + emitWarning( + *context, + LintWarning::Code_GlobalUsedAsLocal, + g.firstRef->location, + "Global '%s' is never read before being written. Consider changing it to local", + g.firstRef->name.value + ); } } } @@ -360,7 +350,8 @@ class LintGlobalLocal : AstVisitor if (node->name == context->placeholder) emitWarning( - *context, LintWarning::Code_PlaceholderRead, node->location, "Placeholder value '_' is read here; consider using a named variable"); + *context, LintWarning::Code_PlaceholderRead, node->location, "Placeholder value '_' is read here; consider using a named variable" + ); return true; } @@ -369,7 +360,8 @@ class LintGlobalLocal : AstVisitor { if (node->local->name == context->placeholder) emitWarning( - *context, LintWarning::Code_PlaceholderRead, node->location, "Placeholder value '_' is read here; consider using a named variable"); + *context, LintWarning::Code_PlaceholderRead, node->location, "Placeholder value '_' is read here; consider using a named variable" + ); return true; } @@ -397,8 +389,13 @@ class LintGlobalLocal : AstVisitor } if (g.builtin) - emitWarning(*context, LintWarning::Code_BuiltinGlobalWrite, gv->location, - "Built-in global '%s' is overwritten here; consider using a local or changing the name", gv->name.value); + emitWarning( + *context, + LintWarning::Code_BuiltinGlobalWrite, + gv->location, + "Built-in global '%s' is overwritten here; consider using a local or changing the name", + gv->name.value + ); else g.assigned = true; @@ -427,8 +424,13 @@ class LintGlobalLocal : AstVisitor Global& g = globals[gv->name]; if (g.builtin) - emitWarning(*context, LintWarning::Code_BuiltinGlobalWrite, gv->location, - "Built-in global '%s' is overwritten here; consider using a local or changing the name", gv->name.value); + emitWarning( + *context, + LintWarning::Code_BuiltinGlobalWrite, + gv->location, + "Built-in global '%s' is overwritten here; consider using a local or changing the name", + gv->name.value + ); else { g.assigned = true; @@ -596,8 +598,12 @@ class LintSameLineStatement : AstVisitor if (node->body.data[i - 1]->hasSemicolon) continue; - emitWarning(*context, LintWarning::Code_SameLineStatement, location, - "A new statement is on the same line; add semi-colon on previous statement to silence"); + emitWarning( + *context, + LintWarning::Code_SameLineStatement, + location, + "A new statement is on the same line; add semi-colon on previous statement to silence" + ); lastLine = location.begin.line; } @@ -644,7 +650,8 @@ class LintMultiLineStatement : AstVisitor if (location.begin.column <= top.start.begin.column) { emitWarning( - *context, LintWarning::Code_MultiLineStatement, location, "Statement spans multiple lines; use indentation to silence"); + *context, LintWarning::Code_MultiLineStatement, location, "Statement spans multiple lines; use indentation to silence" + ); top.flagged = true; } @@ -758,8 +765,14 @@ class LintLocalHygiene : AstVisitor // don't warn on inter-function shadowing since it is much more fragile wrt refactoring if (shadow->functionDepth == local->functionDepth) - emitWarning(*context, LintWarning::Code_LocalShadow, local->location, "Variable '%s' shadows previous declaration at line %d", - local->name.value, shadow->location.begin.line + 1); + emitWarning( + *context, + LintWarning::Code_LocalShadow, + local->location, + "Variable '%s' shadows previous declaration at line %d", + local->name.value, + shadow->location.begin.line + 1 + ); } else if (Global* global = globals.find(local->name)) { @@ -767,8 +780,14 @@ class LintLocalHygiene : AstVisitor ; // there are many builtins with common names like 'table'; some of them are deprecated as well else if (global->firstRef) { - emitWarning(*context, LintWarning::Code_LocalShadow, local->location, "Variable '%s' shadows a global variable used at line %d", - local->name.value, global->firstRef->location.begin.line + 1); + emitWarning( + *context, + LintWarning::Code_LocalShadow, + local->location, + "Variable '%s' shadows a global variable used at line %d", + local->name.value, + global->firstRef->location.begin.line + 1 + ); } else { @@ -783,14 +802,21 @@ class LintLocalHygiene : AstVisitor return; if (info.function) - emitWarning(*context, LintWarning::Code_FunctionUnused, local->location, "Function '%s' is never used; prefix with '_' to silence", - local->name.value); + emitWarning( + *context, + LintWarning::Code_FunctionUnused, + local->location, + "Function '%s' is never used; prefix with '_' to silence", + local->name.value + ); else if (info.import) - emitWarning(*context, LintWarning::Code_ImportUnused, local->location, "Import '%s' is never used; prefix with '_' to silence", - local->name.value); + emitWarning( + *context, LintWarning::Code_ImportUnused, local->location, "Import '%s' is never used; prefix with '_' to silence", local->name.value + ); else - emitWarning(*context, LintWarning::Code_LocalUnused, local->location, "Variable '%s' is never used; prefix with '_' to silence", - local->name.value); + emitWarning( + *context, LintWarning::Code_LocalUnused, local->location, "Variable '%s' is never used; prefix with '_' to silence", local->name.value + ); } bool isRequireCall(AstExpr* expr) @@ -944,8 +970,13 @@ class LintUnusedFunction : AstVisitor for (auto& g : globals) { if (g.second.function && !g.second.used && g.first.value[0] != '_') - emitWarning(*context, LintWarning::Code_FunctionUnused, g.second.location, "Function '%s' is never used; prefix with '_' to silence", - g.first.value); + emitWarning( + *context, + LintWarning::Code_FunctionUnused, + g.second.location, + "Function '%s' is never used; prefix with '_' to silence", + g.first.value + ); } } @@ -1044,8 +1075,13 @@ class LintUnreachableCode : AstVisitor if (step == Error && si->is() && next->is() && i + 2 == stat->body.size) return Error; - emitWarning(*context, LintWarning::Code_UnreachableCode, next->location, "Unreachable code (previous statement always %ss)", - getReason(step)); + emitWarning( + *context, + LintWarning::Code_UnreachableCode, + next->location, + "Unreachable code (previous statement always %ss)", + getReason(step) + ); return step; } } @@ -1142,7 +1178,7 @@ class LintUnknownType : AstVisitor TypeKind getTypeKind(const std::string& name) { if (name == "nil" || name == "boolean" || name == "userdata" || name == "number" || name == "string" || name == "table" || - name == "function" || name == "thread") + name == "function" || name == "thread" || name == "buffer") return Kind_Primitive; if (name == "vector") @@ -1240,22 +1276,34 @@ class LintForRange : AstVisitor // for i=#t,1 do if (fu && fu->op == AstExprUnary::Len && tc && tc->value == 1.0) emitWarning( - *context, LintWarning::Code_ForRange, rangeLocation, "For loop should iterate backwards; did you forget to specify -1 as step?"); + *context, LintWarning::Code_ForRange, rangeLocation, "For loop should iterate backwards; did you forget to specify -1 as step?" + ); // for i=8,1 do else if (fc && tc && fc->value > tc->value) emitWarning( - *context, LintWarning::Code_ForRange, rangeLocation, "For loop should iterate backwards; did you forget to specify -1 as step?"); + *context, LintWarning::Code_ForRange, rangeLocation, "For loop should iterate backwards; did you forget to specify -1 as step?" + ); // for i=1,8.75 do else if (fc && tc && getLoopEnd(fc->value, tc->value) != tc->value) - emitWarning(*context, LintWarning::Code_ForRange, rangeLocation, "For loop ends at %g instead of %g; did you forget to specify step?", - getLoopEnd(fc->value, tc->value), tc->value); + emitWarning( + *context, + LintWarning::Code_ForRange, + rangeLocation, + "For loop ends at %g instead of %g; did you forget to specify step?", + getLoopEnd(fc->value, tc->value), + tc->value + ); // for i=0,#t do else if (fc && tu && fc->value == 0.0 && tu->op == AstExprUnary::Len) emitWarning(*context, LintWarning::Code_ForRange, rangeLocation, "For loop starts at 0, but arrays start at 1"); // for i=#t,0 do else if (fu && fu->op == AstExprUnary::Len && tc && tc->value == 0.0) - emitWarning(*context, LintWarning::Code_ForRange, rangeLocation, - "For loop should iterate backwards; did you forget to specify -1 as step? Also consider changing 0 to 1 since arrays start at 1"); + emitWarning( + *context, + LintWarning::Code_ForRange, + rangeLocation, + "For loop should iterate backwards; did you forget to specify -1 as step? Also consider changing 0 to 1 since arrays start at 1" + ); } return true; @@ -1283,16 +1331,27 @@ class LintUnbalancedAssignment : AstVisitor AstExpr* last = values.data[values.size - 1]; if (vars < values.size) - emitWarning(*context, LintWarning::Code_UnbalancedAssignment, location, - "Assigning %d values to %d variables leaves some values unused", int(values.size), int(vars)); + emitWarning( + *context, + LintWarning::Code_UnbalancedAssignment, + location, + "Assigning %d values to %d variables leaves some values unused", + int(values.size), + int(vars) + ); else if (last->is() || last->is()) ; // we don't know how many values the last expression returns else if (last->is()) ; // last expression is nil which explicitly silences the nil-init warning else - emitWarning(*context, LintWarning::Code_UnbalancedAssignment, location, - "Assigning %d values to %d variables initializes extra variables with nil; add 'nil' to value list to silence", int(values.size), - int(vars)); + emitWarning( + *context, + LintWarning::Code_UnbalancedAssignment, + location, + "Assigning %d values to %d variables initializes extra variables with nil; add 'nil' to value list to silence", + int(values.size), + int(vars) + ); } } @@ -1375,13 +1434,22 @@ class LintImplicitReturn : AstVisitor Location location = getEndLocation(bodyf); if (node->debugname.value) - emitWarning(*context, LintWarning::Code_ImplicitReturn, location, + emitWarning( + *context, + LintWarning::Code_ImplicitReturn, + location, "Function '%s' can implicitly return no values even though there's an explicit return at line %d; add explicit return to silence", - node->debugname.value, vret->location.begin.line + 1); + node->debugname.value, + vret->location.begin.line + 1 + ); else - emitWarning(*context, LintWarning::Code_ImplicitReturn, location, + emitWarning( + *context, + LintWarning::Code_ImplicitReturn, + location, "Function can implicitly return no values even though there's an explicit return at line %d; add explicit return to silence", - vret->location.begin.line + 1); + vret->location.begin.line + 1 + ); } return true; @@ -1852,23 +1920,41 @@ class LintTableLiteral : AstVisitor int& line = names[&expr->value]; if (line) - emitWarning(*context, LintWarning::Code_TableLiteral, expr->location, - "Table field '%.*s' is a duplicate; previously defined at line %d", int(expr->value.size), expr->value.data, line); + emitWarning( + *context, + LintWarning::Code_TableLiteral, + expr->location, + "Table field '%.*s' is a duplicate; previously defined at line %d", + int(expr->value.size), + expr->value.data, + line + ); else line = expr->location.begin.line + 1; } else if (AstExprConstantNumber* expr = item.key->as()) { if (expr->value >= 1 && expr->value <= double(count) && double(int(expr->value)) == expr->value) - emitWarning(*context, LintWarning::Code_TableLiteral, expr->location, - "Table index %d is a duplicate; previously defined as a list entry", int(expr->value)); + emitWarning( + *context, + LintWarning::Code_TableLiteral, + expr->location, + "Table index %d is a duplicate; previously defined as a list entry", + int(expr->value) + ); else if (expr->value >= 0 && expr->value <= double(INT_MAX) && double(int(expr->value)) == expr->value) { int& line = indices[int(expr->value)]; if (line) - emitWarning(*context, LintWarning::Code_TableLiteral, expr->location, - "Table index %d is a duplicate; previously defined at line %d", int(expr->value), line); + emitWarning( + *context, + LintWarning::Code_TableLiteral, + expr->location, + "Table index %d is a duplicate; previously defined at line %d", + int(expr->value), + line + ); else line = expr->location.begin.line + 1; } @@ -1885,6 +1971,72 @@ class LintTableLiteral : AstVisitor bool visit(AstTypeTable* node) override { + if (FFlag::DebugLuauDeferredConstraintResolution) + { + struct Rec + { + AstTableAccess access; + Location location; + }; + DenseHashMap names(AstName{}); + + for (const AstTableProp& item : node->props) + { + Rec* rec = names.find(item.name); + if (!rec) + { + names[item.name] = Rec{item.access, item.location}; + continue; + } + + if (int(rec->access) & int(item.access)) + { + if (rec->access == item.access) + emitWarning( + *context, + LintWarning::Code_TableLiteral, + item.location, + "Table type field '%s' is a duplicate; previously defined at line %d", + item.name.value, + rec->location.begin.line + 1 + ); + else if (rec->access == AstTableAccess::ReadWrite) + emitWarning( + *context, + LintWarning::Code_TableLiteral, + item.location, + "Table type field '%s' is already read-write; previously defined at line %d", + item.name.value, + rec->location.begin.line + 1 + ); + else if (rec->access == AstTableAccess::Read) + emitWarning( + *context, + LintWarning::Code_TableLiteral, + rec->location, + "Table type field '%s' already has a read type defined at line %d", + item.name.value, + rec->location.begin.line + 1 + ); + else if (rec->access == AstTableAccess::Write) + emitWarning( + *context, + LintWarning::Code_TableLiteral, + rec->location, + "Table type field '%s' already has a write type defined at line %d", + item.name.value, + rec->location.begin.line + 1 + ); + else + LUAU_ASSERT(!"Unreachable"); + } + else + rec->access = AstTableAccess(int(rec->access) | int(item.access)); + } + + return true; + } + DenseHashMap names(AstName{}); for (const AstTableProp& item : node->props) @@ -1892,8 +2044,14 @@ class LintTableLiteral : AstVisitor int& line = names[item.name]; if (line) - emitWarning(*context, LintWarning::Code_TableLiteral, item.location, - "Table type field '%s' is a duplicate; previously defined at line %d", item.name.value, line); + emitWarning( + *context, + LintWarning::Code_TableLiteral, + item.location, + "Table type field '%s' is a duplicate; previously defined at line %d", + item.name.value, + line + ); else line = item.location.begin.line + 1; } @@ -1954,9 +2112,14 @@ class LintUninitializedLocal : AstVisitor if (l.defined && !l.initialized && !l.assigned && l.firstUse) { - emitWarning(*context, LintWarning::Code_UninitializedLocal, l.firstUse->location, - "Variable '%s' defined at line %d is never initialized or assigned; initialize with 'nil' to silence", local->name.value, - local->location.begin.line + 1); + emitWarning( + *context, + LintWarning::Code_UninitializedLocal, + l.firstUse->location, + "Variable '%s' defined at line %d is never initialized or assigned; initialize with 'nil' to silence", + local->name.value, + local->location.begin.line + 1 + ); } } } @@ -2090,8 +2253,14 @@ class LintDuplicateFunction : AstVisitor void report(const std::string& name, Location location, Location otherLocation) { - emitWarning(*context, LintWarning::Code_DuplicateFunction, location, "Duplicate function definition: '%s' also defined on line %d", - name.c_str(), otherLocation.begin.line + 1); + emitWarning( + *context, + LintWarning::Code_DuplicateFunction, + location, + "Duplicate function definition: '%s' also defined on line %d", + name.c_str(), + otherLocation.begin.line + 1 + ); } }; @@ -2122,6 +2291,33 @@ class LintDeprecatedApi : AstVisitor return true; } + bool visit(AstExprCall* node) override + { + // getfenv/setfenv are deprecated, however they are still used in some test frameworks and don't have a great general replacement + // for now we warn about the deprecation only when they are used with a numeric first argument; this produces fewer warnings and makes use + // of getfenv/setfenv a little more localized + if (!node->self && node->args.size >= 1) + { + if (AstExprGlobal* fenv = node->func->as(); fenv && (fenv->name == "getfenv" || fenv->name == "setfenv")) + { + AstExpr* level = node->args.data[0]; + std::optional ty = context->getType(level); + + if ((ty && isNumber(*ty)) || level->is()) + { + // some common uses of getfenv(n) can be replaced by debug.info if the goal is to get the caller's identity + const char* suggestion = (fenv->name == "getfenv") ? "; consider using 'debug.info' instead" : ""; + + emitWarning( + *context, LintWarning::Code_DeprecatedApi, node->location, "Function '%s' is deprecated%s", fenv->name.value, suggestion + ); + } + } + } + + return true; + } + void check(AstExprIndexName* node, TypeId ty) { if (const ClassType* cty = get(ty)) @@ -2191,16 +2387,50 @@ class LintTableOperations : AstVisitor { } + bool visit(AstExprUnary* node) override + { + if (node->op == AstExprUnary::Len) + checkIndexer(node, node->expr, "#"); + + return true; + } + bool visit(AstExprCall* node) override { - AstExprIndexName* func = node->func->as(); - if (!func) - return true; + if (AstExprGlobal* func = node->func->as()) + { + if (func->name == "ipairs" && node->args.size == 1) + checkIndexer(node, node->args.data[0], "ipairs"); + } + else if (AstExprIndexName* func = node->func->as()) + { + if (AstExprGlobal* tablib = func->expr->as(); tablib && tablib->name == "table") + checkTableCall(node, func); + } - AstExprGlobal* tablib = func->expr->as(); - if (!tablib || tablib->name != "table") - return true; + return true; + } + + void checkIndexer(AstExpr* node, AstExpr* expr, const char* op) + { + std::optional ty = context->getType(expr); + if (!ty) + return; + + const TableType* tty = get(follow(*ty)); + if (!tty) + return; + + if (!tty->indexer && !tty->props.empty() && tty->state != TableState::Generic) + emitWarning( + *context, LintWarning::Code_TableOperations, node->location, "Using '%s' on a table without an array part is likely a bug", op + ); + else if (tty->indexer && isString(tty->indexer->indexType)) // note: to avoid complexity of subtype tests we just check if the key is a string + emitWarning(*context, LintWarning::Code_TableOperations, node->location, "Using '%s' on a table with string keys is likely a bug", op); + } + void checkTableCall(AstExprCall* node, AstExprIndexName* func) + { AstExpr** args = node->args.data; if (func->index == "insert" && node->args.size == 2) @@ -2212,9 +2442,13 @@ class LintTableOperations : AstVisitor size_t ret = getReturnCount(follow(*funty)); if (ret > 1) - emitWarning(*context, LintWarning::Code_TableOperations, tail->location, + emitWarning( + *context, + LintWarning::Code_TableOperations, + tail->location, "table.insert may change behavior if the call returns more than one result; consider adding parentheses around second " - "argument"); + "argument" + ); } } } @@ -2223,28 +2457,44 @@ class LintTableOperations : AstVisitor { // table.insert(t, 0, ?) if (isConstant(args[1], 0.0)) - emitWarning(*context, LintWarning::Code_TableOperations, args[1]->location, - "table.insert uses index 0 but arrays are 1-based; did you mean 1 instead?"); + emitWarning( + *context, + LintWarning::Code_TableOperations, + args[1]->location, + "table.insert uses index 0 but arrays are 1-based; did you mean 1 instead?" + ); // table.insert(t, #t, ?) if (isLength(args[1], args[0])) - emitWarning(*context, LintWarning::Code_TableOperations, args[1]->location, + emitWarning( + *context, + LintWarning::Code_TableOperations, + args[1]->location, "table.insert will insert the value before the last element, which is likely a bug; consider removing the second argument or " - "wrap it in parentheses to silence"); + "wrap it in parentheses to silence" + ); // table.insert(t, #t+1, ?) if (AstExprBinary* add = args[1]->as(); add && add->op == AstExprBinary::Add && isLength(add->left, args[0]) && isConstant(add->right, 1.0)) - emitWarning(*context, LintWarning::Code_TableOperations, args[1]->location, - "table.insert will append the value to the table; consider removing the second argument for efficiency"); + emitWarning( + *context, + LintWarning::Code_TableOperations, + args[1]->location, + "table.insert will append the value to the table; consider removing the second argument for efficiency" + ); } if (func->index == "remove" && node->args.size >= 2) { // table.remove(t, 0) if (isConstant(args[1], 0.0)) - emitWarning(*context, LintWarning::Code_TableOperations, args[1]->location, - "table.remove uses index 0 but arrays are 1-based; did you mean 1 instead?"); + emitWarning( + *context, + LintWarning::Code_TableOperations, + args[1]->location, + "table.remove uses index 0 but arrays are 1-based; did you mean 1 instead?" + ); // note: it's tempting to check for table.remove(t, #t), which is equivalent to table.remove(t), but it's correct, occurs frequently, // and also reads better. @@ -2252,38 +2502,56 @@ class LintTableOperations : AstVisitor // table.remove(t, #t-1) if (AstExprBinary* sub = args[1]->as(); sub && sub->op == AstExprBinary::Sub && isLength(sub->left, args[0]) && isConstant(sub->right, 1.0)) - emitWarning(*context, LintWarning::Code_TableOperations, args[1]->location, + emitWarning( + *context, + LintWarning::Code_TableOperations, + args[1]->location, "table.remove will remove the value before the last element, which is likely a bug; consider removing the second argument or " - "wrap it in parentheses to silence"); + "wrap it in parentheses to silence" + ); } if (func->index == "move" && node->args.size >= 4) { // table.move(t, 0, _, _) if (isConstant(args[1], 0.0)) - emitWarning(*context, LintWarning::Code_TableOperations, args[1]->location, - "table.move uses index 0 but arrays are 1-based; did you mean 1 instead?"); + emitWarning( + *context, + LintWarning::Code_TableOperations, + args[1]->location, + "table.move uses index 0 but arrays are 1-based; did you mean 1 instead?" + ); // table.move(t, _, _, 0) else if (isConstant(args[3], 0.0)) - emitWarning(*context, LintWarning::Code_TableOperations, args[3]->location, - "table.move uses index 0 but arrays are 1-based; did you mean 1 instead?"); + emitWarning( + *context, + LintWarning::Code_TableOperations, + args[3]->location, + "table.move uses index 0 but arrays are 1-based; did you mean 1 instead?" + ); } if (func->index == "create" && node->args.size == 2) { // table.create(n, {...}) if (args[1]->is()) - emitWarning(*context, LintWarning::Code_TableOperations, args[1]->location, - "table.create with a table literal will reuse the same object for all elements; consider using a for loop instead"); + emitWarning( + *context, + LintWarning::Code_TableOperations, + args[1]->location, + "table.create with a table literal will reuse the same object for all elements; consider using a for loop instead" + ); // table.create(n, {...} :: ?) if (AstExprTypeAssertion* as = args[1]->as(); as && as->expr->is()) - emitWarning(*context, LintWarning::Code_TableOperations, as->expr->location, - "table.create with a table literal will reuse the same object for all elements; consider using a for loop instead"); + emitWarning( + *context, + LintWarning::Code_TableOperations, + as->expr->location, + "table.create with a table literal will reuse the same object for all elements; consider using a for loop instead" + ); } - - return true; } bool isConstant(AstExpr* expr, double value) @@ -2474,11 +2742,21 @@ class LintDuplicateCondition : AstVisitor if (similar(conditions[j], conditions[i])) { if (conditions[i]->location.begin.line == conditions[j]->location.begin.line) - emitWarning(*context, LintWarning::Code_DuplicateCondition, conditions[i]->location, - "Condition has already been checked on column %d", conditions[j]->location.begin.column + 1); + emitWarning( + *context, + LintWarning::Code_DuplicateCondition, + conditions[i]->location, + "Condition has already been checked on column %d", + conditions[j]->location.begin.column + 1 + ); else - emitWarning(*context, LintWarning::Code_DuplicateCondition, conditions[i]->location, - "Condition has already been checked on line %d", conditions[j]->location.begin.line + 1); + emitWarning( + *context, + LintWarning::Code_DuplicateCondition, + conditions[i]->location, + "Condition has already been checked on line %d", + conditions[j]->location.begin.line + 1 + ); break; } } @@ -2523,11 +2801,23 @@ class LintDuplicateLocal : AstVisitor if (local->shadow && locals[local->shadow] == node && !ignoreDuplicate(local)) { if (local->shadow->location.begin.line == local->location.begin.line) - emitWarning(*context, LintWarning::Code_DuplicateLocal, local->location, "Variable '%s' already defined on column %d", - local->name.value, local->shadow->location.begin.column + 1); + emitWarning( + *context, + LintWarning::Code_DuplicateLocal, + local->location, + "Variable '%s' already defined on column %d", + local->name.value, + local->shadow->location.begin.column + 1 + ); else - emitWarning(*context, LintWarning::Code_DuplicateLocal, local->location, "Variable '%s' already defined on line %d", - local->name.value, local->shadow->location.begin.line + 1); + emitWarning( + *context, + LintWarning::Code_DuplicateLocal, + local->location, + "Variable '%s' already defined on line %d", + local->name.value, + local->shadow->location.begin.line + 1 + ); } } @@ -2551,11 +2841,23 @@ class LintDuplicateLocal : AstVisitor if (local->shadow == node->self) emitWarning(*context, LintWarning::Code_DuplicateLocal, local->location, "Function parameter 'self' already defined implicitly"); else if (local->shadow->location.begin.line == local->location.begin.line) - emitWarning(*context, LintWarning::Code_DuplicateLocal, local->location, "Function parameter '%s' already defined on column %d", - local->name.value, local->shadow->location.begin.column + 1); + emitWarning( + *context, + LintWarning::Code_DuplicateLocal, + local->location, + "Function parameter '%s' already defined on column %d", + local->name.value, + local->shadow->location.begin.column + 1 + ); else - emitWarning(*context, LintWarning::Code_DuplicateLocal, local->location, "Function parameter '%s' already defined on line %d", - local->name.value, local->shadow->location.begin.line + 1); + emitWarning( + *context, + LintWarning::Code_DuplicateLocal, + local->location, + "Function parameter '%s' already defined on line %d", + local->name.value, + local->shadow->location.begin.line + 1 + ); } } @@ -2599,10 +2901,14 @@ class LintMisleadingAndOr : AstVisitor alt = "false"; if (alt) - emitWarning(*context, LintWarning::Code_MisleadingAndOr, node->location, + emitWarning( + *context, + LintWarning::Code_MisleadingAndOr, + node->location, "The and-or expression always evaluates to the second alternative because the first alternative is %s; consider using if-then-else " "expression instead", - alt); + alt + ); return true; } @@ -2629,13 +2935,29 @@ class LintIntegerParsing : AstVisitor case ConstantNumberParseResult::Ok: case ConstantNumberParseResult::Malformed: break; + case ConstantNumberParseResult::Imprecise: + emitWarning( + *context, + LintWarning::Code_IntegerParsing, + node->location, + "Number literal exceeded available precision and was truncated to closest representable number" + ); + break; case ConstantNumberParseResult::BinOverflow: - emitWarning(*context, LintWarning::Code_IntegerParsing, node->location, - "Binary number literal exceeded available precision and has been truncated to 2^64"); + emitWarning( + *context, + LintWarning::Code_IntegerParsing, + node->location, + "Binary number literal exceeded available precision and was truncated to 2^64" + ); break; case ConstantNumberParseResult::HexOverflow: - emitWarning(*context, LintWarning::Code_IntegerParsing, node->location, - "Hexadecimal number literal exceeded available precision and has been truncated to 2^64"); + emitWarning( + *context, + LintWarning::Code_IntegerParsing, + node->location, + "Hexadecimal number literal exceeded available precision and was truncated to 2^64" + ); break; } @@ -2686,12 +3008,24 @@ class LintComparisonPrecedence : AstVisitor std::string op = toString(node->op); if (isEquality(node->op)) - emitWarning(*context, LintWarning::Code_ComparisonPrecedence, node->location, - "not X %s Y is equivalent to (not X) %s Y; consider using X %s Y, or add parentheses to silence", op.c_str(), op.c_str(), - node->op == AstExprBinary::CompareEq ? "~=" : "=="); + emitWarning( + *context, + LintWarning::Code_ComparisonPrecedence, + node->location, + "not X %s Y is equivalent to (not X) %s Y; consider using X %s Y, or add parentheses to silence", + op.c_str(), + op.c_str(), + node->op == AstExprBinary::CompareEq ? "~=" : "==" + ); else - emitWarning(*context, LintWarning::Code_ComparisonPrecedence, node->location, - "not X %s Y is equivalent to (not X) %s Y; add parentheses to silence", op.c_str(), op.c_str()); + emitWarning( + *context, + LintWarning::Code_ComparisonPrecedence, + node->location, + "not X %s Y is equivalent to (not X) %s Y; add parentheses to silence", + op.c_str(), + op.c_str() + ); } else if (AstExprBinary* left = node->left->as(); left && isComparison(left->op)) { @@ -2699,12 +3033,29 @@ class LintComparisonPrecedence : AstVisitor std::string rop = toString(node->op); if (isEquality(left->op) || isEquality(node->op)) - emitWarning(*context, LintWarning::Code_ComparisonPrecedence, node->location, - "X %s Y %s Z is equivalent to (X %s Y) %s Z; add parentheses to silence", lop.c_str(), rop.c_str(), lop.c_str(), rop.c_str()); + emitWarning( + *context, + LintWarning::Code_ComparisonPrecedence, + node->location, + "X %s Y %s Z is equivalent to (X %s Y) %s Z; add parentheses to silence", + lop.c_str(), + rop.c_str(), + lop.c_str(), + rop.c_str() + ); else - emitWarning(*context, LintWarning::Code_ComparisonPrecedence, node->location, - "X %s Y %s Z is equivalent to (X %s Y) %s Z; did you mean X %s Y and Y %s Z?", lop.c_str(), rop.c_str(), lop.c_str(), rop.c_str(), - lop.c_str(), rop.c_str()); + emitWarning( + *context, + LintWarning::Code_ComparisonPrecedence, + node->location, + "X %s Y %s Z is equivalent to (X %s Y) %s Z; did you mean X %s Y and Y %s Z?", + lop.c_str(), + rop.c_str(), + lop.c_str(), + rop.c_str(), + lop.c_str(), + rop.c_str() + ); } return true; @@ -2770,8 +3121,12 @@ static void lintComments(LintContext& context, const std::vector& ho if (!hc.header) { - emitWarning(context, LintWarning::Code_CommentDirective, hc.location, - "Comment directive is ignored because it is placed after the first non-comment token"); + emitWarning( + context, + LintWarning::Code_CommentDirective, + hc.location, + "Comment directive is ignored because it is placed after the first non-comment token" + ); } else { @@ -2792,21 +3147,36 @@ static void lintComments(LintContext& context, const std::vector& ho // skip Unknown if (const char* suggestion = fuzzyMatch(rule, kWarningNames + 1, LintWarning::Code__Count - 1)) - emitWarning(context, LintWarning::Code_CommentDirective, hc.location, - "nolint directive refers to unknown lint rule '%s'; did you mean '%s'?", rule, suggestion); + emitWarning( + context, + LintWarning::Code_CommentDirective, + hc.location, + "nolint directive refers to unknown lint rule '%s'; did you mean '%s'?", + rule, + suggestion + ); else emitWarning( - context, LintWarning::Code_CommentDirective, hc.location, "nolint directive refers to unknown lint rule '%s'", rule); + context, LintWarning::Code_CommentDirective, hc.location, "nolint directive refers to unknown lint rule '%s'", rule + ); } } else if (first == "nocheck" || first == "nonstrict" || first == "strict") { if (space != std::string::npos) - emitWarning(context, LintWarning::Code_CommentDirective, hc.location, - "Comment directive with the type checking mode has extra symbols at the end of the line"); + emitWarning( + context, + LintWarning::Code_CommentDirective, + hc.location, + "Comment directive with the type checking mode has extra symbols at the end of the line" + ); else if (seenMode) - emitWarning(context, LintWarning::Code_CommentDirective, hc.location, - "Comment directive with the type checking mode has already been used"); + emitWarning( + context, + LintWarning::Code_CommentDirective, + hc.location, + "Comment directive with the type checking mode has already been used" + ); else seenMode = true; } @@ -2821,10 +3191,22 @@ static void lintComments(LintContext& context, const std::vector& ho const char* level = hc.content.c_str() + notspace; if (strcmp(level, "0") && strcmp(level, "1") && strcmp(level, "2")) - emitWarning(context, LintWarning::Code_CommentDirective, hc.location, - "optimize directive uses unknown optimization level '%s', 0..2 expected", level); + emitWarning( + context, + LintWarning::Code_CommentDirective, + hc.location, + "optimize directive uses unknown optimization level '%s', 0..2 expected", + level + ); } } + else if (first == "native") + { + if (space != std::string::npos) + emitWarning( + context, LintWarning::Code_CommentDirective, hc.location, "native directive has extra symbols at the end of the line" + ); + } else { static const char* kHotComments[] = { @@ -2833,27 +3215,96 @@ static void lintComments(LintContext& context, const std::vector& ho "nonstrict", "strict", "optimize", + "native", }; if (const char* suggestion = fuzzyMatch(first, kHotComments, std::size(kHotComments))) - emitWarning(context, LintWarning::Code_CommentDirective, hc.location, "Unknown comment directive '%.*s'; did you mean '%s'?", - int(first.size()), first.data(), suggestion); + emitWarning( + context, + LintWarning::Code_CommentDirective, + hc.location, + "Unknown comment directive '%.*s'; did you mean '%s'?", + int(first.size()), + first.data(), + suggestion + ); else - emitWarning(context, LintWarning::Code_CommentDirective, hc.location, "Unknown comment directive '%.*s'", int(first.size()), - first.data()); + emitWarning( + context, LintWarning::Code_CommentDirective, hc.location, "Unknown comment directive '%.*s'", int(first.size()), first.data() + ); } } } } -void LintOptions::setDefaults() +static bool hasNativeCommentDirective(const std::vector& hotcomments) { - // By default, we enable all warnings - warningMask = ~0ull; + LUAU_ASSERT(FFlag::LuauNativeAttribute); + LUAU_ASSERT(FFlag::LintRedundantNativeAttribute); + + for (const HotComment& hc : hotcomments) + { + if (hc.content.empty() || hc.content[0] == ' ' || hc.content[0] == '\t') + continue; + + if (hc.header) + { + size_t space = hc.content.find_first_of(" \t"); + std::string_view first = std::string_view(hc.content).substr(0, space); + + if (first == "native") + return true; + } + } + + return false; } -std::vector lint(AstStat* root, const AstNameTable& names, const ScopePtr& env, const Module* module, - const std::vector& hotcomments, const LintOptions& options) +struct LintRedundantNativeAttribute : AstVisitor +{ +public: + LUAU_NOINLINE static void process(LintContext& context) + { + LUAU_ASSERT(FFlag::LuauNativeAttribute); + LUAU_ASSERT(FFlag::LintRedundantNativeAttribute); + + LintRedundantNativeAttribute pass; + pass.context = &context; + context.root->visit(&pass); + } + +private: + LintContext* context; + + bool visit(AstExprFunction* node) override + { + node->body->visit(this); + + for (const auto attribute : node->attributes) + { + if (attribute->type == AstAttr::Type::Native) + { + emitWarning( + *context, + LintWarning::Code_RedundantNativeAttribute, + attribute->location, + "native attribute on a function is redundant in a native module; consider removing it" + ); + } + } + + return false; + } +}; + +std::vector lint( + AstStat* root, + const AstNameTable& names, + const ScopePtr& env, + const Module* module, + const std::vector& hotcomments, + const LintOptions& options +) { LintContext context; @@ -2938,57 +3389,15 @@ std::vector lint(AstStat* root, const AstNameTable& names, const Sc if (context.warningEnabled(LintWarning::Code_ComparisonPrecedence)) LintComparisonPrecedence::process(context); - std::sort(context.result.begin(), context.result.end(), WarningComparator()); - - return context.result; -} - -const char* LintWarning::getName(Code code) -{ - LUAU_ASSERT(unsigned(code) < Code__Count); - - return kWarningNames[code]; -} - -LintWarning::Code LintWarning::parseName(const char* name) -{ - for (int code = Code_Unknown; code < Code__Count; ++code) - if (strcmp(name, getName(Code(code))) == 0) - return Code(code); - - return Code_Unknown; -} - -uint64_t LintWarning::parseMask(const std::vector& hotcomments) -{ - uint64_t result = 0; - - for (const HotComment& hc : hotcomments) + if (FFlag::LuauNativeAttribute && FFlag::LintRedundantNativeAttribute && context.warningEnabled(LintWarning::Code_RedundantNativeAttribute)) { - if (!hc.header) - continue; - - if (hc.content.compare(0, 6, "nolint") != 0) - continue; - - size_t name = hc.content.find_first_not_of(" \t", 6); - - // --!nolint disables everything - if (name == std::string::npos) - return ~0ull; - - // --!nolint needs to be followed by a whitespace character - if (name == 6) - continue; - - // --!nolint name disables the specific lint - LintWarning::Code code = LintWarning::parseName(hc.content.c_str() + name); - - if (code != LintWarning::Code_Unknown) - result |= 1ull << int(code); + if (hasNativeCommentDirective(hotcomments)) + LintRedundantNativeAttribute::process(context); } - return result; + std::sort(context.result.begin(), context.result.end(), WarningComparator()); + + return context.result; } std::vector getDeprecatedGlobals(const AstNameTable& names) diff --git a/third_party/luau/Analysis/src/Module.cpp b/third_party/luau/Analysis/src/Module.cpp index 830aaf75..f9a3f67a 100644 --- a/third_party/luau/Analysis/src/Module.cpp +++ b/third_party/luau/Analysis/src/Module.cpp @@ -3,24 +3,19 @@ #include "Luau/Clone.h" #include "Luau/Common.h" -#include "Luau/ConstraintGraphBuilder.h" +#include "Luau/ConstraintGenerator.h" #include "Luau/Normalize.h" #include "Luau/RecursionCounter.h" #include "Luau/Scope.h" #include "Luau/Type.h" #include "Luau/TypeInfer.h" #include "Luau/TypePack.h" -#include "Luau/TypeReduction.h" #include "Luau/VisitType.h" #include LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); -LUAU_FASTFLAGVARIABLE(LuauClonePublicInterfaceLess2, false); -LUAU_FASTFLAG(LuauSubstitutionReentrant); -LUAU_FASTFLAG(LuauClassTypeVarsInSubstitution); -LUAU_FASTFLAG(LuauSubstitutionFixMissingFields); -LUAU_FASTFLAGVARIABLE(LuauCopyExportedTypes, false); +LUAU_FASTFLAGVARIABLE(LuauSkipEmptyInstantiations, false); namespace Luau { @@ -29,8 +24,8 @@ static bool contains(Position pos, Comment comment) { if (comment.location.contains(pos)) return true; - else if (comment.type == Lexeme::BrokenComment && - comment.location.begin <= pos) // Broken comments are broken specifically because they don't have an end + else if (comment.type == Lexeme::BrokenComment && comment.location.begin <= pos) // Broken comments are broken specifically because they don't + // have an end return true; else if (comment.type == Lexeme::Comment && comment.location.end == pos) return true; @@ -41,9 +36,14 @@ static bool contains(Position pos, Comment comment) static bool isWithinComment(const std::vector& commentLocations, Position pos) { auto iter = std::lower_bound( - commentLocations.begin(), commentLocations.end(), Comment{Lexeme::Comment, Location{pos, pos}}, [](const Comment& a, const Comment& b) { + commentLocations.begin(), + commentLocations.end(), + Comment{Lexeme::Comment, Location{pos, pos}}, + [](const Comment& a, const Comment& b) + { return a.location.end < b.location.end; - }); + } + ); if (iter == commentLocations.end()) return false; @@ -100,14 +100,43 @@ struct ClonePublicInterface : Substitution return tp->owningArena == &module->internalTypes; } + bool ignoreChildrenVisit(TypeId ty) override + { + if (ty->owningArena != &module->internalTypes) + return true; + + return false; + } + + bool ignoreChildrenVisit(TypePackId tp) override + { + if (tp->owningArena != &module->internalTypes) + return true; + + return false; + } + TypeId clean(TypeId ty) override { TypeId result = clone(ty); if (FunctionType* ftv = getMutable(result)) + { + if (FFlag::LuauSkipEmptyInstantiations && ftv->generics.empty() && ftv->genericPacks.empty()) + { + GenericTypeFinder marker; + marker.traverse(result); + + if (!marker.found) + ftv->hasNoFreeOrGenericTypes = true; + } + ftv->level = TypeLevel{0, 0}; + } else if (TableType* ttv = getMutable(result)) + { ttv->level = TypeLevel{0, 0}; + } return result; } @@ -119,8 +148,6 @@ struct ClonePublicInterface : Substitution TypeId cloneType(TypeId ty) { - LUAU_ASSERT(FFlag::LuauSubstitutionReentrant && FFlag::LuauSubstitutionFixMissingFields); - std::optional result = substitute(ty); if (result) { @@ -135,8 +162,6 @@ struct ClonePublicInterface : Substitution TypePackId cloneTypePack(TypePackId tp) { - LUAU_ASSERT(FFlag::LuauSubstitutionReentrant && FFlag::LuauSubstitutionFixMissingFields); - std::optional result = substitute(tp); if (result) { @@ -151,8 +176,6 @@ struct ClonePublicInterface : Substitution TypeFun cloneTypeFun(const TypeFun& tf) { - LUAU_ASSERT(FFlag::LuauSubstitutionReentrant && FFlag::LuauSubstitutionFixMissingFields); - std::vector typeParams; std::vector typePackParams; @@ -192,10 +215,7 @@ Module::~Module() void Module::clonePublicInterface(NotNull builtinTypes, InternalErrorReporter& ice) { - LUAU_ASSERT(interfaceTypes.types.empty()); - LUAU_ASSERT(interfaceTypes.typePacks.empty()); - - CloneState cloneState; + CloneState cloneState{builtinTypes}; ScopePtr moduleScope = getModuleScope(); @@ -205,43 +225,28 @@ void Module::clonePublicInterface(NotNull builtinTypes, InternalEr TxnLog log; ClonePublicInterface clonePublicInterface{&log, builtinTypes, this}; - if (FFlag::LuauClonePublicInterfaceLess2) - returnType = clonePublicInterface.cloneTypePack(returnType); - else - returnType = clone(returnType, interfaceTypes, cloneState); + returnType = clonePublicInterface.cloneTypePack(returnType); moduleScope->returnType = returnType; if (varargPack) { - if (FFlag::LuauClonePublicInterfaceLess2) - varargPack = clonePublicInterface.cloneTypePack(*varargPack); - else - varargPack = clone(*varargPack, interfaceTypes, cloneState); + varargPack = clonePublicInterface.cloneTypePack(*varargPack); moduleScope->varargPack = varargPack; } for (auto& [name, tf] : moduleScope->exportedTypeBindings) { - if (FFlag::LuauClonePublicInterfaceLess2) - tf = clonePublicInterface.cloneTypeFun(tf); - else - tf = clone(tf, interfaceTypes, cloneState); + tf = clonePublicInterface.cloneTypeFun(tf); } for (auto& [name, ty] : declaredGlobals) { - if (FFlag::LuauClonePublicInterfaceLess2) - ty = clonePublicInterface.cloneType(ty); - else - ty = clone(ty, interfaceTypes, cloneState); + ty = clonePublicInterface.cloneType(ty); } // Copy external stuff over to Module itself this->returnType = moduleScope->returnType; - if (FFlag::DebugLuauDeferredConstraintResolution || FFlag::LuauCopyExportedTypes) - this->exportedTypeBindings = moduleScope->exportedTypeBindings; - else - this->exportedTypeBindings = std::move(moduleScope->exportedTypeBindings); + this->exportedTypeBindings = moduleScope->exportedTypeBindings; } bool Module::hasModuleScope() const diff --git a/third_party/luau/Analysis/src/NonStrictTypeChecker.cpp b/third_party/luau/Analysis/src/NonStrictTypeChecker.cpp new file mode 100644 index 00000000..16225e96 --- /dev/null +++ b/third_party/luau/Analysis/src/NonStrictTypeChecker.cpp @@ -0,0 +1,772 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/NonStrictTypeChecker.h" + +#include "Luau/Ast.h" +#include "Luau/Common.h" +#include "Luau/Simplify.h" +#include "Luau/Type.h" +#include "Luau/Simplify.h" +#include "Luau/Subtyping.h" +#include "Luau/Normalize.h" +#include "Luau/Error.h" +#include "Luau/TimeTrace.h" +#include "Luau/TypeArena.h" +#include "Luau/TypeFunction.h" +#include "Luau/Def.h" +#include "Luau/ToString.h" +#include "Luau/TypeFwd.h" + +#include +#include + +namespace Luau +{ + +/* Push a scope onto the end of a stack for the lifetime of the StackPusher instance. + * NonStrictTypeChecker uses this to maintain knowledge about which scope encloses every + * given AstNode. + */ +struct StackPusher +{ + std::vector>* stack; + NotNull scope; + + explicit StackPusher(std::vector>& stack, Scope* scope) + : stack(&stack) + , scope(scope) + { + stack.push_back(NotNull{scope}); + } + + ~StackPusher() + { + if (stack) + { + LUAU_ASSERT(stack->back() == scope); + stack->pop_back(); + } + } + + StackPusher(const StackPusher&) = delete; + StackPusher&& operator=(const StackPusher&) = delete; + + StackPusher(StackPusher&& other) + : stack(std::exchange(other.stack, nullptr)) + , scope(other.scope) + { + } +}; + + +struct NonStrictContext +{ + NonStrictContext() = default; + + NonStrictContext(const NonStrictContext&) = delete; + NonStrictContext& operator=(const NonStrictContext&) = delete; + + NonStrictContext(NonStrictContext&&) = default; + NonStrictContext& operator=(NonStrictContext&&) = default; + + static NonStrictContext disjunction( + NotNull builtinTypes, + NotNull arena, + const NonStrictContext& left, + const NonStrictContext& right + ) + { + // disjunction implements union over the domain of keys + // if the default value for a defId not in the map is `never` + // then never | T is T + NonStrictContext disj{}; + + for (auto [def, leftTy] : left.context) + { + if (std::optional rightTy = right.find(def)) + disj.context[def] = simplifyUnion(builtinTypes, arena, leftTy, *rightTy).result; + else + disj.context[def] = leftTy; + } + + for (auto [def, rightTy] : right.context) + { + if (!left.find(def).has_value()) + disj.context[def] = rightTy; + } + + return disj; + } + + static NonStrictContext conjunction( + NotNull builtins, + NotNull arena, + const NonStrictContext& left, + const NonStrictContext& right + ) + { + NonStrictContext conj{}; + + for (auto [def, leftTy] : left.context) + { + if (std::optional rightTy = right.find(def)) + conj.context[def] = simplifyIntersection(builtins, arena, leftTy, *rightTy).result; + } + + return conj; + } + + // Returns true if the removal was successful + bool remove(const DefId& def) + { + std::vector defs; + collectOperands(def, &defs); + bool result = true; + for (DefId def : defs) + result = result && context.erase(def.get()) == 1; + return result; + } + + std::optional find(const DefId& def) const + { + const Def* d = def.get(); + return find(d); + } + + void addContext(const DefId& def, TypeId ty) + { + std::vector defs; + collectOperands(def, &defs); + for (DefId def : defs) + context[def.get()] = ty; + } + +private: + std::optional find(const Def* d) const + { + auto it = context.find(d); + if (it != context.end()) + return {it->second}; + return {}; + } + + std::unordered_map context; +}; + +struct NonStrictTypeChecker +{ + + NotNull builtinTypes; + const NotNull ice; + NotNull arena; + Module* module; + Normalizer normalizer; + Subtyping subtyping; + NotNull dfg; + DenseHashSet noTypeFunctionErrors{nullptr}; + std::vector> stack; + DenseHashMap cachedNegations{nullptr}; + + const NotNull limits; + + NonStrictTypeChecker( + NotNull arena, + NotNull builtinTypes, + const NotNull ice, + NotNull unifierState, + NotNull dfg, + NotNull limits, + Module* module + ) + : builtinTypes(builtinTypes) + , ice(ice) + , arena(arena) + , module(module) + , normalizer{arena, builtinTypes, unifierState, /* cache inhabitance */ true} + , subtyping{builtinTypes, arena, NotNull(&normalizer), ice, NotNull{module->getModuleScope().get()}} + , dfg(dfg) + , limits(limits) + { + } + + std::optional pushStack(AstNode* node) + { + if (Scope** scope = module->astScopes.find(node)) + return StackPusher{stack, *scope}; + else + return std::nullopt; + } + + TypeId flattenPack(TypePackId pack) + { + pack = follow(pack); + + if (auto fst = first(pack, /*ignoreHiddenVariadics*/ false)) + return *fst; + else if (auto ftp = get(pack)) + { + TypeId result = arena->addType(FreeType{ftp->scope}); + TypePackId freeTail = arena->addTypePack(FreeTypePack{ftp->scope}); + + TypePack* resultPack = emplaceTypePack(asMutable(pack)); + resultPack->head.assign(1, result); + resultPack->tail = freeTail; + + return result; + } + else if (get(pack)) + return builtinTypes->errorRecoveryType(); + else if (finite(pack) && size(pack) == 0) + return builtinTypes->nilType; // `(f())` where `f()` returns no values is coerced into `nil` + else + ice->ice("flattenPack got a weird pack!"); + } + + + TypeId checkForTypeFunctionInhabitance(TypeId instance, Location location) + { + if (noTypeFunctionErrors.find(instance)) + return instance; + + ErrorVec errors = + reduceTypeFunctions(instance, location, TypeFunctionContext{arena, builtinTypes, stack.back(), NotNull{&normalizer}, ice, limits}, true) + .errors; + + if (errors.empty()) + noTypeFunctionErrors.insert(instance); + // TODO?? + // if (!isErrorSuppressing(location, instance)) + // reportErrors(std::move(errors)); + return instance; + } + + + TypeId lookupType(AstExpr* expr) + { + TypeId* ty = module->astTypes.find(expr); + if (ty) + return checkForTypeFunctionInhabitance(follow(*ty), expr->location); + + TypePackId* tp = module->astTypePacks.find(expr); + if (tp) + return checkForTypeFunctionInhabitance(flattenPack(*tp), expr->location); + return builtinTypes->anyType; + } + + NonStrictContext visit(AstStat* stat) + { + auto pusher = pushStack(stat); + if (auto s = stat->as()) + return visit(s); + else if (auto s = stat->as()) + return visit(s); + else if (auto s = stat->as()) + return visit(s); + else if (auto s = stat->as()) + return visit(s); + else if (auto s = stat->as()) + return visit(s); + else if (auto s = stat->as()) + return visit(s); + else if (auto s = stat->as()) + return visit(s); + else if (auto s = stat->as()) + return visit(s); + else if (auto s = stat->as()) + return visit(s); + else if (auto s = stat->as()) + return visit(s); + else if (auto s = stat->as()) + return visit(s); + else if (auto s = stat->as()) + return visit(s); + else if (auto s = stat->as()) + return visit(s); + else if (auto s = stat->as()) + return visit(s); + else if (auto s = stat->as()) + return visit(s); + else if (auto s = stat->as()) + return visit(s); + else if (auto f = stat->as()) + return visit(f); + else if (auto s = stat->as()) + return visit(s); + else if (auto s = stat->as()) + return visit(s); + else if (auto s = stat->as()) + return visit(s); + else if (auto s = stat->as()) + return visit(s); + else + { + LUAU_ASSERT(!"NonStrictTypeChecker encountered an unknown statement type"); + ice->ice("NonStrictTypeChecker encountered an unknown statement type"); + } + } + + NonStrictContext visit(AstStatBlock* block) + { + auto StackPusher = pushStack(block); + NonStrictContext ctx; + + + for (auto it = block->body.rbegin(); it != block->body.rend(); it++) + { + AstStat* stat = *it; + if (AstStatLocal* local = stat->as()) + { + // Iterating in reverse order + // local x ; B generates the context of B without x + visit(local); + for (auto local : local->vars) + ctx.remove(dfg->getDef(local)); + } + else + ctx = NonStrictContext::disjunction(builtinTypes, arena, visit(stat), ctx); + } + return ctx; + } + + NonStrictContext visit(AstStatIf* ifStatement) + { + NonStrictContext condB = visit(ifStatement->condition); + NonStrictContext branchContext; + // If there is no else branch, don't bother generating warnings for the then branch - we can't prove there is an error + if (ifStatement->elsebody) + { + NonStrictContext thenBody = visit(ifStatement->thenbody); + NonStrictContext elseBody = visit(ifStatement->elsebody); + branchContext = NonStrictContext::conjunction(builtinTypes, arena, thenBody, elseBody); + } + return NonStrictContext::disjunction(builtinTypes, arena, condB, branchContext); + } + + NonStrictContext visit(AstStatWhile* whileStatement) + { + return {}; + } + + NonStrictContext visit(AstStatRepeat* repeatStatement) + { + return {}; + } + + NonStrictContext visit(AstStatBreak* breakStatement) + { + return {}; + } + + NonStrictContext visit(AstStatContinue* continueStatement) + { + return {}; + } + + NonStrictContext visit(AstStatReturn* returnStatement) + { + return {}; + } + + NonStrictContext visit(AstStatExpr* expr) + { + return visit(expr->expr); + } + + NonStrictContext visit(AstStatLocal* local) + { + for (AstExpr* rhs : local->values) + visit(rhs); + return {}; + } + + NonStrictContext visit(AstStatFor* forStatement) + { + return {}; + } + + NonStrictContext visit(AstStatForIn* forInStatement) + { + return {}; + } + + NonStrictContext visit(AstStatAssign* assign) + { + return {}; + } + + NonStrictContext visit(AstStatCompoundAssign* compoundAssign) + { + return {}; + } + + NonStrictContext visit(AstStatFunction* statFn) + { + return visit(statFn->func); + } + + NonStrictContext visit(AstStatLocalFunction* localFn) + { + return visit(localFn->func); + } + + NonStrictContext visit(AstStatTypeAlias* typeAlias) + { + return {}; + } + + NonStrictContext visit(AstStatTypeFunction* typeFunc) + { + reportError(GenericError{"This syntax is not supported"}, typeFunc->location); + return {}; + } + + NonStrictContext visit(AstStatDeclareFunction* declFn) + { + return {}; + } + + NonStrictContext visit(AstStatDeclareGlobal* declGlobal) + { + return {}; + } + + NonStrictContext visit(AstStatDeclareClass* declClass) + { + return {}; + } + + NonStrictContext visit(AstStatError* error) + { + return {}; + } + + NonStrictContext visit(AstExpr* expr) + { + auto pusher = pushStack(expr); + if (auto e = expr->as()) + return visit(e); + else if (auto e = expr->as()) + return visit(e); + else if (auto e = expr->as()) + return visit(e); + else if (auto e = expr->as()) + return visit(e); + else if (auto e = expr->as()) + return visit(e); + else if (auto e = expr->as()) + return visit(e); + else if (auto e = expr->as()) + return visit(e); + else if (auto e = expr->as()) + return visit(e); + else if (auto e = expr->as()) + return visit(e); + else if (auto e = expr->as()) + return visit(e); + else if (auto e = expr->as()) + return visit(e); + else if (auto e = expr->as()) + return visit(e); + else if (auto e = expr->as()) + return visit(e); + else if (auto e = expr->as()) + return visit(e); + else if (auto e = expr->as()) + return visit(e); + else if (auto e = expr->as()) + return visit(e); + else if (auto e = expr->as()) + return visit(e); + else if (auto e = expr->as()) + return visit(e); + else if (auto e = expr->as()) + return visit(e); + else + { + LUAU_ASSERT(!"NonStrictTypeChecker encountered an unknown expression type"); + ice->ice("NonStrictTypeChecker encountered an unknown expression type"); + } + } + + NonStrictContext visit(AstExprGroup* group) + { + return {}; + } + + NonStrictContext visit(AstExprConstantNil* expr) + { + return {}; + } + + NonStrictContext visit(AstExprConstantBool* expr) + { + return {}; + } + + NonStrictContext visit(AstExprConstantNumber* expr) + { + return {}; + } + + NonStrictContext visit(AstExprConstantString* expr) + { + return {}; + } + + NonStrictContext visit(AstExprLocal* local) + { + return {}; + } + + NonStrictContext visit(AstExprGlobal* global) + { + return {}; + } + + NonStrictContext visit(AstExprVarargs* global) + { + return {}; + } + + + NonStrictContext visit(AstExprCall* call) + { + NonStrictContext fresh{}; + TypeId* originalCallTy = module->astOriginalCallTypes.find(call); + if (!originalCallTy) + return fresh; + + TypeId fnTy = *originalCallTy; + if (auto fn = get(follow(fnTy))) + { + if (fn->isCheckedFunction) + { + // We know fn is a checked function, which means it looks like: + // (S1, ... SN) -> T & + // (~S1, unknown^N-1) -> error & + // (unknown, ~S2, unknown^N-2) -> error + // ... + // ... + // (unknown^N-1, ~S_N) -> error + std::vector argTypes; + argTypes.reserve(call->args.size); + // Pad out the arg types array with the types you would expect to see + TypePackIterator curr = begin(fn->argTypes); + TypePackIterator fin = end(fn->argTypes); + while (curr != fin) + { + argTypes.push_back(*curr); + ++curr; + } + if (auto argTail = curr.tail()) + { + if (const VariadicTypePack* vtp = get(follow(*argTail))) + { + while (argTypes.size() < call->args.size) + { + argTypes.push_back(vtp->ty); + } + } + } + + std::string functionName = getFunctionNameAsString(*call->func).value_or(""); + if (call->args.size > argTypes.size()) + { + // We are passing more arguments than we expect, so we should error + reportError(CheckedFunctionIncorrectArgs{functionName, argTypes.size(), call->args.size}, call->location); + return fresh; + } + + for (size_t i = 0; i < call->args.size; i++) + { + // For example, if the arg is "hi" + // The actual arg type is string + // The expected arg type is number + // The type of the argument in the overload is ~number + // We will compare arg and ~number + AstExpr* arg = call->args.data[i]; + TypeId expectedArgType = argTypes[i]; + std::shared_ptr norm = normalizer.normalize(expectedArgType); + DefId def = dfg->getDef(arg); + TypeId runTimeErrorTy; + // If we're dealing with any, negating any will cause all subtype tests to fail, since ~any is any + // However, when someone calls this function, they're going to want to be able to pass it anything, + // for that reason, we manually inject never into the context so that the runtime test will always pass. + if (!norm) + reportError(NormalizationTooComplex{}, arg->location); + + if (norm && get(norm->tops)) + runTimeErrorTy = builtinTypes->neverType; + else + runTimeErrorTy = getOrCreateNegation(expectedArgType); + fresh.addContext(def, runTimeErrorTy); + } + + // Populate the context and now iterate through each of the arguments to the call to find out if we satisfy the types + for (size_t i = 0; i < call->args.size; i++) + { + AstExpr* arg = call->args.data[i]; + if (auto runTimeFailureType = willRunTimeError(arg, fresh)) + reportError(CheckedFunctionCallError{argTypes[i], *runTimeFailureType, functionName, i}, arg->location); + } + + if (call->args.size < argTypes.size()) + { + // We are passing fewer arguments than we expect + // so we need to ensure that the rest of the args are optional. + bool remainingArgsOptional = true; + for (size_t i = call->args.size; i < argTypes.size(); i++) + remainingArgsOptional = remainingArgsOptional && isOptional(argTypes[i]); + if (!remainingArgsOptional) + { + reportError(CheckedFunctionIncorrectArgs{functionName, argTypes.size(), call->args.size}, call->location); + return fresh; + } + } + } + } + + return fresh; + } + + NonStrictContext visit(AstExprIndexName* indexName) + { + return {}; + } + + NonStrictContext visit(AstExprIndexExpr* indexExpr) + { + return {}; + } + + NonStrictContext visit(AstExprFunction* exprFn) + { + // TODO: should a function being used as an expression generate a context without the arguments? + auto pusher = pushStack(exprFn); + NonStrictContext remainder = visit(exprFn->body); + for (AstLocal* local : exprFn->args) + { + if (std::optional ty = willRunTimeErrorFunctionDefinition(local, remainder)) + reportError(NonStrictFunctionDefinitionError{exprFn->debugname.value, local->name.value, *ty}, local->location); + remainder.remove(dfg->getDef(local)); + } + return remainder; + } + + NonStrictContext visit(AstExprTable* table) + { + return {}; + } + + NonStrictContext visit(AstExprUnary* unary) + { + return {}; + } + + NonStrictContext visit(AstExprBinary* binary) + { + return {}; + } + + NonStrictContext visit(AstExprTypeAssertion* typeAssertion) + { + return {}; + } + + NonStrictContext visit(AstExprIfElse* ifElse) + { + NonStrictContext condB = visit(ifElse->condition); + NonStrictContext thenB = visit(ifElse->trueExpr); + NonStrictContext elseB = visit(ifElse->falseExpr); + return NonStrictContext::disjunction(builtinTypes, arena, condB, NonStrictContext::conjunction(builtinTypes, arena, thenB, elseB)); + } + + NonStrictContext visit(AstExprInterpString* interpString) + { + return {}; + } + + NonStrictContext visit(AstExprError* error) + { + return {}; + } + + void reportError(TypeErrorData data, const Location& location) + { + module->errors.emplace_back(location, module->name, std::move(data)); + // TODO: weave in logger here? + } + + // If this fragment of the ast will run time error, return the type that causes this + std::optional willRunTimeError(AstExpr* fragment, const NonStrictContext& context) + { + DefId def = dfg->getDef(fragment); + std::vector defs; + collectOperands(def, &defs); + for (DefId def : defs) + { + if (std::optional contextTy = context.find(def)) + { + + TypeId actualType = lookupType(fragment); + SubtypingResult r = subtyping.isSubtype(actualType, *contextTy); + if (r.normalizationTooComplex) + reportError(NormalizationTooComplex{}, fragment->location); + if (r.isSubtype) + return {actualType}; + } + } + + return {}; + } + + std::optional willRunTimeErrorFunctionDefinition(AstLocal* fragment, const NonStrictContext& context) + { + DefId def = dfg->getDef(fragment); + std::vector defs; + collectOperands(def, &defs); + for (DefId def : defs) + { + if (std::optional contextTy = context.find(def)) + { + SubtypingResult r1 = subtyping.isSubtype(builtinTypes->unknownType, *contextTy); + SubtypingResult r2 = subtyping.isSubtype(*contextTy, builtinTypes->unknownType); + if (r1.normalizationTooComplex || r2.normalizationTooComplex) + reportError(NormalizationTooComplex{}, fragment->location); + bool isUnknown = r1.isSubtype && r2.isSubtype; + if (isUnknown) + return {builtinTypes->unknownType}; + } + } + return {}; + } + +private: + TypeId getOrCreateNegation(TypeId baseType) + { + TypeId& cachedResult = cachedNegations[baseType]; + if (!cachedResult) + cachedResult = arena->addType(NegationType{baseType}); + return cachedResult; + }; +}; + +void checkNonStrict( + NotNull builtinTypes, + NotNull ice, + NotNull unifierState, + NotNull dfg, + NotNull limits, + const SourceModule& sourceModule, + Module* module +) +{ + LUAU_TIMETRACE_SCOPE("checkNonStrict", "Typechecking"); + + NonStrictTypeChecker typeChecker{NotNull{&module->internalTypes}, builtinTypes, ice, unifierState, dfg, limits, module}; + typeChecker.visit(sourceModule.root); + unfreeze(module->interfaceTypes); + copyErrors(module->errors, module->interfaceTypes, builtinTypes); + freeze(module->interfaceTypes); +} + +} // namespace Luau diff --git a/third_party/luau/Analysis/src/Normalize.cpp b/third_party/luau/Analysis/src/Normalize.cpp index 29f8b2e6..088b74b1 100644 --- a/third_party/luau/Analysis/src/Normalize.cpp +++ b/third_party/luau/Analysis/src/Normalize.cpp @@ -8,31 +8,68 @@ #include "Luau/Clone.h" #include "Luau/Common.h" #include "Luau/RecursionCounter.h" +#include "Luau/Set.h" +#include "Luau/Simplify.h" +#include "Luau/Subtyping.h" #include "Luau/Type.h" +#include "Luau/TypeFwd.h" #include "Luau/Unifier.h" -LUAU_FASTFLAGVARIABLE(DebugLuauCopyBeforeNormalizing, false) LUAU_FASTFLAGVARIABLE(DebugLuauCheckNormalizeInvariant, false) +LUAU_FASTFLAGVARIABLE(LuauNormalizeAwayUninhabitableTables, false) +LUAU_FASTFLAGVARIABLE(LuauNormalizeNotUnknownIntersection, false); +LUAU_FASTFLAGVARIABLE(LuauFixReduceStackPressure, false); +LUAU_FASTFLAGVARIABLE(LuauFixCyclicTablesBlowingStack, false); // This could theoretically be 2000 on amd64, but x86 requires this. LUAU_FASTINTVARIABLE(LuauNormalizeIterationLimit, 1200); LUAU_FASTINTVARIABLE(LuauNormalizeCacheLimit, 100000); -LUAU_FASTFLAGVARIABLE(LuauNegatedClassTypes, false); -LUAU_FASTFLAGVARIABLE(LuauNegatedTableTypes, false); -LUAU_FASTFLAGVARIABLE(LuauNormalizeBlockedTypes, false); -LUAU_FASTFLAGVARIABLE(LuauNormalizeMetatableFixes, false); -LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) -LUAU_FASTFLAG(LuauUninhabitedSubAnything2) -LUAU_FASTFLAG(LuauTransitiveSubtyping) +LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); + +static bool fixReduceStackPressure() +{ + return FFlag::LuauFixReduceStackPressure || FFlag::DebugLuauDeferredConstraintResolution; +} + +static bool fixCyclicTablesBlowingStack() +{ + return FFlag::LuauFixCyclicTablesBlowingStack || FFlag::DebugLuauDeferredConstraintResolution; +} namespace Luau { + +// helper to make `FFlag::LuauNormalizeAwayUninhabitableTables` not explicitly required when DCR is enabled. +static bool normalizeAwayUninhabitableTables() +{ + return FFlag::LuauNormalizeAwayUninhabitableTables || FFlag::DebugLuauDeferredConstraintResolution; +} + +static bool shouldEarlyExit(NormalizationResult res) +{ + // if res is hit limits, return control flow + if (res == NormalizationResult::HitLimits || res == NormalizationResult::False) + return true; + return false; +} + +TypeIds::TypeIds(std::initializer_list tys) +{ + for (TypeId ty : tys) + insert(ty); +} + void TypeIds::insert(TypeId ty) { ty = follow(ty); - auto [_, fresh] = types.insert(ty); - if (fresh) + + // get a reference to the slot for `ty` in `types` + bool& entry = types[ty]; + + // if `ty` is fresh, we can set it to `true`, add it to the order and hash and be done. + if (!entry) { + entry = true; order.push_back(ty); hash ^= std::hash{}(ty); } @@ -73,25 +110,35 @@ TypeIds::const_iterator TypeIds::end() const TypeIds::iterator TypeIds::erase(TypeIds::const_iterator it) { TypeId ty = *it; - types.erase(ty); + types[ty] = false; hash ^= std::hash{}(ty); return order.erase(it); } +void TypeIds::erase(TypeId ty) +{ + const_iterator it = std::find(order.begin(), order.end(), ty); + if (it == order.end()) + return; + + erase(it); +} + size_t TypeIds::size() const { - return types.size(); + return order.size(); } bool TypeIds::empty() const { - return types.empty(); + return order.empty(); } size_t TypeIds::count(TypeId ty) const { ty = follow(ty); - return types.count(ty); + const bool* val = types.find(ty); + return (val && *val) ? 1 : 0; } void TypeIds::retain(const TypeIds& there) @@ -110,9 +157,44 @@ size_t TypeIds::getHash() const return hash; } +bool TypeIds::isNever() const +{ + return std::all_of( + begin(), + end(), + [&](TypeId i) + { + // If each typeid is never, then I guess typeid's is also never? + return get(i) != nullptr; + } + ); +} + bool TypeIds::operator==(const TypeIds& there) const { - return hash == there.hash && types == there.types; + // we can early return if the hashes don't match. + if (hash != there.hash) + return false; + + // we have to check equality of the sets themselves if not. + + // if the sets are unequal sizes, then they cannot possibly be equal. + // it is important to use `order` here and not `types` since the mappings + // may have different sizes since removal is not possible, and so erase + // simply writes `false` into the map. + if (order.size() != there.order.size()) + return false; + + // otherwise, we'll need to check that every element we have here is in `there`. + for (auto ty : order) + { + // if it's not, we'll return `false` + if (there.count(ty) == 0) + return false; + } + + // otherwise, we've proven the two equal! + return true; } NormalizedStringType::NormalizedStringType() {} @@ -171,7 +253,7 @@ const NormalizedStringType NormalizedStringType::never; bool isSubtype(const NormalizedStringType& subStr, const NormalizedStringType& superStr) { - if (subStr.isUnion() && superStr.isUnion()) + if (subStr.isUnion() && (superStr.isUnion() && !superStr.isNever())) { for (auto [name, ty] : subStr.singletons) { @@ -187,8 +269,10 @@ bool isSubtype(const NormalizedStringType& subStr, const NormalizedStringType& s void NormalizedClassType::pushPair(TypeId ty, TypeIds negations) { - ordering.push_back(ty); - classes.insert(std::make_pair(ty, std::move(negations))); + auto result = classes.insert(std::make_pair(ty, std::move(negations))); + if (result.second) + ordering.push_back(ty); + LUAU_ASSERT(ordering.size() == classes.size()); } void NormalizedClassType::resetToNever() @@ -227,75 +311,241 @@ NormalizedType::NormalizedType(NotNull builtinTypes) , numbers(builtinTypes->neverType) , strings{NormalizedStringType::never} , threads(builtinTypes->neverType) + , buffers(builtinTypes->neverType) { } -static bool isShallowInhabited(const NormalizedType& norm) +bool NormalizedType::isUnknown() const { - bool inhabitedClasses; + if (get(tops)) + return true; - if (FFlag::LuauNegatedClassTypes) - inhabitedClasses = !norm.classes.isNever(); - else - inhabitedClasses = !norm.DEPRECATED_classes.empty(); + // Otherwise, we can still be unknown! + bool hasAllPrimitives = isPrim(booleans, PrimitiveType::Boolean) && isPrim(nils, PrimitiveType::NilType) && isNumber(numbers) && + strings.isString() && isPrim(threads, PrimitiveType::Thread) && isThread(threads); + + // Check is class + bool isTopClass = false; + for (auto [t, disj] : classes.classes) + { + if (auto ct = get(t)) + { + if (ct->name == "class" && disj.empty()) + { + isTopClass = true; + break; + } + } + } + // Check is table + bool isTopTable = false; + for (auto t : tables) + { + if (isPrim(t, PrimitiveType::Table)) + { + isTopTable = true; + break; + } + } + // any = unknown or error ==> we need to make sure we have all the unknown components, but not errors + return get(errors) && hasAllPrimitives && isTopClass && isTopTable && functions.isTop; +} + +bool NormalizedType::isExactlyNumber() const +{ + return hasNumbers() && !hasTops() && !hasBooleans() && !hasClasses() && !hasErrors() && !hasNils() && !hasStrings() && !hasThreads() && + !hasBuffers() && !hasTables() && !hasFunctions() && !hasTyvars(); +} + +bool NormalizedType::isSubtypeOfString() const +{ + return hasStrings() && !hasTops() && !hasBooleans() && !hasClasses() && !hasErrors() && !hasNils() && !hasNumbers() && !hasThreads() && + !hasBuffers() && !hasTables() && !hasFunctions() && !hasTyvars(); +} +bool NormalizedType::isSubtypeOfBooleans() const +{ + return hasBooleans() && !hasTops() && !hasClasses() && !hasErrors() && !hasNils() && !hasNumbers() && !hasStrings() && !hasThreads() && + !hasBuffers() && !hasTables() && !hasFunctions() && !hasTyvars(); +} + +bool NormalizedType::shouldSuppressErrors() const +{ + return hasErrors() || get(tops); +} + +bool NormalizedType::hasTopTable() const +{ + return hasTables() && std::any_of( + tables.begin(), + tables.end(), + [&](TypeId ty) + { + auto primTy = get(ty); + return primTy && primTy->type == PrimitiveType::Type::Table; + } + ); +} + +bool NormalizedType::hasTops() const +{ + return !get(tops); +} + + +bool NormalizedType::hasBooleans() const +{ + return !get(booleans); +} + +bool NormalizedType::hasClasses() const +{ + return !classes.isNever(); +} + +bool NormalizedType::hasErrors() const +{ + return !get(errors); +} + +bool NormalizedType::hasNils() const +{ + return !get(nils); +} + +bool NormalizedType::hasNumbers() const +{ + return !get(numbers); +} + +bool NormalizedType::hasStrings() const +{ + return !strings.isNever(); +} + +bool NormalizedType::hasThreads() const +{ + return !get(threads); +} + +bool NormalizedType::hasBuffers() const +{ + return !get(buffers); +} + +bool NormalizedType::hasTables() const +{ + return !tables.isNever(); +} + +bool NormalizedType::hasFunctions() const +{ + return !functions.isNever(); +} + +bool NormalizedType::hasTyvars() const +{ + return !tyvars.empty(); +} + +bool NormalizedType::isFalsy() const +{ + + bool hasAFalse = false; + if (auto singleton = get(booleans)) + { + if (auto bs = singleton->variant.get_if()) + hasAFalse = !bs->value; + } + + return (hasAFalse || hasNils()) && (!hasTops() && !hasClasses() && !hasErrors() && !hasNumbers() && !hasStrings() && !hasThreads() && + !hasBuffers() && !hasTables() && !hasFunctions() && !hasTyvars()); +} + +bool NormalizedType::isTruthy() const +{ + return !isFalsy(); +} + +static bool isShallowInhabited(const NormalizedType& norm) +{ // This test is just a shallow check, for example it returns `true` for `{ p : never }` - return !get(norm.tops) || !get(norm.booleans) || inhabitedClasses || !get(norm.errors) || + return !get(norm.tops) || !get(norm.booleans) || !norm.classes.isNever() || !get(norm.errors) || !get(norm.nils) || !get(norm.numbers) || !norm.strings.isNever() || !get(norm.threads) || - !norm.functions.isNever() || !norm.tables.empty() || !norm.tyvars.empty(); + !get(norm.buffers) || !norm.functions.isNever() || !norm.tables.empty() || !norm.tyvars.empty(); } -bool isInhabited_DEPRECATED(const NormalizedType& norm) +NormalizationResult Normalizer::isInhabited(const NormalizedType* norm) { - LUAU_ASSERT(!FFlag::LuauUninhabitedSubAnything2); - return isShallowInhabited(norm); + Set seen{nullptr}; + + return isInhabited(norm, seen); } -bool Normalizer::isInhabited(const NormalizedType* norm, std::unordered_set seen) +NormalizationResult Normalizer::isInhabited(const NormalizedType* norm, Set& seen) { - // If normalization failed, the type is complex, and so is more likely than not to be inhabited. - if (!norm) - return true; - - bool inhabitedClasses; - if (FFlag::LuauNegatedClassTypes) - inhabitedClasses = !norm->classes.isNever(); - else - inhabitedClasses = !norm->DEPRECATED_classes.empty(); + RecursionCounter _rc(&sharedState->counters.recursionCount); + if (!withinResourceLimits() || !norm) + return NormalizationResult::HitLimits; if (!get(norm->tops) || !get(norm->booleans) || !get(norm->errors) || !get(norm->nils) || - !get(norm->numbers) || !get(norm->threads) || inhabitedClasses || !norm->strings.isNever() || - !norm->functions.isNever()) - return true; + !get(norm->numbers) || !get(norm->threads) || !get(norm->buffers) || !norm->classes.isNever() || + !norm->strings.isNever() || !norm->functions.isNever()) + return NormalizationResult::True; for (const auto& [_, intersect] : norm->tyvars) { - if (isInhabited(intersect.get(), seen)) - return true; + NormalizationResult res = isInhabited(intersect.get(), seen); + if (res != NormalizationResult::False) + return res; } for (TypeId table : norm->tables) { - if (isInhabited(table, seen)) - return true; + NormalizationResult res = isInhabited(table, seen); + if (res != NormalizationResult::False) + return res; } - return false; + return NormalizationResult::False; +} + +NormalizationResult Normalizer::isInhabited(TypeId ty) +{ + if (cacheInhabitance) + { + if (bool* result = cachedIsInhabited.find(ty)) + return *result ? NormalizationResult::True : NormalizationResult::False; + } + + Set seen{nullptr}; + NormalizationResult result = isInhabited(ty, seen); + + if (cacheInhabitance && result == NormalizationResult::True) + cachedIsInhabited[ty] = true; + else if (cacheInhabitance && result == NormalizationResult::False) + cachedIsInhabited[ty] = false; + + return result; } -bool Normalizer::isInhabited(TypeId ty, std::unordered_set seen) +NormalizationResult Normalizer::isInhabited(TypeId ty, Set& seen) { + RecursionCounter _rc(&sharedState->counters.recursionCount); + if (!withinResourceLimits()) + return NormalizationResult::HitLimits; + // TODO: use log.follow(ty), CLI-64291 ty = follow(ty); if (get(ty)) - return false; + return NormalizationResult::False; if (!get(ty) && !get(ty) && !get(ty) && !get(ty)) - return true; + return NormalizationResult::True; if (seen.count(ty)) - return true; + return NormalizationResult::True; seen.insert(ty); @@ -303,31 +553,75 @@ bool Normalizer::isInhabited(TypeId ty, std::unordered_set seen) { for (const auto& [_, prop] : ttv->props) { - if (!isInhabited(prop.type(), seen)) - return false; + if (FFlag::DebugLuauDeferredConstraintResolution) + { + // A table enclosing a read property whose type is uninhabitable is also itself uninhabitable, + // but not its write property. That just means the write property doesn't exist, and so is readonly. + if (auto ty = prop.readTy) + { + NormalizationResult res = isInhabited(*ty, seen); + if (res != NormalizationResult::True) + return res; + } + } + else + { + NormalizationResult res = isInhabited(prop.type(), seen); + if (res != NormalizationResult::True) + return res; + } } - return true; + return NormalizationResult::True; } if (const MetatableType* mtv = get(ty)) - return isInhabited(mtv->table, seen) && isInhabited(mtv->metatable, seen); + { + NormalizationResult res = isInhabited(mtv->table, seen); + if (res != NormalizationResult::True) + return res; + return isInhabited(mtv->metatable, seen); + } - const NormalizedType* norm = normalize(ty); - return isInhabited(norm, seen); + std::shared_ptr norm = normalize(ty); + return isInhabited(norm.get(), seen); +} + +NormalizationResult Normalizer::isIntersectionInhabited(TypeId left, TypeId right) +{ + Set seen{nullptr}; + return isIntersectionInhabited(left, right, seen); } -bool Normalizer::isIntersectionInhabited(TypeId left, TypeId right) +NormalizationResult Normalizer::isIntersectionInhabited(TypeId left, TypeId right, Set& seenSet) { left = follow(left); right = follow(right); - std::unordered_set seen = {}; - seen.insert(left); - seen.insert(right); + // We're asking if intersection is inahbited between left and right but we've already seen them .... + + if (cacheInhabitance) + { + if (bool* result = cachedIsInhabitedIntersection.find({left, right})) + return *result ? NormalizationResult::True : NormalizationResult::False; + } NormalizedType norm{builtinTypes}; - if (!normalizeIntersections({left, right}, norm)) - return false; - return isInhabited(&norm, seen); + NormalizationResult res = normalizeIntersections({left, right}, norm, seenSet); + if (res != NormalizationResult::True) + { + if (cacheInhabitance && res == NormalizationResult::False) + cachedIsInhabitedIntersection[{left, right}] = false; + + return res; + } + + NormalizationResult result = isInhabited(&norm, seenSet); + + if (cacheInhabitance && result == NormalizationResult::True) + cachedIsInhabitedIntersection[{left, right}] = true; + else if (cacheInhabitance && result == NormalizationResult::False) + cachedIsInhabitedIntersection[{left, right}] = false; + + return result; } static int tyvarIndex(TypeId ty) @@ -445,6 +739,16 @@ static bool isNormalizedThread(TypeId ty) return false; } +static bool isNormalizedBuffer(TypeId ty) +{ + if (get(ty)) + return true; + else if (const PrimitiveType* ptv = get(ty)) + return ptv->type == PrimitiveType::Buffer; + else + return false; +} + static bool areNormalizedFunctions(const NormalizedFunctionType& tys) { for (TypeId ty : tys.parts) @@ -466,7 +770,7 @@ static bool areNormalizedTables(const TypeIds& tys) if (!pt) return false; - if (pt->type == PrimitiveType::Table && FFlag::LuauNegatedTableTypes) + if (pt->type == PrimitiveType::Table) continue; return false; @@ -475,14 +779,6 @@ static bool areNormalizedTables(const TypeIds& tys) return true; } -static bool areNormalizedClasses(const TypeIds& tys) -{ - for (TypeId ty : tys) - if (!get(ty)) - return false; - return true; -} - static bool areNormalizedClasses(const NormalizedClassType& tys) { for (const auto& [ty, negations] : tys.classes) @@ -520,7 +816,8 @@ static bool areNormalizedClasses(const NormalizedClassType& tys) if (isSubclass(ctv, octv)) { - auto iss = [ctv](TypeId t) { + auto iss = [ctv](TypeId t) + { const ClassType* c = get(t); if (!c) return false; @@ -539,7 +836,7 @@ static bool areNormalizedClasses(const NormalizedClassType& tys) static bool isPlainTyvar(TypeId ty) { - return (get(ty) || get(ty) || (FFlag::LuauNormalizeBlockedTypes && get(ty)) || get(ty)); + return (get(ty) || get(ty) || get(ty) || get(ty) || get(ty)); } static bool isNormalizedTyvar(const NormalizedTyvars& tyvars) @@ -567,13 +864,13 @@ static void assertInvariant(const NormalizedType& norm) LUAU_ASSERT(isNormalizedTop(norm.tops)); LUAU_ASSERT(isNormalizedBoolean(norm.booleans)); - LUAU_ASSERT(areNormalizedClasses(norm.DEPRECATED_classes)); LUAU_ASSERT(areNormalizedClasses(norm.classes)); LUAU_ASSERT(isNormalizedError(norm.errors)); LUAU_ASSERT(isNormalizedNil(norm.nils)); LUAU_ASSERT(isNormalizedNumber(norm.numbers)); LUAU_ASSERT(isNormalizedString(norm.strings)); LUAU_ASSERT(isNormalizedThread(norm.threads)); + LUAU_ASSERT(isNormalizedBuffer(norm.buffers)); LUAU_ASSERT(areNormalizedFunctions(norm.functions)); LUAU_ASSERT(areNormalizedTables(norm.tables)); LUAU_ASSERT(isNormalizedTyvar(norm.tyvars)); @@ -582,32 +879,102 @@ static void assertInvariant(const NormalizedType& norm) #endif } -Normalizer::Normalizer(TypeArena* arena, NotNull builtinTypes, NotNull sharedState) +Normalizer::Normalizer(TypeArena* arena, NotNull builtinTypes, NotNull sharedState, bool cacheInhabitance) : arena(arena) , builtinTypes(builtinTypes) , sharedState(sharedState) + , cacheInhabitance(cacheInhabitance) +{ +} + +static bool isCacheable(TypeId ty, Set& seen); + +static bool isCacheable(TypePackId tp, Set& seen) +{ + tp = follow(tp); + + auto it = begin(tp); + auto endIt = end(tp); + for (; it != endIt; ++it) + { + if (!isCacheable(*it, seen)) + return false; + } + + if (auto tail = it.tail()) + { + if (get(*tail) || get(*tail) || get(*tail)) + return false; + } + + return true; +} + +static bool isCacheable(TypeId ty, Set& seen) +{ + if (seen.contains(ty)) + return true; + seen.insert(ty); + + ty = follow(ty); + + if (get(ty) || get(ty) || get(ty)) + return false; + + if (auto tfi = get(ty)) + { + for (TypeId t : tfi->typeArguments) + { + if (!isCacheable(t, seen)) + return false; + } + + for (TypePackId tp : tfi->packArguments) + { + if (!isCacheable(tp, seen)) + return false; + } + } + + return true; +} + +static bool isCacheable(TypeId ty) { + Set seen{nullptr}; + return isCacheable(ty, seen); } -const NormalizedType* Normalizer::normalize(TypeId ty) +std::shared_ptr Normalizer::normalize(TypeId ty) { if (!arena) sharedState->iceHandler->ice("Normalizing types outside a module"); auto found = cachedNormals.find(ty); if (found != cachedNormals.end()) - return found->second.get(); + return found->second; NormalizedType norm{builtinTypes}; - if (!unionNormalWithTy(norm, ty)) + Set seenSetTypes{nullptr}; + NormalizationResult res = unionNormalWithTy(norm, ty, seenSetTypes); + if (res != NormalizationResult::True) return nullptr; - std::unique_ptr uniq = std::make_unique(std::move(norm)); - const NormalizedType* result = uniq.get(); - cachedNormals[ty] = std::move(uniq); - return result; + + if (norm.isUnknown()) + { + clearNormal(norm); + norm.tops = builtinTypes->unknownType; + } + + std::shared_ptr shared = std::make_shared(std::move(norm)); + + if (shared->isCacheable) + cachedNormals[ty] = shared; + + return shared; } -bool Normalizer::normalizeIntersections(const std::vector& intersections, NormalizedType& outType) +NormalizationResult Normalizer::normalizeIntersections(const std::vector& intersections, NormalizedType& outType, Set& seenSet) { if (!arena) sharedState->iceHandler->ice("Normalizing types outside a module"); @@ -615,13 +982,17 @@ bool Normalizer::normalizeIntersections(const std::vector& intersections norm.tops = builtinTypes->anyType; // Now we need to intersect the two types for (auto ty : intersections) - if (!intersectNormalWithTy(norm, ty)) - return false; + { + NormalizationResult res = intersectNormalWithTy(norm, ty, seenSet); + if (res != NormalizationResult::True) + return res; + } - if (!unionNormals(outType, norm)) - return false; + NormalizationResult res = unionNormals(outType, norm); + if (res != NormalizationResult::True) + return res; - return true; + return NormalizationResult::True; } void Normalizer::clearNormal(NormalizedType& norm) @@ -629,12 +1000,12 @@ void Normalizer::clearNormal(NormalizedType& norm) norm.tops = builtinTypes->neverType; norm.booleans = builtinTypes->neverType; norm.classes.resetToNever(); - norm.DEPRECATED_classes.clear(); norm.errors = builtinTypes->neverType; norm.nils = builtinTypes->neverType; norm.numbers = builtinTypes->neverType; norm.strings.resetToNever(); norm.threads = builtinTypes->neverType; + norm.buffers = builtinTypes->neverType; norm.tables.clear(); norm.functions.resetToNever(); norm.tyvars.clear(); @@ -1056,8 +1427,9 @@ std::optional Normalizer::unionOfTypePacks(TypePackId here, TypePack itt++; } - auto dealWithDifferentArities = [&](TypePackIterator& ith, TypePackIterator itt, TypePackId here, TypePackId there, bool& hereSubThere, - bool& thereSubHere) { + auto dealWithDifferentArities = + [&](TypePackIterator& ith, TypePackIterator itt, TypePackId here, TypePackId there, bool& hereSubThere, bool& thereSubHere) + { if (ith != end(here)) { TypeId tty = builtinTypes->nilType; @@ -1246,6 +1618,11 @@ void Normalizer::unionFunctionsWithFunction(NormalizedFunctionType& heres, TypeI void Normalizer::unionTablesWithTable(TypeIds& heres, TypeId there) { // TODO: remove unions of tables where possible + + // we can always skip `never` + if (normalizeAwayUninhabitableTables() && get(there)) + return; + heres.insert(there); } @@ -1253,18 +1630,11 @@ void Normalizer::unionTables(TypeIds& heres, const TypeIds& theres) { for (TypeId there : theres) { - if (FFlag::LuauNegatedTableTypes) + if (there == builtinTypes->tableType) { - if (there == builtinTypes->tableType) - { - heres.clear(); - heres.insert(there); - return; - } - else - { - unionTablesWithTable(heres, there); - } + heres.clear(); + heres.insert(there); + return; } else { @@ -1292,16 +1662,18 @@ void Normalizer::unionTables(TypeIds& heres, const TypeIds& theres) // // And yes, this is essentially a SAT solver hidden inside a typechecker. // That's what you get for having a type system with generics, intersection and union types. -bool Normalizer::unionNormals(NormalizedType& here, const NormalizedType& there, int ignoreSmallerTyvars) +NormalizationResult Normalizer::unionNormals(NormalizedType& here, const NormalizedType& there, int ignoreSmallerTyvars) { + here.isCacheable &= there.isCacheable; + TypeId tops = unionOfTops(here.tops, there.tops); - if (FFlag::LuauTransitiveSubtyping && get(tops) && (get(here.errors) || get(there.errors))) + if (get(tops) && (get(here.errors) || get(there.errors))) tops = builtinTypes->anyType; if (!get(tops)) { clearNormal(here); here.tops = tops; - return true; + return NormalizationResult::True; } for (auto it = there.tyvars.begin(); it != there.tyvars.end(); it++) @@ -1313,26 +1685,29 @@ bool Normalizer::unionNormals(NormalizedType& here, const NormalizedType& there, continue; auto [emplaced, fresh] = here.tyvars.emplace(tyvar, std::make_unique(NormalizedType{builtinTypes})); if (fresh) - if (!unionNormals(*emplaced->second, here, index)) - return false; - if (!unionNormals(*emplaced->second, inter, index)) - return false; + { + NormalizationResult res = unionNormals(*emplaced->second, here, index); + if (res != NormalizationResult::True) + return res; + } + + NormalizationResult res = unionNormals(*emplaced->second, inter, index); + if (res != NormalizationResult::True) + return res; } here.booleans = unionOfBools(here.booleans, there.booleans); - if (FFlag::LuauNegatedClassTypes) - unionClasses(here.classes, there.classes); - else - unionClasses(here.DEPRECATED_classes, there.DEPRECATED_classes); + unionClasses(here.classes, there.classes); here.errors = (get(there.errors) ? here.errors : there.errors); here.nils = (get(there.nils) ? here.nils : there.nils); here.numbers = (get(there.numbers) ? here.numbers : there.numbers); unionStrings(here.strings, there.strings); here.threads = (get(there.threads) ? here.threads : there.threads); + here.buffers = (get(there.buffers) ? here.buffers : there.buffers); unionFunctions(here.functions, there.functions); unionTables(here.tables, there.tables); - return true; + return NormalizationResult::True; } bool Normalizer::withinResourceLimits() @@ -1340,7 +1715,8 @@ bool Normalizer::withinResourceLimits() // If cache is too large, clear it if (FInt::LuauNormalizeCacheLimit > 0) { - size_t cacheUsage = cachedNormals.size() + cachedIntersections.size() + cachedUnions.size() + cachedTypeIds.size(); + size_t cacheUsage = cachedNormals.size() + cachedIntersections.size() + cachedUnions.size() + cachedTypeIds.size() + + cachedIsInhabited.size() + cachedIsInhabitedIntersection.size(); if (cacheUsage > size_t(FInt::LuauNormalizeCacheLimit)) { clearCaches(); @@ -1356,74 +1732,105 @@ bool Normalizer::withinResourceLimits() return true; } -// See above for an explaination of `ignoreSmallerTyvars`. -bool Normalizer::unionNormalWithTy(NormalizedType& here, TypeId there, int ignoreSmallerTyvars) +NormalizationResult Normalizer::intersectNormalWithNegationTy(TypeId toNegate, NormalizedType& intersect) +{ + + std::optional negated; + + std::shared_ptr normal = normalize(toNegate); + negated = negateNormal(*normal); + + if (!negated) + return NormalizationResult::False; + intersectNormals(intersect, *negated); + return NormalizationResult::True; +} + +// See above for an explaination of `ignoreSmallerTyvars`. +NormalizationResult Normalizer::unionNormalWithTy(NormalizedType& here, TypeId there, Set& seenSetTypes, int ignoreSmallerTyvars) { RecursionCounter _rc(&sharedState->counters.recursionCount); if (!withinResourceLimits()) - return false; + return NormalizationResult::HitLimits; there = follow(there); + if (get(there) || get(there)) { TypeId tops = unionOfTops(here.tops, there); - if (FFlag::LuauTransitiveSubtyping && get(tops) && get(here.errors)) + if (get(tops) && get(here.errors)) tops = builtinTypes->anyType; clearNormal(here); here.tops = tops; - return true; + return NormalizationResult::True; } - else if (!FFlag::LuauTransitiveSubtyping && (get(there) || !get(here.tops))) - return true; - else if (FFlag::LuauTransitiveSubtyping && (get(there) || get(here.tops))) - return true; - else if (FFlag::LuauTransitiveSubtyping && get(there) && get(here.tops)) + else if (get(there) || get(here.tops)) + return NormalizationResult::True; + else if (get(there) && get(here.tops)) { here.tops = builtinTypes->anyType; - return true; + return NormalizationResult::True; } else if (const UnionType* utv = get(there)) { + if (seenSetTypes.count(there)) + return NormalizationResult::True; + seenSetTypes.insert(there); + for (UnionTypeIterator it = begin(utv); it != end(utv); ++it) - if (!unionNormalWithTy(here, *it)) - return false; - return true; + { + NormalizationResult res = unionNormalWithTy(here, *it, seenSetTypes); + if (res != NormalizationResult::True) + { + seenSetTypes.erase(there); + return res; + } + } + + seenSetTypes.erase(there); + return NormalizationResult::True; } else if (const IntersectionType* itv = get(there)) { + if (seenSetTypes.count(there)) + return NormalizationResult::True; + seenSetTypes.insert(there); + NormalizedType norm{builtinTypes}; norm.tops = builtinTypes->anyType; for (IntersectionTypeIterator it = begin(itv); it != end(itv); ++it) - if (!intersectNormalWithTy(norm, *it)) - return false; + { + NormalizationResult res = intersectNormalWithTy(norm, *it, seenSetTypes); + if (res != NormalizationResult::True) + { + seenSetTypes.erase(there); + return res; + } + } + + seenSetTypes.erase(there); + return unionNormals(here, norm); } - else if (FFlag::LuauTransitiveSubtyping && get(here.tops)) - return true; - else if (get(there) || get(there) || (FFlag::LuauNormalizeBlockedTypes && get(there)) || - get(there)) + else if (get(here.tops)) + return NormalizationResult::True; + else if (get(there) || get(there) || get(there) || get(there) || get(there)) { if (tyvarIndex(there) <= ignoreSmallerTyvars) - return true; + return NormalizationResult::True; NormalizedType inter{builtinTypes}; inter.tops = builtinTypes->unknownType; here.tyvars.insert_or_assign(there, std::make_unique(std::move(inter))); + + if (!isCacheable(there)) + here.isCacheable = false; } else if (get(there)) unionFunctionsWithFunction(here.functions, there); else if (get(there) || get(there)) unionTablesWithTable(here.tables, there); else if (get(there)) - { - if (FFlag::LuauNegatedClassTypes) - { - unionClassesWithClass(here.classes, there); - } - else - { - unionClassesWithClass(here.DEPRECATED_classes, there); - } - } + unionClassesWithClass(here.classes, there); else if (get(there)) here.errors = there; else if (const PrimitiveType* ptv = get(there)) @@ -1438,11 +1845,13 @@ bool Normalizer::unionNormalWithTy(NormalizedType& here, TypeId there, int ignor here.strings.resetToString(); else if (ptv->type == PrimitiveType::Thread) here.threads = there; + else if (ptv->type == PrimitiveType::Buffer) + here.buffers = there; else if (ptv->type == PrimitiveType::Function) { here.functions.resetToTop(); } - else if (ptv->type == PrimitiveType::Table && FFlag::LuauNegatedTableTypes) + else if (ptv->type == PrimitiveType::Table) { here.tables.clear(); here.tables.insert(there); @@ -1470,17 +1879,19 @@ bool Normalizer::unionNormalWithTy(NormalizedType& here, TypeId there, int ignor } else if (const NegationType* ntv = get(there)) { - const NormalizedType* thereNormal = normalize(ntv->ty); - std::optional tn = negateNormal(*thereNormal); + std::optional tn; + + std::shared_ptr thereNormal = normalize(ntv->ty); + tn = negateNormal(*thereNormal); + if (!tn) - return false; + return NormalizationResult::False; - if (!unionNormals(here, *tn)) - return false; + NormalizationResult res = unionNormals(here, *tn); + if (res != NormalizationResult::True) + return res; } - else if (!FFlag::LuauNormalizeBlockedTypes && get(there)) - LUAU_ASSERT(!"Internal error: Trying to normalize a BlockedType"); - else if (get(there)) + else if (get(there) || get(there)) { // nothing } @@ -1488,11 +1899,14 @@ bool Normalizer::unionNormalWithTy(NormalizedType& here, TypeId there, int ignor LUAU_ASSERT(!"Unreachable"); for (auto& [tyvar, intersect] : here.tyvars) - if (!unionNormalWithTy(*intersect, there, tyvarIndex(tyvar))) - return false; + { + NormalizationResult res = unionNormalWithTy(*intersect, there, seenSetTypes, tyvarIndex(tyvar)); + if (res != NormalizationResult::True) + return res; + } assertInvariant(here); - return true; + return NormalizationResult::True; } // ------- Negations @@ -1500,6 +1914,8 @@ bool Normalizer::unionNormalWithTy(NormalizedType& here, TypeId there, int ignor std::optional Normalizer::negateNormal(const NormalizedType& here) { NormalizedType result{builtinTypes}; + result.isCacheable = here.isCacheable; + if (!get(here.tops)) { // The negation of unknown or any is never. Easy. @@ -1527,36 +1943,29 @@ std::optional Normalizer::negateNormal(const NormalizedType& her result.booleans = builtinTypes->trueType; } - if (FFlag::LuauNegatedClassTypes) + if (here.classes.isNever()) { - if (here.classes.isNever()) - { - resetToTop(builtinTypes, result.classes); - } - else if (isTop(builtinTypes, result.classes)) - { - result.classes.resetToNever(); - } - else - { - TypeIds rootNegations{}; - - for (const auto& [hereParent, hereNegations] : here.classes.classes) - { - if (hereParent != builtinTypes->classType) - rootNegations.insert(hereParent); - - for (TypeId hereNegation : hereNegations) - unionClassesWithClass(result.classes, hereNegation); - } - - if (!rootNegations.empty()) - result.classes.pushPair(builtinTypes->classType, rootNegations); - } + resetToTop(builtinTypes, result.classes); + } + else if (isTop(builtinTypes, result.classes)) + { + result.classes.resetToNever(); } else { - result.DEPRECATED_classes = negateAll(here.DEPRECATED_classes); + TypeIds rootNegations{}; + + for (const auto& [hereParent, hereNegations] : here.classes.classes) + { + if (hereParent != builtinTypes->classType) + rootNegations.insert(hereParent); + + for (TypeId hereNegation : hereNegations) + unionClassesWithClass(result.classes, hereNegation); + } + + if (!rootNegations.empty()) + result.classes.pushPair(builtinTypes->classType, rootNegations); } result.nils = get(here.nils) ? builtinTypes->nilType : builtinTypes->neverType; @@ -1566,6 +1975,7 @@ std::optional Normalizer::negateNormal(const NormalizedType& her result.strings.isCofinite = !result.strings.isCofinite; result.threads = get(here.threads) ? builtinTypes->threadType : builtinTypes->neverType; + result.buffers = get(here.buffers) ? builtinTypes->bufferType : builtinTypes->neverType; /* * Things get weird and so, so complicated if we allow negations of @@ -1584,15 +1994,12 @@ std::optional Normalizer::negateNormal(const NormalizedType& her * types are not runtime-testable. Thus, we prohibit negation of anything * other than `table` and `never`. */ - if (FFlag::LuauNegatedTableTypes) - { - if (here.tables.empty()) - result.tables.insert(builtinTypes->tableType); - else if (here.tables.size() == 1 && here.tables.front() == builtinTypes->tableType) - result.tables.clear(); - else - return std::nullopt; - } + if (here.tables.empty()) + result.tables.insert(builtinTypes->tableType); + else if (here.tables.size() == 1 && here.tables.front() == builtinTypes->tableType) + result.tables.clear(); + else + return std::nullopt; // TODO: negating tables // TODO: negating tyvars? @@ -1658,11 +2065,13 @@ void Normalizer::subtractPrimitive(NormalizedType& here, TypeId ty) case PrimitiveType::Thread: here.threads = builtinTypes->neverType; break; + case PrimitiveType::Buffer: + here.buffers = builtinTypes->neverType; + break; case PrimitiveType::Function: here.functions.resetToNever(); break; case PrimitiveType::Table: - LUAU_ASSERT(FFlag::LuauNegatedTableTypes); here.tables.clear(); break; } @@ -1734,64 +2143,6 @@ TypeId Normalizer::intersectionOfBools(TypeId here, TypeId there) return there; } -void Normalizer::DEPRECATED_intersectClasses(TypeIds& heres, const TypeIds& theres) -{ - TypeIds tmp; - for (auto it = heres.begin(); it != heres.end();) - { - const ClassType* hctv = get(*it); - LUAU_ASSERT(hctv); - bool keep = false; - for (TypeId there : theres) - { - const ClassType* tctv = get(there); - LUAU_ASSERT(tctv); - if (isSubclass(hctv, tctv)) - { - keep = true; - break; - } - else if (isSubclass(tctv, hctv)) - { - keep = false; - tmp.insert(there); - break; - } - } - if (keep) - it++; - else - it = heres.erase(it); - } - heres.insert(tmp.begin(), tmp.end()); -} - -void Normalizer::DEPRECATED_intersectClassesWithClass(TypeIds& heres, TypeId there) -{ - bool foundSuper = false; - const ClassType* tctv = get(there); - LUAU_ASSERT(tctv); - for (auto it = heres.begin(); it != heres.end();) - { - const ClassType* hctv = get(*it); - LUAU_ASSERT(hctv); - if (isSubclass(hctv, tctv)) - it++; - else if (isSubclass(tctv, hctv)) - { - foundSuper = true; - break; - } - else - it = heres.erase(it); - } - if (foundSuper) - { - heres.clear(); - heres.insert(there); - } -} - void Normalizer::intersectClasses(NormalizedClassType& heres, const NormalizedClassType& theres) { if (theres.isNever()) @@ -1946,18 +2297,68 @@ void Normalizer::intersectClassesWithClass(NormalizedClassType& heres, TypeId th void Normalizer::intersectStrings(NormalizedStringType& here, const NormalizedStringType& there) { + /* There are 9 cases to worry about here + Normalized Left | Normalized Right + C1 string | string ===> trivial + C2 string - {u_1,..} | string ===> trivial + C3 {u_1, ..} | string ===> trivial + C4 string | string - {v_1, ..} ===> string - {v_1, ..} + C5 string - {u_1,..} | string - {v_1, ..} ===> string - ({u_s} U {v_s}) + C6 {u_1, ..} | string - {v_1, ..} ===> {u_s} - {v_s} + C7 string | {v_1, ..} ===> {v_s} + C8 string - {u_1,..} | {v_1, ..} ===> {v_s} - {u_s} + C9 {u_1, ..} | {v_1, ..} ===> {u_s} ∩ {v_s} + */ + // Case 1,2,3 if (there.isString()) return; - if (here.isString()) - here.resetToNever(); - - for (auto it = here.singletons.begin(); it != here.singletons.end();) + // Case 4, Case 7 + else if (here.isString()) { - if (there.singletons.count(it->first)) - it++; - else - it = here.singletons.erase(it); + here.singletons.clear(); + for (const auto& [key, type] : there.singletons) + here.singletons[key] = type; + here.isCofinite = here.isCofinite && there.isCofinite; + } + // Case 5 + else if (here.isIntersection() && there.isIntersection()) + { + here.isCofinite = true; + for (const auto& [key, type] : there.singletons) + here.singletons[key] = type; } + // Case 6 + else if (here.isUnion() && there.isIntersection()) + { + here.isCofinite = false; + for (const auto& [key, _] : there.singletons) + here.singletons.erase(key); + } + // Case 8 + else if (here.isIntersection() && there.isUnion()) + { + here.isCofinite = false; + std::map result(there.singletons); + for (const auto& [key, _] : here.singletons) + result.erase(key); + here.singletons = result; + } + // Case 9 + else if (here.isUnion() && there.isUnion()) + { + here.isCofinite = false; + std::map result; + result.insert(here.singletons.begin(), here.singletons.end()); + result.insert(there.singletons.begin(), there.singletons.end()); + for (auto it = result.begin(); it != result.end();) + if (!here.singletons.count(it->first) || !there.singletons.count(it->first)) + it = result.erase(it); + else + ++it; + here.singletons = result; + } + else + LUAU_ASSERT(0 && "Internal Error - unrecognized case"); } std::optional Normalizer::intersectionOfTypePacks(TypePackId here, TypePackId there) @@ -1988,8 +2389,9 @@ std::optional Normalizer::intersectionOfTypePacks(TypePackId here, T itt++; } - auto dealWithDifferentArities = [&](TypePackIterator& ith, TypePackIterator itt, TypePackId here, TypePackId there, bool& hereSubThere, - bool& thereSubHere) { + auto dealWithDifferentArities = + [&](TypePackIterator& ith, TypePackIterator itt, TypePackId here, TypePackId there, bool& hereSubThere, bool& thereSubHere) + { if (ith != end(here)) { TypeId tty = builtinTypes->nilType; @@ -2080,7 +2482,7 @@ std::optional Normalizer::intersectionOfTypePacks(TypePackId here, T return arena->addTypePack({}); } -std::optional Normalizer::intersectionOfTables(TypeId here, TypeId there) +std::optional Normalizer::intersectionOfTables(TypeId here, TypeId there, Set& seenSet) { if (here == there) return here; @@ -2094,50 +2496,37 @@ std::optional Normalizer::intersectionOfTables(TypeId here, TypeId there else if (isPrim(there, PrimitiveType::Table)) return here; - if (FFlag::LuauNormalizeMetatableFixes) - { - if (get(here)) - return there; - else if (get(there)) - return here; - else if (get(here)) - return there; - else if (get(there)) - return here; - } + if (get(here)) + return there; + else if (get(there)) + return here; + else if (get(here)) + return there; + else if (get(there)) + return here; TypeId htable = here; TypeId hmtable = nullptr; if (const MetatableType* hmtv = get(here)) { - htable = hmtv->table; - hmtable = hmtv->metatable; + htable = follow(hmtv->table); + hmtable = follow(hmtv->metatable); } TypeId ttable = there; TypeId tmtable = nullptr; if (const MetatableType* tmtv = get(there)) { - ttable = tmtv->table; - tmtable = tmtv->metatable; + ttable = follow(tmtv->table); + tmtable = follow(tmtv->metatable); } const TableType* httv = get(htable); - if (FFlag::LuauNormalizeMetatableFixes) - { - if (!httv) - return std::nullopt; - } - else - LUAU_ASSERT(httv); + if (!httv) + return std::nullopt; const TableType* tttv = get(ttable); - if (FFlag::LuauNormalizeMetatableFixes) - { - if (!tttv) - return std::nullopt; - } - else - LUAU_ASSERT(tttv); + if (!tttv) + return std::nullopt; if (httv->state == TableState::Free || tttv->state == TableState::Free) @@ -2150,8 +2539,9 @@ std::optional Normalizer::intersectionOfTables(TypeId here, TypeId there state = tttv->state; TypeLevel level = max(httv->level, tttv->level); - TableType result{state, level}; + Scope* scope = max(httv->scope, tttv->scope); + std::unique_ptr result = nullptr; bool hereSubThere = true; bool thereSubHere = true; @@ -2165,19 +2555,114 @@ std::optional Normalizer::intersectionOfTables(TypeId here, TypeId there { const auto& [_name, tprop] = *tfound; // TODO: variance issues here, which can't be fixed until we have read/write property types - prop.setType(intersectionType(hprop.type(), tprop.type())); - hereSubThere &= (prop.type() == hprop.type()); - thereSubHere &= (prop.type() == tprop.type()); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + if (hprop.readTy.has_value()) + { + if (tprop.readTy.has_value()) + { + // if the intersection of the read types of a property is uninhabited, the whole table is `never`. + if (fixReduceStackPressure()) + { + // We've seen these table prop elements before and we're about to ask if their intersection + // is inhabited + if (fixCyclicTablesBlowingStack()) + { + if (seenSet.contains(*hprop.readTy) && seenSet.contains(*tprop.readTy)) + { + seenSet.erase(*hprop.readTy); + seenSet.erase(*tprop.readTy); + return {builtinTypes->neverType}; + } + else + { + seenSet.insert(*hprop.readTy); + seenSet.insert(*tprop.readTy); + } + } + + NormalizationResult res = isIntersectionInhabited(*hprop.readTy, *tprop.readTy); + + // Cleanup + if (fixCyclicTablesBlowingStack()) + { + seenSet.erase(*hprop.readTy); + seenSet.erase(*tprop.readTy); + } + + if (normalizeAwayUninhabitableTables() && NormalizationResult::True != res) + return {builtinTypes->neverType}; + } + else + { + if (normalizeAwayUninhabitableTables() && + NormalizationResult::False == isIntersectionInhabited(*hprop.readTy, *tprop.readTy)) + return {builtinTypes->neverType}; + } + + TypeId ty = simplifyIntersection(builtinTypes, NotNull{arena}, *hprop.readTy, *tprop.readTy).result; + prop.readTy = ty; + hereSubThere &= (ty == hprop.readTy); + thereSubHere &= (ty == tprop.readTy); + } + else + { + prop.readTy = *hprop.readTy; + thereSubHere = false; + } + } + else if (tprop.readTy.has_value()) + { + prop.readTy = *tprop.readTy; + hereSubThere = false; + } + + if (hprop.writeTy.has_value()) + { + if (tprop.writeTy.has_value()) + { + prop.writeTy = simplifyIntersection(builtinTypes, NotNull{arena}, *hprop.writeTy, *tprop.writeTy).result; + hereSubThere &= (prop.writeTy == hprop.writeTy); + thereSubHere &= (prop.writeTy == tprop.writeTy); + } + else + { + prop.writeTy = *hprop.writeTy; + thereSubHere = false; + } + } + else if (tprop.writeTy.has_value()) + { + prop.writeTy = *tprop.writeTy; + hereSubThere = false; + } + } + else + { + prop.setType(intersectionType(hprop.type(), tprop.type())); + hereSubThere &= (prop.type() == hprop.type()); + thereSubHere &= (prop.type() == tprop.type()); + } } + // TODO: string indexers - result.props[name] = prop; + + if (prop.readTy || prop.writeTy) + { + if (!result.get()) + result = std::make_unique(TableType{state, level, scope}); + result->props[name] = prop; + } } for (const auto& [name, tprop] : tttv->props) { if (httv->props.count(name) == 0) { - result.props[name] = tprop; + if (!result.get()) + result = std::make_unique(TableType{state, level, scope}); + + result->props[name] = tprop; hereSubThere = false; } } @@ -2187,18 +2672,24 @@ std::optional Normalizer::intersectionOfTables(TypeId here, TypeId there // TODO: What should intersection of indexes be? TypeId index = unionType(httv->indexer->indexType, tttv->indexer->indexType); TypeId indexResult = intersectionType(httv->indexer->indexResultType, tttv->indexer->indexResultType); - result.indexer = {index, indexResult}; + if (!result.get()) + result = std::make_unique(TableType{state, level, scope}); + result->indexer = {index, indexResult}; hereSubThere &= (httv->indexer->indexType == index) && (httv->indexer->indexResultType == indexResult); thereSubHere &= (tttv->indexer->indexType == index) && (tttv->indexer->indexResultType == indexResult); } else if (httv->indexer) { - result.indexer = httv->indexer; + if (!result.get()) + result = std::make_unique(TableType{state, level, scope}); + result->indexer = httv->indexer; thereSubHere = false; } else if (tttv->indexer) { - result.indexer = tttv->indexer; + if (!result.get()) + result = std::make_unique(TableType{state, level, scope}); + result->indexer = tttv->indexer; hereSubThere = false; } @@ -2208,12 +2699,17 @@ std::optional Normalizer::intersectionOfTables(TypeId here, TypeId there else if (thereSubHere) table = ttable; else - table = arena->addType(std::move(result)); + { + if (result.get()) + table = arena->addType(std::move(*result)); + else + table = arena->addType(TableType{state, level, scope}); + } if (tmtable && hmtable) { // NOTE: this assumes metatables are ivariant - if (std::optional mtable = intersectionOfTables(hmtable, tmtable)) + if (std::optional mtable = intersectionOfTables(hmtable, tmtable, seenSet)) { if (table == htable && *mtable == hmtable) return here; @@ -2243,12 +2739,14 @@ std::optional Normalizer::intersectionOfTables(TypeId here, TypeId there return table; } -void Normalizer::intersectTablesWithTable(TypeIds& heres, TypeId there) +void Normalizer::intersectTablesWithTable(TypeIds& heres, TypeId there, Set& seenSetTypes) { TypeIds tmp; for (TypeId here : heres) - if (std::optional inter = intersectionOfTables(here, there)) + { + if (std::optional inter = intersectionOfTables(here, there, seenSetTypes)) tmp.insert(*inter); + } heres.retain(tmp); heres.insert(tmp.begin(), tmp.end()); } @@ -2257,9 +2755,15 @@ void Normalizer::intersectTables(TypeIds& heres, const TypeIds& theres) { TypeIds tmp; for (TypeId here : heres) + { for (TypeId there : theres) - if (std::optional inter = intersectionOfTables(here, there)) + { + Set seenSetTypes{nullptr}; + if (std::optional inter = intersectionOfTables(here, there, seenSetTypes)) tmp.insert(*inter); + } + } + heres.retain(tmp); heres.insert(tmp.begin(), tmp.end()); } @@ -2473,28 +2977,29 @@ void Normalizer::intersectFunctions(NormalizedFunctionType& heres, const Normali } } -bool Normalizer::intersectTyvarsWithTy(NormalizedTyvars& here, TypeId there) +NormalizationResult Normalizer::intersectTyvarsWithTy(NormalizedTyvars& here, TypeId there, Set& seenSetTypes) { for (auto it = here.begin(); it != here.end();) { NormalizedType& inter = *it->second; - if (!intersectNormalWithTy(inter, there)) - return false; + NormalizationResult res = intersectNormalWithTy(inter, there, seenSetTypes); + if (res != NormalizationResult::True) + return res; if (isShallowInhabited(inter)) ++it; else it = here.erase(it); } - return true; + return NormalizationResult::True; } // See above for an explaination of `ignoreSmallerTyvars`. -bool Normalizer::intersectNormals(NormalizedType& here, const NormalizedType& there, int ignoreSmallerTyvars) +NormalizationResult Normalizer::intersectNormals(NormalizedType& here, const NormalizedType& there, int ignoreSmallerTyvars) { if (!get(there.tops)) { here.tops = intersectionOfTops(here.tops, there.tops); - return true; + return NormalizationResult::True; } else if (!get(here.tops)) { @@ -2504,20 +3009,13 @@ bool Normalizer::intersectNormals(NormalizedType& here, const NormalizedType& th here.booleans = intersectionOfBools(here.booleans, there.booleans); - if (FFlag::LuauNegatedClassTypes) - { - intersectClasses(here.classes, there.classes); - } - else - { - DEPRECATED_intersectClasses(here.DEPRECATED_classes, there.DEPRECATED_classes); - } - + intersectClasses(here.classes, there.classes); here.errors = (get(there.errors) ? there.errors : here.errors); here.nils = (get(there.nils) ? there.nils : here.nils); here.numbers = (get(there.numbers) ? there.numbers : here.numbers); intersectStrings(here.strings, there.strings); here.threads = (get(there.threads) ? there.threads : here.threads); + here.buffers = (get(there.buffers) ? there.buffers : here.buffers); intersectFunctions(here.functions, there.functions); intersectTables(here.tables, there.tables); @@ -2529,8 +3027,9 @@ bool Normalizer::intersectNormals(NormalizedType& here, const NormalizedType& th auto [found, fresh] = here.tyvars.emplace(tyvar, std::make_unique(NormalizedType{builtinTypes})); if (fresh) { - if (!unionNormals(*found->second, here, index)) - return false; + NormalizationResult res = unionNormals(*found->second, here, index); + if (res != NormalizationResult::True) + return res; } } } @@ -2543,61 +3042,70 @@ bool Normalizer::intersectNormals(NormalizedType& here, const NormalizedType& th auto found = there.tyvars.find(tyvar); if (found == there.tyvars.end()) { - if (!intersectNormals(inter, there, index)) - return false; + NormalizationResult res = intersectNormals(inter, there, index); + if (res != NormalizationResult::True) + return res; } else { - if (!intersectNormals(inter, *found->second, index)) - return false; + NormalizationResult res = intersectNormals(inter, *found->second, index); + if (res != NormalizationResult::True) + return res; } if (isShallowInhabited(inter)) it++; else it = here.tyvars.erase(it); } - return true; + return NormalizationResult::True; } -bool Normalizer::intersectNormalWithTy(NormalizedType& here, TypeId there) +NormalizationResult Normalizer::intersectNormalWithTy(NormalizedType& here, TypeId there, Set& seenSetTypes) { RecursionCounter _rc(&sharedState->counters.recursionCount); if (!withinResourceLimits()) - return false; + return NormalizationResult::HitLimits; there = follow(there); + if (get(there) || get(there)) { here.tops = intersectionOfTops(here.tops, there); - return true; + return NormalizationResult::True; } else if (!get(here.tops)) { clearNormal(here); - return unionNormalWithTy(here, there); + return unionNormalWithTy(here, there, seenSetTypes); } else if (const UnionType* utv = get(there)) { NormalizedType norm{builtinTypes}; for (UnionTypeIterator it = begin(utv); it != end(utv); ++it) - if (!unionNormalWithTy(norm, *it)) - return false; + { + NormalizationResult res = unionNormalWithTy(norm, *it, seenSetTypes); + if (res != NormalizationResult::True) + return res; + } return intersectNormals(here, norm); } else if (const IntersectionType* itv = get(there)) { for (IntersectionTypeIterator it = begin(itv); it != end(itv); ++it) - if (!intersectNormalWithTy(here, *it)) - return false; - return true; + { + NormalizationResult res = intersectNormalWithTy(here, *it, seenSetTypes); + if (res != NormalizationResult::True) + return res; + } + return NormalizationResult::True; } - else if (get(there) || get(there) || (FFlag::LuauNormalizeBlockedTypes && get(there)) || - get(there)) + else if (get(there) || get(there) || get(there) || get(there) || get(there)) { NormalizedType thereNorm{builtinTypes}; NormalizedType topNorm{builtinTypes}; topNorm.tops = builtinTypes->unknownType; thereNorm.tyvars.insert_or_assign(there, std::make_unique(std::move(topNorm))); + here.isCacheable = false; return intersectNormals(here, thereNorm); } @@ -2614,25 +3122,15 @@ bool Normalizer::intersectNormalWithTy(NormalizedType& here, TypeId there) { TypeIds tables = std::move(here.tables); clearNormal(here); - intersectTablesWithTable(tables, there); + intersectTablesWithTable(tables, there, seenSetTypes); here.tables = std::move(tables); } else if (get(there)) { - if (FFlag::LuauNegatedClassTypes) - { - NormalizedClassType nct = std::move(here.classes); - clearNormal(here); - intersectClassesWithClass(nct, there); - here.classes = std::move(nct); - } - else - { - TypeIds classes = std::move(here.DEPRECATED_classes); - clearNormal(here); - DEPRECATED_intersectClassesWithClass(classes, there); - here.DEPRECATED_classes = std::move(classes); - } + NormalizedClassType nct = std::move(here.classes); + clearNormal(here); + intersectClassesWithClass(nct, there); + here.classes = std::move(nct); } else if (get(there)) { @@ -2648,6 +3146,7 @@ bool Normalizer::intersectNormalWithTy(NormalizedType& here, TypeId there) NormalizedStringType strings = std::move(here.strings); NormalizedFunctionType functions = std::move(here.functions); TypeId threads = here.threads; + TypeId buffers = here.buffers; TypeIds tables = std::move(here.tables); clearNormal(here); @@ -2662,13 +3161,12 @@ bool Normalizer::intersectNormalWithTy(NormalizedType& here, TypeId there) here.strings = std::move(strings); else if (ptv->type == PrimitiveType::Thread) here.threads = threads; + else if (ptv->type == PrimitiveType::Buffer) + here.buffers = buffers; else if (ptv->type == PrimitiveType::Function) here.functions = std::move(functions); else if (ptv->type == PrimitiveType::Table) - { - LUAU_ASSERT(FFlag::LuauNegatedTableTypes); here.tables = std::move(tables); - } else LUAU_ASSERT(!"Unreachable"); } @@ -2696,33 +3194,42 @@ bool Normalizer::intersectNormalWithTy(NormalizedType& here, TypeId there) subtractPrimitive(here, ntv->ty); else if (const SingletonType* stv = get(t)) subtractSingleton(here, follow(ntv->ty)); - else if (get(t) && FFlag::LuauNegatedClassTypes) + else if (get(t)) { - const NormalizedType* normal = normalize(t); - std::optional negated = negateNormal(*normal); - if (!negated) - return false; - intersectNormals(here, *negated); + NormalizationResult res = intersectNormalWithNegationTy(t, here); + if (shouldEarlyExit(res)) + return res; } else if (const UnionType* itv = get(t)) { for (TypeId part : itv->options) { - const NormalizedType* normalPart = normalize(part); - std::optional negated = negateNormal(*normalPart); - if (!negated) - return false; - intersectNormals(here, *negated); + NormalizationResult res = intersectNormalWithNegationTy(part, here); + if (shouldEarlyExit(res)) + return res; } } else if (get(t)) { // HACK: Refinements sometimes intersect with ~any under the // assumption that it is the same as any. - return true; + return NormalizationResult::True; + } + else if (get(t)) + { + // if we're intersecting with `~never`, this is equivalent to intersecting with `unknown` + // this is a noop since an intersection with `unknown` is trivial. + return NormalizationResult::True; + } + else if ((FFlag::LuauNormalizeNotUnknownIntersection || FFlag::DebugLuauDeferredConstraintResolution) && get(t)) + { + // if we're intersecting with `~unknown`, this is equivalent to intersecting with `never` + // this means we should clear the type entirely. + clearNormal(here); + return NormalizationResult::True; } else if (auto nt = get(t)) - return intersectNormalWithTy(here, nt->ty); + return intersectNormalWithTy(here, nt->ty, seenSetTypes); else { // TODO negated unions, intersections, table, and function. @@ -2730,18 +3237,34 @@ bool Normalizer::intersectNormalWithTy(NormalizedType& here, TypeId there) LUAU_ASSERT(!"Unimplemented"); } } - else if (get(there) && FFlag::LuauNegatedClassTypes) + else if (get(there)) { here.classes.resetToNever(); } else LUAU_ASSERT(!"Unreachable"); - if (!intersectTyvarsWithTy(tyvars, there)) - return false; + NormalizationResult res = intersectTyvarsWithTy(tyvars, there, seenSetTypes); + if (res != NormalizationResult::True) + return res; here.tyvars = std::move(tyvars); - return true; + return NormalizationResult::True; +} + +void makeTableShared(TypeId ty) +{ + ty = follow(ty); + if (auto tableTy = getMutable(ty)) + { + for (auto& [_, prop] : tableTy->props) + prop.makeShared(); + } + else if (auto metatableTy = get(ty)) + { + makeTableShared(metatableTy->metatable); + makeTableShared(metatableTy->table); + } } // -------- Convert back from a normalized type to a type @@ -2756,53 +3279,46 @@ TypeId Normalizer::typeFromNormal(const NormalizedType& norm) if (!get(norm.booleans)) result.push_back(norm.booleans); - if (FFlag::LuauNegatedClassTypes) + if (isTop(builtinTypes, norm.classes)) { - if (isTop(builtinTypes, norm.classes)) - { - result.push_back(builtinTypes->classType); - } - else if (!norm.classes.isNever()) + result.push_back(builtinTypes->classType); + } + else if (!norm.classes.isNever()) + { + std::vector parts; + parts.reserve(norm.classes.classes.size()); + + for (const TypeId normTy : norm.classes.ordering) { - std::vector parts; - parts.reserve(norm.classes.classes.size()); + const TypeIds& normNegations = norm.classes.classes.at(normTy); - for (const TypeId normTy : norm.classes.ordering) + if (normNegations.empty()) { - const TypeIds& normNegations = norm.classes.classes.at(normTy); + parts.push_back(normTy); + } + else + { + std::vector intersection; + intersection.reserve(normNegations.size() + 1); - if (normNegations.empty()) + intersection.push_back(normTy); + for (TypeId negation : normNegations) { - parts.push_back(normTy); + intersection.push_back(arena->addType(NegationType{negation})); } - else - { - std::vector intersection; - intersection.reserve(normNegations.size() + 1); - - intersection.push_back(normTy); - for (TypeId negation : normNegations) - { - intersection.push_back(arena->addType(NegationType{negation})); - } - parts.push_back(arena->addType(IntersectionType{std::move(intersection)})); - } + parts.push_back(arena->addType(IntersectionType{std::move(intersection)})); } + } - if (parts.size() == 1) - { - result.push_back(parts.at(0)); - } - else if (parts.size() > 1) - { - result.push_back(arena->addType(UnionType{std::move(parts)})); - } + if (parts.size() == 1) + { + result.push_back(parts.at(0)); + } + else if (parts.size() > 1) + { + result.push_back(arena->addType(UnionType{std::move(parts)})); } - } - else - { - result.insert(result.end(), norm.DEPRECATED_classes.begin(), norm.DEPRECATED_classes.end()); } if (!get(norm.errors)) @@ -2842,8 +3358,21 @@ TypeId Normalizer::typeFromNormal(const NormalizedType& norm) } if (!get(norm.threads)) result.push_back(builtinTypes->threadType); + if (!get(norm.buffers)) + result.push_back(builtinTypes->bufferType); + + if (FFlag::DebugLuauDeferredConstraintResolution) + { + result.reserve(result.size() + norm.tables.size()); + for (auto table : norm.tables) + { + makeTableShared(table); + result.push_back(table); + } + } + else + result.insert(result.end(), norm.tables.begin(), norm.tables.end()); - result.insert(result.end(), norm.tables.begin(), norm.tables.end()); for (auto& [tyvar, intersect] : norm.tyvars) { if (get(intersect->tops)) @@ -2865,36 +3394,56 @@ TypeId Normalizer::typeFromNormal(const NormalizedType& norm) bool isSubtype(TypeId subTy, TypeId superTy, NotNull scope, NotNull builtinTypes, InternalErrorReporter& ice) { - if (!FFlag::LuauTransitiveSubtyping) - return isConsistentSubtype(subTy, superTy, scope, builtinTypes, ice); UnifierSharedState sharedState{&ice}; TypeArena arena; Normalizer normalizer{&arena, builtinTypes, NotNull{&sharedState}}; - Unifier u{NotNull{&normalizer}, Mode::Strict, scope, Location{}, Covariant}; - u.tryUnify(subTy, superTy); - return !u.failure; + // Subtyping under DCR is not implemented using unification! + if (FFlag::DebugLuauDeferredConstraintResolution) + { + Subtyping subtyping{builtinTypes, NotNull{&arena}, NotNull{&normalizer}, NotNull{&ice}, scope}; + + return subtyping.isSubtype(subTy, superTy).isSubtype; + } + else + { + Unifier u{NotNull{&normalizer}, scope, Location{}, Covariant}; + + u.tryUnify(subTy, superTy); + return !u.failure; + } } bool isSubtype(TypePackId subPack, TypePackId superPack, NotNull scope, NotNull builtinTypes, InternalErrorReporter& ice) { - if (!FFlag::LuauTransitiveSubtyping) - return isConsistentSubtype(subPack, superPack, scope, builtinTypes, ice); UnifierSharedState sharedState{&ice}; TypeArena arena; Normalizer normalizer{&arena, builtinTypes, NotNull{&sharedState}}; - Unifier u{NotNull{&normalizer}, Mode::Strict, scope, Location{}, Covariant}; - u.tryUnify(subPack, superPack); - return !u.failure; + // Subtyping under DCR is not implemented using unification! + if (FFlag::DebugLuauDeferredConstraintResolution) + { + Subtyping subtyping{builtinTypes, NotNull{&arena}, NotNull{&normalizer}, NotNull{&ice}, scope}; + + return subtyping.isSubtype(subPack, superPack).isSubtype; + } + else + { + Unifier u{NotNull{&normalizer}, scope, Location{}, Covariant}; + + u.tryUnify(subPack, superPack); + return !u.failure; + } } bool isConsistentSubtype(TypeId subTy, TypeId superTy, NotNull scope, NotNull builtinTypes, InternalErrorReporter& ice) { + LUAU_ASSERT(!FFlag::DebugLuauDeferredConstraintResolution); + UnifierSharedState sharedState{&ice}; TypeArena arena; Normalizer normalizer{&arena, builtinTypes, NotNull{&sharedState}}; - Unifier u{NotNull{&normalizer}, Mode::Strict, scope, Location{}, Covariant}; + Unifier u{NotNull{&normalizer}, scope, Location{}, Covariant}; u.tryUnify(subTy, superTy); const bool ok = u.errors.empty() && u.log.empty(); @@ -2902,12 +3451,19 @@ bool isConsistentSubtype(TypeId subTy, TypeId superTy, NotNull scope, Not } bool isConsistentSubtype( - TypePackId subPack, TypePackId superPack, NotNull scope, NotNull builtinTypes, InternalErrorReporter& ice) + TypePackId subPack, + TypePackId superPack, + NotNull scope, + NotNull builtinTypes, + InternalErrorReporter& ice +) { + LUAU_ASSERT(!FFlag::DebugLuauDeferredConstraintResolution); + UnifierSharedState sharedState{&ice}; TypeArena arena; Normalizer normalizer{&arena, builtinTypes, NotNull{&sharedState}}; - Unifier u{NotNull{&normalizer}, Mode::Strict, scope, Location{}, Covariant}; + Unifier u{NotNull{&normalizer}, scope, Location{}, Covariant}; u.tryUnify(subPack, superPack); const bool ok = u.errors.empty() && u.log.empty(); diff --git a/third_party/luau/Analysis/src/OverloadResolution.cpp b/third_party/luau/Analysis/src/OverloadResolution.cpp new file mode 100644 index 00000000..9aad142d --- /dev/null +++ b/third_party/luau/Analysis/src/OverloadResolution.cpp @@ -0,0 +1,472 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/OverloadResolution.h" + +#include "Luau/Instantiation2.h" +#include "Luau/Subtyping.h" +#include "Luau/TxnLog.h" +#include "Luau/Type.h" +#include "Luau/TypeFunction.h" +#include "Luau/TypePack.h" +#include "Luau/TypeUtils.h" +#include "Luau/Unifier2.h" + +namespace Luau +{ + +OverloadResolver::OverloadResolver( + NotNull builtinTypes, + NotNull arena, + NotNull normalizer, + NotNull scope, + NotNull reporter, + NotNull limits, + Location callLocation +) + : builtinTypes(builtinTypes) + , arena(arena) + , normalizer(normalizer) + , scope(scope) + , ice(reporter) + , limits(limits) + , subtyping({builtinTypes, arena, normalizer, ice, scope}) + , callLoc(callLocation) +{ +} + +std::pair OverloadResolver::selectOverload(TypeId ty, TypePackId argsPack) +{ + auto tryOne = [&](TypeId f) + { + if (auto ftv = get(f)) + { + Subtyping::Variance variance = subtyping.variance; + subtyping.variance = Subtyping::Variance::Contravariant; + SubtypingResult r = subtyping.isSubtype(argsPack, ftv->argTypes); + subtyping.variance = variance; + + if (r.isSubtype) + return true; + } + + return false; + }; + + TypeId t = follow(ty); + + if (tryOne(ty)) + return {Analysis::Ok, ty}; + + if (auto it = get(t)) + { + for (TypeId component : it) + { + if (tryOne(component)) + return {Analysis::Ok, component}; + } + } + + return {Analysis::OverloadIsNonviable, ty}; +} + +void OverloadResolver::resolve(TypeId fnTy, const TypePack* args, AstExpr* selfExpr, const std::vector* argExprs) +{ + fnTy = follow(fnTy); + + auto it = get(fnTy); + if (!it) + { + auto [analysis, errors] = checkOverload(fnTy, args, selfExpr, argExprs); + add(analysis, fnTy, std::move(errors)); + return; + } + + for (TypeId ty : it) + { + if (resolution.find(ty) != resolution.end()) + continue; + + auto [analysis, errors] = checkOverload(ty, args, selfExpr, argExprs); + add(analysis, ty, std::move(errors)); + } +} + +std::optional OverloadResolver::testIsSubtype(const Location& location, TypeId subTy, TypeId superTy) +{ + auto r = subtyping.isSubtype(subTy, superTy); + ErrorVec errors; + + if (r.normalizationTooComplex) + errors.emplace_back(location, NormalizationTooComplex{}); + + if (!r.isSubtype) + { + switch (shouldSuppressErrors(normalizer, subTy).orElse(shouldSuppressErrors(normalizer, superTy))) + { + case ErrorSuppression::Suppress: + break; + case ErrorSuppression::NormalizationFailed: + errors.emplace_back(location, NormalizationTooComplex{}); + // intentionally fallthrough here since we couldn't prove this was error-suppressing + case ErrorSuppression::DoNotSuppress: + errors.emplace_back(location, TypeMismatch{superTy, subTy}); + break; + } + } + + if (errors.empty()) + return std::nullopt; + + return errors; +} + +std::optional OverloadResolver::testIsSubtype(const Location& location, TypePackId subTy, TypePackId superTy) +{ + auto r = subtyping.isSubtype(subTy, superTy); + ErrorVec errors; + + if (r.normalizationTooComplex) + errors.emplace_back(location, NormalizationTooComplex{}); + + if (!r.isSubtype) + { + switch (shouldSuppressErrors(normalizer, subTy).orElse(shouldSuppressErrors(normalizer, superTy))) + { + case ErrorSuppression::Suppress: + break; + case ErrorSuppression::NormalizationFailed: + errors.emplace_back(location, NormalizationTooComplex{}); + // intentionally fallthrough here since we couldn't prove this was error-suppressing + case ErrorSuppression::DoNotSuppress: + errors.emplace_back(location, TypePackMismatch{superTy, subTy}); + break; + } + } + + if (errors.empty()) + return std::nullopt; + + return errors; +} + +std::pair OverloadResolver::checkOverload( + TypeId fnTy, + const TypePack* args, + AstExpr* fnLoc, + const std::vector* argExprs, + bool callMetamethodOk +) +{ + fnTy = follow(fnTy); + + ErrorVec discard; + if (get(fnTy) || get(fnTy) || get(fnTy)) + return {Ok, {}}; + else if (auto fn = get(fnTy)) + return checkOverload_(fnTy, fn, args, fnLoc, argExprs); // Intentionally split to reduce the stack pressure of this function. + else if (auto callMm = findMetatableEntry(builtinTypes, discard, fnTy, "__call", callLoc); callMm && callMetamethodOk) + { + // Calling a metamethod forwards the `fnTy` as self. + TypePack withSelf = *args; + withSelf.head.insert(withSelf.head.begin(), fnTy); + + std::vector withSelfExprs = *argExprs; + withSelfExprs.insert(withSelfExprs.begin(), fnLoc); + + return checkOverload(*callMm, &withSelf, fnLoc, &withSelfExprs, /*callMetamethodOk=*/false); + } + else + return {TypeIsNotAFunction, {}}; // Intentionally empty. We can just fabricate the type error later on. +} + +bool OverloadResolver::isLiteral(AstExpr* expr) +{ + if (auto group = expr->as()) + return isLiteral(group->expr); + else if (auto assertion = expr->as()) + return isLiteral(assertion->expr); + + return expr->is() || expr->is() || expr->is() || + expr->is() || expr->is() || expr->is(); +} + +std::pair OverloadResolver::checkOverload_( + TypeId fnTy, + const FunctionType* fn, + const TypePack* args, + AstExpr* fnExpr, + const std::vector* argExprs +) +{ + FunctionGraphReductionResult result = + reduceTypeFunctions(fnTy, callLoc, TypeFunctionContext{arena, builtinTypes, scope, normalizer, ice, limits}, /*force=*/true); + if (!result.errors.empty()) + return {OverloadIsNonviable, result.errors}; + + ErrorVec argumentErrors; + TypePackId typ = arena->addTypePack(*args); + + TypeId prospectiveFunction = arena->addType(FunctionType{typ, builtinTypes->anyTypePack}); + SubtypingResult sr = subtyping.isSubtype(fnTy, prospectiveFunction); + + if (sr.isSubtype) + return {Analysis::Ok, {}}; + + if (1 == sr.reasoning.size()) + { + const SubtypingReasoning& reason = *sr.reasoning.begin(); + + const TypePath::Path justArguments{TypePath::PackField::Arguments}; + + if (reason.subPath == justArguments && reason.superPath == justArguments) + { + // If the subtype test failed only due to an arity mismatch, + // it is still possible that this function call is okay. + // Subtype testing does not know anything about optional + // function arguments. + // + // This can only happen if the actual function call has a + // finite set of arguments which is too short for the + // function being called. If all of those unsatisfied + // function arguments are options, then this function call + // is ok. + + const size_t firstUnsatisfiedArgument = argExprs->size(); + const auto [requiredHead, _requiredTail] = flatten(fn->argTypes); + + // If too many arguments were supplied, this overload + // definitely does not match. + if (args->head.size() > requiredHead.size()) + { + auto [minParams, optMaxParams] = getParameterExtents(TxnLog::empty(), fn->argTypes); + TypeError error{fnExpr->location, CountMismatch{minParams, optMaxParams, args->head.size(), CountMismatch::Arg, false}}; + + return {Analysis::ArityMismatch, {error}}; + } + + // If any of the unsatisfied arguments are not supertypes of + // nil, then this overload does not match. + for (size_t i = firstUnsatisfiedArgument; i < requiredHead.size(); ++i) + { + if (!subtyping.isSubtype(builtinTypes->nilType, requiredHead[i]).isSubtype) + { + auto [minParams, optMaxParams] = getParameterExtents(TxnLog::empty(), fn->argTypes); + TypeError error{fnExpr->location, CountMismatch{minParams, optMaxParams, args->head.size(), CountMismatch::Arg, false}}; + + return {Analysis::ArityMismatch, {error}}; + } + } + + return {Analysis::Ok, {}}; + } + } + + ErrorVec errors; + + for (const SubtypingReasoning& reason : sr.reasoning) + { + /* The return type of our prospective function is always + * any... so any subtype failures here can only arise from + * argument type mismatches. + */ + + Location argLocation; + if (reason.superPath.components.size() <= 1) + break; + + if (const Luau::TypePath::Index* pathIndexComponent = get_if(&reason.superPath.components.at(1))) + { + size_t nthArgument = pathIndexComponent->index; + // if the nth type argument to the function is less than the number of ast expressions we passed to the function + // we should be able to pull out the location of the argument + // If the nth type argument to the function is out of range of the ast expressions we passed to the function + // e.g. table.pack(functionThatReturnsMultipleArguments(arg1, arg2, ....)), default to the location of the last passed expression + // If we passed no expression arguments to the call, default to the location of the function expression. + argLocation = nthArgument < argExprs->size() ? argExprs->at(nthArgument)->location + : argExprs->size() != 0 ? argExprs->back()->location + : fnExpr->location; + + std::optional failedSubTy = traverseForType(fnTy, reason.subPath, builtinTypes); + std::optional failedSuperTy = traverseForType(prospectiveFunction, reason.superPath, builtinTypes); + + if (failedSubTy && failedSuperTy) + { + + switch (shouldSuppressErrors(normalizer, *failedSubTy).orElse(shouldSuppressErrors(normalizer, *failedSuperTy))) + { + case ErrorSuppression::Suppress: + break; + case ErrorSuppression::NormalizationFailed: + errors.emplace_back(argLocation, NormalizationTooComplex{}); + // intentionally fallthrough here since we couldn't prove this was error-suppressing + case ErrorSuppression::DoNotSuppress: + // TODO extract location from the SubtypingResult path and argExprs + switch (reason.variance) + { + case SubtypingVariance::Covariant: + case SubtypingVariance::Contravariant: + errors.emplace_back(argLocation, TypeMismatch{*failedSubTy, *failedSuperTy, TypeMismatch::CovariantContext}); + break; + case SubtypingVariance::Invariant: + errors.emplace_back(argLocation, TypeMismatch{*failedSubTy, *failedSuperTy, TypeMismatch::InvariantContext}); + break; + default: + LUAU_ASSERT(0); + break; + } + } + } + } + + std::optional failedSubPack = traverseForPack(fnTy, reason.subPath, builtinTypes); + std::optional failedSuperPack = traverseForPack(prospectiveFunction, reason.superPath, builtinTypes); + + if (failedSubPack && failedSuperPack) + { + // If a bug in type inference occurs, we may have a mismatch in the return packs. + // This happens when inference incorrectly leaves the result type of a function free. + // If this happens, we don't want to explode, so we'll use the function's location. + if (argExprs->empty()) + argLocation = fnExpr->location; + else + argLocation = argExprs->at(argExprs->size() - 1)->location; + + // TODO extract location from the SubtypingResult path and argExprs + switch (reason.variance) + { + case SubtypingVariance::Covariant: + errors.emplace_back(argLocation, TypePackMismatch{*failedSubPack, *failedSuperPack}); + break; + case SubtypingVariance::Contravariant: + errors.emplace_back(argLocation, TypePackMismatch{*failedSuperPack, *failedSubPack}); + break; + case SubtypingVariance::Invariant: + errors.emplace_back(argLocation, TypePackMismatch{*failedSubPack, *failedSuperPack}); + break; + default: + LUAU_ASSERT(0); + break; + } + } + } + + return {Analysis::OverloadIsNonviable, std::move(errors)}; +} + +size_t OverloadResolver::indexof(Analysis analysis) +{ + switch (analysis) + { + case Ok: + return ok.size(); + case TypeIsNotAFunction: + return nonFunctions.size(); + case ArityMismatch: + return arityMismatches.size(); + case OverloadIsNonviable: + return nonviableOverloads.size(); + } + + ice->ice("Inexhaustive switch in FunctionCallResolver::indexof"); +} + +void OverloadResolver::add(Analysis analysis, TypeId ty, ErrorVec&& errors) +{ + resolution.insert(ty, {analysis, indexof(analysis)}); + + switch (analysis) + { + case Ok: + LUAU_ASSERT(errors.empty()); + ok.push_back(ty); + break; + case TypeIsNotAFunction: + LUAU_ASSERT(errors.empty()); + nonFunctions.push_back(ty); + break; + case ArityMismatch: + LUAU_ASSERT(!errors.empty()); + arityMismatches.emplace_back(ty, std::move(errors)); + break; + case OverloadIsNonviable: + nonviableOverloads.emplace_back(ty, std::move(errors)); + break; + } +} + +// we wrap calling the overload resolver in a separate function to reduce overall stack pressure in `solveFunctionCall`. +// this limits the lifetime of `OverloadResolver`, a large type, to only as long as it is actually needed. +std::optional selectOverload( + NotNull builtinTypes, + NotNull arena, + NotNull normalizer, + NotNull scope, + NotNull iceReporter, + NotNull limits, + const Location& location, + TypeId fn, + TypePackId argsPack +) +{ + OverloadResolver resolver{builtinTypes, arena, normalizer, scope, iceReporter, limits, location}; + auto [status, overload] = resolver.selectOverload(fn, argsPack); + + if (status == OverloadResolver::Analysis::Ok) + return overload; + + if (get(fn) || get(fn)) + return fn; + + return {}; +} + +SolveResult solveFunctionCall( + NotNull arena, + NotNull builtinTypes, + NotNull normalizer, + NotNull iceReporter, + NotNull limits, + NotNull scope, + const Location& location, + TypeId fn, + TypePackId argsPack +) +{ + std::optional overloadToUse = selectOverload(builtinTypes, arena, normalizer, scope, iceReporter, limits, location, fn, argsPack); + if (!overloadToUse) + return {SolveResult::NoMatchingOverload}; + + TypePackId resultPack = arena->freshTypePack(scope); + + TypeId inferredTy = arena->addType(FunctionType{TypeLevel{}, scope.get(), argsPack, resultPack}); + Unifier2 u2{NotNull{arena}, builtinTypes, scope, iceReporter}; + + const bool occursCheckPassed = u2.unify(*overloadToUse, inferredTy); + + if (!u2.genericSubstitutions.empty() || !u2.genericPackSubstitutions.empty()) + { + Instantiation2 instantiation{arena, std::move(u2.genericSubstitutions), std::move(u2.genericPackSubstitutions)}; + + std::optional subst = instantiation.substitute(resultPack); + + if (!subst) + return {SolveResult::CodeTooComplex}; + else + resultPack = *subst; + } + + if (!occursCheckPassed) + return {SolveResult::OccursCheckFailed}; + + SolveResult result; + result.result = SolveResult::Ok; + result.typePackId = resultPack; + + LUAU_ASSERT(overloadToUse); + result.overloadToUse = *overloadToUse; + result.inferredTy = inferredTy; + result.expandedFreeTypes = std::move(u2.expandedFreeTypes); + + return result; +} + +} // namespace Luau diff --git a/third_party/luau/Analysis/src/Quantify.cpp b/third_party/luau/Analysis/src/Quantify.cpp index 0a7975f4..daa61fd5 100644 --- a/third_party/luau/Analysis/src/Quantify.cpp +++ b/third_party/luau/Analysis/src/Quantify.cpp @@ -8,10 +8,6 @@ #include "Luau/Type.h" #include "Luau/VisitType.h" -LUAU_FASTFLAG(DebugLuauSharedSelf) -LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); -LUAU_FASTFLAG(LuauClassTypeVarsInSubstitution) - namespace Luau { @@ -102,60 +98,20 @@ struct Quantifier final : TypeOnceVisitor void quantify(TypeId ty, TypeLevel level) { - if (FFlag::DebugLuauSharedSelf) - { - ty = follow(ty); - - if (auto ttv = getTableType(ty); ttv && ttv->selfTy) - { - Quantifier selfQ{level}; - selfQ.traverse(*ttv->selfTy); - - Quantifier q{level}; - q.traverse(ty); - - for (const auto& [_, prop] : ttv->props) - { - auto ftv = getMutable(follow(prop.type())); - if (!ftv || !ftv->hasSelf) - continue; - - if (Luau::first(ftv->argTypes) == ttv->selfTy) - { - ftv->generics.insert(ftv->generics.end(), selfQ.generics.begin(), selfQ.generics.end()); - ftv->genericPacks.insert(ftv->genericPacks.end(), selfQ.genericPacks.begin(), selfQ.genericPacks.end()); - } - } - } - else if (auto ftv = getMutable(ty)) - { - Quantifier q{level}; - q.traverse(ty); - - ftv->generics.insert(ftv->generics.end(), q.generics.begin(), q.generics.end()); - ftv->genericPacks.insert(ftv->genericPacks.end(), q.genericPacks.begin(), q.genericPacks.end()); - - if (ftv->generics.empty() && ftv->genericPacks.empty() && !q.seenMutableType && !q.seenGenericType) - ftv->hasNoGenerics = true; - } - } - else - { - Quantifier q{level}; - q.traverse(ty); + Quantifier q{level}; + q.traverse(ty); - FunctionType* ftv = getMutable(ty); - LUAU_ASSERT(ftv); - ftv->generics.insert(ftv->generics.end(), q.generics.begin(), q.generics.end()); - ftv->genericPacks.insert(ftv->genericPacks.end(), q.genericPacks.begin(), q.genericPacks.end()); - } + FunctionType* ftv = getMutable(ty); + LUAU_ASSERT(ftv); + ftv->generics.insert(ftv->generics.end(), q.generics.begin(), q.generics.end()); + ftv->genericPacks.insert(ftv->genericPacks.end(), q.genericPacks.begin(), q.genericPacks.end()); } struct PureQuantifier : Substitution { Scope* scope; - std::vector insertedGenerics; - std::vector insertedGenericPacks; + OrderedMap insertedGenerics; + OrderedMap insertedGenericPacks; bool seenMutableType = false; bool seenGenericType = false; @@ -203,7 +159,7 @@ struct PureQuantifier : Substitution if (auto ftv = get(ty)) { TypeId result = arena->addType(GenericType{scope}); - insertedGenerics.push_back(result); + insertedGenerics.push(ty, result); return result; } else if (auto ttv = get(ty)) @@ -217,7 +173,10 @@ struct PureQuantifier : Substitution resultTable->scope = scope; if (ttv->state == TableState::Free) + { resultTable->state = TableState::Generic; + insertedGenerics.push(ty, result); + } else if (ttv->state == TableState::Unsealed) resultTable->state = TableState::Sealed; @@ -231,8 +190,8 @@ struct PureQuantifier : Substitution { if (auto ftp = get(tp)) { - TypePackId result = arena->addTypePack(TypePackVar{GenericTypePack{}}); - insertedGenericPacks.push_back(result); + TypePackId result = arena->addTypePack(TypePackVar{GenericTypePack{scope}}); + insertedGenericPacks.push(tp, result); return result; } @@ -241,7 +200,7 @@ struct PureQuantifier : Substitution bool ignoreChildren(TypeId ty) override { - if (FFlag::LuauClassTypeVarsInSubstitution && get(ty)) + if (get(ty)) return true; return ty->persistent; @@ -252,7 +211,7 @@ struct PureQuantifier : Substitution } }; -std::optional quantify(TypeArena* arena, TypeId ty, Scope* scope) +std::optional quantify(TypeArena* arena, TypeId ty, Scope* scope) { PureQuantifier quantifier{arena, scope}; std::optional result = quantifier.substitute(ty); @@ -262,11 +221,20 @@ std::optional quantify(TypeArena* arena, TypeId ty, Scope* scope) FunctionType* ftv = getMutable(*result); LUAU_ASSERT(ftv); ftv->scope = scope; - ftv->generics.insert(ftv->generics.end(), quantifier.insertedGenerics.begin(), quantifier.insertedGenerics.end()); - ftv->genericPacks.insert(ftv->genericPacks.end(), quantifier.insertedGenericPacks.begin(), quantifier.insertedGenericPacks.end()); - ftv->hasNoGenerics = ftv->generics.empty() && ftv->genericPacks.empty() && !quantifier.seenGenericType && !quantifier.seenMutableType; - return *result; + for (auto k : quantifier.insertedGenerics.keys) + { + TypeId g = quantifier.insertedGenerics.pairings[k]; + if (get(g)) + ftv->generics.push_back(g); + } + + for (auto k : quantifier.insertedGenericPacks.keys) + ftv->genericPacks.push_back(quantifier.insertedGenericPacks.pairings[k]); + + ftv->hasNoFreeOrGenericTypes = ftv->generics.empty() && ftv->genericPacks.empty() && !quantifier.seenGenericType && !quantifier.seenMutableType; + + return std::optional({*result, std::move(quantifier.insertedGenerics), std::move(quantifier.insertedGenericPacks)}); } } // namespace Luau diff --git a/third_party/luau/Analysis/src/Refinement.cpp b/third_party/luau/Analysis/src/Refinement.cpp index a81063c7..e98b6e5a 100644 --- a/third_party/luau/Analysis/src/Refinement.cpp +++ b/third_party/luau/Analysis/src/Refinement.cpp @@ -1,37 +1,60 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Refinement.h" +#include namespace Luau { RefinementId RefinementArena::variadic(const std::vector& refis) { + bool hasRefinements = false; + for (RefinementId r : refis) + hasRefinements |= bool(r); + + if (!hasRefinements) + return nullptr; + return NotNull{allocator.allocate(Variadic{refis})}; } RefinementId RefinementArena::negation(RefinementId refinement) { + if (!refinement) + return nullptr; + return NotNull{allocator.allocate(Negation{refinement})}; } RefinementId RefinementArena::conjunction(RefinementId lhs, RefinementId rhs) { + if (!lhs && !rhs) + return nullptr; + return NotNull{allocator.allocate(Conjunction{lhs, rhs})}; } RefinementId RefinementArena::disjunction(RefinementId lhs, RefinementId rhs) { + if (!lhs && !rhs) + return nullptr; + return NotNull{allocator.allocate(Disjunction{lhs, rhs})}; } RefinementId RefinementArena::equivalence(RefinementId lhs, RefinementId rhs) { + if (!lhs && !rhs) + return nullptr; + return NotNull{allocator.allocate(Equivalence{lhs, rhs})}; } -RefinementId RefinementArena::proposition(BreadcrumbId breadcrumb, TypeId discriminantTy) +RefinementId RefinementArena::proposition(const RefinementKey* key, TypeId discriminantTy) { - return NotNull{allocator.allocate(Proposition{breadcrumb, discriminantTy})}; + if (!key) + return nullptr; + + return NotNull{allocator.allocate(Proposition{key, discriminantTy})}; } } // namespace Luau diff --git a/third_party/luau/Analysis/src/Scope.cpp b/third_party/luau/Analysis/src/Scope.cpp index 2de381be..791167c8 100644 --- a/third_party/luau/Analysis/src/Scope.cpp +++ b/third_party/luau/Analysis/src/Scope.cpp @@ -2,6 +2,8 @@ #include "Luau/Scope.h" +LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); + namespace Luau { @@ -36,6 +38,24 @@ std::optional Scope::lookup(Symbol sym) const return std::nullopt; } +std::optional> Scope::lookupEx(DefId def) +{ + Scope* s = this; + + while (true) + { + if (TypeId* it = s->lvalueTypes.find(def)) + return std::pair{*it, s}; + else if (TypeId* it = s->rvalueRefinements.find(def)) + return std::pair{*it, s}; + + if (s->parent) + s = s->parent.get(); + else + return std::nullopt; + } +} + std::optional> Scope::lookupEx(Symbol sym) { Scope* s = this; @@ -53,12 +73,24 @@ std::optional> Scope::lookupEx(Symbol sym) } } -// TODO: We might kill Scope::lookup(Symbol) once data flow is fully fleshed out with type states and control flow analysis. +std::optional Scope::lookupUnrefinedType(DefId def) const +{ + for (const Scope* current = this; current; current = current->parent.get()) + { + if (auto ty = current->lvalueTypes.find(def)) + return *ty; + } + + return std::nullopt; +} + std::optional Scope::lookup(DefId def) const { for (const Scope* current = this; current; current = current->parent.get()) { - if (auto ty = current->dcrRefinements.find(def)) + if (auto ty = current->rvalueRefinements.find(def)) + return *ty; + if (auto ty = current->lvalueTypes.find(def)) return *ty; } @@ -149,25 +181,33 @@ std::optional Scope::linearSearchForBinding(const std::string& name, bo return std::nullopt; } +// Updates the `this` scope with the assignments from the `childScope` including ones that doesn't exist in `this`. +void Scope::inheritAssignments(const ScopePtr& childScope) +{ + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; + + for (const auto& [k, a] : childScope->lvalueTypes) + lvalueTypes[k] = a; +} + // Updates the `this` scope with the refinements from the `childScope` excluding ones that doesn't exist in `this`. void Scope::inheritRefinements(const ScopePtr& childScope) { if (FFlag::DebugLuauDeferredConstraintResolution) { - for (const auto& [k, a] : childScope->dcrRefinements) + for (const auto& [k, a] : childScope->rvalueRefinements) { if (lookup(NotNull{k})) - dcrRefinements[k] = a; + rvalueRefinements[k] = a; } } - else + + for (const auto& [k, a] : childScope->refinements) { - for (const auto& [k, a] : childScope->refinements) - { - Symbol symbol = getBaseSymbol(k); - if (lookup(symbol)) - refinements[k] = a; - } + Symbol symbol = getBaseSymbol(k); + if (lookup(symbol)) + refinements[k] = a; } } diff --git a/third_party/luau/Analysis/src/Simplify.cpp b/third_party/luau/Analysis/src/Simplify.cpp new file mode 100644 index 00000000..ed953f63 --- /dev/null +++ b/third_party/luau/Analysis/src/Simplify.cpp @@ -0,0 +1,1444 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/Simplify.h" + +#include "Luau/DenseHash.h" +#include "Luau/RecursionCounter.h" +#include "Luau/Set.h" +#include "Luau/TypeArena.h" +#include "Luau/TypePairHash.h" +#include "Luau/TypeUtils.h" + +#include + +LUAU_FASTINT(LuauTypeReductionRecursionLimit) +LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) +LUAU_DYNAMIC_FASTINTVARIABLE(LuauSimplificationComplexityLimit, 8); + +namespace Luau +{ + +using SimplifierSeenSet = Set, TypePairHash>; + +struct TypeSimplifier +{ + NotNull builtinTypes; + NotNull arena; + + DenseHashSet blockedTypes{nullptr}; + + int recursionDepth = 0; + + TypeId mkNegation(TypeId ty); + + TypeId intersectFromParts(std::set parts); + + TypeId intersectUnionWithType(TypeId unionTy, TypeId right); + TypeId intersectUnions(TypeId left, TypeId right); + TypeId intersectNegatedUnion(TypeId unionTy, TypeId right); + + TypeId intersectTypeWithNegation(TypeId a, TypeId b); + TypeId intersectNegations(TypeId a, TypeId b); + + TypeId intersectIntersectionWithType(TypeId left, TypeId right); + + // Attempt to intersect the two types. Does not recurse. Does not handle + // unions, intersections, or negations. + std::optional basicIntersect(TypeId left, TypeId right); + + TypeId intersect(TypeId ty, TypeId discriminant); + TypeId union_(TypeId ty, TypeId discriminant); + + TypeId simplify(TypeId ty); + TypeId simplify(TypeId ty, DenseHashSet& seen); +}; + +// Match the exact type false|nil +static bool isFalsyType(TypeId ty) +{ + ty = follow(ty); + const UnionType* ut = get(ty); + if (!ut) + return false; + + bool hasFalse = false; + bool hasNil = false; + + auto it = begin(ut); + if (it == end(ut)) + return false; + + TypeId t = follow(*it); + + if (auto pt = get(t); pt && pt->type == PrimitiveType::NilType) + hasNil = true; + else if (auto st = get(t); st && st->variant == BooleanSingleton{false}) + hasFalse = true; + else + return false; + + ++it; + if (it == end(ut)) + return false; + + t = follow(*it); + + if (auto pt = get(t); pt && pt->type == PrimitiveType::NilType) + hasNil = true; + else if (auto st = get(t); st && st->variant == BooleanSingleton{false}) + hasFalse = true; + else + return false; + + ++it; + if (it != end(ut)) + return false; + + return hasFalse && hasNil; +} + +// Match the exact type ~(false|nil) +bool isTruthyType(TypeId ty) +{ + ty = follow(ty); + + const NegationType* nt = get(ty); + if (!nt) + return false; + + return isFalsyType(nt->ty); +} + +Relation flip(Relation rel) +{ + switch (rel) + { + case Relation::Subset: + return Relation::Superset; + case Relation::Superset: + return Relation::Subset; + default: + return rel; + } +} + +// FIXME: I'm not completely certain that this function is theoretically reasonable. +Relation combine(Relation a, Relation b) +{ + switch (a) + { + case Relation::Disjoint: + switch (b) + { + case Relation::Disjoint: + return Relation::Disjoint; + case Relation::Coincident: + return Relation::Superset; + case Relation::Intersects: + return Relation::Intersects; + case Relation::Subset: + return Relation::Intersects; + case Relation::Superset: + return Relation::Intersects; + } + case Relation::Coincident: + switch (b) + { + case Relation::Disjoint: + return Relation::Coincident; + case Relation::Coincident: + return Relation::Coincident; + case Relation::Intersects: + return Relation::Superset; + case Relation::Subset: + return Relation::Coincident; + case Relation::Superset: + return Relation::Intersects; + } + case Relation::Superset: + switch (b) + { + case Relation::Disjoint: + return Relation::Superset; + case Relation::Coincident: + return Relation::Superset; + case Relation::Intersects: + return Relation::Intersects; + case Relation::Subset: + return Relation::Intersects; + case Relation::Superset: + return Relation::Superset; + } + case Relation::Subset: + switch (b) + { + case Relation::Disjoint: + return Relation::Subset; + case Relation::Coincident: + return Relation::Coincident; + case Relation::Intersects: + return Relation::Intersects; + case Relation::Subset: + return Relation::Subset; + case Relation::Superset: + return Relation::Intersects; + } + case Relation::Intersects: + switch (b) + { + case Relation::Disjoint: + return Relation::Intersects; + case Relation::Coincident: + return Relation::Superset; + case Relation::Intersects: + return Relation::Intersects; + case Relation::Subset: + return Relation::Intersects; + case Relation::Superset: + return Relation::Intersects; + } + } + + LUAU_UNREACHABLE(); + return Relation::Intersects; +} + +// Given A & B, what is A & ~B? +Relation invert(Relation r) +{ + switch (r) + { + case Relation::Disjoint: + return Relation::Subset; + case Relation::Coincident: + return Relation::Disjoint; + case Relation::Intersects: + return Relation::Intersects; + case Relation::Subset: + return Relation::Disjoint; + case Relation::Superset: + return Relation::Intersects; + } + + LUAU_UNREACHABLE(); + return Relation::Intersects; +} + +static bool isTypeVariable(TypeId ty) +{ + return get(ty) || get(ty) || get(ty) || get(ty); +} + +Relation relate(TypeId left, TypeId right, SimplifierSeenSet& seen); + +Relation relateTables(TypeId left, TypeId right, SimplifierSeenSet& seen) +{ + NotNull leftTable{get(left)}; + NotNull rightTable{get(right)}; + LUAU_ASSERT(1 == rightTable->props.size()); + // Disjoint props have nothing in common + // t1 with props p1's cannot appear in t2 and t2 with props p2's cannot appear in t1 + bool foundPropFromLeftInRight = std::any_of( + begin(leftTable->props), + end(leftTable->props), + [&](auto prop) + { + return rightTable->props.count(prop.first) > 0; + } + ); + bool foundPropFromRightInLeft = std::any_of( + begin(rightTable->props), + end(rightTable->props), + [&](auto prop) + { + return leftTable->props.count(prop.first) > 0; + } + ); + + if (!foundPropFromLeftInRight && !foundPropFromRightInLeft && leftTable->props.size() >= 1 && rightTable->props.size() >= 1) + return Relation::Disjoint; + + const auto [propName, rightProp] = *begin(rightTable->props); + + auto it = leftTable->props.find(propName); + if (it == leftTable->props.end()) + { + // Every table lacking a property is a supertype of a table having that + // property but the reverse is not true. + return Relation::Superset; + } + + const Property leftProp = it->second; + + if (!leftProp.isShared() || !rightProp.isShared()) + return Relation::Intersects; + + Relation r = relate(leftProp.type(), rightProp.type(), seen); + if (r == Relation::Coincident && 1 != leftTable->props.size()) + { + // eg {tag: "cat", prop: string} & {tag: "cat"} + return Relation::Subset; + } + else + return r; +} + +// A cheap and approximate subtype test +Relation relate(TypeId left, TypeId right, SimplifierSeenSet& seen) +{ + // TODO nice to have: Relate functions of equal argument and return arity + + left = follow(left); + right = follow(right); + + if (left == right) + return Relation::Coincident; + + std::pair typePair{left, right}; + if (!seen.insert(typePair)) + { + // TODO: is this right at all? + // The thinking here is that this is a cycle if we get here, and therefore its coincident. + return Relation::Coincident; + } + + if (get(left)) + { + if (get(right)) + return Relation::Subset; + else if (get(right)) + return Relation::Coincident; + else if (get(right)) + return Relation::Disjoint; + else + return Relation::Superset; + } + + if (get(right)) + return flip(relate(right, left, seen)); + + if (get(left)) + { + if (get(right)) + return Relation::Coincident; + else + return Relation::Superset; + } + + if (get(right)) + return flip(relate(right, left, seen)); + + // Type variables + // * FreeType + // * GenericType + // * BlockedType + // * PendingExpansionType + + // Tops and bottoms + // * ErrorType + // * AnyType + // * NeverType + // * UnknownType + + // Concrete + // * PrimitiveType + // * SingletonType + // * FunctionType + // * TableType + // * MetatableType + // * ClassType + // * UnionType + // * IntersectionType + // * NegationType + + if (isTypeVariable(left) || isTypeVariable(right)) + return Relation::Intersects; + + if (get(left)) + { + if (get(right)) + return Relation::Coincident; + else if (get(right)) + return Relation::Subset; + else + return Relation::Disjoint; + } + if (get(right)) + return flip(relate(right, left, seen)); + + if (get(left)) + { + if (get(right)) + return Relation::Coincident; + else + return Relation::Subset; + } + if (get(right)) + return flip(relate(right, left, seen)); + + if (auto ut = get(left)) + return Relation::Intersects; + else if (auto ut = get(right)) + return Relation::Intersects; + + if (auto ut = get(left)) + return Relation::Intersects; + else if (auto ut = get(right)) + { + std::vector opts; + for (TypeId part : ut) + { + Relation r = relate(left, part, seen); + + if (r == Relation::Subset || r == Relation::Coincident) + return Relation::Subset; + } + return Relation::Intersects; + } + + if (auto rnt = get(right)) + { + Relation a = relate(left, rnt->ty, seen); + switch (a) + { + case Relation::Coincident: + // number & ~number + return Relation::Disjoint; + case Relation::Disjoint: + if (get(left)) + { + // ~number & ~string + return Relation::Intersects; + } + else + { + // number & ~string + return Relation::Subset; + } + case Relation::Intersects: + // ~(false?) & ~boolean + return Relation::Intersects; + case Relation::Subset: + // "hello" & ~string + return Relation::Disjoint; + case Relation::Superset: + // ~function & ~(false?) -> ~function + // boolean & ~(false?) -> true + // string & ~"hello" -> string & ~"hello" + return Relation::Intersects; + } + } + else if (get(left)) + return flip(relate(right, left, seen)); + + if (auto lp = get(left)) + { + if (auto rp = get(right)) + { + if (lp->type == rp->type) + return Relation::Coincident; + else + return Relation::Disjoint; + } + + if (auto rs = get(right)) + { + if (lp->type == PrimitiveType::String && rs->variant.get_if()) + return Relation::Superset; + else if (lp->type == PrimitiveType::Boolean && rs->variant.get_if()) + return Relation::Superset; + else + return Relation::Disjoint; + } + + if (lp->type == PrimitiveType::Function) + { + if (get(right)) + return Relation::Superset; + else + return Relation::Disjoint; + } + if (lp->type == PrimitiveType::Table) + { + if (get(right)) + return Relation::Superset; + else + return Relation::Disjoint; + } + + if (get(right) || get(right) || get(right) || get(right)) + return Relation::Disjoint; + } + + if (auto ls = get(left)) + { + if (get(right) || get(right) || get(right) || get(right)) + return Relation::Disjoint; + + if (get(right)) + return flip(relate(right, left, seen)); + if (auto rs = get(right)) + { + if (ls->variant == rs->variant) + return Relation::Coincident; + else + return Relation::Disjoint; + } + } + + if (get(left)) + { + if (auto rp = get(right)) + { + if (rp->type == PrimitiveType::Function) + return Relation::Subset; + else + return Relation::Disjoint; + } + else + return Relation::Intersects; + } + + if (auto lt = get(left)) + { + if (auto rp = get(right)) + { + if (rp->type == PrimitiveType::Table) + return Relation::Subset; + else + return Relation::Disjoint; + } + else if (auto rt = get(right)) + { + // TODO PROBABLY indexers and metatables. + if (1 == rt->props.size()) + { + Relation r = relateTables(left, right, seen); + /* + * A reduction of these intersections is certainly possible, but + * it would require minting new table types. Also, I don't think + * it's super likely for this to arise from a refinement. + * + * Time will tell! + * + * ex we simplify this + * {tag: string} & {tag: "cat"} + * but not this + * {tag: string, prop: number} & {tag: "cat"} + */ + if (lt->props.size() > 1 && r == Relation::Superset) + return Relation::Intersects; + else + return r; + } + else if (1 == lt->props.size()) + return flip(relate(right, left, seen)); + else + return Relation::Intersects; + } + // TODO metatables + + return Relation::Disjoint; + } + + if (auto ct = get(left)) + { + if (auto rct = get(right)) + { + if (isSubclass(ct, rct)) + return Relation::Subset; + else if (isSubclass(rct, ct)) + return Relation::Superset; + else + return Relation::Disjoint; + } + + return Relation::Disjoint; + } + + return Relation::Intersects; +} + +// A cheap and approximate subtype test +Relation relate(TypeId left, TypeId right) +{ + SimplifierSeenSet seen{{}}; + return relate(left, right, seen); +} + +TypeId TypeSimplifier::mkNegation(TypeId ty) +{ + TypeId result = nullptr; + + if (ty == builtinTypes->truthyType) + result = builtinTypes->falsyType; + else if (ty == builtinTypes->falsyType) + result = builtinTypes->truthyType; + else if (auto ntv = get(ty)) + result = follow(ntv->ty); + else + result = arena->addType(NegationType{ty}); + + return result; +} + +TypeId TypeSimplifier::intersectFromParts(std::set parts) +{ + if (0 == parts.size()) + return builtinTypes->neverType; + else if (1 == parts.size()) + return *begin(parts); + + { + auto it = begin(parts); + while (it != end(parts)) + { + TypeId t = follow(*it); + + auto copy = it; + ++it; + + if (auto ut = get(t)) + { + for (TypeId part : ut) + parts.insert(part); + parts.erase(copy); + } + } + } + + std::set newParts; + + /* + * It is possible that the parts of the passed intersection are themselves + * reducable. + * + * eg false & boolean + * + * We do a comparison between each pair of types and look for things that we + * can elide. + */ + for (TypeId part : parts) + { + if (newParts.empty()) + { + newParts.insert(part); + continue; + } + + auto it = begin(newParts); + while (it != end(newParts)) + { + TypeId p = *it; + + switch (relate(part, p)) + { + case Relation::Disjoint: + // eg boolean & string + return builtinTypes->neverType; + case Relation::Subset: + { + /* part is a subset of p. Remove p from the set and replace it + * with part. + * + * eg boolean & true + */ + auto saveIt = it; + ++it; + newParts.erase(saveIt); + continue; + } + case Relation::Coincident: + case Relation::Superset: + { + /* part is coincident or a superset of p. We do not need to + * include part in the final intersection. + * + * ex true & boolean + */ + ++it; + continue; + } + case Relation::Intersects: + { + /* It's complicated! A simplification may still be possible, + * but we have to pull the types apart to figure it out. + * + * ex boolean & ~false + */ + std::optional simplified = basicIntersect(part, p); + + auto saveIt = it; + ++it; + + if (simplified) + { + newParts.erase(saveIt); + newParts.insert(*simplified); + } + else + newParts.insert(part); + continue; + } + } + } + } + + if (0 == newParts.size()) + return builtinTypes->neverType; + else if (1 == newParts.size()) + return *begin(newParts); + else + return arena->addType(IntersectionType{std::vector{begin(newParts), end(newParts)}}); +} + +TypeId TypeSimplifier::intersectUnionWithType(TypeId left, TypeId right) +{ + const UnionType* leftUnion = get(left); + LUAU_ASSERT(leftUnion); + + bool changed = false; + std::set newParts; + + if (leftUnion->options.size() > (size_t)DFInt::LuauSimplificationComplexityLimit) + return arena->addType(IntersectionType{{left, right}}); + + for (TypeId part : leftUnion) + { + TypeId simplified = intersect(right, part); + changed |= simplified != part; + + if (get(simplified)) + { + changed = true; + continue; + } + + newParts.insert(simplified); + } + + if (!changed) + return left; + else if (newParts.empty()) + return builtinTypes->neverType; + else if (newParts.size() == 1) + return *begin(newParts); + else + return arena->addType(UnionType{std::vector(begin(newParts), end(newParts))}); +} + +TypeId TypeSimplifier::intersectUnions(TypeId left, TypeId right) +{ + const UnionType* leftUnion = get(left); + LUAU_ASSERT(leftUnion); + + const UnionType* rightUnion = get(right); + LUAU_ASSERT(rightUnion); + + std::set newParts; + + // Combinatorial blowup moment!! + + // combination size + size_t optionSize = (int)leftUnion->options.size() * rightUnion->options.size(); + size_t maxSize = DFInt::LuauSimplificationComplexityLimit; + + if (optionSize > maxSize) + return arena->addType(IntersectionType{{left, right}}); + + for (TypeId leftPart : leftUnion) + { + for (TypeId rightPart : rightUnion) + { + TypeId simplified = intersect(leftPart, rightPart); + if (get(simplified)) + continue; + + newParts.insert(simplified); + } + } + + if (newParts.empty()) + return builtinTypes->neverType; + else if (newParts.size() == 1) + return *begin(newParts); + else + return arena->addType(UnionType{std::vector(begin(newParts), end(newParts))}); +} + +TypeId TypeSimplifier::intersectNegatedUnion(TypeId left, TypeId right) +{ + // ~(A | B) & C + // (~A & C) & (~B & C) + + const NegationType* leftNegation = get(left); + LUAU_ASSERT(leftNegation); + + TypeId negatedTy = follow(leftNegation->ty); + + const UnionType* negatedUnion = get(negatedTy); + LUAU_ASSERT(negatedUnion); + + bool changed = false; + std::set newParts; + + for (TypeId part : negatedUnion) + { + Relation r = relate(part, right); + switch (r) + { + case Relation::Disjoint: + // If A is disjoint from B, then ~A & B is just B. + // + // ~(false?) & true + // (~false & true) & (~nil & true) + // true & true + newParts.insert(right); + break; + case Relation::Coincident: + // If A is coincident with or a superset of B, then ~A & B is never. + // + // ~(false?) & false + // (~false & false) & (~nil & false) + // never & false + // + // fallthrough + case Relation::Superset: + // If A is a superset of B, then ~A & B is never. + // + // ~(boolean | nil) & true + // (~boolean & true) & (~boolean & nil) + // never & nil + return builtinTypes->neverType; + case Relation::Subset: + case Relation::Intersects: + // If A is a subset of B, then ~A & B is a bit more complicated. We need to think harder. + // + // ~(false?) & boolean + // (~false & boolean) & (~nil & boolean) + // true & boolean + TypeId simplified = intersectTypeWithNegation(mkNegation(part), right); + changed |= simplified != right; + if (get(simplified)) + changed = true; + else + newParts.insert(simplified); + break; + } + } + + if (!changed) + return right; + else + return intersectFromParts(std::move(newParts)); +} + +TypeId TypeSimplifier::intersectTypeWithNegation(TypeId left, TypeId right) +{ + const NegationType* leftNegation = get(left); + LUAU_ASSERT(leftNegation); + + TypeId negatedTy = follow(leftNegation->ty); + + if (negatedTy == right) + return builtinTypes->neverType; + + if (auto ut = get(negatedTy)) + { + // ~(A | B) & C + // (~A & C) & (~B & C) + + bool changed = false; + std::set newParts; + + for (TypeId part : ut) + { + Relation r = relate(part, right); + switch (r) + { + case Relation::Coincident: + // ~(false?) & nil + // (~false & nil) & (~nil & nil) + // nil & never + // + // fallthrough + case Relation::Superset: + // ~(boolean | string) & true + // (~boolean & true) & (~boolean & string) + // never & string + + return builtinTypes->neverType; + + case Relation::Disjoint: + // ~nil & boolean + newParts.insert(right); + break; + + case Relation::Subset: + // ~false & boolean + // fallthrough + case Relation::Intersects: + // FIXME: The mkNegation here is pretty unfortunate. + // Memoizing this will probably be important. + changed = true; + newParts.insert(right); + newParts.insert(mkNegation(part)); + } + } + + if (!changed) + return right; + else + return intersectFromParts(std::move(newParts)); + } + + if (auto rightUnion = get(right)) + { + // ~A & (B | C) + bool changed = false; + std::set newParts; + + for (TypeId part : rightUnion) + { + Relation r = relate(negatedTy, part); + switch (r) + { + case Relation::Coincident: + changed = true; + continue; + case Relation::Disjoint: + newParts.insert(part); + break; + case Relation::Superset: + changed = true; + continue; + case Relation::Subset: + // fallthrough + case Relation::Intersects: + changed = true; + newParts.insert(arena->addType(IntersectionType{{left, part}})); + } + } + + if (!changed) + return right; + else if (0 == newParts.size()) + return builtinTypes->neverType; + else if (1 == newParts.size()) + return *begin(newParts); + else + return arena->addType(UnionType{std::vector{begin(newParts), end(newParts)}}); + } + + if (auto pt = get(right); pt && pt->type == PrimitiveType::Boolean) + { + if (auto st = get(negatedTy)) + { + if (st->variant == BooleanSingleton{true}) + return builtinTypes->falseType; + else if (st->variant == BooleanSingleton{false}) + return builtinTypes->trueType; + else + // boolean & ~"hello" + return builtinTypes->booleanType; + } + } + + Relation r = relate(negatedTy, right); + + switch (r) + { + case Relation::Disjoint: + // ~boolean & string + return right; + case Relation::Coincident: + // ~string & string + // fallthrough + case Relation::Superset: + // ~string & "hello" + return builtinTypes->neverType; + case Relation::Subset: + // ~string & unknown + // ~"hello" & string + // fallthrough + case Relation::Intersects: + // ~("hello" | boolean) & string + // fallthrough + default: + return arena->addType(IntersectionType{{left, right}}); + } +} + +TypeId TypeSimplifier::intersectNegations(TypeId left, TypeId right) +{ + const NegationType* leftNegation = get(left); + LUAU_ASSERT(leftNegation); + + if (get(follow(leftNegation->ty))) + return intersectNegatedUnion(left, right); + + const NegationType* rightNegation = get(right); + LUAU_ASSERT(rightNegation); + + if (get(follow(rightNegation->ty))) + return intersectNegatedUnion(right, left); + + Relation r = relate(leftNegation->ty, rightNegation->ty); + + switch (r) + { + case Relation::Coincident: + // ~true & ~true + return left; + case Relation::Subset: + // ~true & ~boolean + return right; + case Relation::Superset: + // ~boolean & ~true + return left; + case Relation::Intersects: + case Relation::Disjoint: + default: + // ~boolean & ~string + return arena->addType(IntersectionType{{left, right}}); + } +} + +TypeId TypeSimplifier::intersectIntersectionWithType(TypeId left, TypeId right) +{ + const IntersectionType* leftIntersection = get(left); + LUAU_ASSERT(leftIntersection); + + if (leftIntersection->parts.size() > (size_t)DFInt::LuauSimplificationComplexityLimit) + return arena->addType(IntersectionType{{left, right}}); + + bool changed = false; + std::set newParts; + + for (TypeId part : leftIntersection) + { + Relation r = relate(part, right); + switch (r) + { + case Relation::Disjoint: + return builtinTypes->neverType; + case Relation::Coincident: + newParts.insert(part); + continue; + case Relation::Subset: + newParts.insert(part); + continue; + case Relation::Superset: + newParts.insert(right); + changed = true; + continue; + default: + newParts.insert(part); + newParts.insert(right); + changed = true; + continue; + } + } + + // It is sometimes the case that an intersection operation will result in + // clipping a free type from the result. + // + // eg (number & 'a) & string --> never + // + // We want to only report the free types that are part of the result. + for (TypeId part : newParts) + { + if (isTypeVariable(part)) + blockedTypes.insert(part); + } + + if (!changed) + return left; + return intersectFromParts(std::move(newParts)); +} + +std::optional TypeSimplifier::basicIntersect(TypeId left, TypeId right) +{ + if (get(left) && get(right)) + return right; + if (get(right) && get(left)) + return left; + if (get(left)) + return arena->addType(UnionType{{right, builtinTypes->errorType}}); + if (get(right)) + return arena->addType(UnionType{{left, builtinTypes->errorType}}); + if (get(left)) + return right; + if (get(right)) + return left; + if (get(left)) + return left; + if (get(right)) + return right; + + if (auto pt = get(left); pt && pt->type == PrimitiveType::Boolean) + { + if (auto st = get(right); st && st->variant.get_if()) + return right; + if (auto nt = get(right)) + { + if (auto st = get(follow(nt->ty)); st && st->variant.get_if()) + { + if (st->variant == BooleanSingleton{true}) + return builtinTypes->falseType; + else + return builtinTypes->trueType; + } + } + } + else if (auto pt = get(right); pt && pt->type == PrimitiveType::Boolean) + { + if (auto st = get(left); st && st->variant.get_if()) + return left; + if (auto nt = get(left)) + { + if (auto st = get(follow(nt->ty)); st && st->variant.get_if()) + { + if (st->variant == BooleanSingleton{true}) + return builtinTypes->falseType; + else + return builtinTypes->trueType; + } + } + } + + if (const auto [lt, rt] = get2(left, right); lt && rt) + { + if (1 == lt->props.size()) + { + const auto [propName, leftProp] = *begin(lt->props); + + auto it = rt->props.find(propName); + if (it != rt->props.end() && leftProp.isShared() && it->second.isShared()) + { + Relation r = relate(leftProp.type(), it->second.type()); + + switch (r) + { + case Relation::Disjoint: + return builtinTypes->neverType; + case Relation::Superset: + case Relation::Coincident: + return right; + case Relation::Subset: + if (1 == rt->props.size()) + return left; + break; + default: + break; + } + } + } + else if (1 == rt->props.size()) + return basicIntersect(right, left); + + // If two tables have disjoint properties and indexers, we can combine them. + if (!lt->indexer && !rt->indexer && lt->state == TableState::Sealed && rt->state == TableState::Sealed) + { + if (rt->props.empty()) + return left; + + bool areDisjoint = true; + for (const auto& [name, leftProp]: lt->props) + { + if (rt->props.count(name)) + { + areDisjoint = false; + break; + } + } + + if (areDisjoint) + { + TableType::Props mergedProps = lt->props; + for (const auto& [name, rightProp]: rt->props) + mergedProps[name] = rightProp; + + return arena->addType(TableType{ + mergedProps, + std::nullopt, + TypeLevel{}, + lt->scope, + TableState::Sealed + }); + } + } + + return std::nullopt; + } + + Relation relation = relate(left, right); + if (left == right || Relation::Coincident == relation) + return left; + + if (relation == Relation::Disjoint) + return builtinTypes->neverType; + else if (relation == Relation::Subset) + return left; + else if (relation == Relation::Superset) + return right; + + return std::nullopt; +} + +TypeId TypeSimplifier::intersect(TypeId left, TypeId right) +{ + RecursionLimiter rl(&recursionDepth, 15); + + left = simplify(left); + right = simplify(right); + + if (left == right) + return left; + + if (get(left) && get(right)) + return right; + if (get(right) && get(left)) + return left; + if (get(left) && !get(right)) + return right; + if (get(right) && !get(left)) + return left; + if (get(left)) + return arena->addType(UnionType{{right, builtinTypes->errorType}}); + if (get(right)) + return arena->addType(UnionType{{left, builtinTypes->errorType}}); + if (get(left)) + return right; + if (get(right)) + return left; + if (get(left)) + return left; + if (get(right)) + return right; + + if (auto lf = get(left)) + { + Relation r = relate(lf->upperBound, right); + if (r == Relation::Subset || r == Relation::Coincident) + return left; + } + else if (auto rf = get(right)) + { + Relation r = relate(left, rf->upperBound); + if (r == Relation::Superset || r == Relation::Coincident) + return right; + } + + if (isTypeVariable(left)) + { + blockedTypes.insert(left); + return arena->addType(IntersectionType{{left, right}}); + } + + if (isTypeVariable(right)) + { + blockedTypes.insert(right); + return arena->addType(IntersectionType{{left, right}}); + } + + if (auto ut = get(left)) + { + if (get(right)) + return intersectUnions(left, right); + else + return intersectUnionWithType(left, right); + } + else if (auto ut = get(right)) + return intersectUnionWithType(right, left); + + if (auto it = get(left)) + return intersectIntersectionWithType(left, right); + else if (auto it = get(right)) + return intersectIntersectionWithType(right, left); + + if (get(left)) + { + if (get(right)) + return intersectNegations(left, right); + else + return intersectTypeWithNegation(left, right); + } + else if (get(right)) + return intersectTypeWithNegation(right, left); + + std::optional res = basicIntersect(left, right); + if (res) + return *res; + else + return arena->addType(IntersectionType{{left, right}}); +} + +TypeId TypeSimplifier::union_(TypeId left, TypeId right) +{ + RecursionLimiter rl(&recursionDepth, 15); + + left = simplify(left); + right = simplify(right); + + if (get(left)) + return right; + if (get(right)) + return left; + + if (auto leftUnion = get(left)) + { + bool changed = false; + std::set newParts; + for (TypeId part : leftUnion) + { + if (get(part)) + { + changed = true; + continue; + } + + Relation r = relate(part, right); + switch (r) + { + case Relation::Coincident: + case Relation::Superset: + return left; + case Relation::Subset: + newParts.insert(right); + changed = true; + break; + default: + newParts.insert(part); + newParts.insert(right); + changed = true; + break; + } + } + + if (!changed) + return left; + if (0 == newParts.size()) + { + // If the left-side is changed but has no parts, then the left-side union is uninhabited. + return right; + } + else if (1 == newParts.size()) + return *begin(newParts); + else + return arena->addType(UnionType{std::vector{begin(newParts), end(newParts)}}); + } + else if (get(right)) + return union_(right, left); + + Relation r = relate(left, right); + if (left == right || r == Relation::Coincident || r == Relation::Superset) + return left; + + if (r == Relation::Subset) + return right; + + if (auto as = get(left)) + { + if (auto abs = as->variant.get_if()) + { + if (auto bs = get(right)) + { + if (auto bbs = bs->variant.get_if()) + { + if (abs->value != bbs->value) + return builtinTypes->booleanType; + } + } + } + } + + return arena->addType(UnionType{{left, right}}); +} + +TypeId TypeSimplifier::simplify(TypeId ty) +{ + DenseHashSet seen{nullptr}; + return simplify(ty, seen); +} + +TypeId TypeSimplifier::simplify(TypeId ty, DenseHashSet& seen) +{ + RecursionLimiter limiter(&recursionDepth, 60); + + ty = follow(ty); + + if (seen.find(ty)) + return ty; + seen.insert(ty); + + if (auto nt = get(ty)) + { + TypeId negatedTy = follow(nt->ty); + if (get(negatedTy)) + return arena->addType(UnionType{{builtinTypes->neverType, builtinTypes->errorType}}); + else if (get(negatedTy)) + return builtinTypes->neverType; + else if (get(negatedTy)) + return builtinTypes->unknownType; + if (auto nnt = get(negatedTy)) + return simplify(nnt->ty, seen); + } + + // Promote {x: never} to never + if (auto tt = get(ty)) + { + if (1 == tt->props.size()) + { + if (std::optional readTy = begin(tt->props)->second.readTy) + { + TypeId propTy = simplify(*readTy, seen); + if (get(propTy)) + return builtinTypes->neverType; + } + } + } + + return ty; +} + +SimplifyResult simplifyIntersection(NotNull builtinTypes, NotNull arena, TypeId left, TypeId right) +{ + LUAU_ASSERT(FFlag::DebugLuauDeferredConstraintResolution); + + TypeSimplifier s{builtinTypes, arena}; + + // fprintf(stderr, "Intersect %s and %s ...\n", toString(left).c_str(), toString(right).c_str()); + + TypeId res = s.intersect(left, right); + + // fprintf(stderr, "Intersect %s and %s -> %s\n", toString(left).c_str(), toString(right).c_str(), toString(res).c_str()); + + return SimplifyResult{res, std::move(s.blockedTypes)}; +} + +SimplifyResult simplifyIntersection(NotNull builtinTypes, NotNull arena, std::set parts) +{ + LUAU_ASSERT(FFlag::DebugLuauDeferredConstraintResolution); + + TypeSimplifier s{builtinTypes, arena}; + + TypeId res = s.intersectFromParts(std::move(parts)); + + return SimplifyResult{res, std::move(s.blockedTypes)}; +} + +SimplifyResult simplifyUnion(NotNull builtinTypes, NotNull arena, TypeId left, TypeId right) +{ + LUAU_ASSERT(FFlag::DebugLuauDeferredConstraintResolution); + + TypeSimplifier s{builtinTypes, arena}; + + TypeId res = s.union_(left, right); + + // fprintf(stderr, "Union %s and %s -> %s\n", toString(left).c_str(), toString(right).c_str(), toString(res).c_str()); + + return SimplifyResult{res, std::move(s.blockedTypes)}; +} + +} // namespace Luau diff --git a/third_party/luau/Analysis/src/Substitution.cpp b/third_party/luau/Analysis/src/Substitution.cpp index 6a600b62..8b8f22e8 100644 --- a/third_party/luau/Analysis/src/Substitution.cpp +++ b/third_party/luau/Analysis/src/Substitution.cpp @@ -8,91 +8,22 @@ #include #include -LUAU_FASTFLAGVARIABLE(LuauSubstitutionFixMissingFields, false) -LUAU_FASTFLAG(LuauClonePublicInterfaceLess2) LUAU_FASTINTVARIABLE(LuauTarjanChildLimit, 10000) -LUAU_FASTFLAGVARIABLE(LuauClassTypeVarsInSubstitution, false) -LUAU_FASTFLAGVARIABLE(LuauSubstitutionReentrant, false) +LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); +LUAU_FASTINTVARIABLE(LuauTarjanPreallocationSize, 256); +LUAU_FASTFLAG(LuauReusableSubstitutions) namespace Luau { -static TypeId DEPRECATED_shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log, bool alwaysClone) -{ - ty = log->follow(ty); - - TypeId result = ty; - - if (auto pty = log->pending(ty)) - ty = &pty->pending; - - if (const FunctionType* ftv = get(ty)) - { - FunctionType clone = FunctionType{ftv->level, ftv->scope, ftv->argTypes, ftv->retTypes, ftv->definition, ftv->hasSelf}; - clone.generics = ftv->generics; - clone.genericPacks = ftv->genericPacks; - clone.magicFunction = ftv->magicFunction; - clone.dcrMagicFunction = ftv->dcrMagicFunction; - clone.dcrMagicRefinement = ftv->dcrMagicRefinement; - clone.tags = ftv->tags; - clone.argNames = ftv->argNames; - result = dest.addType(std::move(clone)); - } - else if (const TableType* ttv = get(ty)) - { - LUAU_ASSERT(!ttv->boundTo); - TableType clone = TableType{ttv->props, ttv->indexer, ttv->level, ttv->scope, ttv->state}; - clone.definitionModuleName = ttv->definitionModuleName; - clone.definitionLocation = ttv->definitionLocation; - clone.name = ttv->name; - clone.syntheticName = ttv->syntheticName; - clone.instantiatedTypeParams = ttv->instantiatedTypeParams; - clone.instantiatedTypePackParams = ttv->instantiatedTypePackParams; - clone.tags = ttv->tags; - result = dest.addType(std::move(clone)); - } - else if (const MetatableType* mtv = get(ty)) - { - MetatableType clone = MetatableType{mtv->table, mtv->metatable}; - clone.syntheticName = mtv->syntheticName; - result = dest.addType(std::move(clone)); - } - else if (const UnionType* utv = get(ty)) - { - UnionType clone; - clone.options = utv->options; - result = dest.addType(std::move(clone)); - } - else if (const IntersectionType* itv = get(ty)) - { - IntersectionType clone; - clone.parts = itv->parts; - result = dest.addType(std::move(clone)); - } - else if (const PendingExpansionType* petv = get(ty)) - { - PendingExpansionType clone{petv->prefix, petv->name, petv->typeArguments, petv->packArguments}; - result = dest.addType(std::move(clone)); - } - else if (const NegationType* ntv = get(ty)) - { - result = dest.addType(NegationType{ntv->ty}); - } - else - return result; - - asMutable(result)->documentationSymbol = ty->documentationSymbol; - return result; -} - static TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log, bool alwaysClone) { - if (!FFlag::LuauClonePublicInterfaceLess2) - return DEPRECATED_shallowClone(ty, dest, log, alwaysClone); - - auto go = [ty, &dest, alwaysClone](auto&& a) { + auto go = [ty, &dest, alwaysClone](auto&& a) + { using T = std::decay_t; + // The pointer identities of free and local types is very important. + // We decline to copy them. if constexpr (std::is_same_v) return ty; else if constexpr (std::is_same_v) @@ -104,19 +35,37 @@ static TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log, bool a else if constexpr (std::is_same_v) return dest.addType(a); else if constexpr (std::is_same_v) - return ty; + return dest.addType(a); else if constexpr (std::is_same_v) + { + LUAU_ASSERT(ty->persistent); return ty; + } else if constexpr (std::is_same_v) - return ty; + { + PendingExpansionType clone = PendingExpansionType{a.prefix, a.name, a.typeArguments, a.packArguments}; + return dest.addType(std::move(clone)); + } else if constexpr (std::is_same_v) + { + LUAU_ASSERT(ty->persistent); return ty; + } else if constexpr (std::is_same_v) + { + LUAU_ASSERT(ty->persistent); return ty; + } else if constexpr (std::is_same_v) + { + LUAU_ASSERT(ty->persistent); return ty; + } else if constexpr (std::is_same_v) + { + LUAU_ASSERT(ty->persistent); return ty; + } else if constexpr (std::is_same_v) return ty; else if constexpr (std::is_same_v) @@ -131,6 +80,7 @@ static TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log, bool a clone.dcrMagicRefinement = a.dcrMagicRefinement; clone.tags = a.tags; clone.argNames = a.argNames; + clone.isCheckedFunction = a.isCheckedFunction; return dest.addType(std::move(clone)); } else if constexpr (std::is_same_v) @@ -168,7 +118,7 @@ static TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log, bool a { if (alwaysClone) { - ClassType clone{a.name, a.props, a.parent, a.metatable, a.tags, a.userData, a.definitionModuleName}; + ClassType clone{a.name, a.props, a.parent, a.metatable, a.tags, a.userData, a.definitionModuleName, a.definitionLocation, a.indexer}; return dest.addType(std::move(clone)); } else @@ -176,6 +126,11 @@ static TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log, bool a } else if constexpr (std::is_same_v) return dest.addType(NegationType{a.ty}); + else if constexpr (std::is_same_v) + { + TypeFunctionInstanceType clone{a.function, a.typeArguments, a.packArguments}; + return dest.addType(std::move(clone)); + } else static_assert(always_false_v, "Non-exhaustive shallowClone switch"); }; @@ -192,11 +147,22 @@ static TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log, bool a return resTy; } +Tarjan::Tarjan() + : typeToIndex(nullptr, FFlag::LuauReusableSubstitutions ? FInt::LuauTarjanPreallocationSize : 0) + , packToIndex(nullptr, FFlag::LuauReusableSubstitutions ? FInt::LuauTarjanPreallocationSize : 0) +{ + nodes.reserve(FInt::LuauTarjanPreallocationSize); + stack.reserve(FInt::LuauTarjanPreallocationSize); + edgesTy.reserve(FInt::LuauTarjanPreallocationSize); + edgesTp.reserve(FInt::LuauTarjanPreallocationSize); + worklist.reserve(FInt::LuauTarjanPreallocationSize); +} + void Tarjan::visitChildren(TypeId ty, int index) { LUAU_ASSERT(ty == log->follow(ty)); - if (ignoreChildren(ty)) + if (ignoreChildrenVisit(ty)) return; if (auto pty = log->pending(ty)) @@ -204,13 +170,10 @@ void Tarjan::visitChildren(TypeId ty, int index) if (const FunctionType* ftv = get(ty)) { - if (FFlag::LuauSubstitutionFixMissingFields) - { - for (TypeId generic : ftv->generics) - visitChild(generic); - for (TypePackId genericPack : ftv->genericPacks) - visitChild(genericPack); - } + for (TypeId generic : ftv->generics) + visitChild(generic); + for (TypePackId genericPack : ftv->genericPacks) + visitChild(genericPack); visitChild(ftv->argTypes); visitChild(ftv->retTypes); @@ -219,7 +182,16 @@ void Tarjan::visitChildren(TypeId ty, int index) { LUAU_ASSERT(!ttv->boundTo); for (const auto& [name, prop] : ttv->props) - visitChild(prop.type()); + { + if (FFlag::DebugLuauDeferredConstraintResolution) + { + visitChild(prop.readTy); + visitChild(prop.writeTy); + } + else + visitChild(prop.type()); + } + if (ttv->indexer) { visitChild(ttv->indexer->indexType); @@ -255,7 +227,15 @@ void Tarjan::visitChildren(TypeId ty, int index) for (TypePackId a : petv->packArguments) visitChild(a); } - else if (const ClassType* ctv = get(ty); FFlag::LuauClassTypeVarsInSubstitution && ctv) + else if (const TypeFunctionInstanceType* tfit = get(ty)) + { + for (TypeId a : tfit->typeArguments) + visitChild(a); + + for (TypePackId a : tfit->packArguments) + visitChild(a); + } + else if (const ClassType* ctv = get(ty)) { for (const auto& [name, prop] : ctv->props) visitChild(prop.type()); @@ -265,6 +245,12 @@ void Tarjan::visitChildren(TypeId ty, int index) if (ctv->metatable) visitChild(*ctv->metatable); + + if (ctv->indexer) + { + visitChild(ctv->indexer->indexType); + visitChild(ctv->indexer->indexResultType); + } } else if (const NegationType* ntv = get(ty)) { @@ -276,7 +262,7 @@ void Tarjan::visitChildren(TypePackId tp, int index) { LUAU_ASSERT(tp == log->follow(tp)); - if (ignoreChildren(tp)) + if (ignoreChildrenVisit(tp)) return; if (auto ptp = log->pending(tp)) @@ -299,17 +285,14 @@ std::pair Tarjan::indexify(TypeId ty) { ty = log->follow(ty); - bool fresh = !typeToIndex.contains(ty); - int& index = typeToIndex[ty]; + auto [index, fresh] = typeToIndex.try_insert(ty, false); if (fresh) { - index = int(indexToType.size()); - indexToType.push_back(ty); - indexToPack.push_back(nullptr); - onStack.push_back(false); - lowlink.push_back(index); + index = int(nodes.size()); + nodes.push_back({ty, nullptr, false, false, index}); } + return {index, fresh}; } @@ -317,17 +300,14 @@ std::pair Tarjan::indexify(TypePackId tp) { tp = log->follow(tp); - bool fresh = !packToIndex.contains(tp); - int& index = packToIndex[tp]; + auto [index, fresh] = packToIndex.try_insert(tp, false); if (fresh) { - index = int(indexToPack.size()); - indexToType.push_back(nullptr); - indexToPack.push_back(tp); - onStack.push_back(false); - lowlink.push_back(index); + index = int(nodes.size()); + nodes.push_back({nullptr, tp, false, false, index}); } + return {index, fresh}; } @@ -362,14 +342,15 @@ TarjanResult Tarjan::loop() return TarjanResult::TooManyChildren; stack.push_back(index); - onStack[index] = true; + + nodes[index].onStack = true; currEdge = int(edgesTy.size()); // Fill in edge list of this vertex - if (TypeId ty = indexToType[index]) + if (TypeId ty = nodes[index].ty) visitChildren(ty, index); - else if (TypePackId tp = indexToPack[index]) + else if (TypePackId tp = nodes[index].tp) visitChildren(tp, index); lastEdge = int(edgesTy.size()); @@ -400,9 +381,9 @@ TarjanResult Tarjan::loop() foundFresh = true; break; } - else if (onStack[childIndex]) + else if (nodes[childIndex].onStack) { - lowlink[index] = std::min(lowlink[index], childIndex); + nodes[index].lowlink = std::min(nodes[index].lowlink, childIndex); } visitEdge(childIndex, index); @@ -411,14 +392,14 @@ TarjanResult Tarjan::loop() if (foundFresh) continue; - if (lowlink[index] == index) + if (nodes[index].lowlink == index) { visitSCC(index); while (!stack.empty()) { int popped = stack.back(); stack.pop_back(); - onStack[popped] = false; + nodes[popped].onStack = false; if (popped == index) break; } @@ -435,7 +416,7 @@ TarjanResult Tarjan::loop() edgesTy.resize(parentEndEdge); edgesTp.resize(parentEndEdge); - lowlink[parentIndex] = std::min(lowlink[parentIndex], lowlink[index]); + nodes[parentIndex].lowlink = std::min(nodes[parentIndex].lowlink, nodes[index].lowlink); visitEdge(index, parentIndex); } } @@ -469,54 +450,67 @@ TarjanResult Tarjan::visitRoot(TypePackId tp) return loop(); } -void FindDirty::clearTarjan() +void Tarjan::clearTarjan(const TxnLog* log) { - dirty.clear(); + if (FFlag::LuauReusableSubstitutions) + { + typeToIndex.clear(~0u); + packToIndex.clear(~0u); + } + else + { + typeToIndex.clear(); + packToIndex.clear(); + } - typeToIndex.clear(); - packToIndex.clear(); - indexToType.clear(); - indexToPack.clear(); + nodes.clear(); stack.clear(); - onStack.clear(); - lowlink.clear(); + + if (FFlag::LuauReusableSubstitutions) + { + childCount = 0; + // childLimit setting stays the same + + this->log = log; + } edgesTy.clear(); edgesTp.clear(); worklist.clear(); } -bool FindDirty::getDirty(int index) +bool Tarjan::getDirty(int index) { - if (dirty.size() <= size_t(index)) - dirty.resize(index + 1, false); - return dirty[index]; + LUAU_ASSERT(size_t(index) < nodes.size()); + return nodes[index].dirty; } -void FindDirty::setDirty(int index, bool d) +void Tarjan::setDirty(int index, bool d) { - if (dirty.size() <= size_t(index)) - dirty.resize(index + 1, false); - dirty[index] = d; + LUAU_ASSERT(size_t(index) < nodes.size()); + nodes[index].dirty = d; } -void FindDirty::visitEdge(int index, int parentIndex) +void Tarjan::visitEdge(int index, int parentIndex) { if (getDirty(index)) setDirty(parentIndex, true); } -void FindDirty::visitSCC(int index) +void Tarjan::visitSCC(int index) { bool d = getDirty(index); for (auto it = stack.rbegin(); !d && it != stack.rend(); it++) { - if (TypeId ty = indexToType[*it]) + TarjanNode& node = nodes[*it]; + + if (TypeId ty = node.ty) d = isDirty(ty); - else if (TypePackId tp = indexToPack[*it]) + else if (TypePackId tp = node.tp) d = isDirty(tp); + if (*it == index) break; } @@ -527,32 +521,52 @@ void FindDirty::visitSCC(int index) for (auto it = stack.rbegin(); it != stack.rend(); it++) { setDirty(*it, true); - if (TypeId ty = indexToType[*it]) + + TarjanNode& node = nodes[*it]; + + if (TypeId ty = node.ty) foundDirty(ty); - else if (TypePackId tp = indexToPack[*it]) + else if (TypePackId tp = node.tp) foundDirty(tp); + if (*it == index) return; } } -TarjanResult FindDirty::findDirty(TypeId ty) +TarjanResult Tarjan::findDirty(TypeId ty) { return visitRoot(ty); } -TarjanResult FindDirty::findDirty(TypePackId tp) +TarjanResult Tarjan::findDirty(TypePackId tp) { return visitRoot(tp); } +Substitution::Substitution(const TxnLog* log_, TypeArena* arena) + : arena(arena) +{ + log = log_; + LUAU_ASSERT(log); +} + +void Substitution::dontTraverseInto(TypeId ty) +{ + noTraverseTypes.insert(ty); +} + +void Substitution::dontTraverseInto(TypePackId tp) +{ + noTraverseTypePacks.insert(tp); +} + std::optional Substitution::substitute(TypeId ty) { ty = log->follow(ty); // clear algorithm state for reentrancy - if (FFlag::LuauSubstitutionReentrant) - clearTarjan(); + clearTarjan(log); auto result = findDirty(ty); if (result != TarjanResult::Ok) @@ -560,34 +574,20 @@ std::optional Substitution::substitute(TypeId ty) for (auto [oldTy, newTy] : newTypes) { - if (FFlag::LuauSubstitutionReentrant) - { - if (!ignoreChildren(oldTy) && !replacedTypes.contains(newTy)) - { - replaceChildren(newTy); - replacedTypes.insert(newTy); - } - } - else + if (!ignoreChildren(oldTy) && !replacedTypes.contains(newTy)) { - if (!ignoreChildren(oldTy)) + if (!noTraverseTypes.contains(newTy)) replaceChildren(newTy); + replacedTypes.insert(newTy); } } for (auto [oldTp, newTp] : newPacks) { - if (FFlag::LuauSubstitutionReentrant) + if (!ignoreChildren(oldTp) && !replacedTypePacks.contains(newTp)) { - if (!ignoreChildren(oldTp) && !replacedTypePacks.contains(newTp)) - { - replaceChildren(newTp); - replacedTypePacks.insert(newTp); - } - } - else - { - if (!ignoreChildren(oldTp)) + if (!noTraverseTypePacks.contains(newTp)) replaceChildren(newTp); + replacedTypePacks.insert(newTp); } } TypeId newTy = replace(ty); @@ -599,8 +599,7 @@ std::optional Substitution::substitute(TypePackId tp) tp = log->follow(tp); // clear algorithm state for reentrancy - if (FFlag::LuauSubstitutionReentrant) - clearTarjan(); + clearTarjan(log); auto result = findDirty(tp); if (result != TarjanResult::Ok) @@ -608,43 +607,46 @@ std::optional Substitution::substitute(TypePackId tp) for (auto [oldTy, newTy] : newTypes) { - if (FFlag::LuauSubstitutionReentrant) - { - if (!ignoreChildren(oldTy) && !replacedTypes.contains(newTy)) - { - replaceChildren(newTy); - replacedTypes.insert(newTy); - } - } - else + if (!ignoreChildren(oldTy) && !replacedTypes.contains(newTy)) { - if (!ignoreChildren(oldTy)) + if (!noTraverseTypes.contains(newTy)) replaceChildren(newTy); + replacedTypes.insert(newTy); } } for (auto [oldTp, newTp] : newPacks) { - if (FFlag::LuauSubstitutionReentrant) - { - if (!ignoreChildren(oldTp) && !replacedTypePacks.contains(newTp)) - { - replaceChildren(newTp); - replacedTypePacks.insert(newTp); - } - } - else + if (!ignoreChildren(oldTp) && !replacedTypePacks.contains(newTp)) { - if (!ignoreChildren(oldTp)) + if (!noTraverseTypePacks.contains(newTp)) replaceChildren(newTp); + replacedTypePacks.insert(newTp); } } TypePackId newTp = replace(tp); return newTp; } +void Substitution::resetState(const TxnLog* log, TypeArena* arena) +{ + LUAU_ASSERT(FFlag::LuauReusableSubstitutions); + + clearTarjan(log); + + this->arena = arena; + + newTypes.clear(); + newPacks.clear(); + replacedTypes.clear(); + replacedTypePacks.clear(); + + noTraverseTypes.clear(); + noTraverseTypePacks.clear(); +} + TypeId Substitution::clone(TypeId ty) { - return shallowClone(ty, *arena, log, /* alwaysClone */ FFlag::LuauClonePublicInterfaceLess2); + return shallowClone(ty, *arena, log, /* alwaysClone */ true); } TypePackId Substitution::clone(TypePackId tp) @@ -665,23 +667,27 @@ TypePackId Substitution::clone(TypePackId tp) { VariadicTypePack clone; clone.ty = vtp->ty; - if (FFlag::LuauSubstitutionFixMissingFields) - clone.hidden = vtp->hidden; + clone.hidden = vtp->hidden; return addTypePack(std::move(clone)); } - else if (FFlag::LuauClonePublicInterfaceLess2) + else if (const TypeFunctionInstanceTypePack* tfitp = get(tp)) { - return addTypePack(*tp); + TypeFunctionInstanceTypePack clone{ + tfitp->function, std::vector(tfitp->typeArguments.size()), std::vector(tfitp->packArguments.size()) + }; + clone.typeArguments.assign(tfitp->typeArguments.begin(), tfitp->typeArguments.end()); + clone.packArguments.assign(tfitp->packArguments.begin(), tfitp->packArguments.end()); + return addTypePack(std::move(clone)); } else - return tp; + return addTypePack(*tp); } void Substitution::foundDirty(TypeId ty) { ty = log->follow(ty); - if (FFlag::LuauSubstitutionReentrant && newTypes.contains(ty)) + if (newTypes.contains(ty)) return; if (isDirty(ty)) @@ -694,7 +700,7 @@ void Substitution::foundDirty(TypePackId tp) { tp = log->follow(tp); - if (FFlag::LuauSubstitutionReentrant && newPacks.contains(tp)) + if (newPacks.contains(tp)) return; if (isDirty(tp)) @@ -735,13 +741,10 @@ void Substitution::replaceChildren(TypeId ty) if (FunctionType* ftv = getMutable(ty)) { - if (FFlag::LuauSubstitutionFixMissingFields) - { - for (TypeId& generic : ftv->generics) - generic = replace(generic); - for (TypePackId& genericPack : ftv->genericPacks) - genericPack = replace(genericPack); - } + for (TypeId& generic : ftv->generics) + generic = replace(generic); + for (TypePackId& genericPack : ftv->genericPacks) + genericPack = replace(genericPack); ftv->argTypes = replace(ftv->argTypes); ftv->retTypes = replace(ftv->retTypes); @@ -750,7 +753,18 @@ void Substitution::replaceChildren(TypeId ty) { LUAU_ASSERT(!ttv->boundTo); for (auto& [name, prop] : ttv->props) - prop.setType(replace(prop.type())); + { + if (FFlag::DebugLuauDeferredConstraintResolution) + { + if (prop.readTy) + prop.readTy = replace(prop.readTy); + if (prop.writeTy) + prop.writeTy = replace(prop.writeTy); + } + else + prop.setType(replace(prop.type())); + } + if (ttv->indexer) { ttv->indexer->indexType = replace(ttv->indexer->indexType); @@ -786,7 +800,15 @@ void Substitution::replaceChildren(TypeId ty) for (TypePackId& a : petv->packArguments) a = replace(a); } - else if (ClassType* ctv = getMutable(ty); FFlag::LuauClassTypeVarsInSubstitution && ctv) + else if (TypeFunctionInstanceType* tfit = getMutable(ty)) + { + for (TypeId& a : tfit->typeArguments) + a = replace(a); + + for (TypePackId& a : tfit->packArguments) + a = replace(a); + } + else if (ClassType* ctv = getMutable(ty)) { for (auto& [name, prop] : ctv->props) prop.setType(replace(prop.type())); @@ -796,6 +818,12 @@ void Substitution::replaceChildren(TypeId ty) if (ctv->metatable) ctv->metatable = replace(*ctv->metatable); + + if (ctv->indexer) + { + ctv->indexer->indexType = replace(ctv->indexer->indexType); + ctv->indexer->indexResultType = replace(ctv->indexer->indexResultType); + } } else if (NegationType* ntv = getMutable(ty)) { @@ -824,6 +852,14 @@ void Substitution::replaceChildren(TypePackId tp) { vtp->ty = replace(vtp->ty); } + else if (TypeFunctionInstanceTypePack* tfitp = getMutable(tp)) + { + for (TypeId& t : tfitp->typeArguments) + t = replace(t); + + for (TypePackId& t : tfitp->packArguments) + t = replace(t); + } } } // namespace Luau diff --git a/third_party/luau/Analysis/src/Subtyping.cpp b/third_party/luau/Analysis/src/Subtyping.cpp new file mode 100644 index 00000000..dc63851f --- /dev/null +++ b/third_party/luau/Analysis/src/Subtyping.cpp @@ -0,0 +1,1679 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/Subtyping.h" + +#include "Luau/Common.h" +#include "Luau/Error.h" +#include "Luau/Normalize.h" +#include "Luau/Scope.h" +#include "Luau/StringUtils.h" +#include "Luau/Substitution.h" +#include "Luau/ToString.h" +#include "Luau/TxnLog.h" +#include "Luau/Type.h" +#include "Luau/TypeArena.h" +#include "Luau/TypeCheckLimits.h" +#include "Luau/TypeFunction.h" +#include "Luau/TypePack.h" +#include "Luau/TypePath.h" +#include "Luau/TypeUtils.h" + +#include + +LUAU_FASTFLAGVARIABLE(DebugLuauSubtypingCheckPathValidity, false); + +namespace Luau +{ + +struct VarianceFlipper +{ + Subtyping::Variance* variance; + Subtyping::Variance oldValue; + + VarianceFlipper(Subtyping::Variance* v) + : variance(v) + , oldValue(*v) + { + switch (oldValue) + { + case Subtyping::Variance::Covariant: + *variance = Subtyping::Variance::Contravariant; + break; + case Subtyping::Variance::Contravariant: + *variance = Subtyping::Variance::Covariant; + break; + } + } + + ~VarianceFlipper() + { + *variance = oldValue; + } +}; + +bool SubtypingReasoning::operator==(const SubtypingReasoning& other) const +{ + return subPath == other.subPath && superPath == other.superPath && variance == other.variance; +} + +size_t SubtypingReasoningHash::operator()(const SubtypingReasoning& r) const +{ + return TypePath::PathHash()(r.subPath) ^ (TypePath::PathHash()(r.superPath) << 1) ^ (static_cast(r.variance) << 1); +} + +template +static void assertReasoningValid(TID subTy, TID superTy, const SubtypingResult& result, NotNull builtinTypes) +{ + if (!FFlag::DebugLuauSubtypingCheckPathValidity) + return; + + for (const SubtypingReasoning& reasoning : result.reasoning) + { + LUAU_ASSERT(traverse(subTy, reasoning.subPath, builtinTypes)); + LUAU_ASSERT(traverse(superTy, reasoning.superPath, builtinTypes)); + } +} + +template<> +void assertReasoningValid(TableIndexer subIdx, TableIndexer superIdx, const SubtypingResult& result, NotNull builtinTypes) +{ + // Empty method to satisfy the compiler. +} + +static SubtypingReasonings mergeReasonings(const SubtypingReasonings& a, const SubtypingReasonings& b) +{ + SubtypingReasonings result{kEmptyReasoning}; + + for (const SubtypingReasoning& r : a) + { + if (r.variance == SubtypingVariance::Invariant) + result.insert(r); + else if (r.variance == SubtypingVariance::Covariant || r.variance == SubtypingVariance::Contravariant) + { + SubtypingReasoning inverseReasoning = SubtypingReasoning{ + r.subPath, r.superPath, r.variance == SubtypingVariance::Covariant ? SubtypingVariance::Contravariant : SubtypingVariance::Covariant + }; + if (b.contains(inverseReasoning)) + result.insert(SubtypingReasoning{r.subPath, r.superPath, SubtypingVariance::Invariant}); + else + result.insert(r); + } + } + + for (const SubtypingReasoning& r : b) + { + if (r.variance == SubtypingVariance::Invariant) + result.insert(r); + else if (r.variance == SubtypingVariance::Covariant || r.variance == SubtypingVariance::Contravariant) + { + SubtypingReasoning inverseReasoning = SubtypingReasoning{ + r.subPath, r.superPath, r.variance == SubtypingVariance::Covariant ? SubtypingVariance::Contravariant : SubtypingVariance::Covariant + }; + if (a.contains(inverseReasoning)) + result.insert(SubtypingReasoning{r.subPath, r.superPath, SubtypingVariance::Invariant}); + else + result.insert(r); + } + } + + return result; +} + +SubtypingResult& SubtypingResult::andAlso(const SubtypingResult& other) +{ + // If the other result is not a subtype, we want to join all of its + // reasonings to this one. If this result already has reasonings of its own, + // those need to be attributed here whenever this _also_ failed. + if (!other.isSubtype) + reasoning = isSubtype ? std::move(other.reasoning) : mergeReasonings(reasoning, other.reasoning); + + isSubtype &= other.isSubtype; + normalizationTooComplex |= other.normalizationTooComplex; + isCacheable &= other.isCacheable; + errors.insert(errors.end(), other.errors.begin(), other.errors.end()); + + return *this; +} + +SubtypingResult& SubtypingResult::orElse(const SubtypingResult& other) +{ + // If this result is a subtype, we do not join the reasoning lists. If this + // result is not a subtype, but the other is a subtype, we want to _clear_ + // our reasoning list. If both results are not subtypes, we join the + // reasoning lists. + if (!isSubtype) + { + if (other.isSubtype) + reasoning.clear(); + else + reasoning = mergeReasonings(reasoning, other.reasoning); + } + + isSubtype |= other.isSubtype; + normalizationTooComplex |= other.normalizationTooComplex; + isCacheable &= other.isCacheable; + errors.insert(errors.end(), other.errors.begin(), other.errors.end()); + + return *this; +} + +SubtypingResult& SubtypingResult::withBothComponent(TypePath::Component component) +{ + return withSubComponent(component).withSuperComponent(component); +} + +SubtypingResult& SubtypingResult::withSubComponent(TypePath::Component component) +{ + if (reasoning.empty()) + reasoning.insert(SubtypingReasoning{Path(component), TypePath::kEmpty}); + else + { + for (auto& r : reasoning) + r.subPath = r.subPath.push_front(component); + } + + return *this; +} + +SubtypingResult& SubtypingResult::withSuperComponent(TypePath::Component component) +{ + if (reasoning.empty()) + reasoning.insert(SubtypingReasoning{TypePath::kEmpty, Path(component)}); + else + { + for (auto& r : reasoning) + r.superPath = r.superPath.push_front(component); + } + + return *this; +} + +SubtypingResult& SubtypingResult::withBothPath(TypePath::Path path) +{ + return withSubPath(path).withSuperPath(path); +} + +SubtypingResult& SubtypingResult::withSubPath(TypePath::Path path) +{ + if (reasoning.empty()) + reasoning.insert(SubtypingReasoning{path, TypePath::kEmpty}); + else + { + for (auto& r : reasoning) + r.subPath = path.append(r.subPath); + } + + return *this; +} + +SubtypingResult& SubtypingResult::withSuperPath(TypePath::Path path) +{ + if (reasoning.empty()) + reasoning.insert(SubtypingReasoning{TypePath::kEmpty, path}); + else + { + for (auto& r : reasoning) + r.superPath = path.append(r.superPath); + } + + return *this; +} + +SubtypingResult& SubtypingResult::withErrors(ErrorVec& err) +{ + for (TypeError& e : err) + errors.emplace_back(e); + return *this; +} + +SubtypingResult& SubtypingResult::withError(TypeError err) +{ + errors.push_back(std::move(err)); + return *this; +} + +SubtypingResult SubtypingResult::negate(const SubtypingResult& result) +{ + return SubtypingResult{ + !result.isSubtype, + result.normalizationTooComplex, + }; +} + +SubtypingResult SubtypingResult::all(const std::vector& results) +{ + SubtypingResult acc{true}; + for (const SubtypingResult& current : results) + acc.andAlso(current); + return acc; +} + +SubtypingResult SubtypingResult::any(const std::vector& results) +{ + SubtypingResult acc{false}; + for (const SubtypingResult& current : results) + acc.orElse(current); + return acc; +} + +struct ApplyMappedGenerics : Substitution +{ + using MappedGenerics = DenseHashMap; + using MappedGenericPacks = DenseHashMap; + + NotNull builtinTypes; + NotNull arena; + + MappedGenerics& mappedGenerics; + MappedGenericPacks& mappedGenericPacks; + + + ApplyMappedGenerics( + NotNull builtinTypes, + NotNull arena, + MappedGenerics& mappedGenerics, + MappedGenericPacks& mappedGenericPacks + ) + : Substitution(TxnLog::empty(), arena) + , builtinTypes(builtinTypes) + , arena(arena) + , mappedGenerics(mappedGenerics) + , mappedGenericPacks(mappedGenericPacks) + { + } + + bool isDirty(TypeId ty) override + { + return mappedGenerics.contains(ty); + } + + bool isDirty(TypePackId tp) override + { + return mappedGenericPacks.contains(tp); + } + + TypeId clean(TypeId ty) override + { + const auto& bounds = mappedGenerics[ty]; + + if (bounds.upperBound.empty()) + return builtinTypes->unknownType; + + if (bounds.upperBound.size() == 1) + return *begin(bounds.upperBound); + + return arena->addType(IntersectionType{std::vector(begin(bounds.upperBound), end(bounds.upperBound))}); + } + + TypePackId clean(TypePackId tp) override + { + return mappedGenericPacks[tp]; + } + + bool ignoreChildren(TypeId ty) override + { + if (get(ty)) + return true; + + return ty->persistent; + } + bool ignoreChildren(TypePackId ty) override + { + return ty->persistent; + } +}; + +std::optional SubtypingEnvironment::applyMappedGenerics(NotNull builtinTypes, NotNull arena, TypeId ty) +{ + ApplyMappedGenerics amg{builtinTypes, arena, mappedGenerics, mappedGenericPacks}; + return amg.substitute(ty); +} + +Subtyping::Subtyping( + NotNull builtinTypes, + NotNull typeArena, + NotNull normalizer, + NotNull iceReporter, + NotNull scope +) + : builtinTypes(builtinTypes) + , arena(typeArena) + , normalizer(normalizer) + , iceReporter(iceReporter) + , scope(scope) +{ +} + +SubtypingResult Subtyping::isSubtype(TypeId subTy, TypeId superTy) +{ + SubtypingEnvironment env; + + SubtypingResult result = isCovariantWith(env, subTy, superTy); + + for (const auto& [subTy, bounds] : env.mappedGenerics) + { + const auto& lb = bounds.lowerBound; + const auto& ub = bounds.upperBound; + + TypeId lowerBound = makeAggregateType(lb, builtinTypes->neverType); + TypeId upperBound = makeAggregateType(ub, builtinTypes->unknownType); + + std::shared_ptr nt = normalizer->normalize(upperBound); + // we say that the result is true if normalization failed because complex types are likely to be inhabited. + NormalizationResult res = nt ? normalizer->isInhabited(nt.get()) : NormalizationResult::True; + + if (!nt || res == NormalizationResult::HitLimits) + result.normalizationTooComplex = true; + else if (res == NormalizationResult::False) + { + /* If the normalized upper bound we're mapping to a generic is + * uninhabited, then we must consider the subtyping relation not to + * hold. + * + * This happens eg in () -> (T, T) <: () -> (string, number) + * + * T appears in covariant position and would have to be both string + * and number at once. + * + * No actual value is both a string and a number, so the test fails. + * + * TODO: We'll need to add explanitory context here. + */ + result.isSubtype = false; + } + + SubtypingResult boundsResult = isCovariantWith(env, lowerBound, upperBound); + boundsResult.reasoning.clear(); + + result.andAlso(boundsResult); + } + + /* TODO: We presently don't store subtype test results in the persistent + * cache if the left-side type is a generic function. + * + * The implementation would be a bit tricky and we haven't seen any material + * impact on benchmarks. + * + * What we would want to do is to remember points within the type where + * mapped generics are introduced. When all the contingent generics are + * introduced at which we're doing the test, we can mark the result as + * cacheable. + */ + + if (result.isCacheable) + resultCache[{subTy, superTy}] = result; + + return result; +} + +SubtypingResult Subtyping::isSubtype(TypePackId subTp, TypePackId superTp) +{ + SubtypingEnvironment env; + return isCovariantWith(env, subTp, superTp); +} + +SubtypingResult Subtyping::cache(SubtypingEnvironment& env, SubtypingResult result, TypeId subTy, TypeId superTy) +{ + const std::pair p{subTy, superTy}; + if (result.isCacheable) + resultCache[p] = result; + else + env.ephemeralCache[p] = result; + + return result; +} + +namespace +{ +struct SeenSetPopper +{ + Subtyping::SeenSet* seenTypes; + std::pair pair; + + SeenSetPopper(Subtyping::SeenSet* seenTypes, std::pair pair) + : seenTypes(seenTypes) + , pair(pair) + { + } + + ~SeenSetPopper() + { + seenTypes->erase(pair); + } +}; +} // namespace + +SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, TypeId subTy, TypeId superTy) +{ + subTy = follow(subTy); + superTy = follow(superTy); + + if (TypeId* subIt = env.substitutions.find(subTy); subIt && *subIt) + subTy = *subIt; + + if (TypeId* superIt = env.substitutions.find(superTy); superIt && *superIt) + superTy = *superIt; + + SubtypingResult* cachedResult = resultCache.find({subTy, superTy}); + if (cachedResult) + return *cachedResult; + + cachedResult = env.ephemeralCache.find({subTy, superTy}); + if (cachedResult) + return *cachedResult; + + // TODO: Do we care about returning a proof that this is error-suppressing? + // e.g. given `a | error <: a | error` where both operands are pointer equal, + // then should it also carry the information that it's error-suppressing? + // If it should, then `error <: error` should also do the same. + if (subTy == superTy) + return {true}; + + std::pair typePair{subTy, superTy}; + if (!seenTypes.insert(typePair)) + { + /* TODO: Caching results for recursive types is really tricky to think + * about. + * + * We'd like to cache at the outermost level where we encounter the + * recursive type, but we do not want to cache interior results that + * involve the cycle. + * + * Presently, we stop at cycles and assume that the subtype check will + * succeed because we'll eventually get there if it won't. However, if + * that cyclic type turns out not to have the asked-for subtyping + * relation, then all the intermediate cached results that were + * contingent on that assumption need to be evicted from the cache, or + * not entered into the cache, or something. + * + * For now, we do the conservative thing and refuse to cache anything + * that touches a cycle. + */ + SubtypingResult res; + res.isSubtype = true; + res.isCacheable = false; + return res; + } + + SeenSetPopper ssp{&seenTypes, typePair}; + + // Within the scope to which a generic belongs, that generic should be + // tested as though it were its upper bounds. We do not yet support bounded + // generics, so the upper bound is always unknown. + if (auto subGeneric = get(subTy); subGeneric && subsumes(subGeneric->scope, scope)) + return isCovariantWith(env, builtinTypes->neverType, superTy); + if (auto superGeneric = get(superTy); superGeneric && subsumes(superGeneric->scope, scope)) + return isCovariantWith(env, subTy, builtinTypes->unknownType); + + SubtypingResult result; + + if (auto subUnion = get(subTy)) + result = isCovariantWith(env, subUnion, superTy); + else if (auto superUnion = get(superTy)) + { + result = isCovariantWith(env, subTy, superUnion); + if (!result.isSubtype && !result.normalizationTooComplex) + { + SubtypingResult semantic = isCovariantWith(env, normalizer->normalize(subTy), normalizer->normalize(superTy)); + if (semantic.isSubtype) + { + semantic.reasoning.clear(); + result = semantic; + } + } + } + else if (auto superIntersection = get(superTy)) + result = isCovariantWith(env, subTy, superIntersection); + else if (auto subIntersection = get(subTy)) + { + result = isCovariantWith(env, subIntersection, superTy); + if (!result.isSubtype && !result.normalizationTooComplex) + { + SubtypingResult semantic = isCovariantWith(env, normalizer->normalize(subTy), normalizer->normalize(superTy)); + if (semantic.isSubtype) + { + // Clear the semantic reasoning, as any reasonings within + // potentially contain invalid paths. + semantic.reasoning.clear(); + result = semantic; + } + } + } + else if (get(superTy)) + result = {true}; + + // We have added this as an exception - the set of inhabitants of any is exactly the set of inhabitants of unknown (since error has no + // inhabitants). any = err | unknown, so under semantic subtyping, {} U unknown = unknown + else if (get(subTy) && get(superTy)) + result = {true}; + else if (get(subTy)) + { + // any = unknown | error, so we rewrite this to match. + // As per TAPL: A | B <: T iff A <: T && B <: T + result = isCovariantWith(env, builtinTypes->unknownType, superTy).andAlso(isCovariantWith(env, builtinTypes->errorType, superTy)); + } + else if (get(superTy)) + { + LUAU_ASSERT(!get(subTy)); // TODO: replace with ice. + LUAU_ASSERT(!get(subTy)); // TODO: replace with ice. + LUAU_ASSERT(!get(subTy)); // TODO: replace with ice. + + bool errorSuppressing = get(subTy); + result = {!errorSuppressing}; + } + else if (get(subTy)) + result = {true}; + else if (get(superTy)) + result = {false}; + else if (get(subTy)) + result = {false}; + else if (auto p = get2(subTy, superTy)) + result = isCovariantWith(env, p.first->ty, p.second->ty).withBothComponent(TypePath::TypeField::Negated); + else if (auto subNegation = get(subTy)) + { + result = isCovariantWith(env, subNegation, superTy); + if (!result.isSubtype && !result.normalizationTooComplex) + { + SubtypingResult semantic = isCovariantWith(env, normalizer->normalize(subTy), normalizer->normalize(superTy)); + if (semantic.isSubtype) + { + semantic.reasoning.clear(); + result = semantic; + } + } + } + else if (auto superNegation = get(superTy)) + { + result = isCovariantWith(env, subTy, superNegation); + if (!result.isSubtype && !result.normalizationTooComplex) + { + SubtypingResult semantic = isCovariantWith(env, normalizer->normalize(subTy), normalizer->normalize(superTy)); + if (semantic.isSubtype) + { + semantic.reasoning.clear(); + result = semantic; + } + } + } + else if (auto subTypeFunctionInstance = get(subTy)) + { + if (auto substSubTy = env.applyMappedGenerics(builtinTypes, arena, subTy)) + subTypeFunctionInstance = get(*substSubTy); + + result = isCovariantWith(env, subTypeFunctionInstance, superTy); + } + else if (auto superTypeFunctionInstance = get(superTy)) + { + if (auto substSuperTy = env.applyMappedGenerics(builtinTypes, arena, superTy)) + superTypeFunctionInstance = get(*substSuperTy); + + result = isCovariantWith(env, subTy, superTypeFunctionInstance); + } + else if (auto subGeneric = get(subTy); subGeneric && variance == Variance::Covariant) + { + bool ok = bindGeneric(env, subTy, superTy); + result.isSubtype = ok; + result.isCacheable = false; + } + else if (auto superGeneric = get(superTy); superGeneric && variance == Variance::Contravariant) + { + bool ok = bindGeneric(env, subTy, superTy); + result.isSubtype = ok; + result.isCacheable = false; + } + else if (auto p = get2(subTy, superTy)) + result = isCovariantWith(env, p); + else if (auto p = get2(subTy, superTy)) + result = isCovariantWith(env, p); + else if (auto p = get2(subTy, superTy)) + result = isCovariantWith(env, p); + else if (auto p = get2(subTy, superTy)) + result = isCovariantWith(env, p); + else if (auto p = get2(subTy, superTy)) + result = isCovariantWith(env, p); + else if (auto p = get2(subTy, superTy)) + result = isCovariantWith(env, p); + else if (auto p = get2(subTy, superTy)) + result = isCovariantWith(env, p); + else if (auto p = get2(subTy, superTy)) + result = isCovariantWith(env, p); + else if (auto p = get2(subTy, superTy)) + result = isCovariantWith(env, subTy, p.first, superTy, p.second); + else if (auto p = get2(subTy, superTy)) + result = isCovariantWith(env, p); + else if (auto p = get2(subTy, superTy)) + result = isCovariantWith(env, p); + + assertReasoningValid(subTy, superTy, result, builtinTypes); + + return cache(env, result, subTy, superTy); +} + +SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, TypePackId subTp, TypePackId superTp) +{ + subTp = follow(subTp); + superTp = follow(superTp); + + auto [subHead, subTail] = flatten(subTp); + auto [superHead, superTail] = flatten(superTp); + + const size_t headSize = std::min(subHead.size(), superHead.size()); + + std::vector results; + results.reserve(std::max(subHead.size(), superHead.size()) + 1); + + if (subTp == superTp) + return {true}; + + // Match head types pairwise + + for (size_t i = 0; i < headSize; ++i) + results.push_back(isCovariantWith(env, subHead[i], superHead[i]).withBothComponent(TypePath::Index{i})); + + // Handle mismatched head sizes + + if (subHead.size() < superHead.size()) + { + if (subTail) + { + if (auto vt = get(*subTail)) + { + for (size_t i = headSize; i < superHead.size(); ++i) + results.push_back(isCovariantWith(env, vt->ty, superHead[i]) + .withSubPath(TypePath::PathBuilder().tail().variadic().build()) + .withSuperComponent(TypePath::Index{i})); + } + else if (auto gt = get(*subTail)) + { + if (variance == Variance::Covariant) + { + // For any non-generic type T: + // + // (X) -> () <: (T) -> () + + // Possible optimization: If headSize == 0 then we can just use subTp as-is. + std::vector headSlice(begin(superHead), begin(superHead) + headSize); + TypePackId superTailPack = arena->addTypePack(std::move(headSlice), superTail); + + if (TypePackId* other = env.mappedGenericPacks.find(*subTail)) + // TODO: TypePath can't express "slice of a pack + its tail". + results.push_back(isCovariantWith(env, *other, superTailPack).withSubComponent(TypePath::PackField::Tail)); + else + env.mappedGenericPacks.try_insert(*subTail, superTailPack); + + // FIXME? Not a fan of the early return here. It makes the + // control flow harder to reason about. + return SubtypingResult::all(results); + } + else + { + // For any non-generic type T: + // + // (T) -> () (X) -> () + // + return SubtypingResult{false}.withSubComponent(TypePath::PackField::Tail); + } + } + else if (get(*subTail)) + return SubtypingResult{true}.withSubComponent(TypePath::PackField::Tail); + else + return SubtypingResult{false} + .withSubComponent(TypePath::PackField::Tail) + .withError({scope->location, UnexpectedTypePackInSubtyping{*subTail}}); + } + else + { + results.push_back({false}); + return SubtypingResult::all(results); + } + } + else if (subHead.size() > superHead.size()) + { + if (superTail) + { + if (auto vt = get(*superTail)) + { + for (size_t i = headSize; i < subHead.size(); ++i) + results.push_back(isCovariantWith(env, subHead[i], vt->ty) + .withSubComponent(TypePath::Index{i}) + .withSuperPath(TypePath::PathBuilder().tail().variadic().build())); + } + else if (auto gt = get(*superTail)) + { + if (variance == Variance::Contravariant) + { + // For any non-generic type T: + // + // (X...) -> () <: (T) -> () + + // Possible optimization: If headSize == 0 then we can just use subTp as-is. + std::vector headSlice(begin(subHead), begin(subHead) + headSize); + TypePackId subTailPack = arena->addTypePack(std::move(headSlice), subTail); + + if (TypePackId* other = env.mappedGenericPacks.find(*superTail)) + // TODO: TypePath can't express "slice of a pack + its tail". + results.push_back(isContravariantWith(env, subTailPack, *other).withSuperComponent(TypePath::PackField::Tail)); + else + env.mappedGenericPacks.try_insert(*superTail, subTailPack); + + // FIXME? Not a fan of the early return here. It makes the + // control flow harder to reason about. + return SubtypingResult::all(results); + } + else + { + // For any non-generic type T: + // + // () -> T () -> X... + return SubtypingResult{false}.withSuperComponent(TypePath::PackField::Tail); + } + } + else if (get(*superTail)) + return SubtypingResult{true}.withSuperComponent(TypePath::PackField::Tail); + else + return SubtypingResult{false} + .withSuperComponent(TypePath::PackField::Tail) + .withError({scope->location, UnexpectedTypePackInSubtyping{*subTail}}); + } + else + return {false}; + } + + // Handle tails + + if (subTail && superTail) + { + if (auto p = get2(*subTail, *superTail)) + { + // Variadic component is added by the isCovariantWith + // implementation; no need to add it here. + results.push_back(isCovariantWith(env, p).withBothComponent(TypePath::PackField::Tail)); + } + else if (auto p = get2(*subTail, *superTail)) + { + bool ok = bindGeneric(env, *subTail, *superTail); + results.push_back(SubtypingResult{ok}.withBothComponent(TypePath::PackField::Tail)); + } + else if (auto p = get2(*subTail, *superTail)) + { + if (variance == Variance::Contravariant) + { + // (A...) -> number <: (...number) -> number + bool ok = bindGeneric(env, *subTail, *superTail); + results.push_back(SubtypingResult{ok}.withBothComponent(TypePath::PackField::Tail)); + } + else + { + // (number) -> ...number (number) -> A... + results.push_back(SubtypingResult{false}.withBothComponent(TypePath::PackField::Tail)); + } + } + else if (auto p = get2(*subTail, *superTail)) + { + if (TypeId t = follow(p.second->ty); get(t) || get(t)) + { + // Extra magic rule: + // T... <: ...any + // T... <: ...unknown + // + // See https://github.com/luau-lang/luau/issues/767 + } + else if (variance == Variance::Contravariant) + { + // (...number) -> number (A...) -> number + results.push_back(SubtypingResult{false}.withBothComponent(TypePath::PackField::Tail)); + } + else + { + // () -> A... <: () -> ...number + bool ok = bindGeneric(env, *subTail, *superTail); + results.push_back(SubtypingResult{ok}.withBothComponent(TypePath::PackField::Tail)); + } + } + else if (get(*subTail) || get(*superTail)) + // error type is fine on either side + results.push_back(SubtypingResult{true}.withBothComponent(TypePath::PackField::Tail)); + else + return SubtypingResult{false} + .withBothComponent(TypePath::PackField::Tail) + .withError({scope->location, UnexpectedTypePackInSubtyping{*subTail}}) + .withError({scope->location, UnexpectedTypePackInSubtyping{*superTail}}); + } + else if (subTail) + { + if (get(*subTail)) + { + return SubtypingResult{false}.withSubComponent(TypePath::PackField::Tail); + } + else if (get(*subTail)) + { + bool ok = bindGeneric(env, *subTail, builtinTypes->emptyTypePack); + return SubtypingResult{ok}.withSubComponent(TypePath::PackField::Tail); + } + else + return SubtypingResult{false} + .withSubComponent(TypePath::PackField::Tail) + .withError({scope->location, UnexpectedTypePackInSubtyping{*subTail}}); + } + else if (superTail) + { + if (get(*superTail)) + { + /* + * A variadic type pack ...T can be thought of as an infinite union of finite type packs. + * () | (T) | (T, T) | (T, T, T) | ... + * + * And, per TAPL: + * T <: A | B iff T <: A or T <: B + * + * All variadic type packs are therefore supertypes of the empty type pack. + */ + } + else if (get(*superTail)) + { + if (variance == Variance::Contravariant) + { + bool ok = bindGeneric(env, builtinTypes->emptyTypePack, *superTail); + results.push_back(SubtypingResult{ok}.withSuperComponent(TypePath::PackField::Tail)); + } + else + results.push_back(SubtypingResult{false}.withSuperComponent(TypePath::PackField::Tail)); + } + else + return SubtypingResult{false} + .withSuperComponent(TypePath::PackField::Tail) + .withError({scope->location, UnexpectedTypePackInSubtyping{*superTail}}); + } + + SubtypingResult result = SubtypingResult::all(results); + assertReasoningValid(subTp, superTp, result, builtinTypes); + + return result; +} + +template +SubtypingResult Subtyping::isContravariantWith(SubtypingEnvironment& env, SubTy&& subTy, SuperTy&& superTy) +{ + VarianceFlipper vf{&variance}; + + SubtypingResult result = isCovariantWith(env, superTy, subTy); + if (result.reasoning.empty()) + result.reasoning.insert(SubtypingReasoning{TypePath::kEmpty, TypePath::kEmpty, SubtypingVariance::Contravariant}); + else + { + // If we don't swap the paths here, we will end up producing an invalid path + // whenever we involve contravariance. We'll end up appending path + // components that should belong to the supertype to the subtype, and vice + // versa. + for (auto& reasoning : result.reasoning) + { + std::swap(reasoning.subPath, reasoning.superPath); + + // Also swap covariant/contravariant, since those are also the other way + // around. + if (reasoning.variance == SubtypingVariance::Covariant) + reasoning.variance = SubtypingVariance::Contravariant; + else if (reasoning.variance == SubtypingVariance::Contravariant) + reasoning.variance = SubtypingVariance::Covariant; + } + } + + assertReasoningValid(subTy, superTy, result, builtinTypes); + + return result; +} + +template +SubtypingResult Subtyping::isInvariantWith(SubtypingEnvironment& env, SubTy&& subTy, SuperTy&& superTy) +{ + SubtypingResult result = isCovariantWith(env, subTy, superTy).andAlso(isContravariantWith(env, subTy, superTy)); + + if (result.reasoning.empty()) + result.reasoning.insert(SubtypingReasoning{TypePath::kEmpty, TypePath::kEmpty, SubtypingVariance::Invariant}); + else + { + for (auto& reasoning : result.reasoning) + reasoning.variance = SubtypingVariance::Invariant; + } + + assertReasoningValid(subTy, superTy, result, builtinTypes); + return result; +} + +template +SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const TryPair& pair) +{ + return isCovariantWith(env, pair.first, pair.second); +} + +template +SubtypingResult Subtyping::isContravariantWith(SubtypingEnvironment& env, const TryPair& pair) +{ + return isContravariantWith(env, pair.first, pair.second); +} + +template +SubtypingResult Subtyping::isInvariantWith(SubtypingEnvironment& env, const TryPair& pair) +{ + return isInvariantWith(env, pair.first, pair.second); +} + +/* + * This is much simpler than the Unifier implementation because we don't + * actually care about potential "cross-talk" between union parts that match the + * left side. + * + * In fact, we're very limited in what we can do: If multiple choices match, but + * all of them have non-overlapping constraints, then we're stuck with an "or" + * conjunction of constraints. Solving this in the general case is quite + * difficult. + * + * For example, we cannot dispatch anything from this constraint: + * + * {x: number, y: string} <: {x: number, y: 'a} | {x: 'b, y: string} + * + * From this constraint, we can know that either string <: 'a or number <: 'b, + * but we don't know which! + * + * However: + * + * {x: number, y: string} <: {x: number, y: 'a} | {x: number, y: string} + * + * We can dispatch this constraint because there is no 'or' conjunction. One of + * the arms requires 0 matches. + * + * {x: number, y: string, z: boolean} | {x: number, y: 'a, z: 'b} | {x: number, + * y: string, z: 'b} + * + * Here, we have two matches. One asks for string ~ 'a and boolean ~ 'b. The + * other just asks for boolean ~ 'b. We can dispatch this and only commit + * boolean ~ 'b. This constraint does not teach us anything about 'a. + */ +SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, TypeId subTy, const UnionType* superUnion) +{ + // As per TAPL: T <: A | B iff T <: A || T <: B + + for (TypeId ty : superUnion) + { + SubtypingResult next = isCovariantWith(env, subTy, ty); + if (next.isSubtype) + return SubtypingResult{true}; + } + + /* + * TODO: Is it possible here to use the context produced by the above + * isCovariantWith() calls to produce a richer, more helpful result in the + * case that the subtyping relation does not hold? + */ + return SubtypingResult{false}; +} + +SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const UnionType* subUnion, TypeId superTy) +{ + // As per TAPL: A | B <: T iff A <: T && B <: T + std::vector subtypings; + size_t i = 0; + for (TypeId ty : subUnion) + subtypings.push_back(isCovariantWith(env, ty, superTy).withSubComponent(TypePath::Index{i++})); + return SubtypingResult::all(subtypings); +} + +SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, TypeId subTy, const IntersectionType* superIntersection) +{ + // As per TAPL: T <: A & B iff T <: A && T <: B + std::vector subtypings; + size_t i = 0; + for (TypeId ty : superIntersection) + subtypings.push_back(isCovariantWith(env, subTy, ty).withSuperComponent(TypePath::Index{i++})); + return SubtypingResult::all(subtypings); +} + +SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const IntersectionType* subIntersection, TypeId superTy) +{ + // As per TAPL: A & B <: T iff A <: T || B <: T + std::vector subtypings; + size_t i = 0; + for (TypeId ty : subIntersection) + subtypings.push_back(isCovariantWith(env, ty, superTy).withSubComponent(TypePath::Index{i++})); + return SubtypingResult::any(subtypings); +} + +SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const NegationType* subNegation, TypeId superTy) +{ + TypeId negatedTy = follow(subNegation->ty); + + SubtypingResult result; + + // In order to follow a consistent codepath, rather than folding the + // isCovariantWith test down to its conclusion here, we test the subtyping test + // of the result of negating the type for never, unknown, any, and error. + if (is(negatedTy)) + { + // ¬never ~ unknown + result = isCovariantWith(env, builtinTypes->unknownType, superTy).withSubComponent(TypePath::TypeField::Negated); + } + else if (is(negatedTy)) + { + // ¬unknown ~ never + result = isCovariantWith(env, builtinTypes->neverType, superTy).withSubComponent(TypePath::TypeField::Negated); + } + else if (is(negatedTy)) + { + // ¬any ~ any + result = isCovariantWith(env, negatedTy, superTy).withSubComponent(TypePath::TypeField::Negated); + } + else if (auto u = get(negatedTy)) + { + // ¬(A ∪ B) ~ ¬A ∩ ¬B + // follow intersection rules: A & B <: T iff A <: T && B <: T + std::vector subtypings; + + for (TypeId ty : u) + { + if (auto negatedPart = get(follow(ty))) + subtypings.push_back(isCovariantWith(env, negatedPart->ty, superTy).withSubComponent(TypePath::TypeField::Negated)); + else + { + NegationType negatedTmp{ty}; + subtypings.push_back(isCovariantWith(env, &negatedTmp, superTy)); + } + } + + result = SubtypingResult::all(subtypings); + } + else if (auto i = get(negatedTy)) + { + // ¬(A ∩ B) ~ ¬A ∪ ¬B + // follow union rules: A | B <: T iff A <: T || B <: T + std::vector subtypings; + + for (TypeId ty : i) + { + if (auto negatedPart = get(follow(ty))) + subtypings.push_back(isCovariantWith(env, negatedPart->ty, superTy).withSubComponent(TypePath::TypeField::Negated)); + else + { + NegationType negatedTmp{ty}; + subtypings.push_back(isCovariantWith(env, &negatedTmp, superTy)); + } + } + + result = SubtypingResult::any(subtypings); + } + else if (is(negatedTy)) + { + iceReporter->ice("attempting to negate a non-testable type"); + } + // negating a different subtype will get you a very wide type that's not a + // subtype of other stuff. + else + { + result = SubtypingResult{false}.withSubComponent(TypePath::TypeField::Negated); + } + + return result; +} + +SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const TypeId subTy, const NegationType* superNegation) +{ + TypeId negatedTy = follow(superNegation->ty); + + SubtypingResult result; + + if (is(negatedTy)) + { + // ¬never ~ unknown + result = isCovariantWith(env, subTy, builtinTypes->unknownType); + } + else if (is(negatedTy)) + { + // ¬unknown ~ never + result = isCovariantWith(env, subTy, builtinTypes->neverType); + } + else if (is(negatedTy)) + { + // ¬any ~ any + result = isSubtype(subTy, negatedTy); + } + else if (auto u = get(negatedTy)) + { + // ¬(A ∪ B) ~ ¬A ∩ ¬B + // follow intersection rules: A & B <: T iff A <: T && B <: T + std::vector subtypings; + + for (TypeId ty : u) + { + if (auto negatedPart = get(follow(ty))) + subtypings.push_back(isCovariantWith(env, subTy, negatedPart->ty)); + else + { + NegationType negatedTmp{ty}; + subtypings.push_back(isCovariantWith(env, subTy, &negatedTmp)); + } + } + + return SubtypingResult::all(subtypings); + } + else if (auto i = get(negatedTy)) + { + // ¬(A ∩ B) ~ ¬A ∪ ¬B + // follow union rules: A | B <: T iff A <: T || B <: T + std::vector subtypings; + + for (TypeId ty : i) + { + if (auto negatedPart = get(follow(ty))) + subtypings.push_back(isCovariantWith(env, subTy, negatedPart->ty)); + else + { + NegationType negatedTmp{ty}; + subtypings.push_back(isCovariantWith(env, subTy, &negatedTmp)); + } + } + + return SubtypingResult::any(subtypings); + } + else if (auto p = get2(subTy, negatedTy)) + { + // number <: ¬boolean + // number type != p.second->type}; + } + else if (auto p = get2(subTy, negatedTy)) + { + // "foo" (p.first) && p.second->type == PrimitiveType::String) + result = {false}; + // false (p.first) && p.second->type == PrimitiveType::Boolean) + result = {false}; + // other cases are true + else + result = {true}; + } + else if (auto p = get2(subTy, negatedTy)) + { + if (p.first->type == PrimitiveType::String && get(p.second)) + result = {false}; + else if (p.first->type == PrimitiveType::Boolean && get(p.second)) + result = {false}; + else + result = {true}; + } + // the top class type is not actually a primitive type, so the negation of + // any one of them includes the top class type. + else if (auto p = get2(subTy, negatedTy)) + result = {true}; + else if (auto p = get(negatedTy); p && is(subTy)) + result = {p->type != PrimitiveType::Table}; + else if (auto p = get2(subTy, negatedTy)) + result = {p.second->type != PrimitiveType::Function}; + else if (auto p = get2(subTy, negatedTy)) + result = {*p.first != *p.second}; + else if (auto p = get2(subTy, negatedTy)) + result = SubtypingResult::negate(isCovariantWith(env, p.first, p.second)); + else if (get2(subTy, negatedTy)) + result = {true}; + else if (is(negatedTy)) + iceReporter->ice("attempting to negate a non-testable type"); + else + result = {false}; + + return result.withSuperComponent(TypePath::TypeField::Negated); +} + +SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const PrimitiveType* subPrim, const PrimitiveType* superPrim) +{ + return {subPrim->type == superPrim->type}; +} + +SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const SingletonType* subSingleton, const PrimitiveType* superPrim) +{ + if (get(subSingleton) && superPrim->type == PrimitiveType::String) + return {true}; + else if (get(subSingleton) && superPrim->type == PrimitiveType::Boolean) + return {true}; + else + return {false}; +} + +SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const SingletonType* subSingleton, const SingletonType* superSingleton) +{ + return {*subSingleton == *superSingleton}; +} + +SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const TableType* subTable, const TableType* superTable) +{ + SubtypingResult result{true}; + + if (subTable->props.empty() && !subTable->indexer && superTable->indexer) + return {false}; + + for (const auto& [name, superProp] : superTable->props) + { + std::vector results; + if (auto subIter = subTable->props.find(name); subIter != subTable->props.end()) + results.push_back(isCovariantWith(env, subIter->second, superProp, name)); + else if (subTable->indexer) + { + if (isCovariantWith(env, builtinTypes->stringType, subTable->indexer->indexType).isSubtype) + { + if (superProp.isShared()) + results.push_back(isInvariantWith(env, subTable->indexer->indexResultType, superProp.type()) + .withSubComponent(TypePath::TypeField::IndexResult) + .withSuperComponent(TypePath::Property::read(name))); + else + { + if (superProp.readTy) + results.push_back(isCovariantWith(env, subTable->indexer->indexResultType, *superProp.readTy) + .withSubComponent(TypePath::TypeField::IndexResult) + .withSuperComponent(TypePath::Property::read(name))); + if (superProp.writeTy) + results.push_back(isContravariantWith(env, subTable->indexer->indexResultType, *superProp.writeTy) + .withSubComponent(TypePath::TypeField::IndexResult) + .withSuperComponent(TypePath::Property::write(name))); + } + } + } + + if (results.empty()) + return SubtypingResult{false}; + + result.andAlso(SubtypingResult::all(results)); + } + + if (superTable->indexer) + { + if (subTable->indexer) + result.andAlso(isInvariantWith(env, *subTable->indexer, *superTable->indexer)); + else + return {false}; + } + + return result; +} + +SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const MetatableType* subMt, const MetatableType* superMt) +{ + return isCovariantWith(env, subMt->table, superMt->table) + .andAlso(isCovariantWith(env, subMt->metatable, superMt->metatable).withBothComponent(TypePath::TypeField::Metatable)); +} + +SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const MetatableType* subMt, const TableType* superTable) +{ + if (auto subTable = get(follow(subMt->table))) + { + // Metatables cannot erase properties from the table they're attached to, so + // the subtyping rule for this is just if the table component is a subtype + // of the supertype table. + // + // There's a flaw here in that if the __index metamethod contributes a new + // field that would satisfy the subtyping relationship, we'll erronously say + // that the metatable isn't a subtype of the table, even though they have + // compatible properties/shapes. We'll revisit this later when we have a + // better understanding of how important this is. + return isCovariantWith(env, subTable, superTable); + } + else + { + // TODO: This may be a case we actually hit? + return {false}; + } +} + +SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const ClassType* subClass, const ClassType* superClass) +{ + return {isSubclass(subClass, superClass)}; +} + +SubtypingResult Subtyping::isCovariantWith( + SubtypingEnvironment& env, + TypeId subTy, + const ClassType* subClass, + TypeId superTy, + const TableType* superTable +) +{ + SubtypingResult result{true}; + + env.substitutions[superTy] = subTy; + + for (const auto& [name, prop] : superTable->props) + { + if (auto classProp = lookupClassProp(subClass, name)) + { + result.andAlso(isCovariantWith(env, *classProp, prop, name)); + } + else + { + result = {false}; + break; + } + } + + env.substitutions[superTy] = nullptr; + + return result; +} + +SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const FunctionType* subFunction, const FunctionType* superFunction) +{ + SubtypingResult result; + { + result.orElse(isContravariantWith(env, subFunction->argTypes, superFunction->argTypes).withBothComponent(TypePath::PackField::Arguments)); + } + + result.andAlso(isCovariantWith(env, subFunction->retTypes, superFunction->retTypes).withBothComponent(TypePath::PackField::Returns)); + + return result; +} + +SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const PrimitiveType* subPrim, const TableType* superTable) +{ + SubtypingResult result{false}; + if (subPrim->type == PrimitiveType::String) + { + if (auto metatable = getMetatable(builtinTypes->stringType, builtinTypes)) + { + if (auto mttv = get(follow(metatable))) + { + if (auto it = mttv->props.find("__index"); it != mttv->props.end()) + { + if (auto stringTable = get(it->second.type())) + result.orElse( + isCovariantWith(env, stringTable, superTable).withSubPath(TypePath::PathBuilder().mt().readProp("__index").build()) + ); + } + } + } + } + + return result; +} + +SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const SingletonType* subSingleton, const TableType* superTable) +{ + SubtypingResult result{false}; + if (auto stringleton = get(subSingleton)) + { + if (auto metatable = getMetatable(builtinTypes->stringType, builtinTypes)) + { + if (auto mttv = get(follow(metatable))) + { + if (auto it = mttv->props.find("__index"); it != mttv->props.end()) + { + if (auto stringTable = get(it->second.type())) + result.orElse( + isCovariantWith(env, stringTable, superTable).withSubPath(TypePath::PathBuilder().mt().readProp("__index").build()) + ); + } + } + } + } + return result; +} + +SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const TableIndexer& subIndexer, const TableIndexer& superIndexer) +{ + return isInvariantWith(env, subIndexer.indexType, superIndexer.indexType) + .withBothComponent(TypePath::TypeField::IndexLookup) + .andAlso(isInvariantWith(env, subIndexer.indexResultType, superIndexer.indexResultType).withBothComponent(TypePath::TypeField::IndexResult)); +} + +SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const Property& subProp, const Property& superProp, const std::string& name) +{ + SubtypingResult res{true}; + + if (superProp.isShared() && subProp.isShared()) + res.andAlso(isInvariantWith(env, subProp.type(), superProp.type()).withBothComponent(TypePath::Property::read(name))); + else + { + if (superProp.readTy.has_value() && subProp.readTy.has_value()) + res.andAlso(isCovariantWith(env, *subProp.readTy, *superProp.readTy).withBothComponent(TypePath::Property::read(name))); + if (superProp.writeTy.has_value() && subProp.writeTy.has_value()) + res.andAlso(isContravariantWith(env, *subProp.writeTy, *superProp.writeTy).withBothComponent(TypePath::Property::write(name))); + + if (superProp.isReadWrite()) + { + if (subProp.isReadOnly()) + res.andAlso(SubtypingResult{false}.withBothComponent(TypePath::Property::read(name))); + else if (subProp.isWriteOnly()) + res.andAlso(SubtypingResult{false}.withBothComponent(TypePath::Property::write(name))); + } + } + + return res; +} + +SubtypingResult Subtyping::isCovariantWith( + SubtypingEnvironment& env, + const std::shared_ptr& subNorm, + const std::shared_ptr& superNorm +) +{ + if (!subNorm || !superNorm) + return {false, true}; + + SubtypingResult result = isCovariantWith(env, subNorm->tops, superNorm->tops); + result.andAlso(isCovariantWith(env, subNorm->booleans, superNorm->booleans)); + result.andAlso(isCovariantWith(env, subNorm->classes, superNorm->classes).orElse(isCovariantWith(env, subNorm->classes, superNorm->tables))); + result.andAlso(isCovariantWith(env, subNorm->errors, superNorm->errors)); + result.andAlso(isCovariantWith(env, subNorm->nils, superNorm->nils)); + result.andAlso(isCovariantWith(env, subNorm->numbers, superNorm->numbers)); + result.andAlso(isCovariantWith(env, subNorm->strings, superNorm->strings)); + result.andAlso(isCovariantWith(env, subNorm->strings, superNorm->tables)); + result.andAlso(isCovariantWith(env, subNorm->threads, superNorm->threads)); + result.andAlso(isCovariantWith(env, subNorm->buffers, superNorm->buffers)); + result.andAlso(isCovariantWith(env, subNorm->tables, superNorm->tables)); + result.andAlso(isCovariantWith(env, subNorm->functions, superNorm->functions)); + // isCovariantWith(subNorm->tyvars, superNorm->tyvars); + return result; +} + +SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const NormalizedClassType& subClass, const NormalizedClassType& superClass) +{ + for (const auto& [subClassTy, _] : subClass.classes) + { + SubtypingResult result; + + for (const auto& [superClassTy, superNegations] : superClass.classes) + { + result.orElse(isCovariantWith(env, subClassTy, superClassTy)); + if (!result.isSubtype) + continue; + + for (TypeId negation : superNegations) + { + result.andAlso(SubtypingResult::negate(isCovariantWith(env, subClassTy, negation))); + if (result.isSubtype) + break; + } + } + + if (!result.isSubtype) + return result; + } + + return {true}; +} + +SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const NormalizedClassType& subClass, const TypeIds& superTables) +{ + for (const auto& [subClassTy, _] : subClass.classes) + { + SubtypingResult result; + + for (TypeId superTableTy : superTables) + result.orElse(isCovariantWith(env, subClassTy, superTableTy)); + + if (!result.isSubtype) + return result; + } + + return {true}; +} + +SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const NormalizedStringType& subString, const NormalizedStringType& superString) +{ + bool isSubtype = Luau::isSubtype(subString, superString); + return {isSubtype}; +} + +SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const NormalizedStringType& subString, const TypeIds& superTables) +{ + if (subString.isNever()) + return {true}; + + if (subString.isCofinite) + { + SubtypingResult result; + for (const auto& superTable : superTables) + { + result.orElse(isCovariantWith(env, builtinTypes->stringType, superTable)); + if (result.isSubtype) + return result; + } + return result; + } + + // Finite case + // S = s1 | s2 | s3 ... sn <: t1 | t2 | ... | tn + // iff for some ti, S <: ti + // iff for all sj, sj <: ti + for (const auto& superTable : superTables) + { + SubtypingResult result{true}; + for (const auto& [_, subString] : subString.singletons) + { + result.andAlso(isCovariantWith(env, subString, superTable)); + if (!result.isSubtype) + break; + } + + if (!result.isSubtype) + continue; + else + return result; + } + + return {false}; +} + +SubtypingResult Subtyping::isCovariantWith( + SubtypingEnvironment& env, + const NormalizedFunctionType& subFunction, + const NormalizedFunctionType& superFunction +) +{ + if (subFunction.isNever()) + return {true}; + else if (superFunction.isTop) + return {true}; + else + return isCovariantWith(env, subFunction.parts, superFunction.parts); +} + +SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const TypeIds& subTypes, const TypeIds& superTypes) +{ + std::vector results; + + for (TypeId subTy : subTypes) + { + results.emplace_back(); + for (TypeId superTy : superTypes) + results.back().orElse(isCovariantWith(env, subTy, superTy)); + } + + return SubtypingResult::all(results); +} + +SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const VariadicTypePack* subVariadic, const VariadicTypePack* superVariadic) +{ + return isCovariantWith(env, subVariadic->ty, superVariadic->ty).withBothComponent(TypePath::TypeField::Variadic); +} + +bool Subtyping::bindGeneric(SubtypingEnvironment& env, TypeId subTy, TypeId superTy) +{ + if (variance == Variance::Covariant) + { + if (!get(subTy)) + return false; + + env.mappedGenerics[subTy].upperBound.insert(superTy); + } + else + { + if (!get(superTy)) + return false; + + env.mappedGenerics[superTy].lowerBound.insert(subTy); + } + + return true; +} + +SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const TypeFunctionInstanceType* subFunctionInstance, const TypeId superTy) +{ + // Reduce the type function instance + auto [ty, errors] = handleTypeFunctionReductionResult(subFunctionInstance); + + // If we return optional, that means the type function was irreducible - we can reduce that to never + return isCovariantWith(env, ty, superTy).withErrors(errors).withSubComponent(TypePath::Reduction{ty}); +} + +SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const TypeId subTy, const TypeFunctionInstanceType* superFunctionInstance) +{ + // Reduce the type function instance + auto [ty, errors] = handleTypeFunctionReductionResult(superFunctionInstance); + return isCovariantWith(env, subTy, ty).withErrors(errors).withSuperComponent(TypePath::Reduction{ty}); +} + +/* + * If, when performing a subtyping test, we encounter a generic on the left + * side, it is permissible to tentatively bind that generic to the right side + * type. + */ +bool Subtyping::bindGeneric(SubtypingEnvironment& env, TypePackId subTp, TypePackId superTp) +{ + if (variance == Variance::Contravariant) + std::swap(superTp, subTp); + + if (!get(subTp)) + return false; + + if (TypePackId* m = env.mappedGenericPacks.find(subTp)) + return *m == superTp; + + env.mappedGenericPacks[subTp] = superTp; + + return true; +} + +template +TypeId Subtyping::makeAggregateType(const Container& container, TypeId orElse) +{ + if (container.empty()) + return orElse; + else if (container.size() == 1) + return *begin(container); + else + return arena->addType(T{std::vector(begin(container), end(container))}); +} + +std::pair Subtyping::handleTypeFunctionReductionResult(const TypeFunctionInstanceType* functionInstance) +{ + TypeFunctionContext context{arena, builtinTypes, scope, normalizer, iceReporter, NotNull{&limits}}; + TypeId function = arena->addType(*functionInstance); + FunctionGraphReductionResult result = reduceTypeFunctions(function, {}, context, true); + ErrorVec errors; + if (result.blockedTypes.size() != 0 || result.blockedPacks.size() != 0) + { + errors.push_back(TypeError{{}, UninhabitedTypeFunction{function}}); + return {builtinTypes->neverType, errors}; + } + if (result.reducedTypes.contains(function)) + return {function, errors}; + return {builtinTypes->neverType, errors}; +} + +} // namespace Luau diff --git a/third_party/luau/Analysis/src/Symbol.cpp b/third_party/luau/Analysis/src/Symbol.cpp index 5922bb50..4b808f19 100644 --- a/third_party/luau/Analysis/src/Symbol.cpp +++ b/third_party/luau/Analysis/src/Symbol.cpp @@ -3,9 +3,23 @@ #include "Luau/Common.h" +LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) + namespace Luau { +bool Symbol::operator==(const Symbol& rhs) const +{ + if (local) + return local == rhs.local; + else if (global.value) + return rhs.global.value && global == rhs.global.value; // Subtlety: AstName::operator==(const char*) uses strcmp, not pointer identity. + else if (FFlag::DebugLuauDeferredConstraintResolution) + return !rhs.local && !rhs.global.value; // Reflexivity: we already know `this` Symbol is empty, so check that rhs is. + else + return false; +} + std::string toString(const Symbol& name) { if (name.local) diff --git a/third_party/luau/Analysis/src/TableLiteralInference.cpp b/third_party/luau/Analysis/src/TableLiteralInference.cpp new file mode 100644 index 00000000..630cb441 --- /dev/null +++ b/third_party/luau/Analysis/src/TableLiteralInference.cpp @@ -0,0 +1,455 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/Ast.h" +#include "Luau/Normalize.h" +#include "Luau/Simplify.h" +#include "Luau/Type.h" +#include "Luau/ToString.h" +#include "Luau/TypeArena.h" +#include "Luau/Unifier2.h" + +namespace Luau +{ + +static bool isLiteral(const AstExpr* expr) +{ + return ( + expr->is() || expr->is() || expr->is() || expr->is() || + expr->is() || expr->is() + ); +} + +// A fast approximation of subTy <: superTy +static bool fastIsSubtype(TypeId subTy, TypeId superTy) +{ + Relation r = relate(superTy, subTy); + return r == Relation::Coincident || r == Relation::Superset; +} + +static bool isRecord(const AstExprTable::Item& item) +{ + if (item.kind == AstExprTable::Item::Record) + return true; + else if (item.kind == AstExprTable::Item::General && item.key->is()) + return true; + else + return false; +} + +static std::optional extractMatchingTableType(std::vector& tables, TypeId exprType, NotNull builtinTypes) +{ + if (tables.empty()) + return std::nullopt; + + const TableType* exprTable = get(follow(exprType)); + if (!exprTable) + return std::nullopt; + + size_t tableCount = 0; + std::optional firstTable; + + for (TypeId ty : tables) + { + ty = follow(ty); + if (auto tt = get(ty)) + { + // If the expected table has a key whose type is a string or boolean + // singleton and the corresponding exprType property does not match, + // then skip this table. + + if (!firstTable) + firstTable = ty; + ++tableCount; + + for (const auto& [name, expectedProp] : tt->props) + { + if (!expectedProp.readTy) + continue; + + const TypeId expectedType = follow(*expectedProp.readTy); + + auto st = get(expectedType); + if (!st) + continue; + + auto it = exprTable->props.find(name); + if (it == exprTable->props.end()) + continue; + + const auto& [_name, exprProp] = *it; + + if (!exprProp.readTy) + continue; + + const TypeId propType = follow(*exprProp.readTy); + + const FreeType* ft = get(propType); + + if (ft && get(ft->lowerBound)) + { + if (fastIsSubtype(builtinTypes->booleanType, ft->upperBound) && fastIsSubtype(expectedType, builtinTypes->booleanType)) + { + return ty; + } + + if (fastIsSubtype(builtinTypes->stringType, ft->upperBound) && fastIsSubtype(expectedType, ft->lowerBound)) + { + return ty; + } + } + } + } + } + + if (tableCount == 1) + { + LUAU_ASSERT(firstTable); + return firstTable; + } + + return std::nullopt; +} + +TypeId matchLiteralType( + NotNull> astTypes, + NotNull> astExpectedTypes, + NotNull builtinTypes, + NotNull arena, + NotNull unifier, + TypeId expectedType, + TypeId exprType, + const AstExpr* expr, + std::vector& toBlock +) +{ + /* + * Table types that arise from literal table expressions have some + * properties that make this algorithm much simpler. + * + * Most importantly, the parts of the type that arise directly from the + * table expression are guaranteed to be acyclic. This means we can do all + * kinds of naive depth first traversal shenanigans and not worry about + * nasty details like aliasing or reentrancy. + * + * We are therefore completely free to mutate these portions of the + * TableType however we choose! We'll take advantage of this property to do + * things like replace explicit named properties with indexers as required + * by the expected type. + */ + if (!isLiteral(expr)) + return exprType; + + expectedType = follow(expectedType); + exprType = follow(exprType); + + if (get(expectedType) || get(expectedType)) + { + // "Narrowing" to unknown or any is not going to do anything useful. + return exprType; + } + + if (expr->is()) + { + auto ft = get(exprType); + if (ft && get(ft->lowerBound) && fastIsSubtype(builtinTypes->stringType, ft->upperBound) && + fastIsSubtype(ft->lowerBound, builtinTypes->stringType)) + { + // if the upper bound is a subtype of the expected type, we can push the expected type in + Relation upperBoundRelation = relate(ft->upperBound, expectedType); + if (upperBoundRelation == Relation::Subset || upperBoundRelation == Relation::Coincident) + { + emplaceType(asMutable(exprType), expectedType); + return exprType; + } + + // likewise, if the lower bound is a subtype, we can force the expected type in + // if this is the case and the previous relation failed, it means that the primitive type + // constraint was going to have to select the lower bound for this type anyway. + Relation lowerBoundRelation = relate(ft->lowerBound, expectedType); + if (lowerBoundRelation == Relation::Subset || lowerBoundRelation == Relation::Coincident) + { + emplaceType(asMutable(exprType), expectedType); + return exprType; + } + } + } + else if (expr->is()) + { + auto ft = get(exprType); + if (ft && get(ft->lowerBound) && fastIsSubtype(builtinTypes->booleanType, ft->upperBound) && + fastIsSubtype(ft->lowerBound, builtinTypes->booleanType)) + { + // if the upper bound is a subtype of the expected type, we can push the expected type in + Relation upperBoundRelation = relate(ft->upperBound, expectedType); + if (upperBoundRelation == Relation::Subset || upperBoundRelation == Relation::Coincident) + { + emplaceType(asMutable(exprType), expectedType); + return exprType; + } + + // likewise, if the lower bound is a subtype, we can force the expected type in + // if this is the case and the previous relation failed, it means that the primitive type + // constraint was going to have to select the lower bound for this type anyway. + Relation lowerBoundRelation = relate(ft->lowerBound, expectedType); + if (lowerBoundRelation == Relation::Subset || lowerBoundRelation == Relation::Coincident) + { + emplaceType(asMutable(exprType), expectedType); + return exprType; + } + } + } + + if (expr->is() || expr->is() || expr->is() || expr->is()) + { + if (auto ft = get(exprType); ft && fastIsSubtype(ft->upperBound, expectedType)) + { + emplaceType(asMutable(exprType), expectedType); + return exprType; + } + + Relation r = relate(exprType, expectedType); + if (r == Relation::Coincident || r == Relation::Subset) + return expectedType; + + return exprType; + } + + // TODO: lambdas + + if (auto exprTable = expr->as()) + { + TableType* const tableTy = getMutable(exprType); + LUAU_ASSERT(tableTy); + + const TableType* expectedTableTy = get(expectedType); + + if (!expectedTableTy) + { + if (auto utv = get(expectedType)) + { + std::vector parts{begin(utv), end(utv)}; + + std::optional tt = extractMatchingTableType(parts, exprType, builtinTypes); + + if (tt) + { + TypeId res = matchLiteralType(astTypes, astExpectedTypes, builtinTypes, arena, unifier, *tt, exprType, expr, toBlock); + + parts.push_back(res); + return arena->addType(UnionType{std::move(parts)}); + } + } + + return exprType; + } + + for (const AstExprTable::Item& item : exprTable->items) + { + if (isRecord(item)) + { + const AstArray& s = item.key->as()->value; + std::string keyStr{s.data, s.data + s.size}; + auto it = tableTy->props.find(keyStr); + LUAU_ASSERT(it != tableTy->props.end()); + + Property& prop = it->second; + + // Table literals always initially result in shared read-write types + LUAU_ASSERT(prop.isShared()); + TypeId propTy = *prop.readTy; + + auto it2 = expectedTableTy->props.find(keyStr); + + if (it2 == expectedTableTy->props.end()) + { + // expectedType may instead have an indexer. This is + // kind of interesting because it means we clip the prop + // from the exprType and fold it into the indexer. + if (expectedTableTy->indexer && isString(expectedTableTy->indexer->indexType)) + { + (*astExpectedTypes)[item.key] = expectedTableTy->indexer->indexType; + (*astExpectedTypes)[item.value] = expectedTableTy->indexer->indexResultType; + + TypeId matchedType = matchLiteralType( + astTypes, + astExpectedTypes, + builtinTypes, + arena, + unifier, + expectedTableTy->indexer->indexResultType, + propTy, + item.value, + toBlock + ); + + if (tableTy->indexer) + unifier->unify(matchedType, tableTy->indexer->indexResultType); + else + tableTy->indexer = TableIndexer{expectedTableTy->indexer->indexType, matchedType}; + + tableTy->props.erase(keyStr); + } + + // If it's just an extra property and the expected type + // has no indexer, there's no work to do here. + + continue; + } + + LUAU_ASSERT(it2 != expectedTableTy->props.end()); + + const Property& expectedProp = it2->second; + + std::optional expectedReadTy = expectedProp.readTy; + std::optional expectedWriteTy = expectedProp.writeTy; + + TypeId matchedType = nullptr; + + // Important optimization: If we traverse into the read and + // write types separately even when they are shared, we go + // quadratic in a hurry. + if (expectedProp.isShared()) + { + matchedType = + matchLiteralType(astTypes, astExpectedTypes, builtinTypes, arena, unifier, *expectedReadTy, propTy, item.value, toBlock); + prop.readTy = matchedType; + prop.writeTy = matchedType; + } + else if (expectedReadTy) + { + matchedType = + matchLiteralType(astTypes, astExpectedTypes, builtinTypes, arena, unifier, *expectedReadTy, propTy, item.value, toBlock); + prop.readTy = matchedType; + prop.writeTy.reset(); + } + else if (expectedWriteTy) + { + matchedType = + matchLiteralType(astTypes, astExpectedTypes, builtinTypes, arena, unifier, *expectedWriteTy, propTy, item.value, toBlock); + prop.readTy.reset(); + prop.writeTy = matchedType; + } + else + { + // Also important: It is presently the case that all + // table properties are either read-only, or have the + // same read and write types. + LUAU_ASSERT(!"Should be unreachable"); + } + + LUAU_ASSERT(prop.readTy || prop.writeTy); + + LUAU_ASSERT(matchedType); + + (*astExpectedTypes)[item.value] = matchedType; + } + else if (item.kind == AstExprTable::Item::List) + { + LUAU_ASSERT(tableTy->indexer); + + if (expectedTableTy->indexer) + { + const TypeId* propTy = astTypes->find(item.value); + LUAU_ASSERT(propTy); + + unifier->unify(expectedTableTy->indexer->indexType, builtinTypes->numberType); + TypeId matchedType = matchLiteralType( + astTypes, + astExpectedTypes, + builtinTypes, + arena, + unifier, + expectedTableTy->indexer->indexResultType, + *propTy, + item.value, + toBlock + ); + + // if the index result type is the prop type, we can replace it with the matched type here. + if (tableTy->indexer->indexResultType == *propTy) + tableTy->indexer->indexResultType = matchedType; + } + } + else if (item.kind == AstExprTable::Item::General) + { + + // We have { ..., [blocked] : somePropExpr, ...} + // If blocked resolves to a string, we will then take care of this above + // If it resolves to some other kind of expression, we don't have a way of folding this information into indexer + // because there is no named prop to remove + // We should just block here + const TypeId* keyTy = astTypes->find(item.key); + LUAU_ASSERT(keyTy); + TypeId tKey = follow(*keyTy); + if (get(tKey)) + toBlock.push_back(tKey); + + const TypeId* propTy = astTypes->find(item.value); + LUAU_ASSERT(propTy); + TypeId tProp = follow(*propTy); + if (get(tProp)) + toBlock.push_back(tProp); + } + else + LUAU_ASSERT(!"Unexpected"); + } + + // Keys that the expectedType says we should have, but that aren't + // specified by the AST fragment. + // + // If any such keys are options, then we'll add them to the expression + // type. + // + // We use std::optional here because the empty string is a + // perfectly reasonable value to insert into the set. We'll use + // std::nullopt as our sentinel value. + Set> missingKeys{{}}; + for (const auto& [name, _] : expectedTableTy->props) + missingKeys.insert(name); + + for (const AstExprTable::Item& item : exprTable->items) + { + if (item.key) + { + if (const auto str = item.key->as()) + { + missingKeys.erase(std::string(str->value.data, str->value.size)); + } + } + } + + for (const auto& key : missingKeys) + { + LUAU_ASSERT(key.has_value()); + + auto it = expectedTableTy->props.find(*key); + LUAU_ASSERT(it != expectedTableTy->props.end()); + + const Property& expectedProp = it->second; + + Property exprProp; + + if (expectedProp.readTy && isOptional(*expectedProp.readTy)) + exprProp.readTy = *expectedProp.readTy; + if (expectedProp.writeTy && isOptional(*expectedProp.writeTy)) + exprProp.writeTy = *expectedProp.writeTy; + + // If the property isn't actually optional, do nothing. + if (exprProp.readTy || exprProp.writeTy) + tableTy->props[*key] = std::move(exprProp); + } + + // If the expected table has an indexer, then the provided table can + // have one too. + // TODO: If the expected table also has an indexer, we might want to + // push the expected indexer's types into it. + if (expectedTableTy->indexer && !tableTy->indexer) + { + tableTy->indexer = expectedTableTy->indexer; + } + } + + return exprType; +} + +} // namespace Luau diff --git a/third_party/luau/Analysis/src/ToDot.cpp b/third_party/luau/Analysis/src/ToDot.cpp index 8d889cb5..aa2dc1e3 100644 --- a/third_party/luau/Analysis/src/ToDot.cpp +++ b/third_party/luau/Analysis/src/ToDot.cpp @@ -4,11 +4,14 @@ #include "Luau/ToString.h" #include "Luau/TypePack.h" #include "Luau/Type.h" +#include "Luau/TypeFunction.h" #include "Luau/StringUtils.h" #include #include +LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); + namespace Luau { @@ -52,7 +55,7 @@ bool StateDot::canDuplicatePrimitive(TypeId ty) if (get(ty)) return false; - return get(ty) || get(ty); + return get(ty) || get(ty) || get(ty) || get(ty); } void StateDot::visitChild(TypeId ty, int parentIndex, const char* linkName) @@ -76,6 +79,10 @@ void StateDot::visitChild(TypeId ty, int parentIndex, const char* linkName) formatAppend(result, "n%d [label=\"%s\"];\n", index, toString(ty).c_str()); else if (get(ty)) formatAppend(result, "n%d [label=\"any\"];\n", index); + else if (get(ty)) + formatAppend(result, "n%d [label=\"unknown\"];\n", index); + else if (get(ty)) + formatAppend(result, "n%d [label=\"never\"];\n", index); } else { @@ -139,153 +146,221 @@ void StateDot::visitChildren(TypeId ty, int index) startNode(index); startNodeLabel(); - if (const BoundType* btv = get(ty)) - { - formatAppend(result, "BoundType %d", index); - finishNodeLabel(ty); - finishNode(); - - visitChild(btv->boundTo, index); - } - else if (const FunctionType* ftv = get(ty)) + auto go = [&](auto&& t) { - formatAppend(result, "FunctionType %d", index); - finishNodeLabel(ty); - finishNode(); - - visitChild(ftv->argTypes, index, "arg"); - visitChild(ftv->retTypes, index, "ret"); - } - else if (const TableType* ttv = get(ty)) - { - if (ttv->name) - formatAppend(result, "TableType %s", ttv->name->c_str()); - else if (ttv->syntheticName) - formatAppend(result, "TableType %s", ttv->syntheticName->c_str()); - else - formatAppend(result, "TableType %d", index); - finishNodeLabel(ty); - finishNode(); + using T = std::decay_t; - if (ttv->boundTo) - return visitChild(*ttv->boundTo, index, "boundTo"); + if constexpr (std::is_same_v) + { + formatAppend(result, "BoundType %d", index); + finishNodeLabel(ty); + finishNode(); - for (const auto& [name, prop] : ttv->props) - visitChild(prop.type(), index, name.c_str()); - if (ttv->indexer) + visitChild(t.boundTo, index); + } + else if constexpr (std::is_same_v) { - visitChild(ttv->indexer->indexType, index, "[index]"); - visitChild(ttv->indexer->indexResultType, index, "[value]"); + formatAppend(result, "BlockedType %d", index); + finishNodeLabel(ty); + finishNode(); } - for (TypeId itp : ttv->instantiatedTypeParams) - visitChild(itp, index, "typeParam"); + else if constexpr (std::is_same_v) + { + formatAppend(result, "FunctionType %d", index); + finishNodeLabel(ty); + finishNode(); - for (TypePackId itp : ttv->instantiatedTypePackParams) - visitChild(itp, index, "typePackParam"); - } - else if (const MetatableType* mtv = get(ty)) - { - formatAppend(result, "MetatableType %d", index); - finishNodeLabel(ty); - finishNode(); + visitChild(t.argTypes, index, "arg"); + visitChild(t.retTypes, index, "ret"); + } + else if constexpr (std::is_same_v) + { + if (t.name) + formatAppend(result, "TableType %s", t.name->c_str()); + else if (t.syntheticName) + formatAppend(result, "TableType %s", t.syntheticName->c_str()); + else + formatAppend(result, "TableType %d", index); + finishNodeLabel(ty); + finishNode(); + + if (t.boundTo) + return visitChild(*t.boundTo, index, "boundTo"); + + for (const auto& [name, prop] : t.props) + visitChild(prop.type(), index, name.c_str()); + if (t.indexer) + { + visitChild(t.indexer->indexType, index, "[index]"); + visitChild(t.indexer->indexResultType, index, "[value]"); + } + for (TypeId itp : t.instantiatedTypeParams) + visitChild(itp, index, "typeParam"); + + for (TypePackId itp : t.instantiatedTypePackParams) + visitChild(itp, index, "typePackParam"); + } + else if constexpr (std::is_same_v) + { + formatAppend(result, "MetatableType %d", index); + finishNodeLabel(ty); + finishNode(); - visitChild(mtv->table, index, "table"); - visitChild(mtv->metatable, index, "metatable"); - } - else if (const UnionType* utv = get(ty)) - { - formatAppend(result, "UnionType %d", index); - finishNodeLabel(ty); - finishNode(); + visitChild(t.table, index, "table"); + visitChild(t.metatable, index, "metatable"); + } + else if constexpr (std::is_same_v) + { + formatAppend(result, "UnionType %d", index); + finishNodeLabel(ty); + finishNode(); - for (TypeId opt : utv->options) - visitChild(opt, index); - } - else if (const IntersectionType* itv = get(ty)) - { - formatAppend(result, "IntersectionType %d", index); - finishNodeLabel(ty); - finishNode(); + for (TypeId opt : t.options) + visitChild(opt, index); + } + else if constexpr (std::is_same_v) + { + formatAppend(result, "IntersectionType %d", index); + finishNodeLabel(ty); + finishNode(); - for (TypeId part : itv->parts) - visitChild(part, index); - } - else if (const GenericType* gtv = get(ty)) - { - if (gtv->explicitName) - formatAppend(result, "GenericType %s", gtv->name.c_str()); - else - formatAppend(result, "GenericType %d", index); - finishNodeLabel(ty); - finishNode(); - } - else if (const FreeType* ftv = get(ty)) - { - formatAppend(result, "FreeType %d", index); - finishNodeLabel(ty); - finishNode(); - } - else if (get(ty)) - { - formatAppend(result, "AnyType %d", index); - finishNodeLabel(ty); - finishNode(); - } - else if (get(ty)) - { - formatAppend(result, "PrimitiveType %s", toString(ty).c_str()); - finishNodeLabel(ty); - finishNode(); - } - else if (get(ty)) - { - formatAppend(result, "ErrorType %d", index); - finishNodeLabel(ty); - finishNode(); - } - else if (const ClassType* ctv = get(ty)) - { - formatAppend(result, "ClassType %s", ctv->name.c_str()); - finishNodeLabel(ty); - finishNode(); + for (TypeId part : t.parts) + visitChild(part, index); + } + else if constexpr (std::is_same_v) + { + formatAppend(result, "LazyType %d", index); + finishNodeLabel(ty); + finishNode(); + } + else if constexpr (std::is_same_v) + { + formatAppend(result, "PendingExpansionType %d", index); + finishNodeLabel(ty); + finishNode(); + } + else if constexpr (std::is_same_v) + { + if (t.explicitName) + formatAppend(result, "GenericType %s", t.name.c_str()); + else + formatAppend(result, "GenericType %d", index); + finishNodeLabel(ty); + finishNode(); + } + else if constexpr (std::is_same_v) + { + formatAppend(result, "FreeType %d", index); + finishNodeLabel(ty); + finishNode(); + + if (FFlag::DebugLuauDeferredConstraintResolution) + { + if (!get(t.lowerBound)) + visitChild(t.lowerBound, index, "[lowerBound]"); + + if (!get(t.upperBound)) + visitChild(t.upperBound, index, "[upperBound]"); + } + } + else if constexpr (std::is_same_v) + { + formatAppend(result, "AnyType %d", index); + finishNodeLabel(ty); + finishNode(); + } + else if constexpr (std::is_same_v) + { + formatAppend(result, "UnknownType %d", index); + finishNodeLabel(ty); + finishNode(); + } + else if constexpr (std::is_same_v) + { + formatAppend(result, "NeverType %d", index); + finishNodeLabel(ty); + finishNode(); + } + else if constexpr (std::is_same_v) + { + formatAppend(result, "PrimitiveType %s", toString(ty).c_str()); + finishNodeLabel(ty); + finishNode(); + } + else if constexpr (std::is_same_v) + { + formatAppend(result, "ErrorType %d", index); + finishNodeLabel(ty); + finishNode(); + } + else if constexpr (std::is_same_v) + { + formatAppend(result, "ClassType %s", t.name.c_str()); + finishNodeLabel(ty); + finishNode(); - for (const auto& [name, prop] : ctv->props) - visitChild(prop.type(), index, name.c_str()); + for (const auto& [name, prop] : t.props) + visitChild(prop.type(), index, name.c_str()); - if (ctv->parent) - visitChild(*ctv->parent, index, "[parent]"); + if (t.parent) + visitChild(*t.parent, index, "[parent]"); - if (ctv->metatable) - visitChild(*ctv->metatable, index, "[metatable]"); - } - else if (const SingletonType* stv = get(ty)) - { - std::string res; + if (t.metatable) + visitChild(*t.metatable, index, "[metatable]"); - if (const StringSingleton* ss = get(stv)) + if (t.indexer) + { + visitChild(t.indexer->indexType, index, "[index]"); + visitChild(t.indexer->indexResultType, index, "[value]"); + } + } + else if constexpr (std::is_same_v) { - // Don't put in quotes anywhere. If it's outside of the call to escape, - // then it's invalid syntax. If it's inside, then escaping is super noisy. - res = "string: " + escape(ss->value); + std::string res; + + if (const StringSingleton* ss = get(&t)) + { + // Don't put in quotes anywhere. If it's outside of the call to escape, + // then it's invalid syntax. If it's inside, then escaping is super noisy. + res = "string: " + escape(ss->value); + } + else if (const BooleanSingleton* bs = get(&t)) + { + res = "boolean: "; + res += bs->value ? "true" : "false"; + } + else + LUAU_ASSERT(!"unknown singleton type"); + + formatAppend(result, "SingletonType %s", res.c_str()); + finishNodeLabel(ty); + finishNode(); } - else if (const BooleanSingleton* bs = get(stv)) + else if constexpr (std::is_same_v) { - res = "boolean: "; - res += bs->value ? "true" : "false"; + formatAppend(result, "NegationType %d", index); + finishNodeLabel(ty); + finishNode(); + + visitChild(t.ty, index, "[negated]"); + } + else if constexpr (std::is_same_v) + { + formatAppend(result, "TypeFunctionInstanceType %s %d", t.function->name.c_str(), index); + finishNodeLabel(ty); + finishNode(); + + for (TypeId tyParam : t.typeArguments) + visitChild(tyParam, index); + + for (TypePackId tpParam : t.packArguments) + visitChild(tpParam, index); } else - LUAU_ASSERT(!"unknown singleton type"); + static_assert(always_false_v, "unknown type kind"); + }; - formatAppend(result, "SingletonType %s", res.c_str()); - finishNodeLabel(ty); - finishNode(); - } - else - { - LUAU_ASSERT(!"unknown type kind"); - finishNodeLabel(ty); - finishNode(); - } + visit(go, ty->ty); } void StateDot::visitChildren(TypePackId tp, int index) diff --git a/third_party/luau/Analysis/src/ToString.cpp b/third_party/luau/Analysis/src/ToString.cpp index ea3ab577..b5d8a98b 100644 --- a/third_party/luau/Analysis/src/ToString.cpp +++ b/third_party/luau/Analysis/src/ToString.cpp @@ -1,26 +1,43 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/ToString.h" +#include "Luau/Common.h" #include "Luau/Constraint.h" +#include "Luau/DenseHash.h" #include "Luau/Location.h" #include "Luau/Scope.h" +#include "Luau/Set.h" #include "Luau/TxnLog.h" #include "Luau/TypeInfer.h" #include "Luau/TypePack.h" #include "Luau/Type.h" +#include "Luau/TypeFunction.h" #include "Luau/VisitType.h" +#include "Luau/TypeOrPack.h" #include #include +#include LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) /* - * Prefix generic typenames with gen- - * Additionally, free types will be prefixed with free- and suffixed with their level. eg free-a-4 - * Fair warning: Setting this will break a lot of Luau unit tests. + * Enables increasing levels of verbosity for Luau type names when stringifying. + * After level 2, test cases will break unpredictably because a pointer to their + * scope will be included in the stringification of generic and free types. + * + * Supported values: + * + * 0: Disabled, no changes. + * + * 1: Prefix free/generic types with free- and gen-, respectively. Also reveal + * hidden variadic tails. Display block count for local types. + * + * 2: Suffix free/generic types with their scope depth. + * + * 3: Suffix free/generic types with their scope pointer, if present. */ -LUAU_FASTFLAGVARIABLE(DebugLuauVerboseTypeNames, false) +LUAU_FASTINTVARIABLE(DebugLuauVerboseTypeNames, 0) LUAU_FASTFLAGVARIABLE(DebugLuauToStringNoLexicalSort, false) namespace Luau @@ -36,8 +53,8 @@ struct FindCyclicTypes final : TypeVisitor FindCyclicTypes& operator=(const FindCyclicTypes&) = delete; bool exhaustive = false; - std::unordered_set visited; - std::unordered_set visitedPacks; + Luau::Set visited{{}}; + Luau::Set visitedPacks{{}}; std::set cycles; std::set cycleTPs; @@ -53,17 +70,39 @@ struct FindCyclicTypes final : TypeVisitor bool visit(TypeId ty) override { - return visited.insert(ty).second; + return visited.insert(ty); } bool visit(TypePackId tp) override { - return visitedPacks.insert(tp).second; + return visitedPacks.insert(tp); + } + + bool visit(TypeId ty, const FreeType& ft) override + { + if (!visited.insert(ty)) + return false; + + if (FFlag::DebugLuauDeferredConstraintResolution) + { + // TODO: Replace these if statements with assert()s when we + // delete FFlag::DebugLuauDeferredConstraintResolution. + // + // When the old solver is used, these pointers are always + // unused. When the new solver is used, they are never null. + + if (ft.lowerBound) + traverse(ft.lowerBound); + if (ft.upperBound) + traverse(ft.upperBound); + } + + return false; } bool visit(TypeId ty, const TableType& ttv) override { - if (!visited.insert(ty).second) + if (!visited.insert(ty)) return false; if (ttv.name || ttv.syntheticName) @@ -126,10 +165,12 @@ struct StringifierState ToStringOptions& opts; ToStringResult& result; - std::unordered_map cycleNames; - std::unordered_map cycleTpNames; - std::unordered_set seen; - std::unordered_set usedNames; + DenseHashMap cycleNames{{}}; + DenseHashMap cycleTpNames{{}}; + Set seen{{}}; + // `$$$` was chosen as the tombstone for `usedNames` since it is not a valid name syntactically and is relatively short for string comparison + // reasons. + DenseHashSet usedNames{"$$$"}; size_t indentation = 0; bool exhaustive; @@ -148,7 +189,7 @@ struct StringifierState bool hasSeen(const void* tv) { void* ttv = const_cast(tv); - if (seen.find(ttv) != seen.end()) + if (seen.contains(ttv)) return true; seen.insert(ttv); @@ -158,9 +199,9 @@ struct StringifierState void unsee(const void* tv) { void* ttv = const_cast(tv); - auto iter = seen.find(ttv); - if (iter != seen.end()) - seen.erase(iter); + + if (seen.contains(ttv)) + seen.erase(ttv); } std::string getName(TypeId ty) @@ -173,7 +214,7 @@ struct StringifierState for (int count = 0; count < 256; ++count) { std::string candidate = generateName(usedNames.size() + count); - if (!usedNames.count(candidate)) + if (!usedNames.contains(candidate)) { usedNames.insert(candidate); n = candidate; @@ -196,7 +237,7 @@ struct StringifierState for (int count = 0; count < 256; ++count) { std::string candidate = generateName(previousNameIndex + count); - if (!usedNames.count(candidate)) + if (!usedNames.contains(candidate)) { previousNameIndex += count; usedNames.insert(candidate); @@ -223,11 +264,15 @@ struct StringifierState ++count; emit(count); - emit("-"); - char buffer[16]; - uint32_t s = uint32_t(intptr_t(scope) & 0xFFFFFF); - snprintf(buffer, sizeof(buffer), "0x%x", s); - emit(buffer); + + if (FInt::DebugLuauVerboseTypeNames >= 3) + { + emit("-"); + char buffer[16]; + uint32_t s = uint32_t(intptr_t(scope) & 0xFFFFFF); + snprintf(buffer, sizeof(buffer), "0x%x", s); + emit(buffer); + } } void emit(TypeLevel level) @@ -305,18 +350,72 @@ struct TypeStringifier return; } - auto it = state.cycleNames.find(tv); - if (it != state.cycleNames.end()) + if (auto p = state.cycleNames.find(tv)) { - state.emit(it->second); + state.emit(*p); return; } Luau::visit( - [this, tv](auto&& t) { + [this, tv](auto&& t) + { return (*this)(tv, t); }, - tv->ty); + tv->ty + ); + } + + void emitKey(const std::string& name) + { + if (isIdentifier(name)) + state.emit(name); + else + { + state.emit("[\""); + state.emit(escape(name)); + state.emit("\"]"); + } + state.emit(": "); + } + + void _newStringify(const std::string& name, const Property& prop) + { + bool comma = false; + if (prop.isShared()) + { + emitKey(name); + stringify(prop.type()); + return; + } + + if (prop.readTy) + { + state.emit("read "); + emitKey(name); + stringify(*prop.readTy); + comma = true; + } + if (prop.writeTy) + { + if (comma) + { + state.emit(","); + state.newline(); + } + + state.emit("write "); + emitKey(name); + stringify(*prop.writeTy); + } + } + + void stringify(const std::string& name, const Property& prop) + { + if (FFlag::DebugLuauDeferredConstraintResolution) + return _newStringify(name, prop); + + emitKey(name); + stringify(prop.type()); } void stringify(TypePackId tp); @@ -371,11 +470,45 @@ struct TypeStringifier void operator()(TypeId ty, const FreeType& ftv) { state.result.invalid = true; - if (FFlag::DebugLuauVerboseTypeNames) + + // TODO: ftv.lowerBound and ftv.upperBound should always be non-nil when + // the new solver is used. This can be replaced with an assert. + if (FFlag::DebugLuauDeferredConstraintResolution && ftv.lowerBound && ftv.upperBound) + { + const TypeId lowerBound = follow(ftv.lowerBound); + const TypeId upperBound = follow(ftv.upperBound); + if (get(lowerBound) && get(upperBound)) + { + state.emit("'"); + state.emit(state.getName(ty)); + } + else + { + state.emit("("); + if (!get(lowerBound)) + { + stringify(lowerBound); + state.emit(" <: "); + } + state.emit("'"); + state.emit(state.getName(ty)); + + if (!get(upperBound)) + { + state.emit(" <: "); + stringify(upperBound); + } + state.emit(")"); + } + return; + } + + if (FInt::DebugLuauVerboseTypeNames >= 1) state.emit("free-"); + state.emit(state.getName(ty)); - if (FFlag::DebugLuauVerboseTypeNames) + if (FInt::DebugLuauVerboseTypeNames >= 2) { state.emit("-"); if (FFlag::DebugLuauDeferredConstraintResolution) @@ -392,6 +525,9 @@ struct TypeStringifier void operator()(TypeId ty, const GenericType& gtv) { + if (FInt::DebugLuauVerboseTypeNames >= 1) + state.emit("gen-"); + if (gtv.explicitName) { state.usedNames.insert(gtv.name); @@ -401,7 +537,7 @@ struct TypeStringifier else state.emit(state.getName(ty)); - if (FFlag::DebugLuauVerboseTypeNames) + if (FInt::DebugLuauVerboseTypeNames >= 2) { state.emit("-"); if (FFlag::DebugLuauDeferredConstraintResolution) @@ -444,6 +580,9 @@ struct TypeStringifier case PrimitiveType::Thread: state.emit("thread"); return; + case PrimitiveType::Buffer: + state.emit("buffer"); + return; case PrimitiveType::Function: state.emit("function"); return; @@ -504,6 +643,12 @@ struct TypeStringifier state.emit(">"); } + if (FFlag::DebugLuauDeferredConstraintResolution) + { + if (ftv.isCheckedFunction) + state.emit("@checked "); + } + state.emit("("); if (state.opts.functionTypeArguments) @@ -581,16 +726,33 @@ struct TypeStringifier std::string openbrace = "@@@"; std::string closedbrace = "@@@?!"; - switch (state.opts.hideTableKind ? TableState::Unsealed : ttv.state) + switch (state.opts.hideTableKind ? (FFlag::DebugLuauDeferredConstraintResolution ? TableState::Sealed : TableState::Unsealed) : ttv.state) { case TableState::Sealed: - state.result.invalid = true; - openbrace = "{|"; - closedbrace = "|}"; + if (FFlag::DebugLuauDeferredConstraintResolution) + { + openbrace = "{"; + closedbrace = "}"; + } + else + { + state.result.invalid = true; + openbrace = "{|"; + closedbrace = "|}"; + } break; case TableState::Unsealed: - openbrace = "{"; - closedbrace = "}"; + if (FFlag::DebugLuauDeferredConstraintResolution) + { + state.result.invalid = true; + openbrace = "{|"; + closedbrace = "|}"; + } + else + { + openbrace = "{"; + closedbrace = "}"; + } break; case TableState::Free: state.result.invalid = true; @@ -651,16 +813,8 @@ struct TypeStringifier break; } - if (isIdentifier(name)) - state.emit(name); - else - { - state.emit("[\""); - state.emit(escape(name)); - state.emit("\"]"); - } - state.emit(": "); - stringify(prop.type()); + stringify(name, prop); + comma = true; ++index; } @@ -731,7 +885,7 @@ struct TypeStringifier std::string saved = std::move(state.result.name); - bool needParens = !state.cycleNames.count(el) && (get(el) || get(el)); + bool needParens = !state.cycleNames.contains(el) && (get(el) || get(el)); if (needParens) state.emit("("); @@ -754,11 +908,15 @@ struct TypeStringifier state.emit("("); bool first = true; + bool shouldPlaceOnNewlines = results.size() > state.opts.compositeTypesSingleLineLimit; for (std::string& ss : results) { if (!first) { - state.newline(); + if (shouldPlaceOnNewlines) + state.newline(); + else + state.emit(" "); state.emit("| "); } state.emit(ss); @@ -778,7 +936,7 @@ struct TypeStringifier } } - void operator()(TypeId, const IntersectionType& uv) + void operator()(TypeId ty, const IntersectionType& uv) { if (state.hasSeen(&uv)) { @@ -794,7 +952,7 @@ struct TypeStringifier std::string saved = std::move(state.result.name); - bool needParens = !state.cycleNames.count(el) && (get(el) || get(el)); + bool needParens = !state.cycleNames.contains(el) && (get(el) || get(el)); if (needParens) state.emit("("); @@ -814,11 +972,15 @@ struct TypeStringifier std::sort(results.begin(), results.end()); bool first = true; + bool shouldPlaceOnNewlines = results.size() > state.opts.compositeTypesSingleLineLimit || isOverloadedFunction(ty); for (std::string& ss : results) { if (!first) { - state.newline(); + if (shouldPlaceOnNewlines) + state.newline(); + else + state.emit(" "); state.emit("& "); } state.emit(ss); @@ -871,6 +1033,33 @@ struct TypeStringifier if (parens) state.emit(")"); } + + void operator()(TypeId, const TypeFunctionInstanceType& tfitv) + { + state.emit(tfitv.function->name); + state.emit("<"); + + bool comma = false; + for (TypeId ty : tfitv.typeArguments) + { + if (comma) + state.emit(", "); + + comma = true; + stringify(ty); + } + + for (TypePackId tp : tfitv.packArguments) + { + if (comma) + state.emit(", "); + + comma = true; + stringify(tp); + } + + state.emit(">"); + } }; struct TypePackStringifier @@ -911,18 +1100,19 @@ struct TypePackStringifier return; } - auto it = state.cycleTpNames.find(tp); - if (it != state.cycleTpNames.end()) + if (auto p = state.cycleTpNames.find(tp)) { - state.emit(it->second); + state.emit(*p); return; } Luau::visit( - [this, tp](auto&& t) { + [this, tp](auto&& t) + { return (*this)(tp, t); }, - tp->ty); + tp->ty + ); } void operator()(TypePackId, const TypePack& tp) @@ -958,7 +1148,7 @@ struct TypePackStringifier if (tp.tail && !isEmpty(*tp.tail)) { TypePackId tail = follow(*tp.tail); - if (auto vtp = get(tail); !vtp || (!FFlag::DebugLuauVerboseTypeNames && !vtp->hidden)) + if (auto vtp = get(tail); !vtp || (FInt::DebugLuauVerboseTypeNames < 1 && !vtp->hidden)) { if (first) first = false; @@ -981,7 +1171,7 @@ struct TypePackStringifier void operator()(TypePackId, const VariadicTypePack& pack) { state.emit("..."); - if (FFlag::DebugLuauVerboseTypeNames && pack.hidden) + if (FInt::DebugLuauVerboseTypeNames >= 1 && pack.hidden) { state.emit("*hidden*"); } @@ -990,6 +1180,9 @@ struct TypePackStringifier void operator()(TypePackId tp, const GenericTypePack& pack) { + if (FInt::DebugLuauVerboseTypeNames >= 1) + state.emit("gen-"); + if (pack.explicitName) { state.usedNames.insert(pack.name); @@ -1001,7 +1194,7 @@ struct TypePackStringifier state.emit(state.getName(tp)); } - if (FFlag::DebugLuauVerboseTypeNames) + if (FInt::DebugLuauVerboseTypeNames >= 2) { state.emit("-"); if (FFlag::DebugLuauDeferredConstraintResolution) @@ -1009,17 +1202,18 @@ struct TypePackStringifier else state.emit(pack.level); } + state.emit("..."); } void operator()(TypePackId tp, const FreeTypePack& pack) { state.result.invalid = true; - if (FFlag::DebugLuauVerboseTypeNames) + if (FInt::DebugLuauVerboseTypeNames >= 1) state.emit("free-"); state.emit(state.getName(tp)); - if (FFlag::DebugLuauVerboseTypeNames) + if (FInt::DebugLuauVerboseTypeNames >= 2) { state.emit("-"); if (FFlag::DebugLuauDeferredConstraintResolution) @@ -1042,6 +1236,33 @@ struct TypePackStringifier state.emit(btp.index); state.emit("*"); } + + void operator()(TypePackId, const TypeFunctionInstanceTypePack& tfitp) + { + state.emit(tfitp.function->name); + state.emit("<"); + + bool comma = false; + for (TypeId p : tfitp.typeArguments) + { + if (comma) + state.emit(", "); + + comma = true; + stringify(p); + } + + for (TypePackId p : tfitp.packArguments) + { + if (comma) + state.emit(", "); + + comma = true; + stringify(p); + } + + state.emit(">"); + } }; void TypeStringifier::stringify(TypePackId tp) @@ -1056,8 +1277,13 @@ void TypeStringifier::stringify(TypePackId tpid, const std::vector& cycles, const std::set& cycleTPs, - std::unordered_map& cycleNames, std::unordered_map& cycleTpNames, bool exhaustive) +static void assignCycleNames( + const std::set& cycles, + const std::set& cycleTPs, + DenseHashMap& cycleNames, + DenseHashMap& cycleTpNames, + bool exhaustive +) { int nextIndex = 1; @@ -1069,9 +1295,14 @@ static void assignCycleNames(const std::set& cycles, const std::set(follow(cycleTy)); !exhaustive && ttv && (ttv->syntheticName || ttv->name)) { // If we have a cycle type in type parameters, assign a cycle name for this named table - if (std::find_if(ttv->instantiatedTypeParams.begin(), ttv->instantiatedTypeParams.end(), [&](auto&& el) { - return cycles.count(follow(el)); - }) != ttv->instantiatedTypeParams.end()) + if (std::find_if( + ttv->instantiatedTypeParams.begin(), + ttv->instantiatedTypeParams.end(), + [&](auto&& el) + { + return cycles.count(follow(el)); + } + ) != ttv->instantiatedTypeParams.end()) cycleNames[cycleTy] = ttv->name ? *ttv->name : *ttv->syntheticName; continue; @@ -1151,9 +1382,8 @@ ToStringResult toStringDetailed(TypeId ty, ToStringOptions& opts) * * t1 where t1 = the_whole_root_type */ - auto it = state.cycleNames.find(ty); - if (it != state.cycleNames.end()) - state.emit(it->second); + if (auto p = state.cycleNames.find(ty)) + state.emit(*p); else tvs.stringify(ty); @@ -1166,9 +1396,14 @@ ToStringResult toStringDetailed(TypeId ty, ToStringOptions& opts) state.exhaustive = true; std::vector> sortedCycleNames{state.cycleNames.begin(), state.cycleNames.end()}; - std::sort(sortedCycleNames.begin(), sortedCycleNames.end(), [](const auto& a, const auto& b) { - return a.second < b.second; - }); + std::sort( + sortedCycleNames.begin(), + sortedCycleNames.end(), + [](const auto& a, const auto& b) + { + return a.second < b.second; + } + ); bool semi = false; for (const auto& [cycleTy, name] : sortedCycleNames) @@ -1179,18 +1414,25 @@ ToStringResult toStringDetailed(TypeId ty, ToStringOptions& opts) state.emit(name); state.emit(" = "); Luau::visit( - [&tvs, cycleTy = cycleTy](auto&& t) { + [&tvs, cycleTy = cycleTy](auto&& t) + { return tvs(cycleTy, t); }, - cycleTy->ty); + cycleTy->ty + ); semi = true; } std::vector> sortedCycleTpNames(state.cycleTpNames.begin(), state.cycleTpNames.end()); - std::sort(sortedCycleTpNames.begin(), sortedCycleTpNames.end(), [](const auto& a, const auto& b) { - return a.second < b.second; - }); + std::sort( + sortedCycleTpNames.begin(), + sortedCycleTpNames.end(), + [](const auto& a, const auto& b) + { + return a.second < b.second; + } + ); TypePackStringifier tps{state}; @@ -1202,10 +1444,12 @@ ToStringResult toStringDetailed(TypeId ty, ToStringOptions& opts) state.emit(name); state.emit(" = "); Luau::visit( - [&tps, cycleTy = cycleTp](auto&& t) { + [&tps, cycleTy = cycleTp](auto&& t) + { return tps(cycleTy, t); }, - cycleTp->ty); + cycleTp->ty + ); semi = true; } @@ -1245,13 +1489,12 @@ ToStringResult toStringDetailed(TypePackId tp, ToStringOptions& opts) * * t1 where t1 = the_whole_root_type */ - auto it = state.cycleTpNames.find(tp); - if (it != state.cycleTpNames.end()) - state.emit(it->second); + if (auto p = state.cycleTpNames.find(tp)) + state.emit(*p); else tvs.stringify(tp); - if (!cycles.empty()) + if (!cycles.empty() || !cycleTPs.empty()) { result.cycle = true; state.emit(" where "); @@ -1260,9 +1503,14 @@ ToStringResult toStringDetailed(TypePackId tp, ToStringOptions& opts) state.exhaustive = true; std::vector> sortedCycleNames{state.cycleNames.begin(), state.cycleNames.end()}; - std::sort(sortedCycleNames.begin(), sortedCycleNames.end(), [](const auto& a, const auto& b) { - return a.second < b.second; - }); + std::sort( + sortedCycleNames.begin(), + sortedCycleNames.end(), + [](const auto& a, const auto& b) + { + return a.second < b.second; + } + ); bool semi = false; for (const auto& [cycleTy, name] : sortedCycleNames) @@ -1273,10 +1521,42 @@ ToStringResult toStringDetailed(TypePackId tp, ToStringOptions& opts) state.emit(name); state.emit(" = "); Luau::visit( - [&tvs, cycleTy = cycleTy](auto t) { + [&tvs, cycleTy = cycleTy](auto t) + { return tvs(cycleTy, t); }, - cycleTy->ty); + cycleTy->ty + ); + + semi = true; + } + + std::vector> sortedCycleTpNames{state.cycleTpNames.begin(), state.cycleTpNames.end()}; + std::sort( + sortedCycleTpNames.begin(), + sortedCycleTpNames.end(), + [](const auto& a, const auto& b) + { + return a.second < b.second; + } + ); + + TypePackStringifier tps{tvs.state}; + + for (const auto& [cycleTp, name] : sortedCycleTpNames) + { + if (semi) + state.emit(" ; "); + + state.emit(name); + state.emit(" = "); + Luau::visit( + [&tps, cycleTp = cycleTp](auto t) + { + return tps(cycleTp, t); + }, + cycleTp->ty + ); semi = true; } @@ -1462,12 +1742,26 @@ std::string generateName(size_t i) return n; } +std::string toStringVector(const std::vector& types, ToStringOptions& opts) +{ + std::string s; + for (TypeId ty : types) + { + if (!s.empty()) + s += ", "; + s += toString(ty, opts); + } + return s; +} + std::string toString(const Constraint& constraint, ToStringOptions& opts) { - auto go = [&opts](auto&& c) -> std::string { + auto go = [&opts](auto&& c) -> std::string + { using T = std::decay_t; - auto tos = [&opts](auto&& a) { + auto tos = [&opts](auto&& a) + { return toString(a, opts); }; @@ -1481,7 +1775,7 @@ std::string toString(const Constraint& constraint, ToStringOptions& opts) { std::string subStr = tos(c.subPack); std::string superStr = tos(c.superPack); - return subStr + " <: " + superStr; + return subStr + " <...: " + superStr; } else if constexpr (std::is_same_v) { @@ -1489,33 +1783,12 @@ std::string toString(const Constraint& constraint, ToStringOptions& opts) std::string superStr = tos(c.sourceType); return subStr + " ~ gen " + superStr; } - else if constexpr (std::is_same_v) - { - std::string subStr = tos(c.subType); - std::string superStr = tos(c.superType); - return subStr + " ~ inst " + superStr; - } - else if constexpr (std::is_same_v) - { - std::string resultStr = tos(c.resultType); - std::string operandStr = tos(c.operandType); - - return resultStr + " ~ Unary<" + toString(c.op) + ", " + operandStr + ">"; - } - else if constexpr (std::is_same_v) - { - std::string resultStr = tos(c.resultType); - std::string leftStr = tos(c.leftType); - std::string rightStr = tos(c.rightType); - - return resultStr + " ~ Binary<" + toString(c.op) + ", " + leftStr + ", " + rightStr + ">"; - } else if constexpr (std::is_same_v) { std::string iteratorStr = tos(c.iterator); - std::string variableStr = tos(c.variables); + std::string variableStr = toStringVector(c.variables, opts); - return variableStr + " ~ Iterate<" + iteratorStr + ">"; + return variableStr + " ~ iterate " + iteratorStr; } else if constexpr (std::is_same_v) { @@ -1531,35 +1804,39 @@ std::string toString(const Constraint& constraint, ToStringOptions& opts) { return "call " + tos(c.fn) + "( " + tos(c.argsPack) + " )" + " with { result = " + tos(c.result) + " }"; } - else if constexpr (std::is_same_v) + else if constexpr (std::is_same_v) { - return tos(c.resultType) + " ~ prim " + tos(c.expectedType) + ", " + tos(c.singletonType) + ", " + tos(c.multitonType); + return "function_check " + tos(c.fn) + " " + tos(c.argsPack); } - else if constexpr (std::is_same_v) + else if constexpr (std::is_same_v) { - return tos(c.resultType) + " ~ hasProp " + tos(c.subjectType) + ", \"" + c.prop + "\""; + if (c.expectedType) + return "prim " + tos(c.freeType) + "[expected: " + tos(*c.expectedType) + "] as " + tos(c.primitiveType); + else + return "prim " + tos(c.freeType) + " as " + tos(c.primitiveType); } - else if constexpr (std::is_same_v) + else if constexpr (std::is_same_v) { - const std::string pathStr = c.path.size() == 1 ? "\"" + c.path[0] + "\"" : "[\"" + join(c.path, "\", \"") + "\"]"; - return tos(c.resultType) + " ~ setProp " + tos(c.subjectType) + ", " + pathStr + " " + tos(c.propType); + return tos(c.resultType) + " ~ hasProp " + tos(c.subjectType) + ", \"" + c.prop + "\" ctx=" + std::to_string(int(c.context)); } - else if constexpr (std::is_same_v) + else if constexpr (std::is_same_v) { - return tos(c.resultType) + " ~ setIndexer " + tos(c.subjectType) + " [ " + tos(c.indexType) + " ] " + tos(c.propType); + return tos(c.resultType) + " ~ hasIndexer " + tos(c.subjectType) + " " + tos(c.indexType); } - else if constexpr (std::is_same_v) + else if constexpr (std::is_same_v) + return "assignProp " + tos(c.lhsType) + " " + c.propName + " " + tos(c.rhsType); + else if constexpr (std::is_same_v) + return "assignIndex " + tos(c.lhsType) + " " + tos(c.indexType) + " " + tos(c.rhsType); + else if constexpr (std::is_same_v) + return toStringVector(c.resultPack, opts) + " ~ ...unpack " + tos(c.sourcePack); + else if constexpr (std::is_same_v) + return "reduce " + tos(c.ty); + else if constexpr (std::is_same_v) { - std::string result = tos(c.resultType); - std::string discriminant = tos(c.discriminantType); - - if (c.negated) - return result + " ~ if isSingleton D then ~D else unknown where D = " + discriminant; - else - return result + " ~ if isSingleton D then D else unknown where D = " + discriminant; + return "reduce " + tos(c.tp); } - else if constexpr (std::is_same_v) - return tos(c.resultPack) + " ~ unpack " + tos(c.sourcePack); + else if constexpr (std::is_same_v) + return "equality: " + tos(c.resultType) + " ~ " + tos(c.assignmentType); else static_assert(always_false_v, "Non-exhaustive constraint switch"); }; @@ -1567,6 +1844,11 @@ std::string toString(const Constraint& constraint, ToStringOptions& opts) return visit(go, constraint.c); } +std::string toString(const Constraint& constraint) +{ + return toString(constraint, ToStringOptions{}); +} + std::string dump(const Constraint& c) { ToStringOptions opts; @@ -1614,9 +1896,30 @@ std::string toString(const Position& position) return "{ line = " + std::to_string(position.line) + ", col = " + std::to_string(position.column) + " }"; } -std::string toString(const Location& location) +std::string toString(const Location& location, int offset, bool useBegin) +{ + return "(" + std::to_string(location.begin.line + offset) + ", " + std::to_string(location.begin.column + offset) + ") - (" + + std::to_string(location.end.line + offset) + ", " + std::to_string(location.end.column + offset) + ")"; +} + +std::string toString(const TypeOrPack& tyOrTp, ToStringOptions& opts) { - return "Location { " + toString(location.begin) + ", " + toString(location.end) + " }"; + if (const TypeId* ty = get(tyOrTp)) + return toString(*ty, opts); + else if (const TypePackId* tp = get(tyOrTp)) + return toString(*tp, opts); + else + LUAU_UNREACHABLE(); +} + +std::string dump(const TypeOrPack& tyOrTp) +{ + ToStringOptions opts; + opts.exhaustive = true; + opts.functionTypeArguments = true; + std::string s = toString(tyOrTp, opts); + printf("%s\n", s.c_str()); + return s; } } // namespace Luau diff --git a/third_party/luau/Analysis/src/Transpiler.cpp b/third_party/luau/Analysis/src/Transpiler.cpp index cdfe6549..a42882ed 100644 --- a/third_party/luau/Analysis/src/Transpiler.cpp +++ b/third_party/luau/Analysis/src/Transpiler.cpp @@ -10,6 +10,7 @@ #include #include + namespace { bool isIdentifierStartChar(char c) @@ -27,8 +28,8 @@ bool isIdentifierChar(char c) return isIdentifierStartChar(c) || isDigit(c); } -const std::vector keywords = {"and", "break", "do", "else", "elseif", "end", "false", "for", "function", "if", "in", "local", "nil", - "not", "or", "repeat", "return", "then", "true", "until", "while"}; +const std::vector keywords = {"and", "break", "do", "else", "elseif", "end", "false", "for", "function", "if", "in", + "local", "nil", "not", "or", "repeat", "return", "then", "true", "until", "while"}; } // namespace @@ -467,6 +468,7 @@ struct Printer case AstExprBinary::Sub: case AstExprBinary::Mul: case AstExprBinary::Div: + case AstExprBinary::FloorDiv: case AstExprBinary::Mod: case AstExprBinary::Pow: case AstExprBinary::CompareLt: @@ -487,6 +489,8 @@ struct Printer writer.maybeSpace(a->right->location.begin, 4); writer.keyword(toString(a->op)); break; + default: + LUAU_ASSERT(!"Unknown Op"); } visualize(*a->right); @@ -753,6 +757,10 @@ struct Printer writer.maybeSpace(a->value->location.begin, 2); writer.symbol("/="); break; + case AstExprBinary::FloorDiv: + writer.maybeSpace(a->value->location.begin, 2); + writer.symbol("//="); + break; case AstExprBinary::Mod: writer.maybeSpace(a->value->location.begin, 2); writer.symbol("%="); @@ -836,6 +844,15 @@ struct Printer visualizeTypeAnnotation(*a->type); } } + else if (const auto& t = program.as()) + { + if (writeTypes) + { + writer.keyword("type function"); + writer.identifier(t->name.value); + visualizeFunctionBody(*t->body); + } + } else if (const auto& a = program.as()) { writer.symbol("(error-stat"); @@ -1174,11 +1191,11 @@ std::string toString(AstNode* node) Printer printer(writer); printer.writeTypes = true; - if (auto statNode = dynamic_cast(node)) + if (auto statNode = node->asStat()) printer.visualize(*statNode); - else if (auto exprNode = dynamic_cast(node)) + else if (auto exprNode = node->asExpr()) printer.visualize(*exprNode); - else if (auto typeNode = dynamic_cast(node)) + else if (auto typeNode = node->asType()) printer.visualizeTypeAnnotation(*typeNode); return writer.str(); diff --git a/third_party/luau/Analysis/src/TxnLog.cpp b/third_party/luau/Analysis/src/TxnLog.cpp index 26618313..bde7751a 100644 --- a/third_party/luau/Analysis/src/TxnLog.cpp +++ b/third_party/luau/Analysis/src/TxnLog.cpp @@ -1,6 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/TxnLog.h" +#include "Luau/Scope.h" #include "Luau/ToString.h" #include "Luau/TypeArena.h" #include "Luau/TypePack.h" @@ -71,17 +72,26 @@ const TxnLog* TxnLog::empty() void TxnLog::concat(TxnLog rhs) { for (auto& [ty, rep] : rhs.typeVarChanges) + { + if (rep->dead) + continue; typeVarChanges[ty] = std::move(rep); + } for (auto& [tp, rep] : rhs.typePackChanges) typePackChanges[tp] = std::move(rep); + + radioactive |= rhs.radioactive; } void TxnLog::concatAsIntersections(TxnLog rhs, NotNull arena) { for (auto& [ty, rightRep] : rhs.typeVarChanges) { - if (auto leftRep = typeVarChanges.find(ty)) + if (rightRep->dead) + continue; + + if (auto leftRep = typeVarChanges.find(ty); leftRep && !(*leftRep)->dead) { TypeId leftTy = arena->addType((*leftRep)->pending); TypeId rightTy = arena->addType(rightRep->pending); @@ -93,17 +103,80 @@ void TxnLog::concatAsIntersections(TxnLog rhs, NotNull arena) for (auto& [tp, rep] : rhs.typePackChanges) typePackChanges[tp] = std::move(rep); + + radioactive |= rhs.radioactive; } void TxnLog::concatAsUnion(TxnLog rhs, NotNull arena) { + /* + * Check for cycles. + * + * We must not combine a log entry that binds 'a to 'b with a log that + * binds 'b to 'a. + * + * Of the two, identify the one with the 'bigger' scope and eliminate the + * entry that rebinds it. + */ + for (const auto& [rightTy, rightRep] : rhs.typeVarChanges) + { + if (rightRep->dead) + continue; + + // We explicitly use get_if here because we do not wish to do anything + // if the uncommitted type is already bound to something else. + const FreeType* rf = get_if(&rightTy->ty); + if (!rf) + continue; + + const BoundType* rb = Luau::get(&rightRep->pending); + if (!rb) + continue; + + const TypeId leftTy = rb->boundTo; + const FreeType* lf = get_if(&leftTy->ty); + if (!lf) + continue; + + auto leftRep = typeVarChanges.find(leftTy); + if (!leftRep) + continue; + + if ((*leftRep)->dead) + continue; + + const BoundType* lb = Luau::get(&(*leftRep)->pending); + if (!lb) + continue; + + if (lb->boundTo == rightTy) + { + // leftTy has been bound to rightTy, but rightTy has also been bound + // to leftTy. We find the one that belongs to the more deeply nested + // scope and remove it from the log. + const bool discardLeft = useScopes ? subsumes(lf->scope, rf->scope) : lf->level.subsumes(rf->level); + + if (discardLeft) + (*leftRep)->dead = true; + else + rightRep->dead = true; + } + } + for (auto& [ty, rightRep] : rhs.typeVarChanges) { - if (auto leftRep = typeVarChanges.find(ty)) + if (rightRep->dead) + continue; + + if (auto leftRep = typeVarChanges.find(ty); leftRep && !(*leftRep)->dead) { TypeId leftTy = arena->addType((*leftRep)->pending); TypeId rightTy = arena->addType(rightRep->pending); - typeVarChanges[ty]->pending.ty = UnionType{{leftTy, rightTy}}; + + if (follow(leftTy) == follow(rightTy)) + typeVarChanges[ty] = std::move(rightRep); + else + typeVarChanges[ty]->pending.ty = UnionType{{leftTy, rightTy}}; } else typeVarChanges[ty] = std::move(rightRep); @@ -111,12 +184,19 @@ void TxnLog::concatAsUnion(TxnLog rhs, NotNull arena) for (auto& [tp, rep] : rhs.typePackChanges) typePackChanges[tp] = std::move(rep); + + radioactive |= rhs.radioactive; } void TxnLog::commit() { + LUAU_ASSERT(!radioactive); + for (auto& [ty, rep] : typeVarChanges) - asMutable(ty)->reassign(rep.get()->pending); + { + if (!rep->dead) + asMutable(ty)->reassign(rep.get()->pending); + } for (auto& [tp, rep] : typePackChanges) asMutable(tp)->reassign(rep.get()->pending); @@ -135,11 +215,16 @@ TxnLog TxnLog::inverse() TxnLog inversed(sharedSeen); for (auto& [ty, _rep] : typeVarChanges) - inversed.typeVarChanges[ty] = std::make_unique(*ty); + { + if (!_rep->dead) + inversed.typeVarChanges[ty] = std::make_unique(*ty); + } for (auto& [tp, _rep] : typePackChanges) inversed.typePackChanges[tp] = std::make_unique(*tp); + inversed.radioactive = radioactive; + return inversed; } @@ -199,12 +284,13 @@ void TxnLog::popSeen(TypeOrPackId lhs, TypeOrPackId rhs) PendingType* TxnLog::queue(TypeId ty) { - LUAU_ASSERT(!ty->persistent); + if (ty->persistent) + radioactive = true; // Explicitly don't look in ancestors. If we have discovered something new // about this type, we don't want to mutate the parent's state. auto& pending = typeVarChanges[ty]; - if (!pending) + if (!pending || (*pending).dead) { pending = std::make_unique(*ty); pending->pending.owningArena = nullptr; @@ -215,7 +301,8 @@ PendingType* TxnLog::queue(TypeId ty) PendingTypePack* TxnLog::queue(TypePackId tp) { - LUAU_ASSERT(!tp->persistent); + if (tp->persistent) + radioactive = true; // Explicitly don't look in ancestors. If we have discovered something new // about this type, we don't want to mutate the parent's state. @@ -237,7 +324,7 @@ PendingType* TxnLog::pending(TypeId ty) const for (const TxnLog* current = this; current; current = current->parent) { - if (auto it = current->typeVarChanges.find(ty)) + if (auto it = current->typeVarChanges.find(ty); it && !(*it)->dead) return it->get(); } @@ -382,32 +469,44 @@ std::optional TxnLog::getLevel(TypeId ty) const TypeId TxnLog::follow(TypeId ty) const { - return Luau::follow(ty, [this](TypeId ty) { - PendingType* state = this->pending(ty); + return Luau::follow( + ty, + this, + [](const void* ctx, TypeId ty) -> TypeId + { + const TxnLog* self = static_cast(ctx); + PendingType* state = self->pending(ty); - if (state == nullptr) - return ty; + if (state == nullptr) + return ty; - // Ugly: Fabricate a TypeId that doesn't adhere to most of the invariants - // that normally apply. This is safe because follow will only call get<> - // on the returned pointer. - return const_cast(&state->pending); - }); + // Ugly: Fabricate a TypeId that doesn't adhere to most of the invariants + // that normally apply. This is safe because follow will only call get<> + // on the returned pointer. + return const_cast(&state->pending); + } + ); } TypePackId TxnLog::follow(TypePackId tp) const { - return Luau::follow(tp, [this](TypePackId tp) { - PendingTypePack* state = this->pending(tp); + return Luau::follow( + tp, + this, + [](const void* ctx, TypePackId tp) -> TypePackId + { + const TxnLog* self = static_cast(ctx); + PendingTypePack* state = self->pending(tp); - if (state == nullptr) - return tp; + if (state == nullptr) + return tp; - // Ugly: Fabricate a TypePackId that doesn't adhere to most of the - // invariants that normally apply. This is safe because follow will - // only call get<> on the returned pointer. - return const_cast(&state->pending); - }); + // Ugly: Fabricate a TypePackId that doesn't adhere to most of the + // invariants that normally apply. This is safe because follow will + // only call get<> on the returned pointer. + return const_cast(&state->pending); + } + ); } std::pair, std::vector> TxnLog::getChanges() const diff --git a/third_party/luau/Analysis/src/Type.cpp b/third_party/luau/Analysis/src/Type.cpp index 888083b4..ffc4a97e 100644 --- a/third_party/luau/Analysis/src/Type.cpp +++ b/third_party/luau/Analysis/src/Type.cpp @@ -9,8 +9,10 @@ #include "Luau/RecursionCounter.h" #include "Luau/StringUtils.h" #include "Luau/ToString.h" +#include "Luau/TypeFunction.h" #include "Luau/TypeInfer.h" #include "Luau/TypePack.h" +#include "Luau/VecDeque.h" #include "Luau/VisitType.h" #include @@ -21,110 +23,86 @@ LUAU_FASTFLAG(DebugLuauFreezeArena) -LUAU_FASTINTVARIABLE(LuauTypeMaximumStringifierLength, 50000) +LUAU_FASTINTVARIABLE(LuauTypeMaximumStringifierLength, 500) LUAU_FASTINTVARIABLE(LuauTableTypeMaximumStringifierLength, 0) LUAU_FASTINT(LuauTypeInferRecursionLimit) LUAU_FASTFLAG(LuauInstantiateInSubtyping) -LUAU_FASTFLAG(LuauNormalizeBlockedTypes) -LUAU_FASTFLAG(DebugLuauReadWriteProperties) -LUAU_FASTFLAGVARIABLE(LuauBoundLazyTypes2, false) namespace Luau { -std::optional> magicFunctionFormat( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate); -static bool dcrMagicFunctionFormat(MagicFunctionCallContext context); +// LUAU_NOINLINE prevents unwrapLazy from being inlined into advance below; advance is important to keep inlineable +static LUAU_NOINLINE TypeId unwrapLazy(LazyType* ltv) +{ + TypeId unwrapped = ltv->unwrapped.load(); + + if (unwrapped) + return unwrapped; -static std::optional> magicFunctionGmatch( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate); -static bool dcrMagicFunctionGmatch(MagicFunctionCallContext context); + ltv->unwrap(*ltv); + unwrapped = ltv->unwrapped.load(); -static std::optional> magicFunctionMatch( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate); -static bool dcrMagicFunctionMatch(MagicFunctionCallContext context); + if (!unwrapped) + throw InternalCompilerError("Lazy Type didn't fill in unwrapped type field"); -static std::optional> magicFunctionFind( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate); -static bool dcrMagicFunctionFind(MagicFunctionCallContext context); + if (get(unwrapped)) + throw InternalCompilerError("Lazy Type cannot resolve to another Lazy Type"); + + return unwrapped; +} TypeId follow(TypeId t) { - return follow(t, [](TypeId t) { - return t; - }); + return follow(t, FollowOption::Normal); } -TypeId follow(TypeId t, std::function mapper) +TypeId follow(TypeId t, FollowOption followOption) { - auto advance = [&mapper](TypeId ty) -> std::optional { - if (FFlag::LuauBoundLazyTypes2) + return follow( + t, + followOption, + nullptr, + [](const void*, TypeId t) -> TypeId { - TypeId mapped = mapper(ty); - - if (auto btv = get>(mapped)) - return btv->boundTo; - - if (auto ttv = get(mapped)) - return ttv->boundTo; - - if (auto ltv = getMutable(mapped)) - { - TypeId unwrapped = ltv->unwrapped.load(); + return t; + } + ); +} - if (unwrapped) - return unwrapped; +TypeId follow(TypeId t, const void* context, TypeId (*mapper)(const void*, TypeId)) +{ + return follow(t, FollowOption::Normal, context, mapper); +} - ltv->unwrap(*ltv); - unwrapped = ltv->unwrapped.load(); +TypeId follow(TypeId t, FollowOption followOption, const void* context, TypeId (*mapper)(const void*, TypeId)) +{ + auto advance = [followOption, context, mapper](TypeId ty) -> std::optional + { + TypeId mapped = mapper(context, ty); - if (!unwrapped) - throw InternalCompilerError("Lazy Type didn't fill in unwrapped type field"); + if (auto btv = get>(mapped)) + return btv->boundTo; - if (get(unwrapped)) - throw InternalCompilerError("Lazy Type cannot resolve to another Lazy Type"); + if (auto ttv = get(mapped)) + return ttv->boundTo; - return unwrapped; - } + if (auto ltv = getMutable(mapped); ltv && followOption != FollowOption::DisableLazyTypeThunks) + return unwrapLazy(ltv); - return std::nullopt; - } - else - { - if (auto btv = get>(mapper(ty))) - return btv->boundTo; - else if (auto ttv = get(mapper(ty))) - return ttv->boundTo; - else - return std::nullopt; - } - }; - - auto force = [&mapper](TypeId ty) { - if (auto ltv = get_if(&mapper(ty)->ty)) - { - TypeId res = ltv->thunk_DEPRECATED(); - if (get(res)) - throw InternalCompilerError("Lazy Type cannot resolve to another Lazy Type"); - - *asMutable(ty) = BoundType(res); - } + return std::nullopt; }; - if (!FFlag::LuauBoundLazyTypes2) - force(t); - TypeId cycleTester = t; // Null once we've determined that there is no cycle if (auto a = advance(cycleTester)) cycleTester = *a; else return t; + if (!advance(cycleTester)) // Short circuit traversal for the rather common case when advance(advance(t)) == null + return cycleTester; + while (true) { - if (!FFlag::LuauBoundLazyTypes2) - force(t); - auto a1 = advance(t); if (a1) t = *a1; @@ -157,7 +135,7 @@ std::vector flattenIntersection(TypeId ty) return {ty}; std::unordered_set seen; - std::deque queue{ty}; + VecDeque queue{ty}; std::vector result; @@ -243,6 +221,11 @@ bool isThread(TypeId ty) return isPrim(ty, PrimitiveType::Thread); } +bool isBuffer(TypeId ty) +{ + return isPrim(ty, PrimitiveType::Buffer); +} + bool isOptional(TypeId ty) { if (isNil(ty)) @@ -269,12 +252,22 @@ bool isTableIntersection(TypeId ty) return std::all_of(parts.begin(), parts.end(), getTableType); } +bool isTableUnion(TypeId ty) +{ + const UnionType* ut = get(follow(ty)); + if (!ut) + return false; + + return std::all_of(begin(ut), end(ut), getTableType); +} + bool isOverloadedFunction(TypeId ty) { if (!get(follow(ty))) return false; - auto isFunction = [](TypeId part) -> bool { + auto isFunction = [](TypeId part) -> bool + { return get(part); }; @@ -434,6 +427,13 @@ bool maybeSingleton(TypeId ty) for (TypeId option : utv) if (get(follow(option))) return true; + if (const IntersectionType* itv = get(ty)) + for (TypeId part : itv) + if (maybeSingleton(part)) // will i regret this? + return true; + if (const TypeFunctionInstanceType* tfit = get(ty)) + if (tfit->function->name == "keyof" || tfit->function->name == "rawkeyof") + return true; return false; } @@ -499,6 +499,14 @@ FreeType::FreeType(Scope* scope, TypeLevel level) { } +FreeType::FreeType(Scope* scope, TypeId lowerBound, TypeId upperBound) + : index(Unifiable::freshIndex()) + , scope(scope) + , lowerBound(lowerBound) + , upperBound(upperBound) +{ +} + GenericType::GenericType() : index(Unifiable::freshIndex()) , name("g" + std::to_string(index)) @@ -542,14 +550,36 @@ GenericType::GenericType(Scope* scope, const Name& name) } BlockedType::BlockedType() - : index(FFlag::LuauNormalizeBlockedTypes ? Unifiable::freshIndex() : ++DEPRECATED_nextIndex) + : index(Unifiable::freshIndex()) +{ +} + +Constraint* BlockedType::getOwner() const { + return owner; +} + +void BlockedType::setOwner(Constraint* newOwner) +{ + LUAU_ASSERT(owner == nullptr); + + if (owner != nullptr) + return; + + owner = newOwner; } -int BlockedType::DEPRECATED_nextIndex = 0; +void BlockedType::replaceOwner(Constraint* newOwner) +{ + owner = newOwner; +} PendingExpansionType::PendingExpansionType( - std::optional prefix, AstName name, std::vector typeArguments, std::vector packArguments) + std::optional prefix, + AstName name, + std::vector typeArguments, + std::vector packArguments +) : prefix(prefix) , name(name) , typeArguments(typeArguments) @@ -578,7 +608,13 @@ FunctionType::FunctionType(TypeLevel level, TypePackId argTypes, TypePackId retT } FunctionType::FunctionType( - TypeLevel level, Scope* scope, TypePackId argTypes, TypePackId retTypes, std::optional defn, bool hasSelf) + TypeLevel level, + Scope* scope, + TypePackId argTypes, + TypePackId retTypes, + std::optional defn, + bool hasSelf +) : definition(std::move(defn)) , level(level) , scope(scope) @@ -588,8 +624,14 @@ FunctionType::FunctionType( { } -FunctionType::FunctionType(std::vector generics, std::vector genericPacks, TypePackId argTypes, TypePackId retTypes, - std::optional defn, bool hasSelf) +FunctionType::FunctionType( + std::vector generics, + std::vector genericPacks, + TypePackId argTypes, + TypePackId retTypes, + std::optional defn, + bool hasSelf +) : definition(std::move(defn)) , generics(generics) , genericPacks(genericPacks) @@ -599,8 +641,15 @@ FunctionType::FunctionType(std::vector generics, std::vector { } -FunctionType::FunctionType(TypeLevel level, std::vector generics, std::vector genericPacks, TypePackId argTypes, - TypePackId retTypes, std::optional defn, bool hasSelf) +FunctionType::FunctionType( + TypeLevel level, + std::vector generics, + std::vector genericPacks, + TypePackId argTypes, + TypePackId retTypes, + std::optional defn, + bool hasSelf +) : definition(std::move(defn)) , generics(generics) , genericPacks(genericPacks) @@ -611,8 +660,16 @@ FunctionType::FunctionType(TypeLevel level, std::vector generics, std::v { } -FunctionType::FunctionType(TypeLevel level, Scope* scope, std::vector generics, std::vector genericPacks, TypePackId argTypes, - TypePackId retTypes, std::optional defn, bool hasSelf) +FunctionType::FunctionType( + TypeLevel level, + Scope* scope, + std::vector generics, + std::vector genericPacks, + TypePackId argTypes, + TypePackId retTypes, + std::optional defn, + bool hasSelf +) : definition(std::move(defn)) , generics(generics) , genericPacks(genericPacks) @@ -626,23 +683,28 @@ FunctionType::FunctionType(TypeLevel level, Scope* scope, std::vector ge Property::Property() {} -Property::Property(TypeId readTy, bool deprecated, const std::string& deprecatedSuggestion, std::optional location, const Tags& tags, - const std::optional& documentationSymbol) +Property::Property( + TypeId readTy, + bool deprecated, + const std::string& deprecatedSuggestion, + std::optional location, + const Tags& tags, + const std::optional& documentationSymbol, + std::optional typeLocation +) : deprecated(deprecated) , deprecatedSuggestion(deprecatedSuggestion) , location(location) + , typeLocation(typeLocation) , tags(tags) , documentationSymbol(documentationSymbol) , readTy(readTy) , writeTy(readTy) { - LUAU_ASSERT(!FFlag::DebugLuauReadWriteProperties); } Property Property::readonly(TypeId ty) { - LUAU_ASSERT(FFlag::DebugLuauReadWriteProperties); - Property p; p.readTy = ty; return p; @@ -650,8 +712,6 @@ Property Property::readonly(TypeId ty) Property Property::writeonly(TypeId ty) { - LUAU_ASSERT(FFlag::DebugLuauReadWriteProperties); - Property p; p.writeTy = ty; return p; @@ -664,29 +724,27 @@ Property Property::rw(TypeId ty) Property Property::rw(TypeId read, TypeId write) { - LUAU_ASSERT(FFlag::DebugLuauReadWriteProperties); - Property p; p.readTy = read; p.writeTy = write; return p; } -std::optional Property::create(std::optional read, std::optional write) +Property Property::create(std::optional read, std::optional write) { if (read && !write) return Property::readonly(*read); else if (!read && write) return Property::writeonly(*write); - else if (read && write) - return Property::rw(*read, *write); else - return std::nullopt; + { + LUAU_ASSERT(read && write); + return Property::rw(*read, *write); + } } TypeId Property::type() const { - LUAU_ASSERT(!FFlag::DebugLuauReadWriteProperties); LUAU_ASSERT(readTy); return *readTy; } @@ -694,20 +752,34 @@ TypeId Property::type() const void Property::setType(TypeId ty) { readTy = ty; + if (FFlag::DebugLuauDeferredConstraintResolution) + writeTy = ty; +} + +void Property::makeShared() +{ + if (writeTy) + writeTy = readTy; } -std::optional Property::readType() const +bool Property::isShared() const { - LUAU_ASSERT(FFlag::DebugLuauReadWriteProperties); - LUAU_ASSERT(!(bool(readTy) && bool(writeTy))); - return readTy; + return readTy && writeTy && readTy == writeTy; } -std::optional Property::writeType() const +bool Property::isReadOnly() const { - LUAU_ASSERT(FFlag::DebugLuauReadWriteProperties); - LUAU_ASSERT(!(bool(readTy) && bool(writeTy))); - return writeTy; + return readTy && !writeTy; +} + +bool Property::isWriteOnly() const +{ + return !readTy && writeTy; +} + +bool Property::isReadWrite() const +{ + return readTy && writeTy; } TableType::TableType(TableState state, TypeLevel level, Scope* scope) @@ -927,9 +999,17 @@ Type& Type::operator=(const Type& rhs) return *this; } -TypeId makeFunction(TypeArena& arena, std::optional selfType, std::initializer_list generics, - std::initializer_list genericPacks, std::initializer_list paramTypes, std::initializer_list paramNames, - std::initializer_list retTypes); +TypeId makeFunction( + TypeArena& arena, + std::optional selfType, + std::initializer_list generics, + std::initializer_list genericPacks, + std::initializer_list paramTypes, + std::initializer_list paramNames, + std::initializer_list retTypes +); + +TypeId makeStringMetatable(NotNull builtinTypes); // BuiltinDefinitions.cpp BuiltinTypes::BuiltinTypes() : arena(new TypeArena) @@ -939,8 +1019,9 @@ BuiltinTypes::BuiltinTypes() , stringType(arena->addType(Type{PrimitiveType{PrimitiveType::String}, /*persistent*/ true})) , booleanType(arena->addType(Type{PrimitiveType{PrimitiveType::Boolean}, /*persistent*/ true})) , threadType(arena->addType(Type{PrimitiveType{PrimitiveType::Thread}, /*persistent*/ true})) + , bufferType(arena->addType(Type{PrimitiveType{PrimitiveType::Buffer}, /*persistent*/ true})) , functionType(arena->addType(Type{PrimitiveType{PrimitiveType::Function}, /*persistent*/ true})) - , classType(arena->addType(Type{ClassType{"class", {}, std::nullopt, std::nullopt, {}, {}, {}}, /*persistent*/ true})) + , classType(arena->addType(Type{ClassType{"class", {}, std::nullopt, std::nullopt, {}, {}, {}, {}}, /*persistent*/ true})) , tableType(arena->addType(Type{PrimitiveType{PrimitiveType::Table}, /*persistent*/ true})) , emptyTableType(arena->addType(Type{TableType{TableState::Sealed, TypeLevel{}, nullptr}, /*persistent*/ true})) , trueType(arena->addType(Type{SingletonType{BooleanSingleton{true}}, /*persistent*/ true})) @@ -953,15 +1034,13 @@ BuiltinTypes::BuiltinTypes() , truthyType(arena->addType(Type{NegationType{falsyType}, /*persistent*/ true})) , optionalNumberType(arena->addType(Type{UnionType{{numberType, nilType}}, /*persistent*/ true})) , optionalStringType(arena->addType(Type{UnionType{{stringType, nilType}}, /*persistent*/ true})) + , emptyTypePack(arena->addTypePack(TypePackVar{TypePack{{}}, /*persistent*/ true})) , anyTypePack(arena->addTypePack(TypePackVar{VariadicTypePack{anyType}, /*persistent*/ true})) + , unknownTypePack(arena->addTypePack(TypePackVar{VariadicTypePack{unknownType}, /*persistent*/ true})) , neverTypePack(arena->addTypePack(TypePackVar{VariadicTypePack{neverType}, /*persistent*/ true})) , uninhabitableTypePack(arena->addTypePack(TypePackVar{TypePack{{neverType}, neverTypePack}, /*persistent*/ true})) , errorTypePack(arena->addTypePack(TypePackVar{Unifiable::Error{}, /*persistent*/ true})) { - TypeId stringMetatable = makeStringMetatable(); - asMutable(stringType)->ty = PrimitiveType{PrimitiveType::String, stringMetatable}; - persist(stringMetatable); - freeze(*arena); } @@ -977,82 +1056,6 @@ BuiltinTypes::~BuiltinTypes() FFlag::DebugLuauFreezeArena.value = prevFlag; } -TypeId BuiltinTypes::makeStringMetatable() -{ - const TypeId optionalNumber = arena->addType(UnionType{{nilType, numberType}}); - const TypeId optionalString = arena->addType(UnionType{{nilType, stringType}}); - const TypeId optionalBoolean = arena->addType(UnionType{{nilType, booleanType}}); - - const TypePackId oneStringPack = arena->addTypePack({stringType}); - const TypePackId anyTypePack = arena->addTypePack(TypePackVar{VariadicTypePack{anyType}, true}); - - FunctionType formatFTV{arena->addTypePack(TypePack{{stringType}, anyTypePack}), oneStringPack}; - formatFTV.magicFunction = &magicFunctionFormat; - const TypeId formatFn = arena->addType(formatFTV); - attachDcrMagicFunction(formatFn, dcrMagicFunctionFormat); - - const TypePackId emptyPack = arena->addTypePack({}); - const TypePackId stringVariadicList = arena->addTypePack(TypePackVar{VariadicTypePack{stringType}}); - const TypePackId numberVariadicList = arena->addTypePack(TypePackVar{VariadicTypePack{numberType}}); - - const TypeId stringToStringType = makeFunction(*arena, std::nullopt, {}, {}, {stringType}, {}, {stringType}); - - const TypeId replArgType = - arena->addType(UnionType{{stringType, arena->addType(TableType({}, TableIndexer(stringType, stringType), TypeLevel{}, TableState::Generic)), - makeFunction(*arena, std::nullopt, {}, {}, {stringType}, {}, {stringType})}}); - const TypeId gsubFunc = makeFunction(*arena, stringType, {}, {}, {stringType, replArgType, optionalNumber}, {}, {stringType, numberType}); - const TypeId gmatchFunc = - makeFunction(*arena, stringType, {}, {}, {stringType}, {}, {arena->addType(FunctionType{emptyPack, stringVariadicList})}); - attachMagicFunction(gmatchFunc, magicFunctionGmatch); - attachDcrMagicFunction(gmatchFunc, dcrMagicFunctionGmatch); - - const TypeId matchFunc = arena->addType( - FunctionType{arena->addTypePack({stringType, stringType, optionalNumber}), arena->addTypePack(TypePackVar{VariadicTypePack{stringType}})}); - attachMagicFunction(matchFunc, magicFunctionMatch); - attachDcrMagicFunction(matchFunc, dcrMagicFunctionMatch); - - const TypeId findFunc = arena->addType(FunctionType{arena->addTypePack({stringType, stringType, optionalNumber, optionalBoolean}), - arena->addTypePack(TypePack{{optionalNumber, optionalNumber}, stringVariadicList})}); - attachMagicFunction(findFunc, magicFunctionFind); - attachDcrMagicFunction(findFunc, dcrMagicFunctionFind); - - TableType::Props stringLib = { - {"byte", {arena->addType(FunctionType{arena->addTypePack({stringType, optionalNumber, optionalNumber}), numberVariadicList})}}, - {"char", {arena->addType(FunctionType{numberVariadicList, arena->addTypePack({stringType})})}}, - {"find", {findFunc}}, - {"format", {formatFn}}, // FIXME - {"gmatch", {gmatchFunc}}, - {"gsub", {gsubFunc}}, - {"len", {makeFunction(*arena, stringType, {}, {}, {}, {}, {numberType})}}, - {"lower", {stringToStringType}}, - {"match", {matchFunc}}, - {"rep", {makeFunction(*arena, stringType, {}, {}, {numberType}, {}, {stringType})}}, - {"reverse", {stringToStringType}}, - {"sub", {makeFunction(*arena, stringType, {}, {}, {numberType, optionalNumber}, {}, {stringType})}}, - {"upper", {stringToStringType}}, - {"split", {makeFunction(*arena, stringType, {}, {}, {optionalString}, {}, - {arena->addType(TableType{{}, TableIndexer{numberType, stringType}, TypeLevel{}, TableState::Sealed})})}}, - {"pack", {arena->addType(FunctionType{ - arena->addTypePack(TypePack{{stringType}, anyTypePack}), - oneStringPack, - })}}, - {"packsize", {makeFunction(*arena, stringType, {}, {}, {}, {}, {numberType})}}, - {"unpack", {arena->addType(FunctionType{ - arena->addTypePack(TypePack{{stringType, stringType, optionalNumber}}), - anyTypePack, - })}}, - }; - - assignPropDocumentationSymbols(stringLib, "@luau/global/string"); - - TypeId tableType = arena->addType(TableType{std::move(stringLib), std::nullopt, TypeLevel{}, TableState::Sealed}); - - if (TableType* ttv = getMutable(tableType)) - ttv->name = "typeof(string)"; - - return arena->addType(TableType{{{{"__index", {tableType}}}}, std::nullopt, TypeLevel{}, TableState::Sealed}); -} - TypeId BuiltinTypes::errorRecoveryType() const { return errorType; @@ -1075,7 +1078,7 @@ TypePackId BuiltinTypes::errorRecoveryTypePack(TypePackId guess) const void persist(TypeId ty) { - std::deque queue{ty}; + VecDeque queue{ty}; while (!queue.empty()) { @@ -1130,6 +1133,14 @@ void persist(TypeId ty) else if (get(t) || get(t) || get(t) || get(t) || get(t) || get(t)) { } + else if (auto tfit = get(t)) + { + for (auto ty : tfit->typeArguments) + queue.push_back(ty); + + for (auto tp : tfit->packArguments) + persist(tp); + } else { LUAU_ASSERT(!"TypeId is not supported in a persist call"); @@ -1158,6 +1169,14 @@ void persist(TypePackId tp) else if (get(tp)) { } + else if (auto tfitp = get(tp)) + { + for (auto ty : tfitp->typeArguments) + persist(ty); + + for (auto tp : tfitp->packArguments) + persist(tp); + } else { LUAU_ASSERT(!"TypePackId is not supported in a persist call"); @@ -1258,434 +1277,9 @@ IntersectionTypeIterator end(const IntersectionType* itv) return IntersectionTypeIterator{}; } -static std::vector parseFormatString(NotNull builtinTypes, const char* data, size_t size) -{ - const char* options = "cdiouxXeEfgGqs*"; - - std::vector result; - - for (size_t i = 0; i < size; ++i) - { - if (data[i] == '%') - { - i++; - - if (i < size && data[i] == '%') - continue; - - // we just ignore all characters (including flags/precision) up until first alphabetic character - while (i < size && !(data[i] > 0 && (isalpha(data[i]) || data[i] == '*'))) - i++; - - if (i == size) - break; - - if (data[i] == 'q' || data[i] == 's') - result.push_back(builtinTypes->stringType); - else if (data[i] == '*') - result.push_back(builtinTypes->unknownType); - else if (strchr(options, data[i])) - result.push_back(builtinTypes->numberType); - else - result.push_back(builtinTypes->errorRecoveryType(builtinTypes->anyType)); - } - } - - return result; -} - -std::optional> magicFunctionFormat( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate) -{ - auto [paramPack, _predicates] = withPredicate; - - TypeArena& arena = typechecker.currentModule->internalTypes; - - AstExprConstantString* fmt = nullptr; - if (auto index = expr.func->as(); index && expr.self) - { - if (auto group = index->expr->as()) - fmt = group->expr->as(); - else - fmt = index->expr->as(); - } - - if (!expr.self && expr.args.size > 0) - fmt = expr.args.data[0]->as(); - - if (!fmt) - return std::nullopt; - - std::vector expected = parseFormatString(typechecker.builtinTypes, fmt->value.data, fmt->value.size); - const auto& [params, tail] = flatten(paramPack); - - size_t paramOffset = 1; - size_t dataOffset = expr.self ? 0 : 1; - - // unify the prefix one argument at a time - for (size_t i = 0; i < expected.size() && i + paramOffset < params.size(); ++i) - { - Location location = expr.args.data[std::min(i + dataOffset, expr.args.size - 1)]->location; - - typechecker.unify(params[i + paramOffset], expected[i], scope, location); - } - - // if we know the argument count or if we have too many arguments for sure, we can issue an error - size_t numActualParams = params.size(); - size_t numExpectedParams = expected.size() + 1; // + 1 for the format string - - if (numExpectedParams != numActualParams && (!tail || numExpectedParams < numActualParams)) - typechecker.reportError(TypeError{expr.location, CountMismatch{numExpectedParams, std::nullopt, numActualParams}}); - - return WithPredicate{arena.addTypePack({typechecker.stringType})}; -} - -static bool dcrMagicFunctionFormat(MagicFunctionCallContext context) -{ - TypeArena* arena = context.solver->arena; - - AstExprConstantString* fmt = nullptr; - if (auto index = context.callSite->func->as(); index && context.callSite->self) - { - if (auto group = index->expr->as()) - fmt = group->expr->as(); - else - fmt = index->expr->as(); - } - - if (!context.callSite->self && context.callSite->args.size > 0) - fmt = context.callSite->args.data[0]->as(); - - if (!fmt) - return false; - - std::vector expected = parseFormatString(context.solver->builtinTypes, fmt->value.data, fmt->value.size); - const auto& [params, tail] = flatten(context.arguments); - - size_t paramOffset = 1; - - // unify the prefix one argument at a time - for (size_t i = 0; i < expected.size() && i + paramOffset < params.size(); ++i) - { - context.solver->unify(params[i + paramOffset], expected[i], context.solver->rootScope); - } - - // if we know the argument count or if we have too many arguments for sure, we can issue an error - size_t numActualParams = params.size(); - size_t numExpectedParams = expected.size() + 1; // + 1 for the format string - - if (numExpectedParams != numActualParams && (!tail || numExpectedParams < numActualParams)) - context.solver->reportError(TypeError{context.callSite->location, CountMismatch{numExpectedParams, std::nullopt, numActualParams}}); - - TypePackId resultPack = arena->addTypePack({context.solver->builtinTypes->stringType}); - asMutable(context.result)->ty.emplace(resultPack); - - return true; -} - -static std::vector parsePatternString(NotNull builtinTypes, const char* data, size_t size) +TypeId freshType(NotNull arena, NotNull builtinTypes, Scope* scope) { - std::vector result; - int depth = 0; - bool parsingSet = false; - - for (size_t i = 0; i < size; ++i) - { - if (data[i] == '%') - { - ++i; - if (!parsingSet && i < size && data[i] == 'b') - i += 2; - } - else if (!parsingSet && data[i] == '[') - { - parsingSet = true; - if (i + 1 < size && data[i + 1] == ']') - i += 1; - } - else if (parsingSet && data[i] == ']') - { - parsingSet = false; - } - else if (data[i] == '(') - { - if (parsingSet) - continue; - - if (i + 1 < size && data[i + 1] == ')') - { - i++; - result.push_back(builtinTypes->optionalNumberType); - continue; - } - - ++depth; - result.push_back(builtinTypes->optionalStringType); - } - else if (data[i] == ')') - { - if (parsingSet) - continue; - - --depth; - - if (depth < 0) - break; - } - } - - if (depth != 0 || parsingSet) - return std::vector(); - - if (result.empty()) - result.push_back(builtinTypes->optionalStringType); - - return result; -} - -static std::optional> magicFunctionGmatch( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate) -{ - auto [paramPack, _predicates] = withPredicate; - const auto& [params, tail] = flatten(paramPack); - - if (params.size() != 2) - return std::nullopt; - - TypeArena& arena = typechecker.currentModule->internalTypes; - - AstExprConstantString* pattern = nullptr; - size_t index = expr.self ? 0 : 1; - if (expr.args.size > index) - pattern = expr.args.data[index]->as(); - - if (!pattern) - return std::nullopt; - - std::vector returnTypes = parsePatternString(typechecker.builtinTypes, pattern->value.data, pattern->value.size); - - if (returnTypes.empty()) - return std::nullopt; - - typechecker.unify(params[0], typechecker.stringType, scope, expr.args.data[0]->location); - - const TypePackId emptyPack = arena.addTypePack({}); - const TypePackId returnList = arena.addTypePack(returnTypes); - const TypeId iteratorType = arena.addType(FunctionType{emptyPack, returnList}); - return WithPredicate{arena.addTypePack({iteratorType})}; -} - -static bool dcrMagicFunctionGmatch(MagicFunctionCallContext context) -{ - const auto& [params, tail] = flatten(context.arguments); - - if (params.size() != 2) - return false; - - TypeArena* arena = context.solver->arena; - - AstExprConstantString* pattern = nullptr; - size_t index = context.callSite->self ? 0 : 1; - if (context.callSite->args.size > index) - pattern = context.callSite->args.data[index]->as(); - - if (!pattern) - return false; - - std::vector returnTypes = parsePatternString(context.solver->builtinTypes, pattern->value.data, pattern->value.size); - - if (returnTypes.empty()) - return false; - - context.solver->unify(params[0], context.solver->builtinTypes->stringType, context.solver->rootScope); - - const TypePackId emptyPack = arena->addTypePack({}); - const TypePackId returnList = arena->addTypePack(returnTypes); - const TypeId iteratorType = arena->addType(FunctionType{emptyPack, returnList}); - const TypePackId resTypePack = arena->addTypePack({iteratorType}); - asMutable(context.result)->ty.emplace(resTypePack); - - return true; -} - -static std::optional> magicFunctionMatch( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate) -{ - auto [paramPack, _predicates] = withPredicate; - const auto& [params, tail] = flatten(paramPack); - - if (params.size() < 2 || params.size() > 3) - return std::nullopt; - - TypeArena& arena = typechecker.currentModule->internalTypes; - - AstExprConstantString* pattern = nullptr; - size_t patternIndex = expr.self ? 0 : 1; - if (expr.args.size > patternIndex) - pattern = expr.args.data[patternIndex]->as(); - - if (!pattern) - return std::nullopt; - - std::vector returnTypes = parsePatternString(typechecker.builtinTypes, pattern->value.data, pattern->value.size); - - if (returnTypes.empty()) - return std::nullopt; - - typechecker.unify(params[0], typechecker.stringType, scope, expr.args.data[0]->location); - - const TypeId optionalNumber = arena.addType(UnionType{{typechecker.nilType, typechecker.numberType}}); - - size_t initIndex = expr.self ? 1 : 2; - if (params.size() == 3 && expr.args.size > initIndex) - typechecker.unify(params[2], optionalNumber, scope, expr.args.data[initIndex]->location); - - const TypePackId returnList = arena.addTypePack(returnTypes); - return WithPredicate{returnList}; -} - -static bool dcrMagicFunctionMatch(MagicFunctionCallContext context) -{ - const auto& [params, tail] = flatten(context.arguments); - - if (params.size() < 2 || params.size() > 3) - return false; - - TypeArena* arena = context.solver->arena; - - AstExprConstantString* pattern = nullptr; - size_t patternIndex = context.callSite->self ? 0 : 1; - if (context.callSite->args.size > patternIndex) - pattern = context.callSite->args.data[patternIndex]->as(); - - if (!pattern) - return false; - - std::vector returnTypes = parsePatternString(context.solver->builtinTypes, pattern->value.data, pattern->value.size); - - if (returnTypes.empty()) - return false; - - context.solver->unify(params[0], context.solver->builtinTypes->stringType, context.solver->rootScope); - - const TypeId optionalNumber = arena->addType(UnionType{{context.solver->builtinTypes->nilType, context.solver->builtinTypes->numberType}}); - - size_t initIndex = context.callSite->self ? 1 : 2; - if (params.size() == 3 && context.callSite->args.size > initIndex) - context.solver->unify(params[2], optionalNumber, context.solver->rootScope); - - const TypePackId returnList = arena->addTypePack(returnTypes); - asMutable(context.result)->ty.emplace(returnList); - - return true; -} - -static std::optional> magicFunctionFind( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate) -{ - auto [paramPack, _predicates] = withPredicate; - const auto& [params, tail] = flatten(paramPack); - - if (params.size() < 2 || params.size() > 4) - return std::nullopt; - - TypeArena& arena = typechecker.currentModule->internalTypes; - - AstExprConstantString* pattern = nullptr; - size_t patternIndex = expr.self ? 0 : 1; - if (expr.args.size > patternIndex) - pattern = expr.args.data[patternIndex]->as(); - - if (!pattern) - return std::nullopt; - - bool plain = false; - size_t plainIndex = expr.self ? 2 : 3; - if (expr.args.size > plainIndex) - { - AstExprConstantBool* p = expr.args.data[plainIndex]->as(); - plain = p && p->value; - } - - std::vector returnTypes; - if (!plain) - { - returnTypes = parsePatternString(typechecker.builtinTypes, pattern->value.data, pattern->value.size); - - if (returnTypes.empty()) - return std::nullopt; - } - - typechecker.unify(params[0], typechecker.stringType, scope, expr.args.data[0]->location); - - const TypeId optionalNumber = arena.addType(UnionType{{typechecker.nilType, typechecker.numberType}}); - const TypeId optionalBoolean = arena.addType(UnionType{{typechecker.nilType, typechecker.booleanType}}); - - size_t initIndex = expr.self ? 1 : 2; - if (params.size() >= 3 && expr.args.size > initIndex) - typechecker.unify(params[2], optionalNumber, scope, expr.args.data[initIndex]->location); - - if (params.size() == 4 && expr.args.size > plainIndex) - typechecker.unify(params[3], optionalBoolean, scope, expr.args.data[plainIndex]->location); - - returnTypes.insert(returnTypes.begin(), {optionalNumber, optionalNumber}); - - const TypePackId returnList = arena.addTypePack(returnTypes); - return WithPredicate{returnList}; -} - -static bool dcrMagicFunctionFind(MagicFunctionCallContext context) -{ - const auto& [params, tail] = flatten(context.arguments); - - if (params.size() < 2 || params.size() > 4) - return false; - - TypeArena* arena = context.solver->arena; - NotNull builtinTypes = context.solver->builtinTypes; - - AstExprConstantString* pattern = nullptr; - size_t patternIndex = context.callSite->self ? 0 : 1; - if (context.callSite->args.size > patternIndex) - pattern = context.callSite->args.data[patternIndex]->as(); - - if (!pattern) - return false; - - bool plain = false; - size_t plainIndex = context.callSite->self ? 2 : 3; - if (context.callSite->args.size > plainIndex) - { - AstExprConstantBool* p = context.callSite->args.data[plainIndex]->as(); - plain = p && p->value; - } - - std::vector returnTypes; - if (!plain) - { - returnTypes = parsePatternString(builtinTypes, pattern->value.data, pattern->value.size); - - if (returnTypes.empty()) - return false; - } - - context.solver->unify(params[0], builtinTypes->stringType, context.solver->rootScope); - - const TypeId optionalNumber = arena->addType(UnionType{{builtinTypes->nilType, builtinTypes->numberType}}); - const TypeId optionalBoolean = arena->addType(UnionType{{builtinTypes->nilType, builtinTypes->booleanType}}); - - size_t initIndex = context.callSite->self ? 1 : 2; - if (params.size() >= 3 && context.callSite->args.size > initIndex) - context.solver->unify(params[2], optionalNumber, context.solver->rootScope); - - if (params.size() == 4 && context.callSite->args.size > plainIndex) - context.solver->unify(params[3], optionalBoolean, context.solver->rootScope); - - returnTypes.insert(returnTypes.begin(), {optionalNumber, optionalNumber}); - - const TypePackId returnList = arena->addTypePack(returnTypes); - asMutable(context.result)->ty.emplace(returnList); - return true; + return arena->addType(FreeType{scope, builtinTypes->neverType, builtinTypes->unknownType}); } std::vector filterMap(TypeId type, TypeIdPredicate predicate) @@ -1787,4 +1381,11 @@ bool GenericTypePackDefinition::operator==(const GenericTypePackDefinition& rhs) return tp == rhs.tp && defaultValue == rhs.defaultValue; } +template<> +LUAU_NOINLINE Unifiable::Bound* emplaceType(Type* ty, TypeId& tyArg) +{ + LUAU_ASSERT(ty != follow(tyArg)); + return &ty->ty.emplace(tyArg); +} + } // namespace Luau diff --git a/third_party/luau/Analysis/src/TypeArena.cpp b/third_party/luau/Analysis/src/TypeArena.cpp index ed51517e..6cf81471 100644 --- a/third_party/luau/Analysis/src/TypeArena.cpp +++ b/third_party/luau/Analysis/src/TypeArena.cpp @@ -94,6 +94,26 @@ TypePackId TypeArena::addTypePack(TypePackVar tp) return allocated; } +TypeId TypeArena::addTypeFunction(const TypeFunction& function, std::initializer_list types) +{ + return addType(TypeFunctionInstanceType{function, std::move(types)}); +} + +TypeId TypeArena::addTypeFunction(const TypeFunction& function, std::vector typeArguments, std::vector packArguments) +{ + return addType(TypeFunctionInstanceType{function, std::move(typeArguments), std::move(packArguments)}); +} + +TypePackId TypeArena::addTypePackFunction(const TypePackFunction& function, std::initializer_list types) +{ + return addTypePack(TypeFunctionInstanceTypePack{NotNull{&function}, std::move(types)}); +} + +TypePackId TypeArena::addTypePackFunction(const TypePackFunction& function, std::vector typeArguments, std::vector packArguments) +{ + return addTypePack(TypeFunctionInstanceTypePack{NotNull{&function}, std::move(typeArguments), std::move(packArguments)}); +} + void freeze(TypeArena& arena) { if (!FFlag::DebugLuauFreezeArena) diff --git a/third_party/luau/Analysis/src/TypeAttach.cpp b/third_party/luau/Analysis/src/TypeAttach.cpp index 86f78165..a288cfbe 100644 --- a/third_party/luau/Analysis/src/TypeAttach.cpp +++ b/third_party/luau/Analysis/src/TypeAttach.cpp @@ -9,6 +9,7 @@ #include "Luau/TypeInfer.h" #include "Luau/TypePack.h" #include "Luau/Type.h" +#include "Luau/TypeFunction.h" #include @@ -103,7 +104,14 @@ class TypeRehydrationVisitor return allocator->alloc(Location(), std::nullopt, AstName("string"), std::nullopt, Location()); case PrimitiveType::Thread: return allocator->alloc(Location(), std::nullopt, AstName("thread"), std::nullopt, Location()); + case PrimitiveType::Buffer: + return allocator->alloc(Location(), std::nullopt, AstName("buffer"), std::nullopt, Location()); + case PrimitiveType::Function: + return allocator->alloc(Location(), std::nullopt, AstName("function"), std::nullopt, Location()); + case PrimitiveType::Table: + return allocator->alloc(Location(), std::nullopt, AstName("table"), std::nullopt, Location()); default: + LUAU_ASSERT(false); // this should be unreachable. return nullptr; } } @@ -158,7 +166,8 @@ class TypeRehydrationVisitor } return allocator->alloc( - Location(), std::nullopt, AstName(ttv.name->c_str()), std::nullopt, Location(), parameters.size != 0, parameters); + Location(), std::nullopt, AstName(ttv.name->c_str()), std::nullopt, Location(), parameters.size != 0, parameters + ); } if (hasSeen(&ttv)) @@ -226,7 +235,17 @@ class TypeRehydrationVisitor idx++; } - return allocator->alloc(Location(), props); + AstTableIndexer* indexer = nullptr; + if (ctv.indexer) + { + RecursionCounter counter(&count); + + indexer = allocator->alloc(); + indexer->indexType = Luau::visit(*this, ctv.indexer->indexType->ty); + indexer->resultType = Luau::visit(*this, ctv.indexer->indexResultType->ty); + } + + return allocator->alloc(Location(), props, indexer); } AstType* operator()(const FunctionType& ftv) @@ -301,7 +320,8 @@ class TypeRehydrationVisitor retTailAnnotation = rehydrate(*retTail); return allocator->alloc( - Location(), generics, genericPacks, AstTypeList{argTypes, argTailAnnotation}, argNames, AstTypeList{returnTypes, retTailAnnotation}); + Location(), generics, genericPacks, AstTypeList{argTypes, argTailAnnotation}, argNames, AstTypeList{returnTypes, retTailAnnotation} + ); } AstType* operator()(const Unifiable::Error&) { @@ -310,13 +330,14 @@ class TypeRehydrationVisitor AstType* operator()(const GenericType& gtv) { return allocator->alloc( - Location(), std::nullopt, AstName(getName(allocator, syntheticNames, gtv)), std::nullopt, Location()); + Location(), std::nullopt, AstName(getName(allocator, syntheticNames, gtv)), std::nullopt, Location() + ); } AstType* operator()(const Unifiable::Bound& bound) { return Luau::visit(*this, bound.boundTo->ty); } - AstType* operator()(const FreeType& ftv) + AstType* operator()(const FreeType& ft) { return allocator->alloc(Location(), std::nullopt, AstName("free"), std::nullopt, Location()); } @@ -362,6 +383,10 @@ class TypeRehydrationVisitor // FIXME: do the same thing we do with ErrorType throw InternalCompilerError("Cannot convert NegationType into AstNode"); } + AstType* operator()(const TypeFunctionInstanceType& tfit) + { + return allocator->alloc(Location(), std::nullopt, AstName{tfit.function->name.c_str()}, std::nullopt, Location()); + } private: Allocator* allocator; @@ -432,6 +457,11 @@ class TypePackRehydrationVisitor return allocator->alloc(Location(), AstName("Unifiable")); } + AstTypePack* operator()(const TypeFunctionInstanceTypePack& tfitp) const + { + return allocator->alloc(Location(), AstName(tfitp.function->name.c_str())); + } + private: Allocator* allocator; SyntheticNames* syntheticNames; diff --git a/third_party/luau/Analysis/src/TypeChecker2.cpp b/third_party/luau/Analysis/src/TypeChecker2.cpp index a103df14..f53c994e 100644 --- a/third_party/luau/Analysis/src/TypeChecker2.cpp +++ b/third_party/luau/Analysis/src/TypeChecker2.cpp @@ -3,26 +3,34 @@ #include "Luau/Ast.h" #include "Luau/AstQuery.h" -#include "Luau/Clone.h" #include "Luau/Common.h" #include "Luau/DcrLogger.h" +#include "Luau/DenseHash.h" #include "Luau/Error.h" +#include "Luau/InsertionOrderedMap.h" #include "Luau/Instantiation.h" #include "Luau/Metamethods.h" #include "Luau/Normalize.h" +#include "Luau/OverloadResolution.h" +#include "Luau/Subtyping.h" +#include "Luau/TimeTrace.h" #include "Luau/ToString.h" #include "Luau/TxnLog.h" #include "Luau/Type.h" -#include "Luau/TypeReduction.h" +#include "Luau/TypeFunction.h" +#include "Luau/TypeFunctionReductionGuesser.h" +#include "Luau/TypeFwd.h" +#include "Luau/TypePack.h" +#include "Luau/TypePath.h" #include "Luau/TypeUtils.h" -#include "Luau/Unifier.h" +#include "Luau/TypeOrPack.h" +#include "Luau/VisitType.h" #include +#include +#include LUAU_FASTFLAG(DebugLuauMagicTypes) -LUAU_FASTFLAG(DebugLuauDontReduceTypes) - -LUAU_FASTFLAG(LuauNegatedClassTypes) namespace Luau { @@ -32,6 +40,7 @@ namespace Luau using PrintLineProc = void (*)(const std::string&); extern PrintLineProc luauPrintLine; + /* Push a scope onto the end of a stack for the lifetime of the StackPusher instance. * TypeChecker2 uses this to maintain knowledge about which scope encloses every * given AstNode. @@ -84,29 +93,329 @@ static std::optional getIdentifierOfBaseVar(AstExpr* node) return std::nullopt; } +template +bool areEquivalent(const T& a, const T& b) +{ + if (a.function != b.function) + return false; + + if (a.typeArguments.size() != b.typeArguments.size() || a.packArguments.size() != b.packArguments.size()) + return false; + + for (size_t i = 0; i < a.typeArguments.size(); ++i) + { + if (follow(a.typeArguments[i]) != follow(b.typeArguments[i])) + return false; + } + + for (size_t i = 0; i < a.packArguments.size(); ++i) + { + if (follow(a.packArguments[i]) != follow(b.packArguments[i])) + return false; + } + + return true; +} + +struct TypeFunctionFinder : TypeOnceVisitor +{ + DenseHashSet mentionedFunctions{nullptr}; + DenseHashSet mentionedFunctionPacks{nullptr}; + + bool visit(TypeId ty, const TypeFunctionInstanceType&) override + { + mentionedFunctions.insert(ty); + return true; + } + + bool visit(TypePackId tp, const TypeFunctionInstanceTypePack&) override + { + mentionedFunctionPacks.insert(tp); + return true; + } +}; + +struct InternalTypeFunctionFinder : TypeOnceVisitor +{ + DenseHashSet internalFunctions{nullptr}; + DenseHashSet internalPackFunctions{nullptr}; + DenseHashSet mentionedFunctions{nullptr}; + DenseHashSet mentionedFunctionPacks{nullptr}; + + InternalTypeFunctionFinder(std::vector& declStack) + { + TypeFunctionFinder f; + for (TypeId fn : declStack) + f.traverse(fn); + + mentionedFunctions = std::move(f.mentionedFunctions); + mentionedFunctionPacks = std::move(f.mentionedFunctionPacks); + } + + bool visit(TypeId ty, const TypeFunctionInstanceType& tfit) override + { + bool hasGeneric = false; + + for (TypeId p : tfit.typeArguments) + { + if (get(follow(p))) + { + hasGeneric = true; + break; + } + } + + for (TypePackId p : tfit.packArguments) + { + if (get(follow(p))) + { + hasGeneric = true; + break; + } + } + + if (hasGeneric) + { + for (TypeId mentioned : mentionedFunctions) + { + const TypeFunctionInstanceType* mentionedTfit = get(mentioned); + LUAU_ASSERT(mentionedTfit); + if (areEquivalent(tfit, *mentionedTfit)) + { + return true; + } + } + + internalFunctions.insert(ty); + } + + return true; + } + + bool visit(TypePackId tp, const TypeFunctionInstanceTypePack& tfitp) override + { + bool hasGeneric = false; + + for (TypeId p : tfitp.typeArguments) + { + if (get(follow(p))) + { + hasGeneric = true; + break; + } + } + + for (TypePackId p : tfitp.packArguments) + { + if (get(follow(p))) + { + hasGeneric = true; + break; + } + } + + if (hasGeneric) + { + for (TypePackId mentioned : mentionedFunctionPacks) + { + const TypeFunctionInstanceTypePack* mentionedTfitp = get(mentioned); + LUAU_ASSERT(mentionedTfitp); + if (areEquivalent(tfitp, *mentionedTfitp)) + { + return true; + } + } + + internalPackFunctions.insert(tp); + } + + return true; + } +}; + struct TypeChecker2 { NotNull builtinTypes; DcrLogger* logger; - NotNull ice; + const NotNull limits; + const NotNull ice; const SourceModule* sourceModule; Module* module; - TypeArena testArena; + TypeContext typeContext = TypeContext::Default; std::vector> stack; + std::vector functionDeclStack; + + DenseHashSet seenTypeFunctionInstances{nullptr}; Normalizer normalizer; + Subtyping _subtyping; + NotNull subtyping; - TypeChecker2(NotNull builtinTypes, NotNull unifierState, DcrLogger* logger, const SourceModule* sourceModule, Module* module) + TypeChecker2(NotNull builtinTypes, NotNull unifierState, NotNull limits, DcrLogger* logger, + const SourceModule* sourceModule, Module* module) : builtinTypes(builtinTypes) , logger(logger) + , limits(limits) , ice(unifierState->iceHandler) , sourceModule(sourceModule) , module(module) - , normalizer{&testArena, builtinTypes, unifierState} + , normalizer{&module->internalTypes, builtinTypes, unifierState, /* cacheInhabitance */ true} + , _subtyping{builtinTypes, NotNull{&module->internalTypes}, NotNull{&normalizer}, NotNull{unifierState->iceHandler}, + NotNull{module->getModuleScope().get()}} + , subtyping(&_subtyping) { } + static bool allowsNoReturnValues(const TypePackId tp) + { + for (TypeId ty : tp) + { + if (!get(follow(ty))) + return false; + } + + return true; + } + + static Location getEndLocation(const AstExprFunction* function) + { + Location loc = function->location; + if (loc.begin.line != loc.end.line) + { + Position begin = loc.end; + begin.column = std::max(0u, begin.column - 3); + loc = Location(begin, 3); + } + + return loc; + } + + bool isErrorCall(const AstExprCall* call) + { + const AstExprGlobal* global = call->func->as(); + if (!global) + return false; + + if (global->name == "error") + return true; + else if (global->name == "assert") + { + // assert() will error because it is missing the first argument + if (call->args.size == 0) + return true; + + if (AstExprConstantBool* expr = call->args.data[0]->as()) + if (!expr->value) + return true; + } + + return false; + } + + bool hasBreak(AstStat* node) + { + if (AstStatBlock* stat = node->as()) + { + for (size_t i = 0; i < stat->body.size; ++i) + { + if (hasBreak(stat->body.data[i])) + return true; + } + + return false; + } + + if (node->is()) + return true; + + if (AstStatIf* stat = node->as()) + { + if (hasBreak(stat->thenbody)) + return true; + + if (stat->elsebody && hasBreak(stat->elsebody)) + return true; + + return false; + } + + return false; + } + + // returns the last statement before the block implicitly exits, or nullptr if the block does not implicitly exit + // i.e. returns nullptr if the block returns properly or never returns + const AstStat* getFallthrough(const AstStat* node) + { + if (const AstStatBlock* stat = node->as()) + { + if (stat->body.size == 0) + return stat; + + for (size_t i = 0; i < stat->body.size - 1; ++i) + { + if (getFallthrough(stat->body.data[i]) == nullptr) + return nullptr; + } + + return getFallthrough(stat->body.data[stat->body.size - 1]); + } + + if (const AstStatIf* stat = node->as()) + { + if (const AstStat* thenf = getFallthrough(stat->thenbody)) + return thenf; + + if (stat->elsebody) + { + if (const AstStat* elsef = getFallthrough(stat->elsebody)) + return elsef; + + return nullptr; + } + else + return stat; + } + + if (node->is()) + return nullptr; + + if (const AstStatExpr* stat = node->as()) + { + if (AstExprCall* call = stat->expr->as(); call && isErrorCall(call)) + return nullptr; + + return stat; + } + + if (const AstStatWhile* stat = node->as()) + { + if (AstExprConstantBool* expr = stat->condition->as()) + { + if (expr->value && !hasBreak(stat->body)) + return nullptr; + } + + return node; + } + + if (const AstStatRepeat* stat = node->as()) + { + if (AstExprConstantBool* expr = stat->condition->as()) + { + if (!expr->value && !hasBreak(stat->body)) + return nullptr; + } + + if (getFallthrough(stat->body) == nullptr) + return nullptr; + + return node; + } + + return node; + } + std::optional pushStack(AstNode* node) { if (Scope** scope = module->astScopes.find(node)) @@ -115,6 +424,36 @@ struct TypeChecker2 return std::nullopt; } + void checkForInternalTypeFunction(TypeId ty, Location location) + { + InternalTypeFunctionFinder finder(functionDeclStack); + finder.traverse(ty); + + for (TypeId internal : finder.internalFunctions) + reportError(WhereClauseNeeded{internal}, location); + + for (TypePackId internal : finder.internalPackFunctions) + reportError(PackWhereClauseNeeded{internal}, location); + } + + TypeId checkForTypeFunctionInhabitance(TypeId instance, Location location) + { + if (seenTypeFunctionInstances.find(instance)) + return instance; + seenTypeFunctionInstances.insert(instance); + + ErrorVec errors = reduceTypeFunctions( + instance, + location, + TypeFunctionContext{NotNull{&module->internalTypes}, builtinTypes, stack.back(), NotNull{&normalizer}, ice, limits}, + true + ) + .errors; + if (!isErrorSuppressing(location, instance)) + reportErrors(std::move(errors)); + return instance; + } + TypePackId lookupPack(AstExpr* expr) { // If a type isn't in the type graph, it probably means that a recursion limit was exceeded. @@ -134,11 +473,11 @@ struct TypeChecker2 // allows us not to think about this very much in the actual typechecking logic. TypeId* ty = module->astTypes.find(expr); if (ty) - return follow(*ty); + return checkForTypeFunctionInhabitance(follow(*ty), expr->location); TypePackId* tp = module->astTypePacks.find(expr); if (tp) - return flattenPack(*tp); + return checkForTypeFunctionInhabitance(flattenPack(*tp), expr->location); return builtinTypes->anyType; } @@ -153,7 +492,8 @@ struct TypeChecker2 { TypeId argTy = lookupAnnotation(ref->parameters.data[0].type); luauPrintLine(format( - "_luau_print (%d, %d): %s\n", annotation->location.begin.line, annotation->location.begin.column, toString(argTy).c_str())); + "_luau_print (%d, %d): %s\n", annotation->location.begin.line, annotation->location.begin.column, toString(argTy).c_str() + )); return follow(argTy); } } @@ -161,14 +501,15 @@ struct TypeChecker2 TypeId* ty = module->astResolvedTypes.find(annotation); LUAU_ASSERT(ty); - return follow(*ty); + return checkForTypeFunctionInhabitance(follow(*ty), annotation->location); } - TypePackId lookupPackAnnotation(AstTypePack* annotation) + std::optional lookupPackAnnotation(AstTypePack* annotation) { TypePackId* tp = module->astResolvedTypePacks.find(annotation); - LUAU_ASSERT(tp); - return follow(*tp); + if (tp != nullptr) + return {follow(*tp)}; + return {}; } TypeId lookupExpectedType(AstExpr* expr) @@ -206,20 +547,21 @@ struct TypeChecker2 Scope* findInnermostScope(Location location) { Scope* bestScope = module->getModuleScope().get(); - Location bestLocation = module->scopes[0].first; - for (size_t i = 0; i < module->scopes.size(); ++i) + bool didNarrow; + do { - auto& [scopeBounds, scope] = module->scopes[i]; - if (scopeBounds.encloses(location)) + didNarrow = false; + for (auto scope : bestScope->children) { - if (scopeBounds.begin > bestLocation.begin || scopeBounds.end < bestLocation.end) + if (scope->location.encloses(location)) { bestScope = scope.get(); - bestLocation = scopeBounds; + didNarrow = true; + break; } } - } + } while (didNarrow && bestScope->children.size() > 0); return bestScope; } @@ -260,6 +602,8 @@ struct TypeChecker2 return visit(s); else if (auto s = stat->as()) return visit(s); + else if (auto f = stat->as()) + return visit(f); else if (auto s = stat->as()) return visit(s); else if (auto s = stat->as()) @@ -282,7 +626,11 @@ struct TypeChecker2 void visit(AstStatIf* ifStatement) { - visit(ifStatement->condition, ValueContext::RValue); + { + InConditionalContext flipper{&typeContext}; + visit(ifStatement->condition, ValueContext::RValue); + } + visit(ifStatement->thenbody); if (ifStatement->elsebody) visit(ifStatement->elsebody); @@ -309,19 +657,10 @@ struct TypeChecker2 Scope* scope = findInnermostScope(ret->location); TypePackId expectedRetType = scope->returnType; - TypeArena* arena = &testArena; + TypeArena* arena = &module->internalTypes; TypePackId actualRetType = reconstructPack(ret->list, *arena); - Unifier u{NotNull{&normalizer}, Mode::Strict, stack.back(), ret->location, Covariant}; - - u.tryUnify(actualRetType, expectedRetType); - const bool ok = u.errors.empty() && u.log.empty(); - - if (!ok) - { - for (const TypeError& e : u.errors) - reportError(e); - } + testIsSubtype(actualRetType, expectedRetType, ret->location); for (AstExpr* expr : ret->list) visit(expr, ValueContext::RValue); @@ -352,11 +691,7 @@ struct TypeChecker2 TypeId annotationType = lookupAnnotation(var->annotation); TypeId valueType = value ? lookupType(value) : nullptr; if (valueType) - { - ErrorVec errors = tryUnify(stack.back(), value->location, valueType, annotationType); - if (!errors.empty()) - reportErrors(std::move(errors)); - } + testIsSubtype(valueType, annotationType, value->location); visit(var->annotation); } @@ -381,9 +716,7 @@ struct TypeChecker2 if (var->annotation) { TypeId varType = lookupAnnotation(var->annotation); - ErrorVec errors = tryUnify(stack.back(), value->location, valueTypes.head[j - i], varType); - if (!errors.empty()) - reportErrors(std::move(errors)); + testIsSubtype(valueTypes.head[j - i], varType, value->location); visit(var->annotation); } @@ -402,7 +735,8 @@ struct TypeChecker2 local->values.data[local->values.size - 1]->is() ? CountMismatch::FunctionResult : CountMismatch::ExprListResult, }, - errorLocation); + errorLocation + ); } } } @@ -410,20 +744,21 @@ struct TypeChecker2 void visit(AstStatFor* forStatement) { - NotNull scope = stack.back(); - if (forStatement->var->annotation) { visit(forStatement->var->annotation); - reportErrors(tryUnify(scope, forStatement->var->location, builtinTypes->numberType, lookupAnnotation(forStatement->var->annotation))); + + TypeId annotatedType = lookupAnnotation(forStatement->var->annotation); + testIsSubtype(builtinTypes->numberType, annotatedType, forStatement->var->location); } - auto checkNumber = [this, scope](AstExpr* expr) { + auto checkNumber = [this](AstExpr* expr) + { if (!expr) return; visit(expr, ValueContext::RValue); - reportErrors(tryUnify(scope, expr->location, lookupType(expr), builtinTypes->numberType)); + testIsSubtype(lookupType(expr), builtinTypes->numberType, expr->location); }; checkNumber(forStatement->from); @@ -451,7 +786,7 @@ struct TypeChecker2 return; NotNull scope = stack.back(); - TypeArena& arena = testArena; + TypeArena& arena = module->internalTypes; std::vector variableTypes; for (AstLocal* var : forInStatement->vars) @@ -461,12 +796,47 @@ struct TypeChecker2 variableTypes.emplace_back(*ty); } - // ugh. There's nothing in the AST to hang a whole type pack on for the - // set of iteratees, so we have to piece it back together by hand. + AstExpr* firstValue = forInStatement->values.data[0]; + + // we need to build up a typepack for the iterators/values portion of the for-in statement. std::vector valueTypes; - for (size_t i = 0; i < forInStatement->values.size - 1; ++i) + std::optional iteratorTail; + + // since the first value may be the only iterator (e.g. if it is a call), we want to + // look to see if it has a resulting typepack as our iterators. + TypePackId* retPack = module->astTypePacks.find(firstValue); + if (retPack) + { + auto [head, tail] = flatten(*retPack); + valueTypes = head; + iteratorTail = tail; + } + else + { + valueTypes.emplace_back(lookupType(firstValue)); + } + + // if the initial and expected types from the iterator unified during constraint solving, + // we'll have a resolved type to use here, but we'll only use it if either the iterator is + // directly present in the for-in statement or if we have an iterator state constraining us + TypeId* resolvedTy = module->astForInNextTypes.find(firstValue); + if (resolvedTy && (!retPack || valueTypes.size() > 1)) + valueTypes[0] = *resolvedTy; + + for (size_t i = 1; i < forInStatement->values.size - 1; ++i) + { valueTypes.emplace_back(lookupType(forInStatement->values.data[i])); - TypePackId iteratorTail = lookupPack(forInStatement->values.data[forInStatement->values.size - 1]); + } + + // if we had more than one value, the tail from the first value is no longer appropriate to use. + if (forInStatement->values.size > 1) + { + auto [head, tail] = flatten(lookupPack(forInStatement->values.data[forInStatement->values.size - 1])); + valueTypes.insert(valueTypes.end(), head.begin(), head.end()); + iteratorTail = tail; + } + + // and now we can put everything together to get the actual typepack of the iterators. TypePackId iteratorPack = arena.addTypePack(valueTypes, iteratorTail); // ... and then expand it out to 3 values (if possible) @@ -478,8 +848,8 @@ struct TypeChecker2 } TypeId iteratorTy = follow(iteratorTypes.head[0]); - auto checkFunction = [this, &arena, &scope, &forInStatement, &variableTypes]( - const FunctionType* iterFtv, std::vector iterTys, bool isMm) { + auto checkFunction = [this, &arena, &forInStatement, &variableTypes](const FunctionType* iterFtv, std::vector iterTys, bool isMm) + { if (iterTys.size() < 1 || iterTys.size() > 3) { if (isMm) @@ -496,13 +866,14 @@ struct TypeChecker2 { if (isMm) reportError( - GenericError{"__iter metamethod's next() function does not return enough values"}, getLocation(forInStatement->values)); + GenericError{"__iter metamethod's next() function does not return enough values"}, getLocation(forInStatement->values) + ); else reportError(GenericError{"next() does not return enough values"}, forInStatement->values.data[0]->location); } for (size_t i = 0; i < std::min(expectedVariableTypes.head.size(), variableTypes.size()); ++i) - reportErrors(tryUnify(scope, forInStatement->vars.data[i]->location, variableTypes[i], expectedVariableTypes.head[i])); + testIsSubtype(variableTypes[i], expectedVariableTypes.head[i], forInStatement->vars.data[i]->location); // nextFn is going to be invoked with (arrayTy, startIndexTy) @@ -513,38 +884,54 @@ struct TypeChecker2 // This depends on the types in iterateePack and therefore // iteratorTypes. + // If the iteratee is an error type, then we can't really say anything else about iteration over it. + // After all, it _could've_ been a table. + if (get(follow(flattenPack(iterFtv->argTypes)))) + return; + // If iteratorTypes is too short to be a valid call to nextFn, we have to report a count mismatch error. // If 2 is too short to be a valid call to nextFn, we have to report a count mismatch error. // If 2 is too long to be a valid call to nextFn, we have to report a count mismatch error. auto [minCount, maxCount] = getParameterExtents(TxnLog::empty(), iterFtv->argTypes, /*includeHiddenVariadics*/ true); - if (minCount > 2) - reportError(CountMismatch{2, std::nullopt, minCount, CountMismatch::Arg}, forInStatement->vars.data[0]->location); - if (maxCount && *maxCount < 2) - reportError(CountMismatch{2, std::nullopt, *maxCount, CountMismatch::Arg}, forInStatement->vars.data[0]->location); - TypePack flattenedArgTypes = extendTypePack(arena, builtinTypes, iterFtv->argTypes, 2); size_t firstIterationArgCount = iterTys.empty() ? 0 : iterTys.size() - 1; size_t actualArgCount = expectedVariableTypes.head.size(); - if (firstIterationArgCount < minCount) - reportError(CountMismatch{2, std::nullopt, firstIterationArgCount, CountMismatch::Arg}, forInStatement->vars.data[0]->location); + { + if (isMm) + reportError(GenericError{"__iter metamethod must return (next[, table[, state]])"}, getLocation(forInStatement->values)); + else + reportError(CountMismatch{2, std::nullopt, firstIterationArgCount, CountMismatch::Arg}, forInStatement->values.data[0]->location); + } + else if (actualArgCount < minCount) - reportError(CountMismatch{2, std::nullopt, actualArgCount, CountMismatch::Arg}, forInStatement->vars.data[0]->location); + { + if (isMm) + reportError(GenericError{"__iter metamethod must return (next[, table[, state]])"}, getLocation(forInStatement->values)); + else + reportError(CountMismatch{2, std::nullopt, firstIterationArgCount, CountMismatch::Arg}, forInStatement->values.data[0]->location); + } + if (iterTys.size() >= 2 && flattenedArgTypes.head.size() > 0) { size_t valueIndex = forInStatement->values.size > 1 ? 1 : 0; - reportErrors(tryUnify(scope, forInStatement->values.data[valueIndex]->location, iterTys[1], flattenedArgTypes.head[0])); + testIsSubtype(iterTys[1], flattenedArgTypes.head[0], forInStatement->values.data[valueIndex]->location); } if (iterTys.size() == 3 && flattenedArgTypes.head.size() > 1) { size_t valueIndex = forInStatement->values.size > 2 ? 2 : 0; - reportErrors(tryUnify(scope, forInStatement->values.data[valueIndex]->location, iterTys[2], flattenedArgTypes.head[1])); + testIsSubtype(iterTys[2], flattenedArgTypes.head[1], forInStatement->values.data[valueIndex]->location); } }; + std::shared_ptr iteratorNorm = normalizer.normalize(iteratorTy); + + if (!iteratorNorm) + reportError(NormalizationTooComplex{}, firstValue->location); + /* * If the first iterator argument is a function * * There must be 1 to 3 iterator arguments. Name them (nextTy, @@ -568,9 +955,9 @@ struct TypeChecker2 { if ((forInStatement->vars.size == 1 || forInStatement->vars.size == 2) && ttv->indexer) { - reportErrors(tryUnify(scope, forInStatement->vars.data[0]->location, variableTypes[0], ttv->indexer->indexType)); + testIsSubtype(variableTypes[0], ttv->indexer->indexType, forInStatement->vars.data[0]->location); if (variableTypes.size() == 2) - reportErrors(tryUnify(scope, forInStatement->vars.data[1]->location, variableTypes[1], ttv->indexer->indexResultType)); + testIsSubtype(variableTypes[1], ttv->indexer->indexResultType, forInStatement->vars.data[1]->location); } else reportError(GenericError{"Cannot iterate over a table without indexer"}, forInStatement->values.data[0]->location); @@ -579,21 +966,21 @@ struct TypeChecker2 { // nothing } - else if (isOptional(iteratorTy)) + else if (isOptional(iteratorTy) && !(iteratorNorm && iteratorNorm->shouldSuppressErrors())) { reportError(OptionalValueAccess{iteratorTy}, forInStatement->values.data[0]->location); } else if (std::optional iterMmTy = findMetatableEntry(builtinTypes, module->errors, iteratorTy, "__iter", forInStatement->values.data[0]->location)) { - Instantiation instantiation{TxnLog::empty(), &arena, TypeLevel{}, scope}; + Instantiation instantiation{TxnLog::empty(), &arena, builtinTypes, TypeLevel{}, scope}; - if (std::optional instantiatedIterMmTy = instantiation.substitute(*iterMmTy)) + if (std::optional instantiatedIterMmTy = instantiate(builtinTypes, NotNull{&arena}, limits, scope, *iterMmTy)) { if (const FunctionType* iterMmFtv = get(*instantiatedIterMmTy)) { TypePackId argPack = arena.addTypePack({iteratorTy}); - reportErrors(tryUnify(scope, forInStatement->values.data[0]->location, argPack, iterMmFtv->argTypes)); + testIsSubtype(argPack, iterMmFtv->argTypes, forInStatement->values.data[0]->location); TypePack mmIteratorTypes = extendTypePack(arena, builtinTypes, iterMmFtv->retTypes, 3); @@ -614,7 +1001,7 @@ struct TypeChecker2 { checkFunction(nextFtv, instantiatedIteratorTypes, true); } - else + else if (!isErrorSuppressing(forInStatement->values.data[0]->location, *instantiatedNextFn)) { reportError(CannotCallNonFunction{*instantiatedNextFn}, forInStatement->values.data[0]->location); } @@ -624,7 +1011,7 @@ struct TypeChecker2 reportError(UnificationTooComplex{}, forInStatement->values.data[0]->location); } } - else + else if (!isErrorSuppressing(forInStatement->values.data[0]->location, *iterMmTy)) { // TODO: This will not tell the user that this is because the // metamethod isn't callable. This is not ideal, and we should @@ -640,12 +1027,63 @@ struct TypeChecker2 reportError(UnificationTooComplex{}, forInStatement->values.data[0]->location); } } - else + else if (iteratorNorm && iteratorNorm->hasTables()) + { + // Ok. All tables can be iterated. + } + else if (!iteratorNorm || !iteratorNorm->shouldSuppressErrors()) { reportError(CannotCallNonFunction{iteratorTy}, forInStatement->values.data[0]->location); } } + std::optional getBindingType(AstExpr* expr) + { + if (auto localExpr = expr->as()) + { + Scope* s = stack.back(); + return s->lookup(localExpr->local); + } + else if (auto globalExpr = expr->as()) + { + Scope* s = stack.back(); + return s->lookup(globalExpr->name); + } + else + return std::nullopt; + } + + // this should only be called if the type of `lhs` is `never`. + void reportErrorsFromAssigningToNever(AstExpr* lhs, TypeId rhsType) + { + + if (auto indexName = lhs->as()) + { + TypeId indexedType = lookupType(indexName->expr); + + // if it's already never, I don't think we have anything to do here. + if (get(indexedType)) + return; + + std::string prop = indexName->index.value; + + std::shared_ptr norm = normalizer.normalize(indexedType); + if (!norm) + { + reportError(NormalizationTooComplex{}, lhs->location); + return; + } + + // if the type is error suppressing, we don't actually have any work left to do. + if (norm->shouldSuppressErrors()) + return; + + const auto propTypes = lookupProp(norm.get(), prop, ValueContext::LValue, lhs->location, builtinTypes->stringType, module->errors); + + reportError(CannotAssignToNever{rhsType, propTypes.typesOfProp, CannotAssignToNever::Reason::PropertyNarrowed}, lhs->location); + } + } + void visit(AstStatAssign* assign) { size_t count = std::min(assign->vars.size, assign->values.size); @@ -661,11 +1099,19 @@ struct TypeChecker2 TypeId rhsType = lookupType(rhs); if (get(lhsType)) + { + reportErrorsFromAssigningToNever(lhs, rhsType); continue; + } + + bool ok = testIsSubtype(rhsType, lhsType, rhs->location); - if (!isSubtype(rhsType, lhsType, stack.back())) + // If rhsType location); + std::optional bindingType = getBindingType(lhs); + if (bindingType) + testIsSubtype(rhsType, *bindingType, rhs->location); } } } @@ -673,10 +1119,13 @@ struct TypeChecker2 void visit(AstStatCompoundAssign* stat) { AstExprBinary fake{stat->location, stat->op, stat->var, stat->value}; - TypeId resultTy = visit(&fake, stat); + visit(&fake, stat); + + TypeId* resultTy = module->astCompoundAssignResultTypes.find(stat); + LUAU_ASSERT(resultTy); TypeId varTy = lookupType(stat->var); - reportErrors(tryUnify(stack.back(), stat->location, resultTy, varTy)); + testIsSubtype(*resultTy, varTy, stat->location); } void visit(AstStatFunction* stat) @@ -705,6 +1154,13 @@ struct TypeChecker2 visit(stat->type); } + void visit(AstStatTypeFunction* stat) + { + // TODO: add type checking for user-defined type functions + + reportError(TypeError{stat->location, GenericError{"This syntax is not supported"}}); + } + void visit(AstTypeList types) { for (AstType* type : types.types) @@ -796,34 +1252,48 @@ struct TypeChecker2 void visit(AstExprConstantNil* expr) { - NotNull scope = stack.back(); +#if defined(LUAU_ENABLE_ASSERT) TypeId actualType = lookupType(expr); TypeId expectedType = builtinTypes->nilType; - LUAU_ASSERT(isSubtype(actualType, expectedType, scope)); + + SubtypingResult r = subtyping->isSubtype(actualType, expectedType); + LUAU_ASSERT(r.isSubtype || isErrorSuppressing(expr->location, actualType)); +#endif } void visit(AstExprConstantBool* expr) { - NotNull scope = stack.back(); - TypeId actualType = lookupType(expr); - TypeId expectedType = builtinTypes->booleanType; - LUAU_ASSERT(isSubtype(actualType, expectedType, scope)); + // booleans use specialized inference logic for singleton types, which can lead to real type errors here. + + const TypeId bestType = expr->value ? builtinTypes->trueType : builtinTypes->falseType; + const TypeId inferredType = lookupType(expr); + + const SubtypingResult r = subtyping->isSubtype(bestType, inferredType); + if (!r.isSubtype && !isErrorSuppressing(expr->location, inferredType)) + reportError(TypeMismatch{inferredType, bestType}, expr->location); } void visit(AstExprConstantNumber* expr) { - NotNull scope = stack.back(); - TypeId actualType = lookupType(expr); - TypeId expectedType = builtinTypes->numberType; - LUAU_ASSERT(isSubtype(actualType, expectedType, scope)); +#if defined(LUAU_ENABLE_ASSERT) + const TypeId bestType = builtinTypes->numberType; + const TypeId inferredType = lookupType(expr); + + const SubtypingResult r = subtyping->isSubtype(bestType, inferredType); + LUAU_ASSERT(r.isSubtype || isErrorSuppressing(expr->location, inferredType)); +#endif } void visit(AstExprConstantString* expr) { - NotNull scope = stack.back(); - TypeId actualType = lookupType(expr); - TypeId expectedType = builtinTypes->stringType; - LUAU_ASSERT(isSubtype(actualType, expectedType, scope)); + // strings use specialized inference logic for singleton types, which can lead to real type errors here. + + const TypeId bestType = module->internalTypes.addType(SingletonType{StringSingleton{std::string{expr->value.data, expr->value.size}}}); + const TypeId inferredType = lookupType(expr); + + const SubtypingResult r = subtyping->isSubtype(bestType, inferredType); + if (!r.isSubtype && !isErrorSuppressing(expr->location, inferredType)) + reportError(TypeMismatch{inferredType, bestType}, expr->location); } void visit(AstExprLocal* expr) @@ -833,7 +1303,9 @@ struct TypeChecker2 void visit(AstExprGlobal* expr) { - // TODO! + NotNull scope = stack.back(); + if (!scope->lookup(expr->name)) + reportError(UnknownSymbol{expr->name.value, UnknownSymbol::Binding}, expr->location); } void visit(AstExprVarargs* expr) @@ -841,189 +1313,50 @@ struct TypeChecker2 // TODO! } - ErrorVec visitOverload(AstExprCall* call, NotNull overloadFunctionType, const std::vector& argLocs, - TypePackId expectedArgTypes, TypePackId expectedRetType) + // Note: this is intentionally separated from `visit(AstExprCall*)` for stack allocation purposes. + void visitCall(AstExprCall* call) { - ErrorVec overloadErrors = - tryUnify(stack.back(), call->location, overloadFunctionType->retTypes, expectedRetType, CountMismatch::FunctionResult); - - size_t argIndex = 0; - auto inferredArgIt = begin(overloadFunctionType->argTypes); - auto expectedArgIt = begin(expectedArgTypes); - while (inferredArgIt != end(overloadFunctionType->argTypes) && expectedArgIt != end(expectedArgTypes)) - { - Location argLoc = (argIndex >= argLocs.size()) ? argLocs.back() : argLocs[argIndex]; - ErrorVec argErrors = tryUnify(stack.back(), argLoc, *expectedArgIt, *inferredArgIt); - for (TypeError e : argErrors) - overloadErrors.emplace_back(e); - - ++argIndex; - ++inferredArgIt; - ++expectedArgIt; - } - - // piggyback on the unifier for arity checking, but we can't do this for checking the actual arguments since the locations would be bad - ErrorVec argumentErrors = tryUnify(stack.back(), call->location, expectedArgTypes, overloadFunctionType->argTypes); - for (TypeError e : argumentErrors) - if (get(e) != nullptr) - overloadErrors.emplace_back(std::move(e)); - - return overloadErrors; - } + TypePack args; + std::vector argExprs; + argExprs.reserve(call->args.size + 1); - void reportOverloadResolutionErrors(AstExprCall* call, std::vector overloads, TypePackId expectedArgTypes, - const std::vector& overloadsThatMatchArgCount, std::vector> overloadsErrors) - { - if (overloads.size() == 1) - { - reportErrors(std::get<0>(overloadsErrors.front())); + TypeId* originalCallTy = module->astOriginalCallTypes.find(call); + TypeId* selectedOverloadTy = module->astOverloadResolvedTypes.find(call); + if (!originalCallTy) return; - } - - std::vector overloadTypes = overloadsThatMatchArgCount; - if (overloadsThatMatchArgCount.size() == 0) - { - reportError(GenericError{"No overload for function accepts " + std::to_string(size(expectedArgTypes)) + " arguments."}, call->location); - // If no overloads match argument count, just list all overloads. - overloadTypes = overloads; - } - else - { - // Report errors of the first argument-count-matching, but failing overload - TypeId overload = overloadsThatMatchArgCount[0]; - - // Remove the overload we are reporting errors about from the list of alternatives - overloadTypes.erase(std::remove(overloadTypes.begin(), overloadTypes.end(), overload), overloadTypes.end()); - - const FunctionType* ftv = get(overload); - LUAU_ASSERT(ftv); // overload must be a function type here - - auto error = std::find_if(overloadsErrors.begin(), overloadsErrors.end(), [overload](const std::pair& e) { - return overload == e.second; - }); - - LUAU_ASSERT(error != overloadsErrors.end()); - reportErrors(std::get<0>(*error)); - - // If only one overload matched, we don't need this error because we provided the previous errors. - if (overloadsThatMatchArgCount.size() == 1) - return; - } - std::string s; - for (size_t i = 0; i < overloadTypes.size(); ++i) - { - TypeId overload = follow(overloadTypes[i]); - - if (i > 0) - s += "; "; - - if (i > 0 && i == overloadTypes.size() - 1) - s += "and "; - - s += toString(overload); - } - - if (overloadsThatMatchArgCount.size() == 0) - reportError(ExtraInformation{"Available overloads: " + s}, call->func->location); - else - reportError(ExtraInformation{"Other overloads are also not viable: " + s}, call->func->location); - } + TypeId fnTy = follow(*originalCallTy); - // Note: this is intentionally separated from `visit(AstExprCall*)` for stack allocation purposes. - void visitCall(AstExprCall* call) - { - TypeArena* arena = &testArena; - Instantiation instantiation{TxnLog::empty(), arena, TypeLevel{}, stack.back()}; - TypePackId expectedRetType = lookupExpectedPack(call, *arena); - TypeId functionType = lookupType(call->func); - TypeId testFunctionType = functionType; - TypePack args; - std::vector argLocs; - argLocs.reserve(call->args.size + 1); - - if (get(functionType) || get(functionType) || get(functionType)) + if (get(fnTy) || get(fnTy) || get(fnTy)) return; - else if (std::optional callMm = findMetatableEntry(builtinTypes, module->errors, functionType, "__call", call->func->location)) - { - if (get(follow(*callMm))) - { - if (std::optional instantiatedCallMm = instantiation.substitute(*callMm)) - { - args.head.push_back(functionType); - argLocs.push_back(call->func->location); - testFunctionType = follow(*instantiatedCallMm); - } - else - { - reportError(UnificationTooComplex{}, call->func->location); - return; - } - } - else - { - // TODO: This doesn't flag the __call metamethod as the problem - // very clearly. - reportError(CannotCallNonFunction{*callMm}, call->func->location); - return; - } - } - else if (get(functionType)) + else if (isOptional(fnTy)) { - if (std::optional instantiatedFunctionType = instantiation.substitute(functionType)) - { - testFunctionType = *instantiatedFunctionType; - } - else + switch (shouldSuppressErrors(NotNull{&normalizer}, fnTy)) { - reportError(UnificationTooComplex{}, call->func->location); - return; + case ErrorSuppression::Suppress: + break; + case ErrorSuppression::NormalizationFailed: + reportError(NormalizationTooComplex{}, call->func->location); + // fallthrough intentional + case ErrorSuppression::DoNotSuppress: + reportError(OptionalValueAccess{fnTy}, call->func->location); } + return; } - else if (auto itv = get(functionType)) - { - // We do nothing here because we'll flatten the intersection later, but we don't want to report it as a non-function. - } - else if (auto utv = get(functionType)) - { - // Sometimes it's okay to call a union of functions, but only if all of the functions are the same. - // Another scenario we might run into it is if the union has a nil member. In this case, we want to throw an error - if (isOptional(functionType)) - { - reportError(OptionalValueAccess{functionType}, call->location); - return; - } - std::optional fst; - for (TypeId ty : utv) - { - if (!fst) - fst = follow(ty); - else if (fst != follow(ty)) - { - reportError(CannotCallNonFunction{functionType}, call->func->location); - return; - } - } - if (!fst) - ice->ice("UnionType had no elements, so fst is nullopt?"); + if (selectedOverloadTy) + { + SubtypingResult result = subtyping->isSubtype(*originalCallTy, *selectedOverloadTy); + if (result.isSubtype) + fnTy = follow(*selectedOverloadTy); - if (std::optional instantiatedFunctionType = instantiation.substitute(*fst)) - { - testFunctionType = *instantiatedFunctionType; - } - else + if (result.normalizationTooComplex) { - reportError(UnificationTooComplex{}, call->func->location); + reportError(NormalizationTooComplex{}, call->func->location); return; } } - else - { - reportError(CannotCallNonFunction{functionType}, call->func->location); - return; - } if (call->self) { @@ -1032,21 +1365,24 @@ struct TypeChecker2 ice->ice("method call expression has no 'self'"); args.head.push_back(lookupType(indexExpr->expr)); - argLocs.push_back(indexExpr->expr->location); + argExprs.push_back(indexExpr->expr); } for (size_t i = 0; i < call->args.size; ++i) { AstExpr* arg = call->args.data[i]; - argLocs.push_back(arg->location); + argExprs.push_back(arg); TypeId* argTy = module->astTypes.find(arg); if (argTy) args.head.push_back(*argTy); else if (i == call->args.size - 1) { - TypePackId* argTail = module->astTypePacks.find(arg); - if (argTail) - args.tail = *argTail; + if (auto argTail = module->astTypePacks.find(arg)) + { + auto [head, tail] = flatten(*argTail); + args.head.insert(args.head.end(), head.begin(), head.end()); + args.tail = tail; + } else args.tail = builtinTypes->anyTypePack; } @@ -1054,63 +1390,92 @@ struct TypeChecker2 args.head.push_back(builtinTypes->anyType); } - TypePackId expectedArgTypes = arena->addTypePack(args); - std::vector overloads = flattenIntersection(testFunctionType); - std::vector> overloadsErrors; - overloadsErrors.reserve(overloads.size()); - std::vector overloadsThatMatchArgCount; - - for (TypeId overload : overloads) - { - overload = follow(overload); + OverloadResolver resolver{ + builtinTypes, + NotNull{&module->internalTypes}, + NotNull{&normalizer}, + NotNull{stack.back()}, + ice, + limits, + call->location, + }; + resolver.resolve(fnTy, &args, call->func, &argExprs); - const FunctionType* overloadFn = get(overload); - if (!overloadFn) + auto norm = normalizer.normalize(fnTy); + if (!norm) + reportError(NormalizationTooComplex{}, call->func->location); + auto isInhabited = normalizer.isInhabited(norm.get()); + if (isInhabited == NormalizationResult::HitLimits) + reportError(NormalizationTooComplex{}, call->func->location); + + if (norm && norm->shouldSuppressErrors()) + return; // error suppressing function type! + else if (!resolver.ok.empty()) + return; // We found a call that works, so this is ok. + else if (!norm || isInhabited == NormalizationResult::False) + return; // Ok. Calling an uninhabited type is no-op. + else if (!resolver.nonviableOverloads.empty()) + { + if (resolver.nonviableOverloads.size() == 1 && !isErrorSuppressing(call->func->location, resolver.nonviableOverloads.front().first)) + reportErrors(resolver.nonviableOverloads.front().second); + else { - reportError(CannotCallNonFunction{overload}, call->func->location); - return; + std::string s = "None of the overloads for function that accept "; + s += std::to_string(args.head.size()); + s += " arguments are compatible."; + reportError(GenericError{std::move(s)}, call->location); } + } + else if (!resolver.arityMismatches.empty()) + { + if (resolver.arityMismatches.size() == 1) + reportErrors(resolver.arityMismatches.front().second); else { - // We may have to instantiate the overload in order for it to typecheck. - if (std::optional instantiatedFunctionType = instantiation.substitute(overload)) - { - overloadFn = get(*instantiatedFunctionType); - } - else - { - overloadsErrors.emplace_back(std::vector{TypeError{call->func->location, UnificationTooComplex{}}}, overload); - return; - } + std::string s = "No overload for function accepts "; + s += std::to_string(args.head.size()); + s += " arguments."; + reportError(GenericError{std::move(s)}, call->location); } + } + else if (!resolver.nonFunctions.empty()) + reportError(CannotCallNonFunction{fnTy}, call->func->location); + else + LUAU_ASSERT(!"Generating the best possible error from this function call resolution was inexhaustive?"); - ErrorVec overloadErrors = visitOverload(call, NotNull{overloadFn}, argLocs, expectedArgTypes, expectedRetType); - if (overloadErrors.empty()) - return; + if (resolver.arityMismatches.size() > 1 || resolver.nonviableOverloads.size() > 1) + { + std::string s = "Available overloads: "; - bool argMismatch = false; - for (auto error : overloadErrors) + std::vector overloads; + if (resolver.nonviableOverloads.empty()) { - CountMismatch* cm = get(error); - if (!cm) - continue; - - if (cm->context == CountMismatch::Arg) + for (const auto& [ty, p] : resolver.resolution) { - argMismatch = true; - break; + if (p.first == OverloadResolver::TypeIsNotAFunction) + continue; + + overloads.push_back(ty); } } + else + { + for (const auto& [ty, _] : resolver.nonviableOverloads) + overloads.push_back(ty); + } - if (!argMismatch) - overloadsThatMatchArgCount.push_back(overload); + for (size_t i = 0; i < overloads.size(); ++i) + { + if (i > 0) + s += (i == overloads.size() - 1) ? "; and " : "; "; - overloadsErrors.emplace_back(std::move(overloadErrors), overload); - } + s += toString(overloads[i]); + } - reportOverloadResolutionErrors(call, overloads, expectedArgTypes, overloadsThatMatchArgCount, overloadsErrors); + reportError(ExtraInformation{std::move(s)}, call->func->location); + } } void visit(AstExprCall* call) @@ -1159,91 +1524,248 @@ struct TypeChecker2 if (std::optional strippedUnion = tryStripUnionFromNil(ty)) { - reportError(OptionalValueAccess{ty}, location); + switch (shouldSuppressErrors(NotNull{&normalizer}, ty)) + { + case ErrorSuppression::Suppress: + break; + case ErrorSuppression::NormalizationFailed: + reportError(NormalizationTooComplex{}, location); + // fallthrough intentional + case ErrorSuppression::DoNotSuppress: + reportError(OptionalValueAccess{ty}, location); + } + return follow(*strippedUnion); } return ty; } - void visitExprName(AstExpr* expr, Location location, const std::string& propName, ValueContext context) + void visitExprName(AstExpr* expr, Location location, const std::string& propName, ValueContext context, TypeId astIndexExprTy) { visit(expr, ValueContext::RValue); - TypeId leftType = stripFromNilAndReport(lookupType(expr), location); - checkIndexTypeFromType(leftType, propName, location, context); + checkIndexTypeFromType(leftType, propName, context, location, astIndexExprTy); } void visit(AstExprIndexName* indexName, ValueContext context) { - visitExprName(indexName->expr, indexName->location, indexName->index.value, context); + // If we're indexing like _.foo - foo could either be a prop or a string. + visitExprName(indexName->expr, indexName->location, indexName->index.value, context, builtinTypes->stringType); + } + + void indexExprMetatableHelper(AstExprIndexExpr* indexExpr, const MetatableType* metaTable, TypeId exprType, TypeId indexType) + { + if (auto tt = get(follow(metaTable->table)); tt && tt->indexer) + testIsSubtype(indexType, tt->indexer->indexType, indexExpr->index->location); + else if (auto mt = get(follow(metaTable->table))) + indexExprMetatableHelper(indexExpr, mt, exprType, indexType); + else if (auto tmt = get(follow(metaTable->metatable)); tmt && tmt->indexer) + testIsSubtype(indexType, tmt->indexer->indexType, indexExpr->index->location); + else if (auto mtmt = get(follow(metaTable->metatable))) + indexExprMetatableHelper(indexExpr, mtmt, exprType, indexType); + else + { + LUAU_ASSERT(tt || get(follow(metaTable->table))); + + reportError(CannotExtendTable{exprType, CannotExtendTable::Indexer, "indexer??"}, indexExpr->location); + } } void visit(AstExprIndexExpr* indexExpr, ValueContext context) { if (auto str = indexExpr->index->as()) { + TypeId astIndexExprType = lookupType(indexExpr->index); const std::string stringValue(str->value.data, str->value.size); - visitExprName(indexExpr->expr, indexExpr->location, stringValue, context); + visitExprName(indexExpr->expr, indexExpr->location, stringValue, context, astIndexExprType); return; } - // TODO! - visit(indexExpr->expr, ValueContext::LValue); + visit(indexExpr->expr, ValueContext::RValue); visit(indexExpr->index, ValueContext::RValue); - NotNull scope = stack.back(); + TypeId exprType = follow(lookupType(indexExpr->expr)); + TypeId indexType = follow(lookupType(indexExpr->index)); + + if (auto tt = get(exprType)) + { + if (tt->indexer) + testIsSubtype(indexType, tt->indexer->indexType, indexExpr->index->location); + else + reportError(CannotExtendTable{exprType, CannotExtendTable::Indexer, "indexer??"}, indexExpr->location); + } + else if (auto mt = get(exprType)) + { + return indexExprMetatableHelper(indexExpr, mt, exprType, indexType); + } + else if (auto cls = get(exprType)) + { + if (cls->indexer) + testIsSubtype(indexType, cls->indexer->indexType, indexExpr->index->location); + else + reportError(DynamicPropertyLookupOnClassesUnsafe{exprType}, indexExpr->location); + } + else if (get(exprType) && isOptional(exprType)) + { + switch (shouldSuppressErrors(NotNull{&normalizer}, exprType)) + { + case ErrorSuppression::Suppress: + break; + case ErrorSuppression::NormalizationFailed: + reportError(NormalizationTooComplex{}, indexExpr->location); + // fallthrough intentional + case ErrorSuppression::DoNotSuppress: + reportError(OptionalValueAccess{exprType}, indexExpr->location); + } + } + else if (auto exprIntersection = get(exprType)) + { + for (TypeId part : exprIntersection) + { + (void)part; + } + } + else if (get(exprType) || isErrorSuppressing(indexExpr->location, exprType)) + { + // Nothing + } + else + reportError(NotATable{exprType}, indexExpr->location); + } + + void visit(AstExprFunction* fn) + { + auto StackPusher = pushStack(fn); + + visitGenerics(fn->generics, fn->genericPacks); + + TypeId inferredFnTy = lookupType(fn); + functionDeclStack.push_back(inferredFnTy); + + std::shared_ptr normalizedFnTy = normalizer.normalize(inferredFnTy); + if (!normalizedFnTy) + { + reportError(CodeTooComplex{}, fn->location); + } + else if (get(normalizedFnTy->errors)) + { + // Nothing + } + else if (!normalizedFnTy->hasFunctions()) + { + ice->ice("Internal error: Lambda has non-function type " + toString(inferredFnTy), fn->location); + } + else + { + if (1 != normalizedFnTy->functions.parts.size()) + ice->ice("Unexpected: Lambda has unexpected type " + toString(inferredFnTy), fn->location); + + const FunctionType* inferredFtv = get(normalizedFnTy->functions.parts.front()); + LUAU_ASSERT(inferredFtv); + + // There is no way to write an annotation for the self argument, so we + // cannot do anything to check it. + auto argIt = begin(inferredFtv->argTypes); + if (fn->self) + ++argIt; + + for (const auto& arg : fn->args) + { + if (argIt == end(inferredFtv->argTypes)) + break; + + TypeId inferredArgTy = *argIt; + + if (arg->annotation) + { + // we need to typecheck any argument annotations themselves. + visit(arg->annotation); + + TypeId annotatedArgTy = lookupAnnotation(arg->annotation); + + testIsSubtype(inferredArgTy, annotatedArgTy, arg->location); + } + + // Some Luau constructs can result in an argument type being + // reduced to never by inference. In this case, we want to + // report an error at the function, instead of reporting an + // error at every callsite. + if (is(follow(inferredArgTy))) + { + // If the annotation simplified to never, we don't want to + // even look at contributors. + bool explicitlyNever = false; + if (arg->annotation) + { + TypeId annotatedArgTy = lookupAnnotation(arg->annotation); + explicitlyNever = is(annotatedArgTy); + } + + // Not following here is deliberate: the contribution map is + // keyed by type pointer, but that type pointer has, at some + // point, been transmuted to a bound type pointing to never. + if (const auto contributors = module->upperBoundContributors.find(inferredArgTy); contributors && !explicitlyNever) + { + // It's unfortunate that we can't link error messages + // together. For now, this will work. + reportError( + GenericError{format( + "Parameter '%s' has been reduced to never. This function is not callable with any possible value.", arg->name.value + )}, + arg->location + ); + for (const auto& [site, component] : *contributors) + reportError( + ExtraInformation{ + format("Parameter '%s' is required to be a subtype of '%s' here.", arg->name.value, toString(component).c_str()) + }, + site + ); + } + } + + ++argIt; + } - TypeId exprType = lookupType(indexExpr->expr); - TypeId indexType = lookupType(indexExpr->index); + // we need to typecheck the vararg annotation, if it exists. + if (fn->vararg && fn->varargAnnotation) + visit(fn->varargAnnotation); - if (auto tt = get(exprType)) - { - if (tt->indexer) - reportErrors(tryUnify(scope, indexExpr->index->location, indexType, tt->indexer->indexType)); - else - reportError(CannotExtendTable{exprType, CannotExtendTable::Indexer, "indexer??"}, indexExpr->location); + bool reachesImplicitReturn = getFallthrough(fn->body) != nullptr; + if (reachesImplicitReturn && !allowsNoReturnValues(follow(inferredFtv->retTypes))) + reportError(FunctionExitsWithoutReturning{inferredFtv->retTypes}, getEndLocation(fn)); } - else if (get(exprType) && isOptional(exprType)) - reportError(OptionalValueAccess{exprType}, indexExpr->location); - } - - void visit(AstExprFunction* fn) - { - auto StackPusher = pushStack(fn); - visitGenerics(fn->generics, fn->genericPacks); + visit(fn->body); - TypeId inferredFnTy = lookupType(fn); - const FunctionType* inferredFtv = get(inferredFnTy); - LUAU_ASSERT(inferredFtv); + // we need to typecheck the return annotation itself, if it exists. + if (fn->returnAnnotation) + visit(*fn->returnAnnotation); - // There is no way to write an annotation for the self argument, so we - // cannot do anything to check it. - auto argIt = begin(inferredFtv->argTypes); - if (fn->self) - ++argIt; - for (const auto& arg : fn->args) + // If the function type has a function annotation, we need to see if we can suggest an annotation + if (normalizedFnTy) { - if (argIt == end(inferredFtv->argTypes)) - break; + const FunctionType* inferredFtv = get(normalizedFnTy->functions.parts.front()); + LUAU_ASSERT(inferredFtv); - if (arg->annotation) + TypeFunctionReductionGuesser guesser{NotNull{&module->internalTypes}, builtinTypes, NotNull{&normalizer}}; + for (TypeId retTy : inferredFtv->retTypes) { - TypeId inferredArgTy = *argIt; - TypeId annotatedArgTy = lookupAnnotation(arg->annotation); - - if (!isSubtype(inferredArgTy, annotatedArgTy, stack.back())) + if (get(follow(retTy))) { - reportError(TypeMismatch{inferredArgTy, annotatedArgTy}, arg->location); + TypeFunctionReductionGuessResult result = guesser.guessTypeFunctionReductionForFunctionExpr(*fn, inferredFtv, retTy); + if (result.shouldRecommendAnnotation) + reportError( + ExplicitFunctionAnnotationRecommended{std::move(result.guessedFunctionAnnotations), result.guessedReturnType}, + fn->location + ); } } - - ++argIt; } - visit(fn->body); + functionDeclStack.pop_back(); } void visit(AstExprTable* expr) @@ -1261,11 +1783,10 @@ struct TypeChecker2 { visit(expr->expr, ValueContext::RValue); - NotNull scope = stack.back(); TypeId operandType = lookupType(expr->expr); TypeId resultType = lookupType(expr); - if (get(operandType) || get(operandType) || get(operandType)) + if (isErrorSuppressing(expr->expr->location, operandType)) return; if (auto it = kUnaryOpMetamethods.find(expr->op); it != kUnaryOpMetamethods.end()) @@ -1279,7 +1800,7 @@ struct TypeChecker2 { if (expr->op == AstExprUnary::Op::Len) { - reportErrors(tryUnify(scope, expr->location, follow(*ret), builtinTypes->numberType)); + testIsSubtype(follow(*ret), builtinTypes->numberType, expr->location); } } else @@ -1294,17 +1815,14 @@ struct TypeChecker2 return; } - TypePackId expectedArgs = testArena.addTypePack({operandType}); - TypePackId expectedRet = testArena.addTypePack({resultType}); + TypePackId expectedArgs = module->internalTypes.addTypePack({operandType}); + TypePackId expectedRet = module->internalTypes.addTypePack({resultType}); - TypeId expectedFunction = testArena.addType(FunctionType{expectedArgs, expectedRet}); + TypeId expectedFunction = module->internalTypes.addType(FunctionType{expectedArgs, expectedRet}); - ErrorVec errors = tryUnify(scope, expr->location, *mm, expectedFunction); - if (!errors.empty()) - { - reportError(TypeMismatch{*firstArg, operandType}, expr->location); + bool success = testIsSubtype(*mm, expectedFunction, expr->location); + if (!success) return; - } } return; @@ -1315,7 +1833,10 @@ struct TypeChecker2 { DenseHashSet seen{nullptr}; int recursionCount = 0; + std::shared_ptr nty = normalizer.normalize(operandType); + if (nty && nty->shouldSuppressErrors()) + return; if (!hasLength(operandType, seen, &recursionCount)) { @@ -1327,7 +1848,7 @@ struct TypeChecker2 } else if (expr->op == AstExprUnary::Op::Minus) { - reportErrors(tryUnify(scope, expr->location, operandType, builtinTypes->numberType)); + testIsSubtype(operandType, builtinTypes->numberType, expr->location); } else if (expr->op == AstExprUnary::Op::Not) { @@ -1340,8 +1861,8 @@ struct TypeChecker2 TypeId visit(AstExprBinary* expr, AstNode* overrideKey = nullptr) { - visit(expr->left, ValueContext::LValue); - visit(expr->right, ValueContext::LValue); + visit(expr->left, ValueContext::RValue); + visit(expr->right, ValueContext::RValue); NotNull scope = stack.back(); @@ -1349,31 +1870,47 @@ struct TypeChecker2 bool isComparison = expr->op >= AstExprBinary::Op::CompareEq && expr->op <= AstExprBinary::Op::CompareGe; bool isLogical = expr->op == AstExprBinary::Op::And || expr->op == AstExprBinary::Op::Or; - TypeId leftType = lookupType(expr->left); - TypeId rightType = lookupType(expr->right); + TypeId leftType = follow(lookupType(expr->left)); + TypeId rightType = follow(lookupType(expr->right)); + TypeId expectedResult = follow(lookupType(expr)); + + if (get(expectedResult)) + { + checkForInternalTypeFunction(expectedResult, expr->location); + return expectedResult; + } if (expr->op == AstExprBinary::Op::Or) { - leftType = stripNil(builtinTypes, testArena, leftType); + leftType = stripNil(builtinTypes, module->internalTypes, leftType); } - bool isStringOperation = isString(leftType) && isString(rightType); + std::shared_ptr normLeft = normalizer.normalize(leftType); + std::shared_ptr normRight = normalizer.normalize(rightType); - if (get(leftType) || get(leftType)) + bool isStringOperation = + (normLeft ? normLeft->isSubtypeOfString() : isString(leftType)) && (normRight ? normRight->isSubtypeOfString() : isString(rightType)); + leftType = follow(leftType); + if (get(leftType) || get(leftType) || get(leftType)) return leftType; - else if (get(rightType) || get(rightType)) + else if (get(rightType) || get(rightType) || get(rightType)) return rightType; + else if ((normLeft && normLeft->shouldSuppressErrors()) || (normRight && normRight->shouldSuppressErrors())) + return builtinTypes->anyType; // we can't say anything better if it's error suppressing but not any or error alone. - if ((get(leftType) || get(leftType)) && !isEquality && !isLogical) + if ((get(leftType) || get(leftType) || get(leftType)) && !isEquality && !isLogical) { auto name = getIdentifierOfBaseVar(expr->left); - reportError(CannotInferBinaryOperation{expr->op, name, - isComparison ? CannotInferBinaryOperation::OpKind::Comparison : CannotInferBinaryOperation::OpKind::Operation}, - expr->location); + reportError( + CannotInferBinaryOperation{ + expr->op, name, isComparison ? CannotInferBinaryOperation::OpKind::Comparison : CannotInferBinaryOperation::OpKind::Operation + }, + expr->location + ); return leftType; } - bool typesHaveIntersection = normalizer.isIntersectionInhabited(leftType, rightType); + NormalizationResult typesHaveIntersection = normalizer.isIntersectionInhabited(leftType, rightType); if (auto it = kBinaryOpMetamethods.find(expr->op); it != kBinaryOpMetamethods.end()) { std::optional leftMt = getMetatable(leftType, builtinTypes); @@ -1383,7 +1920,8 @@ struct TypeChecker2 if (isEquality && !matches) { - auto testUnion = [&matches, builtinTypes = this->builtinTypes](const UnionType* utv, std::optional otherMt) { + auto testUnion = [&matches, builtinTypes = this->builtinTypes](const UnionType* utv, std::optional otherMt) + { for (TypeId option : utv) { if (getMetatable(follow(option), builtinTypes) == otherMt) @@ -1403,20 +1941,27 @@ struct TypeChecker2 { testUnion(utv, leftMt); } + } - // If either left or right has no metatable (or both), we need to consider if - // there are values in common that could possibly inhabit the type (and thus equality could be considered) + // If we're working with things that are not tables, the metatable comparisons above are a little excessive + // It's ok for one type to have a meta table and the other to not. In that case, we should fall back on + // checking if the intersection of the types is inhabited. If `typesHaveIntersection` failed due to limits, + // TODO: Maybe add more checks here (e.g. for functions, classes, etc) + if (!(get(leftType) || get(rightType))) if (!leftMt.has_value() || !rightMt.has_value()) - { - matches = matches || typesHaveIntersection; - } - } + matches = matches || typesHaveIntersection != NormalizationResult::False; if (!matches && isComparison) { - reportError(GenericError{format("Types %s and %s cannot be compared with %s because they do not have the same metatable", - toString(leftType).c_str(), toString(rightType).c_str(), toString(expr->op).c_str())}, - expr->location); + reportError( + GenericError{format( + "Types %s and %s cannot be compared with %s because they do not have the same metatable", + toString(leftType).c_str(), + toString(rightType).c_str(), + toString(expr->op).c_str() + )}, + expr->location + ); return builtinTypes->errorRecoveryType(); } @@ -1436,38 +1981,42 @@ struct TypeChecker2 if (overrideKey != nullptr) key = overrideKey; - TypeId instantiatedMm = module->astOverloadResolvedTypes[key]; - if (!instantiatedMm) - reportError(CodeTooComplex{}, expr->location); + TypeId* selectedOverloadTy = module->astOverloadResolvedTypes.find(key); + if (!selectedOverloadTy) + { + // reportError(CodeTooComplex{}, expr->location); + // was handled by a type function + return expectedResult; + } - else if (const FunctionType* ftv = get(follow(instantiatedMm))) + else if (const FunctionType* ftv = get(follow(*selectedOverloadTy))) { TypePackId expectedArgs; // For >= and > we invoke __lt and __le respectively with // swapped argument ordering. if (expr->op == AstExprBinary::Op::CompareGe || expr->op == AstExprBinary::Op::CompareGt) { - expectedArgs = testArena.addTypePack({rightType, leftType}); + expectedArgs = module->internalTypes.addTypePack({rightType, leftType}); } else { - expectedArgs = testArena.addTypePack({leftType, rightType}); + expectedArgs = module->internalTypes.addTypePack({leftType, rightType}); } TypePackId expectedRets; if (expr->op == AstExprBinary::CompareEq || expr->op == AstExprBinary::CompareNe || expr->op == AstExprBinary::CompareGe || expr->op == AstExprBinary::CompareGt || expr->op == AstExprBinary::Op::CompareLe || expr->op == AstExprBinary::Op::CompareLt) { - expectedRets = testArena.addTypePack({builtinTypes->booleanType}); + expectedRets = module->internalTypes.addTypePack({builtinTypes->booleanType}); } else { - expectedRets = testArena.addTypePack({testArena.freshType(scope, TypeLevel{})}); + expectedRets = module->internalTypes.addTypePack({module->internalTypes.freshType(scope, TypeLevel{})}); } - TypeId expectedTy = testArena.addType(FunctionType(expectedArgs, expectedRets)); + TypeId expectedTy = module->internalTypes.addType(FunctionType(expectedArgs, expectedRets)); - reportErrors(tryUnify(scope, expr->location, follow(*mm), expectedTy)); + testIsSubtype(follow(*mm), expectedTy, expr->location); std::optional ret = first(ftv->retTypes); if (ret) @@ -1515,17 +2064,29 @@ struct TypeChecker2 { if (isComparison) { - reportError(GenericError{format( - "Types '%s' and '%s' cannot be compared with %s because neither type's metatable has a '%s' metamethod", - toString(leftType).c_str(), toString(rightType).c_str(), toString(expr->op).c_str(), it->second)}, - expr->location); + reportError( + GenericError{format( + "Types '%s' and '%s' cannot be compared with %s because neither type's metatable has a '%s' metamethod", + toString(leftType).c_str(), + toString(rightType).c_str(), + toString(expr->op).c_str(), + it->second + )}, + expr->location + ); } else { - reportError(GenericError{format( - "Operator %s is not applicable for '%s' and '%s' because neither type's metatable has a '%s' metamethod", - toString(expr->op).c_str(), toString(leftType).c_str(), toString(rightType).c_str(), it->second)}, - expr->location); + reportError( + GenericError{format( + "Operator %s is not applicable for '%s' and '%s' because neither type's metatable has a '%s' metamethod", + toString(expr->op).c_str(), + toString(leftType).c_str(), + toString(rightType).c_str(), + it->second + )}, + expr->location + ); } return builtinTypes->errorRecoveryType(); @@ -1534,15 +2095,27 @@ struct TypeChecker2 { if (isComparison) { - reportError(GenericError{format("Types '%s' and '%s' cannot be compared with %s because neither type has a metatable", - toString(leftType).c_str(), toString(rightType).c_str(), toString(expr->op).c_str())}, - expr->location); + reportError( + GenericError{format( + "Types '%s' and '%s' cannot be compared with %s because neither type has a metatable", + toString(leftType).c_str(), + toString(rightType).c_str(), + toString(expr->op).c_str() + )}, + expr->location + ); } else { - reportError(GenericError{format("Operator %s is not applicable for '%s' and '%s' because neither type has a metatable", - toString(expr->op).c_str(), toString(leftType).c_str(), toString(rightType).c_str())}, - expr->location); + reportError( + GenericError{format( + "Operator %s is not applicable for '%s' and '%s' because neither type has a metatable", + toString(expr->op).c_str(), + toString(leftType).c_str(), + toString(rightType).c_str() + )}, + expr->location + ); } return builtinTypes->errorRecoveryType(); @@ -1556,38 +2129,54 @@ struct TypeChecker2 case AstExprBinary::Op::Sub: case AstExprBinary::Op::Mul: case AstExprBinary::Op::Div: + case AstExprBinary::Op::FloorDiv: case AstExprBinary::Op::Pow: case AstExprBinary::Op::Mod: - reportErrors(tryUnify(scope, expr->left->location, leftType, builtinTypes->numberType)); - reportErrors(tryUnify(scope, expr->right->location, rightType, builtinTypes->numberType)); + testIsSubtype(leftType, builtinTypes->numberType, expr->left->location); + testIsSubtype(rightType, builtinTypes->numberType, expr->right->location); return builtinTypes->numberType; case AstExprBinary::Op::Concat: - reportErrors(tryUnify(scope, expr->left->location, leftType, builtinTypes->stringType)); - reportErrors(tryUnify(scope, expr->right->location, rightType, builtinTypes->stringType)); + testIsSubtype(leftType, builtinTypes->stringType, expr->left->location); + testIsSubtype(rightType, builtinTypes->stringType, expr->right->location); return builtinTypes->stringType; case AstExprBinary::Op::CompareGe: case AstExprBinary::Op::CompareGt: case AstExprBinary::Op::CompareLe: case AstExprBinary::Op::CompareLt: - if (isNumber(leftType)) - { - reportErrors(tryUnify(scope, expr->right->location, rightType, builtinTypes->numberType)); - return builtinTypes->numberType; - } - else if (isString(leftType)) + { + if (normLeft && normLeft->shouldSuppressErrors()) + return builtinTypes->booleanType; + + // if we're comparing against an uninhabited type, it's unobservable that the comparison did not run + if (normLeft && normalizer.isInhabited(normLeft.get()) == NormalizationResult::False) + return builtinTypes->booleanType; + + if (normLeft && normLeft->isExactlyNumber()) { - reportErrors(tryUnify(scope, expr->right->location, rightType, builtinTypes->stringType)); - return builtinTypes->stringType; + testIsSubtype(rightType, builtinTypes->numberType, expr->right->location); + return builtinTypes->booleanType; } - else + + if (normLeft && normLeft->isSubtypeOfString()) { - reportError(GenericError{format("Types '%s' and '%s' cannot be compared with relational operator %s", toString(leftType).c_str(), - toString(rightType).c_str(), toString(expr->op).c_str())}, - expr->location); - return builtinTypes->errorRecoveryType(); + testIsSubtype(rightType, builtinTypes->stringType, expr->right->location); + return builtinTypes->booleanType; } + + reportError( + GenericError{format( + "Types '%s' and '%s' cannot be compared with relational operator %s", + toString(leftType).c_str(), + toString(rightType).c_str(), + toString(expr->op).c_str() + )}, + expr->location + ); + return builtinTypes->errorRecoveryType(); + } + case AstExprBinary::Op::And: case AstExprBinary::Op::Or: case AstExprBinary::Op::CompareEq: @@ -1610,14 +2199,39 @@ struct TypeChecker2 TypeId annotationType = lookupAnnotation(expr->annotation); TypeId computedType = lookupType(expr->expr); - // Note: As an optimization, we try 'number <: number | string' first, as that is the more likely case. - if (isSubtype(annotationType, computedType, stack.back())) + switch (shouldSuppressErrors(NotNull{&normalizer}, computedType).orElse(shouldSuppressErrors(NotNull{&normalizer}, annotationType))) + { + case ErrorSuppression::Suppress: return; + case ErrorSuppression::NormalizationFailed: + reportError(NormalizationTooComplex{}, expr->location); + return; + case ErrorSuppression::DoNotSuppress: + break; + } - if (isSubtype(computedType, annotationType, stack.back())) + switch (normalizer.isInhabited(computedType)) + { + case NormalizationResult::True: + break; + case NormalizationResult::False: + return; + case NormalizationResult::HitLimits: + reportError(NormalizationTooComplex{}, expr->location); return; + } - reportError(TypesAreUnrelated{computedType, annotationType}, expr->location); + switch (normalizer.isIntersectionInhabited(computedType, annotationType)) + { + case NormalizationResult::True: + return; + case NormalizationResult::False: + reportError(TypesAreUnrelated{computedType, annotationType}, expr->location); + break; + case NormalizationResult::HitLimits: + reportError(NormalizationTooComplex{}, expr->location); + break; + } } void visit(AstExprIfElse* expr) @@ -1653,12 +2267,12 @@ struct TypeChecker2 return *fst; else if (auto ftp = get(pack)) { - TypeId result = testArena.addType(FreeType{ftp->scope}); - TypePackId freeTail = testArena.addTypePack(FreeTypePack{ftp->scope}); + TypeId result = module->internalTypes.addType(FreeType{ftp->scope}); + TypePackId freeTail = module->internalTypes.addTypePack(FreeTypePack{ftp->scope}); - TypePack& resultPack = asMutable(pack)->ty.emplace(); - resultPack.head.assign(1, result); - resultPack.tail = freeTail; + TypePack* resultPack = emplaceTypePack(asMutable(pack)); + resultPack->head.assign(1, result); + resultPack->tail = freeTail; return result; } @@ -1699,6 +2313,10 @@ struct TypeChecker2 void visit(AstType* ty) { + TypeId* resolvedTy = module->astResolvedTypes.find(ty); + if (resolvedTy) + checkForTypeFunctionInhabitance(follow(*resolvedTy), ty->location); + if (auto t = ty->as()) return visit(t); else if (auto t = ty->as()) @@ -1717,7 +2335,7 @@ struct TypeChecker2 { // No further validation is necessary in this case. The main logic for // _luau_print is contained in lookupAnnotation. - if (FFlag::DebugLuauMagicTypes && ty->name == "_luau_print" && ty->parameters.size > 0) + if (FFlag::DebugLuauMagicTypes && ty->name == "_luau_print") return; for (const AstTypeOrPack& param : ty->parameters) @@ -1739,13 +2357,23 @@ struct TypeChecker2 size_t typesRequired = alias->typeParams.size(); size_t packsRequired = alias->typePackParams.size(); - bool hasDefaultTypes = std::any_of(alias->typeParams.begin(), alias->typeParams.end(), [](auto&& el) { - return el.defaultValue.has_value(); - }); + bool hasDefaultTypes = std::any_of( + alias->typeParams.begin(), + alias->typeParams.end(), + [](auto&& el) + { + return el.defaultValue.has_value(); + } + ); - bool hasDefaultPacks = std::any_of(alias->typePackParams.begin(), alias->typePackParams.end(), [](auto&& el) { - return el.defaultValue.has_value(); - }); + bool hasDefaultPacks = std::any_of( + alias->typePackParams.begin(), + alias->typePackParams.end(), + [](auto&& el) + { + return el.defaultValue.has_value(); + } + ); if (!ty->hasParameterList) { @@ -1766,6 +2394,7 @@ struct TypeChecker2 if (packsProvided != 0) { reportError(GenericError{"Type parameters must come before type pack parameters"}, ty->location); + continue; } if (typesProvided < typesRequired) @@ -1779,9 +2408,11 @@ struct TypeChecker2 } else if (p.typePack) { - TypePackId tp = lookupPackAnnotation(p.typePack); + std::optional tp = lookupPackAnnotation(p.typePack); + if (!tp.has_value()) + continue; - if (typesProvided < typesRequired && size(tp) == 1 && finite(tp) && first(tp)) + if (typesProvided < typesRequired && size(*tp) == 1 && finite(*tp) && first(*tp)) { typesProvided += 1; } @@ -1794,7 +2425,11 @@ struct TypeChecker2 if (extraTypes != 0 && packsProvided == 0) { - packsProvided += 1; + // Extra types are only collected into a pack if a pack is expected + if (packsRequired != 0) + packsProvided += 1; + else + typesProvided += extraTypes; } for (size_t i = typesProvided; i < typesRequired; ++i) @@ -1820,13 +2455,15 @@ struct TypeChecker2 if (typesProvided != typesRequired || packsProvided != packsRequired) { - reportError(IncorrectGenericParameterCount{ - /* name */ ty->name.value, - /* typeFun */ *alias, - /* actualParameters */ typesProvided, - /* actualPackParameters */ packsProvided, - }, - ty->location); + reportError( + IncorrectGenericParameterCount{ + /* name */ ty->name.value, + /* typeFun */ *alias, + /* actualParameters */ typesProvided, + /* actualPackParameters */ packsProvided, + }, + ty->location + ); } } else @@ -1838,7 +2475,8 @@ struct TypeChecker2 ty->name.value, SwappedGenericTypeParameter::Kind::Type, }, - ty->location); + ty->location + ); } else { @@ -1936,7 +2574,8 @@ struct TypeChecker2 tp->genericName.value, SwappedGenericTypeParameter::Kind::Pack, }, - tp->location); + tp->location + ); } else { @@ -1945,94 +2584,163 @@ struct TypeChecker2 } } - void reduceTypes() + struct Reasonings { - if (FFlag::DebugLuauDontReduceTypes) - return; + // the list of reasons + std::vector reasons; + + // this should be true if _all_ of the reasons have an error suppressing type, and false otherwise. + bool suppressed; - for (auto [_, scope] : module->scopes) + std::string toString() { - for (auto& [_, b] : scope->bindings) + // DenseHashSet ordering is entirely undefined, so we want to + // sort the reasons here to achieve a stable error + // stringification. + std::sort(reasons.begin(), reasons.end()); + std::string allReasons; + bool first = true; + for (const std::string& reason : reasons) { - if (auto reduced = module->reduction->reduce(b.typeId)) - b.typeId = *reduced; + if (first) + first = false; + else + allReasons += "\n\t"; + + allReasons += reason; } - if (auto reduced = module->reduction->reduce(scope->returnType)) - scope->returnType = *reduced; + return allReasons; + } + }; - if (scope->varargPack) - { - if (auto reduced = module->reduction->reduce(*scope->varargPack)) - scope->varargPack = *reduced; - } + template + Reasonings explainReasonings(TID subTy, TID superTy, Location location, const SubtypingResult& r) + { + if (r.reasoning.empty()) + return {}; - auto reduceMap = [this](auto& map) { - for (auto& [_, tf] : map) - { - if (auto reduced = module->reduction->reduce(tf)) - tf = *reduced; - } - }; + std::vector reasons; + bool suppressed = true; + for (const SubtypingReasoning& reasoning : r.reasoning) + { + if (reasoning.subPath.empty() && reasoning.superPath.empty()) + continue; - reduceMap(scope->exportedTypeBindings); - reduceMap(scope->privateTypeBindings); - reduceMap(scope->privateTypePackBindings); - for (auto& [_, space] : scope->importedTypeBindings) - reduceMap(space); - } + std::optional optSubLeaf = traverse(subTy, reasoning.subPath, builtinTypes); + std::optional optSuperLeaf = traverse(superTy, reasoning.superPath, builtinTypes); + + if (!optSubLeaf || !optSuperLeaf) + ice->ice("Subtyping test returned a reasoning with an invalid path", location); + + const TypeOrPack& subLeaf = *optSubLeaf; + const TypeOrPack& superLeaf = *optSuperLeaf; + + auto subLeafTy = get(subLeaf); + auto superLeafTy = get(superLeaf); + + auto subLeafTp = get(subLeaf); + auto superLeafTp = get(superLeaf); + + if (!subLeafTy && !superLeafTy && !subLeafTp && !superLeafTp) + ice->ice("Subtyping test returned a reasoning where one path ends at a type and the other ends at a pack.", location); + + std::string relation = "a subtype of"; + if (reasoning.variance == SubtypingVariance::Invariant) + relation = "exactly"; + else if (reasoning.variance == SubtypingVariance::Contravariant) + relation = "a supertype of"; - auto reduceOrError = [this](auto& map) { - for (auto [ast, t] : map) + std::string reason; + if (reasoning.subPath == reasoning.superPath) + reason = "at " + toString(reasoning.subPath) + ", " + toString(subLeaf) + " is not " + relation + " " + toString(superLeaf); + else + reason = "type " + toString(subTy) + toString(reasoning.subPath, /* prefixDot */ true) + " (" + toString(subLeaf) + ") is not " + + relation + " " + toString(superTy) + toString(reasoning.superPath, /* prefixDot */ true) + " (" + toString(superLeaf) + ")"; + + reasons.push_back(reason); + + // if we haven't already proved this isn't suppressing, we have to keep checking. + if (suppressed) { - if (!t) - continue; // Reminder: this implies that the recursion limit was exceeded. - else if (auto reduced = module->reduction->reduce(t)) - map[ast] = *reduced; + if (subLeafTy && superLeafTy) + suppressed &= isErrorSuppressing(location, *subLeafTy) || isErrorSuppressing(location, *superLeafTy); else - reportError(NormalizationTooComplex{}, ast->location); + suppressed &= isErrorSuppressing(location, *subLeafTp) || isErrorSuppressing(location, *superLeafTp); } - }; + } + + return {std::move(reasons), suppressed}; + } + - module->astOriginalResolvedTypes = module->astResolvedTypes; + void explainError(TypeId subTy, TypeId superTy, Location location, const SubtypingResult& result) + { + switch (shouldSuppressErrors(NotNull{&normalizer}, subTy).orElse(shouldSuppressErrors(NotNull{&normalizer}, superTy))) + { + case ErrorSuppression::Suppress: + return; + case ErrorSuppression::NormalizationFailed: + reportError(NormalizationTooComplex{}, location); + case ErrorSuppression::DoNotSuppress: + break; + } + + Reasonings reasonings = explainReasonings(subTy, superTy, location, result); - // Both [`Module::returnType`] and [`Module::exportedTypeBindings`] are empty here, and - // is populated by [`Module::clonePublicInterface`] in the future, so by that point these - // two aforementioned fields will only contain types that are irreducible. - reduceOrError(module->astTypes); - reduceOrError(module->astTypePacks); - reduceOrError(module->astExpectedTypes); - reduceOrError(module->astOriginalCallTypes); - reduceOrError(module->astOverloadResolvedTypes); - reduceOrError(module->astResolvedTypes); - reduceOrError(module->astResolvedTypePacks); + if (!reasonings.suppressed) + reportError(TypeMismatch{superTy, subTy, reasonings.toString()}, location); } - template - bool isSubtype(TID subTy, TID superTy, NotNull scope) + void explainError(TypePackId subTy, TypePackId superTy, Location location, const SubtypingResult& result) { - TypeArena arena; - Unifier u{NotNull{&normalizer}, Mode::Strict, scope, Location{}, Covariant}; - u.useScopes = true; + switch (shouldSuppressErrors(NotNull{&normalizer}, subTy).orElse(shouldSuppressErrors(NotNull{&normalizer}, superTy))) + { + case ErrorSuppression::Suppress: + return; + case ErrorSuppression::NormalizationFailed: + reportError(NormalizationTooComplex{}, location); + case ErrorSuppression::DoNotSuppress: + break; + } + + Reasonings reasonings = explainReasonings(subTy, superTy, location, result); - u.tryUnify(subTy, superTy); - const bool ok = u.errors.empty() && u.log.empty(); - return ok; + if (!reasonings.suppressed) + reportError(TypePackMismatch{superTy, subTy, reasonings.toString()}, location); } - template - ErrorVec tryUnify(NotNull scope, const Location& location, TID subTy, TID superTy, CountMismatch::Context context = CountMismatch::Arg) + bool testIsSubtype(TypeId subTy, TypeId superTy, Location location) + { + SubtypingResult r = subtyping->isSubtype(subTy, superTy); + + if (r.normalizationTooComplex) + reportError(NormalizationTooComplex{}, location); + + if (!r.isSubtype) + explainError(subTy, superTy, location, r); + + return r.isSubtype; + } + + bool testIsSubtype(TypePackId subTy, TypePackId superTy, Location location) { - Unifier u{NotNull{&normalizer}, Mode::Strict, scope, location, Covariant}; - u.ctx = context; - u.useScopes = true; - u.tryUnify(subTy, superTy); + SubtypingResult r = subtyping->isSubtype(subTy, superTy); - return std::move(u.errors); + if (r.normalizationTooComplex) + reportError(NormalizationTooComplex{}, location); + + if (!r.isSubtype) + explainError(subTy, superTy, location, r); + + return r.isSubtype; } void reportError(TypeErrorData data, const Location& location) { + if (auto utk = get_if(&data)) + diagnoseMissingTableKey(utk, data); + module->errors.emplace_back(location, module->name, std::move(data)); if (logger) @@ -2050,56 +2758,119 @@ struct TypeChecker2 reportError(std::move(e)); } - // If the provided type does not have the named property, report an error. - void checkIndexTypeFromType(TypeId tableTy, const std::string& prop, const Location& location, ValueContext context) + struct PropertyTypes { - const NormalizedType* norm = normalizer.normalize(tableTy); - if (!norm) + // a vector of all the types assigned to the given property. + std::vector typesOfProp; + + // a vector of all the types that are missing the given property. + std::vector missingProp; + + bool foundOneProp() const { - reportError(NormalizationTooComplex{}, location); - return; + return !typesOfProp.empty(); + } + + bool noneMissingProp() const + { + return missingProp.empty(); + } + + bool foundMissingProp() const + { + return !missingProp.empty(); } + }; - bool foundOneProp = false; + /* A helper for checkIndexTypeFromType. + * + * Returns a pair: + * * A boolean indicating that at least one of the constituent types + * contains the prop, and + * * A vector of types that do not contain the prop. + */ + PropertyTypes lookupProp( + const NormalizedType* norm, + const std::string& prop, + ValueContext context, + const Location& location, + TypeId astIndexExprType, + std::vector& errors + ) + { + std::vector typesOfProp; std::vector typesMissingTheProp; - auto fetch = [&](TypeId ty) { - if (!normalizer.isInhabited(ty)) + // this is `false` if we ever hit the resource limits during any of our uses of `fetch`. + bool normValid = true; + + auto fetch = [&](TypeId ty) + { + NormalizationResult result = normalizer.isInhabited(ty); + if (result == NormalizationResult::HitLimits) + normValid = false; + if (result != NormalizationResult::True) + return; + + DenseHashSet seen{nullptr}; + PropertyType res = hasIndexTypeFromType(ty, prop, context, location, seen, astIndexExprType, errors); + + if (res.present == NormalizationResult::HitLimits) + { + normValid = false; return; + } - std::unordered_set seen; - bool found = hasIndexTypeFromType(ty, prop, location, seen); - foundOneProp |= found; - if (!found) + if (res.present == NormalizationResult::True && res.result) + typesOfProp.emplace_back(*res.result); + + if (res.present == NormalizationResult::False) typesMissingTheProp.push_back(ty); }; - fetch(norm->tops); - fetch(norm->booleans); + if (normValid) + fetch(norm->tops); + if (normValid) + fetch(norm->booleans); - if (FFlag::LuauNegatedClassTypes) + if (normValid) { for (const auto& [ty, _negations] : norm->classes.classes) { fetch(ty); + + if (!normValid) + break; } } - else + + if (normValid) + fetch(norm->errors); + if (normValid) + fetch(norm->nils); + if (normValid) + fetch(norm->numbers); + if (normValid && !norm->strings.isNever()) + fetch(builtinTypes->stringType); + if (normValid) + fetch(norm->threads); + if (normValid) + fetch(norm->buffers); + + if (normValid) { - for (TypeId ty : norm->DEPRECATED_classes) + for (TypeId ty : norm->tables) + { fetch(ty); + + if (!normValid) + break; + } } - fetch(norm->errors); - fetch(norm->nils); - fetch(norm->numbers); - if (!norm->strings.isNever()) - fetch(builtinTypes->stringType); - fetch(norm->threads); - for (TypeId ty : norm->tables) - fetch(ty); - if (norm->functions.isTop) + + if (normValid && norm->functions.isTop) fetch(builtinTypes->functionType); - else if (!norm->functions.isNever()) + else if (normValid && !norm->functions.isNever()) { if (norm->functions.parts.size() == 1) fetch(norm->functions.parts.front()); @@ -2107,88 +2878,295 @@ struct TypeChecker2 { std::vector parts; parts.insert(parts.end(), norm->functions.parts.begin(), norm->functions.parts.end()); - fetch(testArena.addType(IntersectionType{std::move(parts)})); + fetch(module->internalTypes.addType(IntersectionType{std::move(parts)})); } } - for (const auto& [tyvar, intersect] : norm->tyvars) + + if (normValid) { - if (get(intersect->tops)) + for (const auto& [tyvar, intersect] : norm->tyvars) { - TypeId ty = normalizer.typeFromNormal(*intersect); - fetch(testArena.addType(IntersectionType{{tyvar, ty}})); + if (get(intersect->tops)) + { + TypeId ty = normalizer.typeFromNormal(*intersect); + fetch(module->internalTypes.addType(IntersectionType{{tyvar, ty}})); + } + else + fetch(follow(tyvar)); + + if (!normValid) + break; } - else - fetch(tyvar); } - if (!typesMissingTheProp.empty()) + return {typesOfProp, typesMissingTheProp}; + } + + // If the provided type does not have the named property, report an error. + void checkIndexTypeFromType(TypeId tableTy, const std::string& prop, ValueContext context, const Location& location, TypeId astIndexExprType) + { + std::shared_ptr norm = normalizer.normalize(tableTy); + if (!norm) + { + reportError(NormalizationTooComplex{}, location); + return; + } + + // if the type is error suppressing, we don't actually have any work left to do. + if (norm->shouldSuppressErrors()) + return; + + std::vector dummy; + const auto propTypes = lookupProp(norm.get(), prop, context, location, astIndexExprType, module->errors); + + if (propTypes.foundMissingProp()) { - if (foundOneProp) - reportError(MissingUnionProperty{tableTy, typesMissingTheProp, prop}, location); + if (propTypes.foundOneProp()) + reportError(MissingUnionProperty{tableTy, propTypes.missingProp, prop}, location); // For class LValues, we don't want to report an extension error, // because classes come into being with full knowledge of their // shape. We instead want to report the unknown property error of // the `else` branch. else if (context == ValueContext::LValue && !get(tableTy)) - reportError(CannotExtendTable{tableTy, CannotExtendTable::Property, prop}, location); + { + const auto lvPropTypes = lookupProp(norm.get(), prop, ValueContext::RValue, location, astIndexExprType, dummy); + if (lvPropTypes.foundOneProp() && lvPropTypes.noneMissingProp()) + reportError(PropertyAccessViolation{tableTy, prop, PropertyAccessViolation::CannotWrite}, location); + else if (get(tableTy) || get(tableTy)) + reportError(NotATable{tableTy}, location); + else + reportError(CannotExtendTable{tableTy, CannotExtendTable::Property, prop}, location); + } + else if (context == ValueContext::RValue && !get(tableTy)) + { + const auto rvPropTypes = lookupProp(norm.get(), prop, ValueContext::LValue, location, astIndexExprType, dummy); + if (rvPropTypes.foundOneProp() && rvPropTypes.noneMissingProp()) + reportError(PropertyAccessViolation{tableTy, prop, PropertyAccessViolation::CannotRead}, location); + else + reportError(UnknownProperty{tableTy, prop}, location); + } else reportError(UnknownProperty{tableTy, prop}, location); } } - bool hasIndexTypeFromType(TypeId ty, const std::string& prop, const Location& location, std::unordered_set& seen) + struct PropertyType + { + NormalizationResult present; + std::optional result; + }; + + PropertyType hasIndexTypeFromType( + TypeId ty, + const std::string& prop, + ValueContext context, + const Location& location, + DenseHashSet& seen, + TypeId astIndexExprType, + std::vector& errors + ) { // If we have already encountered this type, we must assume that some // other codepath will do the right thing and signal false if the // property is not present. - const bool isUnseen = seen.insert(ty).second; - if (!isUnseen) - return true; + if (seen.contains(ty)) + return {NormalizationResult::True, {}}; + seen.insert(ty); if (get(ty) || get(ty) || get(ty)) - return true; + return {NormalizationResult::True, {ty}}; if (isString(ty)) { - std::optional mtIndex = Luau::findMetatableEntry(builtinTypes, module->errors, builtinTypes->stringType, "__index", location); + std::optional mtIndex = Luau::findMetatableEntry(builtinTypes, errors, builtinTypes->stringType, "__index", location); LUAU_ASSERT(mtIndex); ty = *mtIndex; } if (auto tt = getTableType(ty)) { - if (findTablePropertyRespectingMeta(builtinTypes, module->errors, ty, prop, location)) - return true; + if (auto resTy = findTablePropertyRespectingMeta(builtinTypes, errors, ty, prop, context, location)) + return {NormalizationResult::True, resTy}; - else if (tt->indexer && isPrim(tt->indexer->indexType, PrimitiveType::String)) - return true; + if (tt->indexer) + { + TypeId indexType = follow(tt->indexer->indexType); + if (isPrim(indexType, PrimitiveType::String)) + return {NormalizationResult::True, {tt->indexer->indexResultType}}; + // If the indexer looks like { [any] : _} - the prop lookup should be allowed! + else if (get(indexType) || get(indexType)) + return {NormalizationResult::True, {tt->indexer->indexResultType}}; + } - else - return false; + + // if we are in a conditional context, we treat the property as present and `unknown` because + // we may be _refining_ `tableTy` to include that property. we will want to revisit this a bit + // in the future once luau has support for exact tables since this only applies when inexact. + return {inConditional(typeContext) ? NormalizationResult::True : NormalizationResult::False, {builtinTypes->unknownType}}; } else if (const ClassType* cls = get(ty)) - return bool(lookupClassProp(cls, prop)); + { + // If the property doesn't exist on the class, we consult the indexer + // We need to check if the type of the index expression foo (x[foo]) + // is compatible with the indexer's indexType + // Construct the intersection and test inhabitedness! + if (auto property = lookupClassProp(cls, prop)) + return {NormalizationResult::True, context == ValueContext::LValue ? property->writeTy : property->readTy}; + if (cls->indexer) + { + TypeId inhabitatedTestType = module->internalTypes.addType(IntersectionType{{cls->indexer->indexType, astIndexExprType}}); + return {normalizer.isInhabited(inhabitatedTestType), {cls->indexer->indexResultType}}; + } + return {NormalizationResult::False, {}}; + } else if (const UnionType* utv = get(ty)) - return std::all_of(begin(utv), end(utv), [&](TypeId part) { - return hasIndexTypeFromType(part, prop, location, seen); - }); + { + std::vector parts; + parts.reserve(utv->options.size()); + + for (TypeId part : utv) + { + PropertyType result = hasIndexTypeFromType(part, prop, context, location, seen, astIndexExprType, errors); + + if (result.present != NormalizationResult::True) + return {result.present, {}}; + if (result.result) + parts.emplace_back(*result.result); + } + + if (parts.size() == 0) + return {NormalizationResult::False, {}}; + + if (parts.size() == 1) + return {NormalizationResult::True, {parts[0]}}; + + TypeId propTy; + if (context == ValueContext::LValue) + propTy = module->internalTypes.addType(IntersectionType{parts}); + else + propTy = module->internalTypes.addType(UnionType{parts}); + + return {NormalizationResult::True, propTy}; + } else if (const IntersectionType* itv = get(ty)) - return std::any_of(begin(itv), end(itv), [&](TypeId part) { - return hasIndexTypeFromType(part, prop, location, seen); - }); + { + for (TypeId part : itv) + { + PropertyType result = hasIndexTypeFromType(part, prop, context, location, seen, astIndexExprType, errors); + if (result.present != NormalizationResult::False) + return result; + } + + return {NormalizationResult::False, {}}; + } + else if (const PrimitiveType* pt = get(ty)) + return {(inConditional(typeContext) && pt->type == PrimitiveType::Table) ? NormalizationResult::True : NormalizationResult::False, {ty}}; else + return {NormalizationResult::False, {}}; + } + + void diagnoseMissingTableKey(UnknownProperty* utk, TypeErrorData& data) const + { + std::string_view sv(utk->key); + std::set candidates; + + auto accumulate = [&](const TableType::Props& props) + { + for (const auto& [name, ty] : props) + { + if (sv != name && equalsLower(sv, name)) + candidates.insert(name); + } + }; + + if (auto ttv = getTableType(utk->table)) + accumulate(ttv->props); + else if (auto ctv = get(follow(utk->table))) + { + while (ctv) + { + accumulate(ctv->props); + + if (!ctv->parent) + break; + + ctv = get(*ctv->parent); + LUAU_ASSERT(ctv); + } + } + + if (!candidates.empty()) + data = TypeErrorData(UnknownPropButFoundLikeProp{utk->table, utk->key, candidates}); + } + + bool isErrorSuppressing(Location loc, TypeId ty) + { + switch (shouldSuppressErrors(NotNull{&normalizer}, ty)) + { + case ErrorSuppression::DoNotSuppress: + return false; + case ErrorSuppression::Suppress: + return true; + case ErrorSuppression::NormalizationFailed: + reportError(NormalizationTooComplex{}, loc); + return false; + }; + + LUAU_ASSERT(false); + return false; // UNREACHABLE + } + + bool isErrorSuppressing(Location loc1, TypeId ty1, Location loc2, TypeId ty2) + { + return isErrorSuppressing(loc1, ty1) || isErrorSuppressing(loc2, ty2); + } + + bool isErrorSuppressing(Location loc, TypePackId tp) + { + switch (shouldSuppressErrors(NotNull{&normalizer}, tp)) + { + case ErrorSuppression::DoNotSuppress: + return false; + case ErrorSuppression::Suppress: + return true; + case ErrorSuppression::NormalizationFailed: + reportError(NormalizationTooComplex{}, loc); return false; + }; + + LUAU_ASSERT(false); + return false; // UNREACHABLE + } + + bool isErrorSuppressing(Location loc1, TypePackId tp1, Location loc2, TypePackId tp2) + { + return isErrorSuppressing(loc1, tp1) || isErrorSuppressing(loc2, tp2); } }; -void check(NotNull builtinTypes, NotNull unifierState, DcrLogger* logger, const SourceModule& sourceModule, Module* module) +void check( + NotNull builtinTypes, + NotNull unifierState, + NotNull limits, + DcrLogger* logger, + const SourceModule& sourceModule, + Module* module +) { - TypeChecker2 typeChecker{builtinTypes, unifierState, logger, &sourceModule, module}; - typeChecker.reduceTypes(); + LUAU_TIMETRACE_SCOPE("check", "Typechecking"); + + TypeChecker2 typeChecker{builtinTypes, unifierState, limits, logger, &sourceModule, module}; + typeChecker.visit(sourceModule.root); + // if the only error we're producing is one about constraint solving being incomplete, we can silence it. + // this means we won't give this warning if types seem totally nonsensical, but there are no other errors. + // this is probably, on the whole, a good decision to not annoy users though. + if (module->errors.size() == 1 && get(module->errors[0])) + module->errors.clear(); + unfreeze(module->interfaceTypes); - copyErrors(module->errors, module->interfaceTypes); + copyErrors(module->errors, module->interfaceTypes, builtinTypes); freeze(module->interfaceTypes); } diff --git a/third_party/luau/Analysis/src/TypeFunction.cpp b/third_party/luau/Analysis/src/TypeFunction.cpp new file mode 100644 index 00000000..76fa18f6 --- /dev/null +++ b/third_party/luau/Analysis/src/TypeFunction.cpp @@ -0,0 +1,2335 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/TypeFunction.h" + +#include "Luau/Common.h" +#include "Luau/ConstraintSolver.h" +#include "Luau/DenseHash.h" +#include "Luau/Instantiation.h" +#include "Luau/Normalize.h" +#include "Luau/NotNull.h" +#include "Luau/OverloadResolution.h" +#include "Luau/Set.h" +#include "Luau/Simplify.h" +#include "Luau/Subtyping.h" +#include "Luau/ToString.h" +#include "Luau/TxnLog.h" +#include "Luau/Type.h" +#include "Luau/TypeFunctionReductionGuesser.h" +#include "Luau/TypeFwd.h" +#include "Luau/TypeUtils.h" +#include "Luau/Unifier2.h" +#include "Luau/VecDeque.h" +#include "Luau/VisitType.h" + +#include + +// used to control emitting CodeTooComplex warnings on type function reduction +LUAU_DYNAMIC_FASTINTVARIABLE(LuauTypeFamilyGraphReductionMaximumSteps, 1'000'000); + +// used to control the limits of type function application over union type arguments +// e.g. `mul` blows up into `mul | mul | mul | mul` +LUAU_DYNAMIC_FASTINTVARIABLE(LuauTypeFamilyApplicationCartesianProductLimit, 5'000); + +// used to control falling back to a more conservative reduction based on guessing +// when this value is set to a negative value, guessing will be totally disabled. +LUAU_DYNAMIC_FASTINTVARIABLE(LuauTypeFamilyUseGuesserDepth, -1); + +LUAU_FASTFLAGVARIABLE(DebugLuauLogTypeFamilies, false); + +namespace Luau +{ + +using TypeOrTypePackIdSet = DenseHashSet; + +struct InstanceCollector : TypeOnceVisitor +{ + VecDeque tys; + VecDeque tps; + TypeOrTypePackIdSet shouldGuess{nullptr}; + std::vector cyclicInstance; + + bool visit(TypeId ty, const TypeFunctionInstanceType&) override + { + // TypeOnceVisitor performs a depth-first traversal in the absence of + // cycles. This means that by pushing to the front of the queue, we will + // try to reduce deeper instances first if we start with the first thing + // in the queue. Consider Add, number>, number>: + // we want to reduce the innermost Add instantiation + // first. + + if (DFInt::LuauTypeFamilyUseGuesserDepth >= 0 && typeFunctionDepth > DFInt::LuauTypeFamilyUseGuesserDepth) + shouldGuess.insert(ty); + + tys.push_front(ty); + + return true; + } + + void cycle(TypeId ty) override + { + /// Detected cyclic type pack + TypeId t = follow(ty); + if (get(t)) + cyclicInstance.push_back(t); + } + + bool visit(TypeId ty, const ClassType&) override + { + return false; + } + + bool visit(TypePackId tp, const TypeFunctionInstanceTypePack&) override + { + // TypeOnceVisitor performs a depth-first traversal in the absence of + // cycles. This means that by pushing to the front of the queue, we will + // try to reduce deeper instances first if we start with the first thing + // in the queue. Consider Add, number>, number>: + // we want to reduce the innermost Add instantiation + // first. + + if (DFInt::LuauTypeFamilyUseGuesserDepth >= 0 && typeFunctionDepth > DFInt::LuauTypeFamilyUseGuesserDepth) + shouldGuess.insert(tp); + + tps.push_front(tp); + + return true; + } +}; + +struct TypeFunctionReducer +{ + TypeFunctionContext ctx; + + VecDeque queuedTys; + VecDeque queuedTps; + TypeOrTypePackIdSet shouldGuess; + std::vector cyclicTypeFunctions; + TypeOrTypePackIdSet irreducible{nullptr}; + FunctionGraphReductionResult result; + bool force = false; + + // Local to the constraint being reduced. + Location location; + + TypeFunctionReducer( + VecDeque queuedTys, + VecDeque queuedTps, + TypeOrTypePackIdSet shouldGuess, + std::vector cyclicTypes, + Location location, + TypeFunctionContext ctx, + bool force = false + ) + : ctx(ctx) + , queuedTys(std::move(queuedTys)) + , queuedTps(std::move(queuedTps)) + , shouldGuess(std::move(shouldGuess)) + , cyclicTypeFunctions(std::move(cyclicTypes)) + , force(force) + , location(location) + { + } + + enum class SkipTestResult + { + CyclicTypeFunction, + Irreducible, + Defer, + Okay, + }; + + SkipTestResult testForSkippability(TypeId ty) + { + ty = follow(ty); + + if (is(ty)) + { + for (auto t : cyclicTypeFunctions) + { + if (ty == t) + return SkipTestResult::CyclicTypeFunction; + } + + if (!irreducible.contains(ty)) + return SkipTestResult::Defer; + + return SkipTestResult::Irreducible; + } + else if (is(ty)) + { + return SkipTestResult::Irreducible; + } + + return SkipTestResult::Okay; + } + + SkipTestResult testForSkippability(TypePackId ty) + { + ty = follow(ty); + + if (is(ty)) + { + if (!irreducible.contains(ty)) + return SkipTestResult::Defer; + else + return SkipTestResult::Irreducible; + } + else if (is(ty)) + { + return SkipTestResult::Irreducible; + } + + return SkipTestResult::Okay; + } + + template + void replace(T subject, T replacement) + { + if (subject->owningArena != ctx.arena.get()) + { + result.errors.emplace_back(location, InternalError{"Attempting to modify a type function instance from another arena"}); + return; + } + + if (FFlag::DebugLuauLogTypeFamilies) + printf("%s -> %s\n", toString(subject, {true}).c_str(), toString(replacement, {true}).c_str()); + + asMutable(subject)->ty.template emplace>(replacement); + + if constexpr (std::is_same_v) + result.reducedTypes.insert(subject); + else if constexpr (std::is_same_v) + result.reducedPacks.insert(subject); + } + + template + void handleTypeFunctionReduction(T subject, TypeFunctionReductionResult reduction) + { + if (reduction.result) + replace(subject, *reduction.result); + else + { + irreducible.insert(subject); + + if (reduction.uninhabited || force) + { + if (FFlag::DebugLuauLogTypeFamilies) + printf("%s is uninhabited\n", toString(subject, {true}).c_str()); + + if constexpr (std::is_same_v) + result.errors.push_back(TypeError{location, UninhabitedTypeFunction{subject}}); + else if constexpr (std::is_same_v) + result.errors.push_back(TypeError{location, UninhabitedTypePackFunction{subject}}); + } + else if (!reduction.uninhabited && !force) + { + if (FFlag::DebugLuauLogTypeFamilies) + printf( + "%s is irreducible; blocked on %zu types, %zu packs\n", + toString(subject, {true}).c_str(), + reduction.blockedTypes.size(), + reduction.blockedPacks.size() + ); + + for (TypeId b : reduction.blockedTypes) + result.blockedTypes.insert(b); + + for (TypePackId b : reduction.blockedPacks) + result.blockedPacks.insert(b); + } + } + } + + bool done() + { + return queuedTys.empty() && queuedTps.empty(); + } + + template + bool testParameters(T subject, const I* tfit) + { + for (TypeId p : tfit->typeArguments) + { + SkipTestResult skip = testForSkippability(p); + + if (skip == SkipTestResult::Irreducible) + { + if (FFlag::DebugLuauLogTypeFamilies) + printf("%s is irreducible due to a dependency on %s\n", toString(subject, {true}).c_str(), toString(p, {true}).c_str()); + + irreducible.insert(subject); + return false; + } + else if (skip == SkipTestResult::Defer) + { + if (FFlag::DebugLuauLogTypeFamilies) + printf("Deferring %s until %s is solved\n", toString(subject, {true}).c_str(), toString(p, {true}).c_str()); + + if constexpr (std::is_same_v) + queuedTys.push_back(subject); + else if constexpr (std::is_same_v) + queuedTps.push_back(subject); + + return false; + } + } + + for (TypePackId p : tfit->packArguments) + { + SkipTestResult skip = testForSkippability(p); + + if (skip == SkipTestResult::Irreducible) + { + if (FFlag::DebugLuauLogTypeFamilies) + printf("%s is irreducible due to a dependency on %s\n", toString(subject, {true}).c_str(), toString(p, {true}).c_str()); + + irreducible.insert(subject); + return false; + } + else if (skip == SkipTestResult::Defer) + { + if (FFlag::DebugLuauLogTypeFamilies) + printf("Deferring %s until %s is solved\n", toString(subject, {true}).c_str(), toString(p, {true}).c_str()); + + if constexpr (std::is_same_v) + queuedTys.push_back(subject); + else if constexpr (std::is_same_v) + queuedTps.push_back(subject); + + return false; + } + } + + return true; + } + + template + inline bool tryGuessing(TID subject) + { + if (shouldGuess.contains(subject)) + { + if (FFlag::DebugLuauLogTypeFamilies) + printf("Flagged %s for reduction with guesser.\n", toString(subject, {true}).c_str()); + + TypeFunctionReductionGuesser guesser{ctx.arena, ctx.builtins, ctx.normalizer}; + auto guessed = guesser.guess(subject); + + if (guessed) + { + if (FFlag::DebugLuauLogTypeFamilies) + printf("Selected %s as the guessed result type.\n", toString(*guessed, {true}).c_str()); + + replace(subject, *guessed); + return true; + } + + if (FFlag::DebugLuauLogTypeFamilies) + printf("Failed to produce a guess for the result of %s.\n", toString(subject, {true}).c_str()); + } + + return false; + } + + + void stepType() + { + TypeId subject = follow(queuedTys.front()); + queuedTys.pop_front(); + + if (irreducible.contains(subject)) + return; + + if (FFlag::DebugLuauLogTypeFamilies) + printf("Trying to reduce %s\n", toString(subject, {true}).c_str()); + + if (const TypeFunctionInstanceType* tfit = get(subject)) + { + SkipTestResult testCyclic = testForSkippability(subject); + + if (!testParameters(subject, tfit) && testCyclic != SkipTestResult::CyclicTypeFunction) + { + if (FFlag::DebugLuauLogTypeFamilies) + printf("Irreducible due to irreducible/pending and a non-cyclic function\n"); + + return; + } + + if (tryGuessing(subject)) + return; + + TypeFunctionReductionResult result = tfit->function->reducer(subject, tfit->typeArguments, tfit->packArguments, NotNull{&ctx}); + handleTypeFunctionReduction(subject, result); + } + } + + void stepPack() + { + TypePackId subject = follow(queuedTps.front()); + queuedTps.pop_front(); + + if (irreducible.contains(subject)) + return; + + if (FFlag::DebugLuauLogTypeFamilies) + printf("Trying to reduce %s\n", toString(subject, {true}).c_str()); + + if (const TypeFunctionInstanceTypePack* tfit = get(subject)) + { + if (!testParameters(subject, tfit)) + return; + + if (tryGuessing(subject)) + return; + + TypeFunctionReductionResult result = + tfit->function->reducer(subject, tfit->typeArguments, tfit->packArguments, NotNull{&ctx}); + handleTypeFunctionReduction(subject, result); + } + } + + void step() + { + if (!queuedTys.empty()) + stepType(); + else if (!queuedTps.empty()) + stepPack(); + } +}; + +static FunctionGraphReductionResult reduceFunctionsInternal( + VecDeque queuedTys, + VecDeque queuedTps, + TypeOrTypePackIdSet shouldGuess, + std::vector cyclics, + Location location, + TypeFunctionContext ctx, + bool force +) +{ + TypeFunctionReducer reducer{std::move(queuedTys), std::move(queuedTps), std::move(shouldGuess), std::move(cyclics), location, ctx, force}; + int iterationCount = 0; + + while (!reducer.done()) + { + reducer.step(); + + ++iterationCount; + if (iterationCount > DFInt::LuauTypeFamilyGraphReductionMaximumSteps) + { + reducer.result.errors.push_back(TypeError{location, CodeTooComplex{}}); + break; + } + } + + return std::move(reducer.result); +} + +FunctionGraphReductionResult reduceTypeFunctions(TypeId entrypoint, Location location, TypeFunctionContext ctx, bool force) +{ + InstanceCollector collector; + + try + { + collector.traverse(entrypoint); + } + catch (RecursionLimitException&) + { + return FunctionGraphReductionResult{}; + } + + if (collector.tys.empty() && collector.tps.empty()) + return {}; + + return reduceFunctionsInternal( + std::move(collector.tys), + std::move(collector.tps), + std::move(collector.shouldGuess), + std::move(collector.cyclicInstance), + location, + ctx, + force + ); +} + +FunctionGraphReductionResult reduceTypeFunctions(TypePackId entrypoint, Location location, TypeFunctionContext ctx, bool force) +{ + InstanceCollector collector; + + try + { + collector.traverse(entrypoint); + } + catch (RecursionLimitException&) + { + return FunctionGraphReductionResult{}; + } + + if (collector.tys.empty() && collector.tps.empty()) + return {}; + + return reduceFunctionsInternal( + std::move(collector.tys), + std::move(collector.tps), + std::move(collector.shouldGuess), + std::move(collector.cyclicInstance), + location, + ctx, + force + ); +} + +bool isPending(TypeId ty, ConstraintSolver* solver) +{ + return is(ty) || (solver && solver->hasUnresolvedConstraints(ty)); +} + +template +static std::optional> tryDistributeTypeFunctionApp( + F f, + TypeId instance, + const std::vector& typeParams, + const std::vector& packParams, + NotNull ctx, + Args&&... args +) +{ + // op (a | b) (c | d) ~ (op a (c | d)) | (op b (c | d)) ~ (op a c) | (op a d) | (op b c) | (op b d) + bool uninhabited = false; + std::vector blockedTypes; + std::vector results; + size_t cartesianProductSize = 1; + + const UnionType* firstUnion = nullptr; + size_t unionIndex; + + std::vector arguments = typeParams; + for (size_t i = 0; i < arguments.size(); ++i) + { + const UnionType* ut = get(follow(arguments[i])); + if (!ut) + continue; + + // We want to find the first union type in the set of arguments to distribute that one and only that one union. + // The function `f` we have is recursive, so `arguments[unionIndex]` will be updated in-place for each option in + // the union we've found in this context, so that index will no longer be a union type. Any other arguments at + // index + 1 or after will instead be distributed, if those are a union, which will be subjected to the same rules. + if (!firstUnion && ut) + { + firstUnion = ut; + unionIndex = i; + } + + cartesianProductSize *= std::distance(begin(ut), end(ut)); + + // TODO: We'd like to report that the type function application is too complex here. + if (size_t(DFInt::LuauTypeFamilyApplicationCartesianProductLimit) <= cartesianProductSize) + return {{std::nullopt, true, {}, {}}}; + } + + if (!firstUnion) + { + // If we couldn't find any union type argument, we're not distributing. + return std::nullopt; + } + + for (TypeId option : firstUnion) + { + arguments[unionIndex] = option; + + TypeFunctionReductionResult result = f(instance, arguments, packParams, ctx, args...); + blockedTypes.insert(blockedTypes.end(), result.blockedTypes.begin(), result.blockedTypes.end()); + uninhabited |= result.uninhabited; + + if (result.uninhabited || !result.result) + break; + else + results.push_back(*result.result); + } + + if (uninhabited || !blockedTypes.empty()) + return {{std::nullopt, uninhabited, blockedTypes, {}}}; + + if (!results.empty()) + { + if (results.size() == 1) + return {{results[0], false, {}, {}}}; + + TypeId resultTy = ctx->arena->addType(TypeFunctionInstanceType{ + NotNull{&builtinTypeFunctions().unionFunc}, + std::move(results), + {}, + }); + + return {{resultTy, false, {}, {}}}; + } + + return std::nullopt; +} + +TypeFunctionReductionResult notTypeFunction( + TypeId instance, + const std::vector& typeParams, + const std::vector& packParams, + NotNull ctx +) +{ + if (typeParams.size() != 1 || !packParams.empty()) + { + ctx->ice->ice("not type function: encountered a type function instance without the required argument structure"); + LUAU_ASSERT(false); + } + + TypeId ty = follow(typeParams.at(0)); + + if (ty == instance) + return {ctx->builtins->neverType, false, {}, {}}; + + if (isPending(ty, ctx->solver)) + return {std::nullopt, false, {ty}, {}}; + + if (auto result = tryDistributeTypeFunctionApp(notTypeFunction, instance, typeParams, packParams, ctx)) + return *result; + + // `not` operates on anything and returns a `boolean` always. + return {ctx->builtins->booleanType, false, {}, {}}; +} + +TypeFunctionReductionResult lenTypeFunction( + TypeId instance, + const std::vector& typeParams, + const std::vector& packParams, + NotNull ctx +) +{ + if (typeParams.size() != 1 || !packParams.empty()) + { + ctx->ice->ice("len type function: encountered a type function instance without the required argument structure"); + LUAU_ASSERT(false); + } + + TypeId operandTy = follow(typeParams.at(0)); + + if (operandTy == instance) + return {ctx->builtins->neverType, false, {}, {}}; + + // check to see if the operand type is resolved enough, and wait to reduce if not + // the use of `typeFromNormal` later necessitates blocking on local types. + if (isPending(operandTy, ctx->solver)) + return {std::nullopt, false, {operandTy}, {}}; + + // if the type is free but has only one remaining reference, we can generalize it to its upper bound here. + if (ctx->solver) + { + std::optional maybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, operandTy, /* avoidSealingTables */ true); + if (!maybeGeneralized) + return {std::nullopt, false, {operandTy}, {}}; + operandTy = *maybeGeneralized; + } + + std::shared_ptr normTy = ctx->normalizer->normalize(operandTy); + NormalizationResult inhabited = ctx->normalizer->isInhabited(normTy.get()); + + // if the type failed to normalize, we can't reduce, but know nothing about inhabitance. + if (!normTy || inhabited == NormalizationResult::HitLimits) + return {std::nullopt, false, {}, {}}; + + // if the operand type is error suppressing, we can immediately reduce to `number`. + if (normTy->shouldSuppressErrors()) + return {ctx->builtins->numberType, false, {}, {}}; + + // if we have an uninhabited type (like `never`), we can never observe that the operator didn't work. + if (inhabited == NormalizationResult::False) + return {ctx->builtins->neverType, false, {}, {}}; + + // if we're checking the length of a string, that works! + if (normTy->isSubtypeOfString()) + return {ctx->builtins->numberType, false, {}, {}}; + + // we use the normalized operand here in case there was an intersection or union. + TypeId normalizedOperand = ctx->normalizer->typeFromNormal(*normTy); + if (normTy->hasTopTable() || get(normalizedOperand)) + return {ctx->builtins->numberType, false, {}, {}}; + + if (auto result = tryDistributeTypeFunctionApp(notTypeFunction, instance, typeParams, packParams, ctx)) + return *result; + + // findMetatableEntry demands the ability to emit errors, so we must give it + // the necessary state to do that, even if we intend to just eat the errors. + ErrorVec dummy; + + std::optional mmType = findMetatableEntry(ctx->builtins, dummy, operandTy, "__len", Location{}); + if (!mmType) + return {std::nullopt, true, {}, {}}; + + mmType = follow(*mmType); + if (isPending(*mmType, ctx->solver)) + return {std::nullopt, false, {*mmType}, {}}; + + const FunctionType* mmFtv = get(*mmType); + if (!mmFtv) + return {std::nullopt, true, {}, {}}; + + std::optional instantiatedMmType = instantiate(ctx->builtins, ctx->arena, ctx->limits, ctx->scope, *mmType); + if (!instantiatedMmType) + return {std::nullopt, true, {}, {}}; + + const FunctionType* instantiatedMmFtv = get(*instantiatedMmType); + if (!instantiatedMmFtv) + return {ctx->builtins->errorRecoveryType(), false, {}, {}}; + + TypePackId inferredArgPack = ctx->arena->addTypePack({operandTy}); + Unifier2 u2{ctx->arena, ctx->builtins, ctx->scope, ctx->ice}; + if (!u2.unify(inferredArgPack, instantiatedMmFtv->argTypes)) + return {std::nullopt, true, {}, {}}; // occurs check failed + + Subtyping subtyping{ctx->builtins, ctx->arena, ctx->normalizer, ctx->ice, ctx->scope}; + if (!subtyping.isSubtype(inferredArgPack, instantiatedMmFtv->argTypes).isSubtype) // TODO: is this the right variance? + return {std::nullopt, true, {}, {}}; + + // `len` must return a `number`. + return {ctx->builtins->numberType, false, {}, {}}; +} + +TypeFunctionReductionResult unmTypeFunction( + TypeId instance, + const std::vector& typeParams, + const std::vector& packParams, + NotNull ctx +) +{ + if (typeParams.size() != 1 || !packParams.empty()) + { + ctx->ice->ice("unm type function: encountered a type function instance without the required argument structure"); + LUAU_ASSERT(false); + } + + TypeId operandTy = follow(typeParams.at(0)); + + if (operandTy == instance) + return {ctx->builtins->neverType, false, {}, {}}; + + // check to see if the operand type is resolved enough, and wait to reduce if not + if (isPending(operandTy, ctx->solver)) + return {std::nullopt, false, {operandTy}, {}}; + + // if the type is free but has only one remaining reference, we can generalize it to its upper bound here. + if (ctx->solver) + { + std::optional maybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, operandTy); + if (!maybeGeneralized) + return {std::nullopt, false, {operandTy}, {}}; + operandTy = *maybeGeneralized; + } + + std::shared_ptr normTy = ctx->normalizer->normalize(operandTy); + + // if the operand failed to normalize, we can't reduce, but know nothing about inhabitance. + if (!normTy) + return {std::nullopt, false, {}, {}}; + + // if the operand is error suppressing, we can just go ahead and reduce. + if (normTy->shouldSuppressErrors()) + return {operandTy, false, {}, {}}; + + // if we have a `never`, we can never observe that the operation didn't work. + if (is(operandTy)) + return {ctx->builtins->neverType, false, {}, {}}; + + // If the type is exactly `number`, we can reduce now. + if (normTy->isExactlyNumber()) + return {ctx->builtins->numberType, false, {}, {}}; + + if (auto result = tryDistributeTypeFunctionApp(notTypeFunction, instance, typeParams, packParams, ctx)) + return *result; + + // findMetatableEntry demands the ability to emit errors, so we must give it + // the necessary state to do that, even if we intend to just eat the errors. + ErrorVec dummy; + + std::optional mmType = findMetatableEntry(ctx->builtins, dummy, operandTy, "__unm", Location{}); + if (!mmType) + return {std::nullopt, true, {}, {}}; + + mmType = follow(*mmType); + if (isPending(*mmType, ctx->solver)) + return {std::nullopt, false, {*mmType}, {}}; + + const FunctionType* mmFtv = get(*mmType); + if (!mmFtv) + return {std::nullopt, true, {}, {}}; + + std::optional instantiatedMmType = instantiate(ctx->builtins, ctx->arena, ctx->limits, ctx->scope, *mmType); + if (!instantiatedMmType) + return {std::nullopt, true, {}, {}}; + + const FunctionType* instantiatedMmFtv = get(*instantiatedMmType); + if (!instantiatedMmFtv) + return {ctx->builtins->errorRecoveryType(), false, {}, {}}; + + TypePackId inferredArgPack = ctx->arena->addTypePack({operandTy}); + Unifier2 u2{ctx->arena, ctx->builtins, ctx->scope, ctx->ice}; + if (!u2.unify(inferredArgPack, instantiatedMmFtv->argTypes)) + return {std::nullopt, true, {}, {}}; // occurs check failed + + Subtyping subtyping{ctx->builtins, ctx->arena, ctx->normalizer, ctx->ice, ctx->scope}; + if (!subtyping.isSubtype(inferredArgPack, instantiatedMmFtv->argTypes).isSubtype) // TODO: is this the right variance? + return {std::nullopt, true, {}, {}}; + + if (std::optional ret = first(instantiatedMmFtv->retTypes)) + return {*ret, false, {}, {}}; + else + return {std::nullopt, true, {}, {}}; +} + +NotNull TypeFunctionContext::pushConstraint(ConstraintV&& c) +{ + LUAU_ASSERT(solver); + NotNull newConstraint = solver->pushConstraint(scope, constraint ? constraint->location : Location{}, std::move(c)); + + // Every constraint that is blocked on the current constraint must also be + // blocked on this new one. + if (constraint) + solver->inheritBlocks(NotNull{constraint}, newConstraint); + + return newConstraint; +} + +TypeFunctionReductionResult numericBinopTypeFunction( + TypeId instance, + const std::vector& typeParams, + const std::vector& packParams, + NotNull ctx, + const std::string metamethod +) +{ + if (typeParams.size() != 2 || !packParams.empty()) + { + ctx->ice->ice("encountered a type function instance without the required argument structure"); + LUAU_ASSERT(false); + } + + TypeId lhsTy = follow(typeParams.at(0)); + TypeId rhsTy = follow(typeParams.at(1)); + + // isPending of `lhsTy` or `rhsTy` would return true, even if it cycles. We want a different answer for that. + if (lhsTy == instance || rhsTy == instance) + return {ctx->builtins->neverType, false, {}, {}}; + + // if we have a `never`, we can never observe that the math operator is unreachable. + if (is(lhsTy) || is(rhsTy)) + return {ctx->builtins->neverType, false, {}, {}}; + + const Location location = ctx->constraint ? ctx->constraint->location : Location{}; + + // check to see if both operand types are resolved enough, and wait to reduce if not + if (isPending(lhsTy, ctx->solver)) + return {std::nullopt, false, {lhsTy}, {}}; + else if (isPending(rhsTy, ctx->solver)) + return {std::nullopt, false, {rhsTy}, {}}; + + // if either type is free but has only one remaining reference, we can generalize it to its upper bound here. + if (ctx->solver) + { + std::optional lhsMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, lhsTy); + std::optional rhsMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, rhsTy); + + if (!lhsMaybeGeneralized) + return {std::nullopt, false, {lhsTy}, {}}; + else if (!rhsMaybeGeneralized) + return {std::nullopt, false, {rhsTy}, {}}; + + lhsTy = *lhsMaybeGeneralized; + rhsTy = *rhsMaybeGeneralized; + } + + // TODO: Normalization needs to remove cyclic type functions from a `NormalizedType`. + std::shared_ptr normLhsTy = ctx->normalizer->normalize(lhsTy); + std::shared_ptr normRhsTy = ctx->normalizer->normalize(rhsTy); + + // if either failed to normalize, we can't reduce, but know nothing about inhabitance. + if (!normLhsTy || !normRhsTy) + return {std::nullopt, false, {}, {}}; + + // if one of the types is error suppressing, we can reduce to `any` since we should suppress errors in the result of the usage. + if (normLhsTy->shouldSuppressErrors() || normRhsTy->shouldSuppressErrors()) + return {ctx->builtins->anyType, false, {}, {}}; + + // if we're adding two `number` types, the result is `number`. + if (normLhsTy->isExactlyNumber() && normRhsTy->isExactlyNumber()) + return {ctx->builtins->numberType, false, {}, {}}; + + if (auto result = tryDistributeTypeFunctionApp(numericBinopTypeFunction, instance, typeParams, packParams, ctx, metamethod)) + return *result; + + // findMetatableEntry demands the ability to emit errors, so we must give it + // the necessary state to do that, even if we intend to just eat the errors. + ErrorVec dummy; + + std::optional mmType = findMetatableEntry(ctx->builtins, dummy, lhsTy, metamethod, location); + bool reversed = false; + if (!mmType) + { + mmType = findMetatableEntry(ctx->builtins, dummy, rhsTy, metamethod, location); + reversed = true; + } + + if (!mmType) + return {std::nullopt, true, {}, {}}; + + mmType = follow(*mmType); + if (isPending(*mmType, ctx->solver)) + return {std::nullopt, false, {*mmType}, {}}; + + TypePackId argPack = ctx->arena->addTypePack({lhsTy, rhsTy}); + SolveResult solveResult; + + if (!reversed) + solveResult = solveFunctionCall(ctx->arena, ctx->builtins, ctx->normalizer, ctx->ice, ctx->limits, ctx->scope, location, *mmType, argPack); + else + { + TypePack* p = getMutable(argPack); + std::swap(p->head.front(), p->head.back()); + solveResult = solveFunctionCall(ctx->arena, ctx->builtins, ctx->normalizer, ctx->ice, ctx->limits, ctx->scope, location, *mmType, argPack); + } + + if (!solveResult.typePackId.has_value()) + return {std::nullopt, true, {}, {}}; + + TypePack extracted = extendTypePack(*ctx->arena, ctx->builtins, *solveResult.typePackId, 1); + if (extracted.head.empty()) + return {std::nullopt, true, {}, {}}; + + return {extracted.head.front(), false, {}, {}}; +} + +TypeFunctionReductionResult addTypeFunction( + TypeId instance, + const std::vector& typeParams, + const std::vector& packParams, + NotNull ctx +) +{ + if (typeParams.size() != 2 || !packParams.empty()) + { + ctx->ice->ice("add type function: encountered a type function instance without the required argument structure"); + LUAU_ASSERT(false); + } + + return numericBinopTypeFunction(instance, typeParams, packParams, ctx, "__add"); +} + +TypeFunctionReductionResult subTypeFunction( + TypeId instance, + const std::vector& typeParams, + const std::vector& packParams, + NotNull ctx +) +{ + if (typeParams.size() != 2 || !packParams.empty()) + { + ctx->ice->ice("sub type function: encountered a type function instance without the required argument structure"); + LUAU_ASSERT(false); + } + + return numericBinopTypeFunction(instance, typeParams, packParams, ctx, "__sub"); +} + +TypeFunctionReductionResult mulTypeFunction( + TypeId instance, + const std::vector& typeParams, + const std::vector& packParams, + NotNull ctx +) +{ + if (typeParams.size() != 2 || !packParams.empty()) + { + ctx->ice->ice("mul type function: encountered a type function instance without the required argument structure"); + LUAU_ASSERT(false); + } + + return numericBinopTypeFunction(instance, typeParams, packParams, ctx, "__mul"); +} + +TypeFunctionReductionResult divTypeFunction( + TypeId instance, + const std::vector& typeParams, + const std::vector& packParams, + NotNull ctx +) +{ + if (typeParams.size() != 2 || !packParams.empty()) + { + ctx->ice->ice("div type function: encountered a type function instance without the required argument structure"); + LUAU_ASSERT(false); + } + + return numericBinopTypeFunction(instance, typeParams, packParams, ctx, "__div"); +} + +TypeFunctionReductionResult idivTypeFunction( + TypeId instance, + const std::vector& typeParams, + const std::vector& packParams, + NotNull ctx +) +{ + if (typeParams.size() != 2 || !packParams.empty()) + { + ctx->ice->ice("integer div type function: encountered a type function instance without the required argument structure"); + LUAU_ASSERT(false); + } + + return numericBinopTypeFunction(instance, typeParams, packParams, ctx, "__idiv"); +} + +TypeFunctionReductionResult powTypeFunction( + TypeId instance, + const std::vector& typeParams, + const std::vector& packParams, + NotNull ctx +) +{ + if (typeParams.size() != 2 || !packParams.empty()) + { + ctx->ice->ice("pow type function: encountered a type function instance without the required argument structure"); + LUAU_ASSERT(false); + } + + return numericBinopTypeFunction(instance, typeParams, packParams, ctx, "__pow"); +} + +TypeFunctionReductionResult modTypeFunction( + TypeId instance, + const std::vector& typeParams, + const std::vector& packParams, + NotNull ctx +) +{ + if (typeParams.size() != 2 || !packParams.empty()) + { + ctx->ice->ice("modulo type function: encountered a type function instance without the required argument structure"); + LUAU_ASSERT(false); + } + + return numericBinopTypeFunction(instance, typeParams, packParams, ctx, "__mod"); +} + +TypeFunctionReductionResult concatTypeFunction( + TypeId instance, + const std::vector& typeParams, + const std::vector& packParams, + NotNull ctx +) +{ + if (typeParams.size() != 2 || !packParams.empty()) + { + ctx->ice->ice("concat type function: encountered a type function instance without the required argument structure"); + LUAU_ASSERT(false); + } + + TypeId lhsTy = follow(typeParams.at(0)); + TypeId rhsTy = follow(typeParams.at(1)); + + // isPending of `lhsTy` or `rhsTy` would return true, even if it cycles. We want a different answer for that. + if (lhsTy == instance || rhsTy == instance) + return {ctx->builtins->neverType, false, {}, {}}; + + // check to see if both operand types are resolved enough, and wait to reduce if not + if (isPending(lhsTy, ctx->solver)) + return {std::nullopt, false, {lhsTy}, {}}; + else if (isPending(rhsTy, ctx->solver)) + return {std::nullopt, false, {rhsTy}, {}}; + + // if either type is free but has only one remaining reference, we can generalize it to its upper bound here. + if (ctx->solver) + { + std::optional lhsMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, lhsTy); + std::optional rhsMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, rhsTy); + + if (!lhsMaybeGeneralized) + return {std::nullopt, false, {lhsTy}, {}}; + else if (!rhsMaybeGeneralized) + return {std::nullopt, false, {rhsTy}, {}}; + + lhsTy = *lhsMaybeGeneralized; + rhsTy = *rhsMaybeGeneralized; + } + + std::shared_ptr normLhsTy = ctx->normalizer->normalize(lhsTy); + std::shared_ptr normRhsTy = ctx->normalizer->normalize(rhsTy); + + // if either failed to normalize, we can't reduce, but know nothing about inhabitance. + if (!normLhsTy || !normRhsTy) + return {std::nullopt, false, {}, {}}; + + // if one of the types is error suppressing, we can reduce to `any` since we should suppress errors in the result of the usage. + if (normLhsTy->shouldSuppressErrors() || normRhsTy->shouldSuppressErrors()) + return {ctx->builtins->anyType, false, {}, {}}; + + // if we have a `never`, we can never observe that the numeric operator didn't work. + if (is(lhsTy) || is(rhsTy)) + return {ctx->builtins->neverType, false, {}, {}}; + + // if we're concatenating two elements that are either strings or numbers, the result is `string`. + if ((normLhsTy->isSubtypeOfString() || normLhsTy->isExactlyNumber()) && (normRhsTy->isSubtypeOfString() || normRhsTy->isExactlyNumber())) + return {ctx->builtins->stringType, false, {}, {}}; + + if (auto result = tryDistributeTypeFunctionApp(concatTypeFunction, instance, typeParams, packParams, ctx)) + return *result; + + // findMetatableEntry demands the ability to emit errors, so we must give it + // the necessary state to do that, even if we intend to just eat the errors. + ErrorVec dummy; + + std::optional mmType = findMetatableEntry(ctx->builtins, dummy, lhsTy, "__concat", Location{}); + bool reversed = false; + if (!mmType) + { + mmType = findMetatableEntry(ctx->builtins, dummy, rhsTy, "__concat", Location{}); + reversed = true; + } + + if (!mmType) + return {std::nullopt, true, {}, {}}; + + mmType = follow(*mmType); + if (isPending(*mmType, ctx->solver)) + return {std::nullopt, false, {*mmType}, {}}; + + const FunctionType* mmFtv = get(*mmType); + if (!mmFtv) + return {std::nullopt, true, {}, {}}; + + std::optional instantiatedMmType = instantiate(ctx->builtins, ctx->arena, ctx->limits, ctx->scope, *mmType); + if (!instantiatedMmType) + return {std::nullopt, true, {}, {}}; + + const FunctionType* instantiatedMmFtv = get(*instantiatedMmType); + if (!instantiatedMmFtv) + return {ctx->builtins->errorRecoveryType(), false, {}, {}}; + + std::vector inferredArgs; + if (!reversed) + inferredArgs = {lhsTy, rhsTy}; + else + inferredArgs = {rhsTy, lhsTy}; + + TypePackId inferredArgPack = ctx->arena->addTypePack(std::move(inferredArgs)); + Unifier2 u2{ctx->arena, ctx->builtins, ctx->scope, ctx->ice}; + if (!u2.unify(inferredArgPack, instantiatedMmFtv->argTypes)) + return {std::nullopt, true, {}, {}}; // occurs check failed + + Subtyping subtyping{ctx->builtins, ctx->arena, ctx->normalizer, ctx->ice, ctx->scope}; + if (!subtyping.isSubtype(inferredArgPack, instantiatedMmFtv->argTypes).isSubtype) // TODO: is this the right variance? + return {std::nullopt, true, {}, {}}; + + return {ctx->builtins->stringType, false, {}, {}}; +} + +TypeFunctionReductionResult andTypeFunction( + TypeId instance, + const std::vector& typeParams, + const std::vector& packParams, + NotNull ctx +) +{ + if (typeParams.size() != 2 || !packParams.empty()) + { + ctx->ice->ice("and type function: encountered a type function instance without the required argument structure"); + LUAU_ASSERT(false); + } + + TypeId lhsTy = follow(typeParams.at(0)); + TypeId rhsTy = follow(typeParams.at(1)); + + // t1 = and ~> lhs + if (follow(rhsTy) == instance && lhsTy != rhsTy) + return {lhsTy, false, {}, {}}; + // t1 = and ~> rhs + if (follow(lhsTy) == instance && lhsTy != rhsTy) + return {rhsTy, false, {}, {}}; + + // check to see if both operand types are resolved enough, and wait to reduce if not + if (isPending(lhsTy, ctx->solver)) + return {std::nullopt, false, {lhsTy}, {}}; + else if (isPending(rhsTy, ctx->solver)) + return {std::nullopt, false, {rhsTy}, {}}; + + // if either type is free but has only one remaining reference, we can generalize it to its upper bound here. + if (ctx->solver) + { + std::optional lhsMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, lhsTy); + std::optional rhsMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, rhsTy); + + if (!lhsMaybeGeneralized) + return {std::nullopt, false, {lhsTy}, {}}; + else if (!rhsMaybeGeneralized) + return {std::nullopt, false, {rhsTy}, {}}; + + lhsTy = *lhsMaybeGeneralized; + rhsTy = *rhsMaybeGeneralized; + } + + // And evalutes to a boolean if the LHS is falsey, and the RHS type if LHS is truthy. + SimplifyResult filteredLhs = simplifyIntersection(ctx->builtins, ctx->arena, lhsTy, ctx->builtins->falsyType); + SimplifyResult overallResult = simplifyUnion(ctx->builtins, ctx->arena, rhsTy, filteredLhs.result); + std::vector blockedTypes{}; + for (auto ty : filteredLhs.blockedTypes) + blockedTypes.push_back(ty); + for (auto ty : overallResult.blockedTypes) + blockedTypes.push_back(ty); + return {overallResult.result, false, std::move(blockedTypes), {}}; +} + +TypeFunctionReductionResult orTypeFunction( + TypeId instance, + const std::vector& typeParams, + const std::vector& packParams, + NotNull ctx +) +{ + if (typeParams.size() != 2 || !packParams.empty()) + { + ctx->ice->ice("or type function: encountered a type function instance without the required argument structure"); + LUAU_ASSERT(false); + } + + TypeId lhsTy = follow(typeParams.at(0)); + TypeId rhsTy = follow(typeParams.at(1)); + + // t1 = or ~> lhs + if (follow(rhsTy) == instance && lhsTy != rhsTy) + return {lhsTy, false, {}, {}}; + // t1 = or ~> rhs + if (follow(lhsTy) == instance && lhsTy != rhsTy) + return {rhsTy, false, {}, {}}; + + // check to see if both operand types are resolved enough, and wait to reduce if not + if (isPending(lhsTy, ctx->solver)) + return {std::nullopt, false, {lhsTy}, {}}; + else if (isPending(rhsTy, ctx->solver)) + return {std::nullopt, false, {rhsTy}, {}}; + + // if either type is free but has only one remaining reference, we can generalize it to its upper bound here. + if (ctx->solver) + { + std::optional lhsMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, lhsTy); + std::optional rhsMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, rhsTy); + + if (!lhsMaybeGeneralized) + return {std::nullopt, false, {lhsTy}, {}}; + else if (!rhsMaybeGeneralized) + return {std::nullopt, false, {rhsTy}, {}}; + + lhsTy = *lhsMaybeGeneralized; + rhsTy = *rhsMaybeGeneralized; + } + + // Or evalutes to the LHS type if the LHS is truthy, and the RHS type if LHS is falsy. + SimplifyResult filteredLhs = simplifyIntersection(ctx->builtins, ctx->arena, lhsTy, ctx->builtins->truthyType); + SimplifyResult overallResult = simplifyUnion(ctx->builtins, ctx->arena, rhsTy, filteredLhs.result); + std::vector blockedTypes{}; + for (auto ty : filteredLhs.blockedTypes) + blockedTypes.push_back(ty); + for (auto ty : overallResult.blockedTypes) + blockedTypes.push_back(ty); + return {overallResult.result, false, std::move(blockedTypes), {}}; +} + +static TypeFunctionReductionResult comparisonTypeFunction( + TypeId instance, + const std::vector& typeParams, + const std::vector& packParams, + NotNull ctx, + const std::string metamethod +) +{ + + if (typeParams.size() != 2 || !packParams.empty()) + { + ctx->ice->ice("encountered a type function instance without the required argument structure"); + LUAU_ASSERT(false); + } + + TypeId lhsTy = follow(typeParams.at(0)); + TypeId rhsTy = follow(typeParams.at(1)); + + if (lhsTy == instance || rhsTy == instance) + return {ctx->builtins->neverType, false, {}, {}}; + + if (isPending(lhsTy, ctx->solver)) + return {std::nullopt, false, {lhsTy}, {}}; + else if (isPending(rhsTy, ctx->solver)) + return {std::nullopt, false, {rhsTy}, {}}; + + // Algebra Reduction Rules for comparison type functions + // Note that comparing to never tells you nothing about the other operand + // lt< 'a , never> -> continue + // lt< never, 'a> -> continue + // lt< 'a, t> -> 'a is t - we'll solve the constraint, return and solve lt -> bool + // lt< t, 'a> -> same as above + bool canSubmitConstraint = ctx->solver && ctx->constraint; + bool lhsFree = get(lhsTy) != nullptr; + bool rhsFree = get(rhsTy) != nullptr; + if (canSubmitConstraint) + { + // Implement injective type functions for comparison type functions + // lt implies t is number + // lt implies t is number + if (lhsFree && isNumber(rhsTy)) + emplaceType(asMutable(lhsTy), ctx->builtins->numberType); + else if (rhsFree && isNumber(lhsTy)) + emplaceType(asMutable(rhsTy), ctx->builtins->numberType); + else if (lhsFree && ctx->normalizer->isInhabited(rhsTy) != NormalizationResult::False) + { + auto c1 = ctx->pushConstraint(EqualityConstraint{lhsTy, rhsTy}); + const_cast(ctx->constraint)->dependencies.emplace_back(c1); + } + else if (rhsFree && ctx->normalizer->isInhabited(lhsTy) != NormalizationResult::False) + { + auto c1 = ctx->pushConstraint(EqualityConstraint{rhsTy, lhsTy}); + const_cast(ctx->constraint)->dependencies.emplace_back(c1); + } + } + + // The above might have caused the operand types to be rebound, we need to follow them again + lhsTy = follow(lhsTy); + rhsTy = follow(rhsTy); + + // if either type is free but has only one remaining reference, we can generalize it to its upper bound here. + if (ctx->solver) + { + std::optional lhsMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, lhsTy); + std::optional rhsMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, rhsTy); + + if (!lhsMaybeGeneralized) + return {std::nullopt, false, {lhsTy}, {}}; + else if (!rhsMaybeGeneralized) + return {std::nullopt, false, {rhsTy}, {}}; + + lhsTy = *lhsMaybeGeneralized; + rhsTy = *rhsMaybeGeneralized; + } + + // check to see if both operand types are resolved enough, and wait to reduce if not + + std::shared_ptr normLhsTy = ctx->normalizer->normalize(lhsTy); + std::shared_ptr normRhsTy = ctx->normalizer->normalize(rhsTy); + NormalizationResult lhsInhabited = ctx->normalizer->isInhabited(normLhsTy.get()); + NormalizationResult rhsInhabited = ctx->normalizer->isInhabited(normRhsTy.get()); + + // if either failed to normalize, we can't reduce, but know nothing about inhabitance. + if (!normLhsTy || !normRhsTy || lhsInhabited == NormalizationResult::HitLimits || rhsInhabited == NormalizationResult::HitLimits) + return {std::nullopt, false, {}, {}}; + + // if one of the types is error suppressing, we can just go ahead and reduce. + if (normLhsTy->shouldSuppressErrors() || normRhsTy->shouldSuppressErrors()) + return {ctx->builtins->booleanType, false, {}, {}}; + + // if we have an uninhabited type (e.g. `never`), we can never observe that the comparison didn't work. + if (lhsInhabited == NormalizationResult::False || rhsInhabited == NormalizationResult::False) + return {ctx->builtins->booleanType, false, {}, {}}; + + // If both types are some strict subset of `string`, we can reduce now. + if (normLhsTy->isSubtypeOfString() && normRhsTy->isSubtypeOfString()) + return {ctx->builtins->booleanType, false, {}, {}}; + + // If both types are exactly `number`, we can reduce now. + if (normLhsTy->isExactlyNumber() && normRhsTy->isExactlyNumber()) + return {ctx->builtins->booleanType, false, {}, {}}; + + if (auto result = tryDistributeTypeFunctionApp(comparisonTypeFunction, instance, typeParams, packParams, ctx, metamethod)) + return *result; + + // findMetatableEntry demands the ability to emit errors, so we must give it + // the necessary state to do that, even if we intend to just eat the errors. + ErrorVec dummy; + + std::optional mmType = findMetatableEntry(ctx->builtins, dummy, lhsTy, metamethod, Location{}); + if (!mmType) + mmType = findMetatableEntry(ctx->builtins, dummy, rhsTy, metamethod, Location{}); + + if (!mmType) + return {std::nullopt, true, {}, {}}; + + mmType = follow(*mmType); + if (isPending(*mmType, ctx->solver)) + return {std::nullopt, false, {*mmType}, {}}; + + const FunctionType* mmFtv = get(*mmType); + if (!mmFtv) + return {std::nullopt, true, {}, {}}; + + std::optional instantiatedMmType = instantiate(ctx->builtins, ctx->arena, ctx->limits, ctx->scope, *mmType); + if (!instantiatedMmType) + return {std::nullopt, true, {}, {}}; + + const FunctionType* instantiatedMmFtv = get(*instantiatedMmType); + if (!instantiatedMmFtv) + return {ctx->builtins->errorRecoveryType(), false, {}, {}}; + + TypePackId inferredArgPack = ctx->arena->addTypePack({lhsTy, rhsTy}); + Unifier2 u2{ctx->arena, ctx->builtins, ctx->scope, ctx->ice}; + if (!u2.unify(inferredArgPack, instantiatedMmFtv->argTypes)) + return {std::nullopt, true, {}, {}}; // occurs check failed + + Subtyping subtyping{ctx->builtins, ctx->arena, ctx->normalizer, ctx->ice, ctx->scope}; + if (!subtyping.isSubtype(inferredArgPack, instantiatedMmFtv->argTypes).isSubtype) // TODO: is this the right variance? + return {std::nullopt, true, {}, {}}; + + return {ctx->builtins->booleanType, false, {}, {}}; +} + +TypeFunctionReductionResult ltTypeFunction( + TypeId instance, + const std::vector& typeParams, + const std::vector& packParams, + NotNull ctx +) +{ + if (typeParams.size() != 2 || !packParams.empty()) + { + ctx->ice->ice("lt type function: encountered a type function instance without the required argument structure"); + LUAU_ASSERT(false); + } + + return comparisonTypeFunction(instance, typeParams, packParams, ctx, "__lt"); +} + +TypeFunctionReductionResult leTypeFunction( + TypeId instance, + const std::vector& typeParams, + const std::vector& packParams, + NotNull ctx +) +{ + if (typeParams.size() != 2 || !packParams.empty()) + { + ctx->ice->ice("le type function: encountered a type function instance without the required argument structure"); + LUAU_ASSERT(false); + } + + return comparisonTypeFunction(instance, typeParams, packParams, ctx, "__le"); +} + +TypeFunctionReductionResult eqTypeFunction( + TypeId instance, + const std::vector& typeParams, + const std::vector& packParams, + NotNull ctx +) +{ + if (typeParams.size() != 2 || !packParams.empty()) + { + ctx->ice->ice("eq type function: encountered a type function instance without the required argument structure"); + LUAU_ASSERT(false); + } + + TypeId lhsTy = follow(typeParams.at(0)); + TypeId rhsTy = follow(typeParams.at(1)); + + // check to see if both operand types are resolved enough, and wait to reduce if not + if (isPending(lhsTy, ctx->solver)) + return {std::nullopt, false, {lhsTy}, {}}; + else if (isPending(rhsTy, ctx->solver)) + return {std::nullopt, false, {rhsTy}, {}}; + + // if either type is free but has only one remaining reference, we can generalize it to its upper bound here. + if (ctx->solver) + { + std::optional lhsMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, lhsTy); + std::optional rhsMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, rhsTy); + + if (!lhsMaybeGeneralized) + return {std::nullopt, false, {lhsTy}, {}}; + else if (!rhsMaybeGeneralized) + return {std::nullopt, false, {rhsTy}, {}}; + + lhsTy = *lhsMaybeGeneralized; + rhsTy = *rhsMaybeGeneralized; + } + + std::shared_ptr normLhsTy = ctx->normalizer->normalize(lhsTy); + std::shared_ptr normRhsTy = ctx->normalizer->normalize(rhsTy); + NormalizationResult lhsInhabited = ctx->normalizer->isInhabited(normLhsTy.get()); + NormalizationResult rhsInhabited = ctx->normalizer->isInhabited(normRhsTy.get()); + + // if either failed to normalize, we can't reduce, but know nothing about inhabitance. + if (!normLhsTy || !normRhsTy || lhsInhabited == NormalizationResult::HitLimits || rhsInhabited == NormalizationResult::HitLimits) + return {std::nullopt, false, {}, {}}; + + // if one of the types is error suppressing, we can just go ahead and reduce. + if (normLhsTy->shouldSuppressErrors() || normRhsTy->shouldSuppressErrors()) + return {ctx->builtins->booleanType, false, {}, {}}; + + // if we have a `never`, we can never observe that the comparison didn't work. + if (lhsInhabited == NormalizationResult::False || rhsInhabited == NormalizationResult::False) + return {ctx->builtins->booleanType, false, {}, {}}; + + // findMetatableEntry demands the ability to emit errors, so we must give it + // the necessary state to do that, even if we intend to just eat the errors. + ErrorVec dummy; + + std::optional mmType = findMetatableEntry(ctx->builtins, dummy, lhsTy, "__eq", Location{}); + if (!mmType) + mmType = findMetatableEntry(ctx->builtins, dummy, rhsTy, "__eq", Location{}); + + // if neither type has a metatable entry for `__eq`, then we'll check for inhabitance of the intersection! + NormalizationResult intersectInhabited = ctx->normalizer->isIntersectionInhabited(lhsTy, rhsTy); + if (!mmType) + { + if (intersectInhabited == NormalizationResult::True) + return {ctx->builtins->booleanType, false, {}, {}}; // if it's inhabited, everything is okay! + + // we might be in a case where we still want to accept the comparison... + if (intersectInhabited == NormalizationResult::False) + { + // if they're both subtypes of `string` but have no common intersection, the comparison is allowed but always `false`. + if (normLhsTy->isSubtypeOfString() && normRhsTy->isSubtypeOfString()) + return {ctx->builtins->falseType, false, {}, {}}; + + // if they're both subtypes of `boolean` but have no common intersection, the comparison is allowed but always `false`. + if (normLhsTy->isSubtypeOfBooleans() && normRhsTy->isSubtypeOfBooleans()) + return {ctx->builtins->falseType, false, {}, {}}; + } + + return {std::nullopt, true, {}, {}}; // if it's not, then this type function is irreducible! + } + + mmType = follow(*mmType); + if (isPending(*mmType, ctx->solver)) + return {std::nullopt, false, {*mmType}, {}}; + + const FunctionType* mmFtv = get(*mmType); + if (!mmFtv) + return {std::nullopt, true, {}, {}}; + + std::optional instantiatedMmType = instantiate(ctx->builtins, ctx->arena, ctx->limits, ctx->scope, *mmType); + if (!instantiatedMmType) + return {std::nullopt, true, {}, {}}; + + const FunctionType* instantiatedMmFtv = get(*instantiatedMmType); + if (!instantiatedMmFtv) + return {ctx->builtins->errorRecoveryType(), false, {}, {}}; + + TypePackId inferredArgPack = ctx->arena->addTypePack({lhsTy, rhsTy}); + Unifier2 u2{ctx->arena, ctx->builtins, ctx->scope, ctx->ice}; + if (!u2.unify(inferredArgPack, instantiatedMmFtv->argTypes)) + return {std::nullopt, true, {}, {}}; // occurs check failed + + Subtyping subtyping{ctx->builtins, ctx->arena, ctx->normalizer, ctx->ice, ctx->scope}; + if (!subtyping.isSubtype(inferredArgPack, instantiatedMmFtv->argTypes).isSubtype) // TODO: is this the right variance? + return {std::nullopt, true, {}, {}}; + + return {ctx->builtins->booleanType, false, {}, {}}; +} + +// Collect types that prevent us from reducing a particular refinement. +struct FindRefinementBlockers : TypeOnceVisitor +{ + DenseHashSet found{nullptr}; + bool visit(TypeId ty, const BlockedType&) override + { + found.insert(ty); + return false; + } + + bool visit(TypeId ty, const PendingExpansionType&) override + { + found.insert(ty); + return false; + } + + bool visit(TypeId ty, const ClassType&) override + { + return false; + } +}; + + +TypeFunctionReductionResult refineTypeFunction( + TypeId instance, + const std::vector& typeParams, + const std::vector& packParams, + NotNull ctx +) +{ + if (typeParams.size() != 2 || !packParams.empty()) + { + ctx->ice->ice("refine type function: encountered a type function instance without the required argument structure"); + LUAU_ASSERT(false); + } + + TypeId targetTy = follow(typeParams.at(0)); + TypeId discriminantTy = follow(typeParams.at(1)); + + // check to see if both operand types are resolved enough, and wait to reduce if not + if (isPending(targetTy, ctx->solver)) + return {std::nullopt, false, {targetTy}, {}}; + else if (isPending(discriminantTy, ctx->solver)) + return {std::nullopt, false, {discriminantTy}, {}}; + + // if either type is free but has only one remaining reference, we can generalize it to its upper bound here. + if (ctx->solver) + { + std::optional targetMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, targetTy); + std::optional discriminantMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, discriminantTy); + + if (!targetMaybeGeneralized) + return {std::nullopt, false, {targetTy}, {}}; + else if (!discriminantMaybeGeneralized) + return {std::nullopt, false, {discriminantTy}, {}}; + + targetTy = *targetMaybeGeneralized; + discriminantTy = *discriminantMaybeGeneralized; + } + + // we need a more complex check for blocking on the discriminant in particular + FindRefinementBlockers frb; + frb.traverse(discriminantTy); + + if (!frb.found.empty()) + return {std::nullopt, false, {frb.found.begin(), frb.found.end()}, {}}; + + /* HACK: Refinements sometimes produce a type T & ~any under the assumption + * that ~any is the same as any. This is so so weird, but refinements needs + * some way to say "I may refine this, but I'm not sure." + * + * It does this by refining on a blocked type and deferring the decision + * until it is unblocked. + * + * Refinements also get negated, so we wind up with types like T & ~*blocked* + * + * We need to treat T & ~any as T in this case. + */ + + if (auto nt = get(discriminantTy)) + if (get(follow(nt->ty))) + return {targetTy, false, {}, {}}; + + // If the target type is a table, then simplification already implements the logic to deal with refinements properly since the + // type of the discriminant is guaranteed to only ever be an (arbitrarily-nested) table of a single property type. + if (get(targetTy)) + { + SimplifyResult result = simplifyIntersection(ctx->builtins, ctx->arena, targetTy, discriminantTy); + if (!result.blockedTypes.empty()) + return {std::nullopt, false, {result.blockedTypes.begin(), result.blockedTypes.end()}, {}}; + + return {result.result, false, {}, {}}; + } + + // In the general case, we'll still use normalization though. + TypeId intersection = ctx->arena->addType(IntersectionType{{targetTy, discriminantTy}}); + std::shared_ptr normIntersection = ctx->normalizer->normalize(intersection); + std::shared_ptr normType = ctx->normalizer->normalize(targetTy); + + // if the intersection failed to normalize, we can't reduce, but know nothing about inhabitance. + if (!normIntersection || !normType) + return {std::nullopt, false, {}, {}}; + + TypeId resultTy = ctx->normalizer->typeFromNormal(*normIntersection); + + // include the error type if the target type is error-suppressing and the intersection we computed is not + if (normType->shouldSuppressErrors() && !normIntersection->shouldSuppressErrors()) + resultTy = ctx->arena->addType(UnionType{{resultTy, ctx->builtins->errorType}}); + + return {resultTy, false, {}, {}}; +} + +TypeFunctionReductionResult singletonTypeFunction( + TypeId instance, + const std::vector& typeParams, + const std::vector& packParams, + NotNull ctx +) +{ + if (typeParams.size() != 1 || !packParams.empty()) + { + ctx->ice->ice("singleton type function: encountered a type function instance without the required argument structure"); + LUAU_ASSERT(false); + } + + TypeId type = follow(typeParams.at(0)); + + // check to see if both operand types are resolved enough, and wait to reduce if not + if (isPending(type, ctx->solver)) + return {std::nullopt, false, {type}, {}}; + + // if the type is free but has only one remaining reference, we can generalize it to its upper bound here. + if (ctx->solver) + { + std::optional maybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, type); + if (!maybeGeneralized) + return {std::nullopt, false, {type}, {}}; + type = *maybeGeneralized; + } + + TypeId followed = type; + // we want to follow through a negation here as well. + if (auto negation = get(followed)) + followed = follow(negation->ty); + + // if we have a singleton type or `nil`, which is its own singleton type... + if (get(followed) || isNil(followed)) + return {type, false, {}, {}}; + + // otherwise, we'll return the top type, `unknown`. + return {ctx->builtins->unknownType, false, {}, {}}; +} + +TypeFunctionReductionResult unionTypeFunction( + TypeId instance, + const std::vector& typeParams, + const std::vector& packParams, + NotNull ctx +) +{ + if (!packParams.empty()) + { + ctx->ice->ice("union type function: encountered a type function instance without the required argument structure"); + LUAU_ASSERT(false); + } + + // if we only have one parameter, there's nothing to do. + if (typeParams.size() == 1) + return {follow(typeParams[0]), false, {}, {}}; + + // we need to follow all of the type parameters. + std::vector types; + types.reserve(typeParams.size()); + for (auto ty : typeParams) + types.emplace_back(follow(ty)); + + // unfortunately, we need this short-circuit: if all but one type is `never`, we will return that one type. + // this also will early return if _everything_ is `never`, since we already have to check that. + std::optional lastType = std::nullopt; + for (auto ty : types) + { + // if we have a previous type and it's not `never` and the current type isn't `never`... + if (lastType && !get(lastType) && !get(ty)) + { + // we know we are not taking the short-circuited path. + lastType = std::nullopt; + break; + } + + if (get(ty)) + continue; + lastType = ty; + } + + // if we still have a `lastType` at the end, we're taking the short-circuit and reducing early. + if (lastType) + return {lastType, false, {}, {}}; + + // check to see if the operand types are resolved enough, and wait to reduce if not + for (auto ty : types) + if (isPending(ty, ctx->solver)) + return {std::nullopt, false, {ty}, {}}; + + // fold over the types with `simplifyUnion` + TypeId resultTy = ctx->builtins->neverType; + for (auto ty : types) + { + SimplifyResult result = simplifyUnion(ctx->builtins, ctx->arena, resultTy, ty); + if (!result.blockedTypes.empty()) + return {std::nullopt, false, {result.blockedTypes.begin(), result.blockedTypes.end()}, {}}; + + resultTy = result.result; + } + + return {resultTy, false, {}, {}}; +} + + +TypeFunctionReductionResult intersectTypeFunction( + TypeId instance, + const std::vector& typeParams, + const std::vector& packParams, + NotNull ctx +) +{ + if (!packParams.empty()) + { + ctx->ice->ice("intersect type function: encountered a type function instance without the required argument structure"); + LUAU_ASSERT(false); + } + + // if we only have one parameter, there's nothing to do. + if (typeParams.size() == 1) + return {follow(typeParams[0]), false, {}, {}}; + + // we need to follow all of the type parameters. + std::vector types; + types.reserve(typeParams.size()); + for (auto ty : typeParams) + types.emplace_back(follow(ty)); + + // check to see if the operand types are resolved enough, and wait to reduce if not + // if any of them are `never`, the intersection will always be `never`, so we can reduce directly. + for (auto ty : types) + { + if (isPending(ty, ctx->solver)) + return {std::nullopt, false, {ty}, {}}; + else if (get(ty)) + return {ctx->builtins->neverType, false, {}, {}}; + } + + // fold over the types with `simplifyIntersection` + TypeId resultTy = ctx->builtins->unknownType; + for (auto ty : types) + { + SimplifyResult result = simplifyIntersection(ctx->builtins, ctx->arena, resultTy, ty); + if (!result.blockedTypes.empty()) + return {std::nullopt, false, {result.blockedTypes.begin(), result.blockedTypes.end()}, {}}; + + resultTy = result.result; + } + + // if the intersection simplifies to `never`, this gives us bad autocomplete. + // we'll just produce the intersection plainly instead, but this might be revisitable + // if we ever give `never` some kind of "explanation" trail. + if (get(resultTy)) + { + TypeId intersection = ctx->arena->addType(IntersectionType{typeParams}); + return {intersection, false, {}, {}}; + } + + return {resultTy, false, {}, {}}; +} + +// computes the keys of `ty` into `result` +// `isRaw` parameter indicates whether or not we should follow __index metamethods +// returns `false` if `result` should be ignored because the answer is "all strings" +bool computeKeysOf(TypeId ty, Set& result, DenseHashSet& seen, bool isRaw, NotNull ctx) +{ + // if the type is the top table type, the answer is just "all strings" + if (get(ty)) + return false; + + // if we've already seen this type, we can do nothing + if (seen.contains(ty)) + return true; + seen.insert(ty); + + // if we have a particular table type, we can insert the keys + if (auto tableTy = get(ty)) + { + if (tableTy->indexer) + { + // if we have a string indexer, the answer is, again, "all strings" + if (isString(tableTy->indexer->indexType)) + return false; + } + + for (auto [key, _] : tableTy->props) + result.insert(key); + return true; + } + + // otherwise, we have a metatable to deal with + if (auto metatableTy = get(ty)) + { + bool res = true; + + if (!isRaw) + { + // findMetatableEntry demands the ability to emit errors, so we must give it + // the necessary state to do that, even if we intend to just eat the errors. + ErrorVec dummy; + + std::optional mmType = findMetatableEntry(ctx->builtins, dummy, ty, "__index", Location{}); + if (mmType) + res = res && computeKeysOf(*mmType, result, seen, isRaw, ctx); + } + + res = res && computeKeysOf(metatableTy->table, result, seen, isRaw, ctx); + + return res; + } + + // this should not be reachable since the type should be a valid tables part from normalization. + LUAU_ASSERT(false); + return false; +} + +TypeFunctionReductionResult keyofFunctionImpl( + const std::vector& typeParams, + const std::vector& packParams, + NotNull ctx, + bool isRaw +) +{ + if (typeParams.size() != 1 || !packParams.empty()) + { + ctx->ice->ice("keyof type function: encountered a type function instance without the required argument structure"); + LUAU_ASSERT(false); + } + + TypeId operandTy = follow(typeParams.at(0)); + + std::shared_ptr normTy = ctx->normalizer->normalize(operandTy); + + // if the operand failed to normalize, we can't reduce, but know nothing about inhabitance. + if (!normTy) + return {std::nullopt, false, {}, {}}; + + // if we don't have either just tables or just classes, we've got nothing to get keys of (at least until a future version perhaps adds classes + // as well) + if (normTy->hasTables() == normTy->hasClasses()) + return {std::nullopt, true, {}, {}}; + + // this is sort of atrocious, but we're trying to reject any type that has not normalized to a table or a union of tables. + if (normTy->hasTops() || normTy->hasBooleans() || normTy->hasErrors() || normTy->hasNils() || normTy->hasNumbers() || normTy->hasStrings() || + normTy->hasThreads() || normTy->hasBuffers() || normTy->hasFunctions() || normTy->hasTyvars()) + return {std::nullopt, true, {}, {}}; + + // we're going to collect the keys in here + Set keys{{}}; + + // computing the keys for classes + if (normTy->hasClasses()) + { + LUAU_ASSERT(!normTy->hasTables()); + + auto classesIter = normTy->classes.ordering.begin(); + auto classesIterEnd = normTy->classes.ordering.end(); + LUAU_ASSERT(classesIter != classesIterEnd); // should be guaranteed by the `hasClasses` check + + auto classTy = get(*classesIter); + if (!classTy) + { + LUAU_ASSERT(false); // this should not be possible according to normalization's spec + return {std::nullopt, true, {}, {}}; + } + + for (auto [key, _] : classTy->props) + keys.insert(key); + + // we need to look at each class to remove any keys that are not common amongst them all + while (++classesIter != classesIterEnd) + { + auto classTy = get(*classesIter); + if (!classTy) + { + LUAU_ASSERT(false); // this should not be possible according to normalization's spec + return {std::nullopt, true, {}, {}}; + } + + for (auto key : keys) + { + // remove any keys that are not present in each class + if (classTy->props.find(key) == classTy->props.end()) + keys.erase(key); + } + } + } + + // computing the keys for tables + if (normTy->hasTables()) + { + LUAU_ASSERT(!normTy->hasClasses()); + + // seen set for key computation for tables + DenseHashSet seen{{}}; + + auto tablesIter = normTy->tables.begin(); + LUAU_ASSERT(tablesIter != normTy->tables.end()); // should be guaranteed by the `hasTables` check earlier + + // collect all the properties from the first table type + if (!computeKeysOf(*tablesIter, keys, seen, isRaw, ctx)) + return {ctx->builtins->stringType, false, {}, {}}; // if it failed, we have the top table type! + + // we need to look at each tables to remove any keys that are not common amongst them all + while (++tablesIter != normTy->tables.end()) + { + seen.clear(); // we'll reuse the same seen set + + Set localKeys{{}}; + + // we can skip to the next table if this one is the top table type + if (!computeKeysOf(*tablesIter, localKeys, seen, isRaw, ctx)) + continue; + + for (auto key : keys) + { + // remove any keys that are not present in each table + if (!localKeys.contains(key)) + keys.erase(key); + } + } + } + + // if the set of keys is empty, `keyof` is `never` + if (keys.empty()) + return {ctx->builtins->neverType, false, {}, {}}; + + // everything is validated, we need only construct our big union of singletons now! + std::vector singletons; + singletons.reserve(keys.size()); + + for (std::string key : keys) + singletons.push_back(ctx->arena->addType(SingletonType{StringSingleton{key}})); + + return {ctx->arena->addType(UnionType{singletons}), false, {}, {}}; +} + +TypeFunctionReductionResult keyofTypeFunction( + TypeId instance, + const std::vector& typeParams, + const std::vector& packParams, + NotNull ctx +) +{ + if (typeParams.size() != 1 || !packParams.empty()) + { + ctx->ice->ice("keyof type function: encountered a type function instance without the required argument structure"); + LUAU_ASSERT(false); + } + + return keyofFunctionImpl(typeParams, packParams, ctx, /* isRaw */ false); +} + +TypeFunctionReductionResult rawkeyofTypeFunction( + TypeId instance, + const std::vector& typeParams, + const std::vector& packParams, + NotNull ctx +) +{ + if (typeParams.size() != 1 || !packParams.empty()) + { + ctx->ice->ice("rawkeyof type function: encountered a type function instance without the required argument structure"); + LUAU_ASSERT(false); + } + + return keyofFunctionImpl(typeParams, packParams, ctx, /* isRaw */ true); +} + +/* Searches through table's or class's props/indexer to find the property of `ty` + If found, appends that property to `result` and returns true + Else, returns false */ +bool searchPropsAndIndexer( + TypeId ty, + TableType::Props tblProps, + std::optional tblIndexer, + DenseHashSet& result, + NotNull ctx +) +{ + ty = follow(ty); + + // index into tbl's properties + if (auto stringSingleton = get(get(ty))) + { + if (tblProps.find(stringSingleton->value) != tblProps.end()) + { + TypeId propTy = follow(tblProps.at(stringSingleton->value).type()); + + // property is a union type -> we need to extend our reduction type + if (auto propUnionTy = get(propTy)) + { + for (TypeId option : propUnionTy->options) + result.insert(option); + } + else // property is a singular type or intersection type -> we can simply append + result.insert(propTy); + + return true; + } + } + + // index into tbl's indexer + if (tblIndexer) + { + if (isSubtype(ty, tblIndexer->indexType, ctx->scope, ctx->builtins, *ctx->ice)) + { + TypeId idxResultTy = follow(tblIndexer->indexResultType); + + // indexResultType is a union type -> we need to extend our reduction type + if (auto idxResUnionTy = get(idxResultTy)) + { + for (TypeId option : idxResUnionTy->options) + result.insert(option); + } + else // indexResultType is a singular type or intersection type -> we can simply append + result.insert(idxResultTy); + + return true; + } + } + + return false; +} + +/* Handles recursion / metamethods of tables/classes + `isRaw` parameter indicates whether or not we should follow __index metamethods + returns false if property of `ty` could not be found */ +bool tblIndexInto(TypeId indexer, TypeId indexee, DenseHashSet& result, NotNull ctx, bool isRaw) +{ + indexer = follow(indexer); + indexee = follow(indexee); + + // we have a table type to try indexing + if (auto tableTy = get(indexee)) + { + return searchPropsAndIndexer(indexer, tableTy->props, tableTy->indexer, result, ctx); + } + + // we have a metatable type to try indexing + if (auto metatableTy = get(indexee)) + { + if (auto tableTy = get(metatableTy->table)) + { + + // try finding all properties within the current scope of the table + if (searchPropsAndIndexer(indexer, tableTy->props, tableTy->indexer, result, ctx)) + return true; + } + + // if the code reached here, it means we weren't able to find all properties -> look into __index metamethod + if (!isRaw) + { + // findMetatableEntry demands the ability to emit errors, so we must give it + // the necessary state to do that, even if we intend to just eat the errors. + ErrorVec dummy; + std::optional mmType = findMetatableEntry(ctx->builtins, dummy, indexee, "__index", Location{}); + if (mmType) + return tblIndexInto(indexer, *mmType, result, ctx, isRaw); + } + } + + return false; +} + +/* Vocabulary note: indexee refers to the type that contains the properties, + indexer refers to the type that is used to access indexee + Example: index => `Person` is the indexee and `"name"` is the indexer */ +TypeFunctionReductionResult indexFunctionImpl( + const std::vector& typeParams, + const std::vector& packParams, + NotNull ctx, + bool isRaw +) +{ + TypeId indexeeTy = follow(typeParams.at(0)); + std::shared_ptr indexeeNormTy = ctx->normalizer->normalize(indexeeTy); + + // if the indexee failed to normalize, we can't reduce, but know nothing about inhabitance. + if (!indexeeNormTy) + return {std::nullopt, false, {}, {}}; + + // if we don't have either just tables or just classes, we've got nothing to index into + if (indexeeNormTy->hasTables() == indexeeNormTy->hasClasses()) + return {std::nullopt, true, {}, {}}; + + // we're trying to reject any type that has not normalized to a table/class or a union of tables/classes. + if (indexeeNormTy->hasTops() || indexeeNormTy->hasBooleans() || indexeeNormTy->hasErrors() || indexeeNormTy->hasNils() || + indexeeNormTy->hasNumbers() || indexeeNormTy->hasStrings() || indexeeNormTy->hasThreads() || indexeeNormTy->hasBuffers() || + indexeeNormTy->hasFunctions() || indexeeNormTy->hasTyvars()) + return {std::nullopt, true, {}, {}}; + + TypeId indexerTy = follow(typeParams.at(1)); + std::shared_ptr indexerNormTy = ctx->normalizer->normalize(indexerTy); + + // if the indexer failed to normalize, we can't reduce, but know nothing about inhabitance. + if (!indexerNormTy) + return {std::nullopt, false, {}, {}}; + + // we're trying to reject any type that is not a string singleton or primitive (string, number, boolean, thread, nil, function, table, or buffer) + if (indexerNormTy->hasTops() || indexerNormTy->hasErrors()) + return {std::nullopt, true, {}, {}}; + + // indexer can be a union —> break them down into a vector + const std::vector* typesToFind; + const std::vector singleType{indexerTy}; + if (auto unionTy = get(indexerTy)) + typesToFind = &unionTy->options; + else + typesToFind = &singleType; + + DenseHashSet properties{{}}; // vector of types that will be returned + + if (indexeeNormTy->hasClasses()) + { + LUAU_ASSERT(!indexeeNormTy->hasTables()); + + if (isRaw) // rawget should never reduce for classes (to match the behavior of the rawget global function) + return {std::nullopt, true, {}, {}}; + + // at least one class is guaranteed to be in the iterator by .hasClasses() + for (auto classesIter = indexeeNormTy->classes.ordering.begin(); classesIter != indexeeNormTy->classes.ordering.end(); ++classesIter) + { + auto classTy = get(*classesIter); + if (!classTy) + { + LUAU_ASSERT(false); // this should not be possible according to normalization's spec + return {std::nullopt, true, {}, {}}; + } + + for (TypeId ty : *typesToFind) + { + // Search for all instances of indexer in class->props and class->indexer + if (searchPropsAndIndexer(ty, classTy->props, classTy->indexer, properties, ctx)) + continue; // Indexer was found in this class, so we can move on to the next + + // If code reaches here,that means the property not found -> check in the metatable's __index + + // findMetatableEntry demands the ability to emit errors, so we must give it + // the necessary state to do that, even if we intend to just eat the errors. + ErrorVec dummy; + std::optional mmType = findMetatableEntry(ctx->builtins, dummy, *classesIter, "__index", Location{}); + if (!mmType) // if a metatable does not exist, there is no where else to look + return {std::nullopt, true, {}, {}}; + + if (!tblIndexInto(ty, *mmType, properties, ctx, isRaw)) // if indexer is not in the metatable, we fail to reduce + return {std::nullopt, true, {}, {}}; + } + } + } + + if (indexeeNormTy->hasTables()) + { + LUAU_ASSERT(!indexeeNormTy->hasClasses()); + + // at least one table is guaranteed to be in the iterator by .hasTables() + for (auto tablesIter = indexeeNormTy->tables.begin(); tablesIter != indexeeNormTy->tables.end(); ++tablesIter) + { + for (TypeId ty : *typesToFind) + if (!tblIndexInto(ty, *tablesIter, properties, ctx, isRaw)) + return {std::nullopt, true, {}, {}}; + } + } + + // Call `follow()` on each element to resolve all Bound types before returning + std::transform( + properties.begin(), + properties.end(), + properties.begin(), + [](TypeId ty) + { + return follow(ty); + } + ); + + // If the type being reduced to is a single type, no need to union + if (properties.size() == 1) + return {*properties.begin(), false, {}, {}}; + + return {ctx->arena->addType(UnionType{std::vector(properties.begin(), properties.end())}), false, {}, {}}; +} + +TypeFunctionReductionResult indexTypeFunction( + TypeId instance, + const std::vector& typeParams, + const std::vector& packParams, + NotNull ctx +) +{ + if (typeParams.size() != 2 || !packParams.empty()) + { + ctx->ice->ice("index type function: encountered a type function instance without the required argument structure"); + LUAU_ASSERT(false); + } + + return indexFunctionImpl(typeParams, packParams, ctx, /* isRaw */ false); +} + +TypeFunctionReductionResult rawgetTypeFunction( + TypeId instance, + const std::vector& typeParams, + const std::vector& packParams, + NotNull ctx +) +{ + if (typeParams.size() != 2 || !packParams.empty()) + { + ctx->ice->ice("rawget type function: encountered a type function instance without the required argument structure"); + LUAU_ASSERT(false); + } + + return indexFunctionImpl(typeParams, packParams, ctx, /* isRaw */ true); +} + +BuiltinTypeFunctions::BuiltinTypeFunctions() + : notFunc{"not", notTypeFunction} + , lenFunc{"len", lenTypeFunction} + , unmFunc{"unm", unmTypeFunction} + , addFunc{"add", addTypeFunction} + , subFunc{"sub", subTypeFunction} + , mulFunc{"mul", mulTypeFunction} + , divFunc{"div", divTypeFunction} + , idivFunc{"idiv", idivTypeFunction} + , powFunc{"pow", powTypeFunction} + , modFunc{"mod", modTypeFunction} + , concatFunc{"concat", concatTypeFunction} + , andFunc{"and", andTypeFunction} + , orFunc{"or", orTypeFunction} + , ltFunc{"lt", ltTypeFunction} + , leFunc{"le", leTypeFunction} + , eqFunc{"eq", eqTypeFunction} + , refineFunc{"refine", refineTypeFunction} + , singletonFunc{"singleton", singletonTypeFunction} + , unionFunc{"union", unionTypeFunction} + , intersectFunc{"intersect", intersectTypeFunction} + , keyofFunc{"keyof", keyofTypeFunction} + , rawkeyofFunc{"rawkeyof", rawkeyofTypeFunction} + , indexFunc{"index", indexTypeFunction} + , rawgetFunc{"rawget", rawgetTypeFunction} +{ +} + +void BuiltinTypeFunctions::addToScope(NotNull arena, NotNull scope) const +{ + // make a type function for a one-argument type function + auto mkUnaryTypeFunction = [&](const TypeFunction* tf) + { + TypeId t = arena->addType(GenericType{"T"}); + GenericTypeDefinition genericT{t}; + + return TypeFun{{genericT}, arena->addType(TypeFunctionInstanceType{NotNull{tf}, {t}, {}})}; + }; + + // make a type function for a two-argument type function + auto mkBinaryTypeFunction = [&](const TypeFunction* tf) + { + TypeId t = arena->addType(GenericType{"T"}); + TypeId u = arena->addType(GenericType{"U"}); + GenericTypeDefinition genericT{t}; + GenericTypeDefinition genericU{u, {t}}; + + return TypeFun{{genericT, genericU}, arena->addType(TypeFunctionInstanceType{NotNull{tf}, {t, u}, {}})}; + }; + + scope->exportedTypeBindings[lenFunc.name] = mkUnaryTypeFunction(&lenFunc); + scope->exportedTypeBindings[unmFunc.name] = mkUnaryTypeFunction(&unmFunc); + + scope->exportedTypeBindings[addFunc.name] = mkBinaryTypeFunction(&addFunc); + scope->exportedTypeBindings[subFunc.name] = mkBinaryTypeFunction(&subFunc); + scope->exportedTypeBindings[mulFunc.name] = mkBinaryTypeFunction(&mulFunc); + scope->exportedTypeBindings[divFunc.name] = mkBinaryTypeFunction(&divFunc); + scope->exportedTypeBindings[idivFunc.name] = mkBinaryTypeFunction(&idivFunc); + scope->exportedTypeBindings[powFunc.name] = mkBinaryTypeFunction(&powFunc); + scope->exportedTypeBindings[modFunc.name] = mkBinaryTypeFunction(&modFunc); + scope->exportedTypeBindings[concatFunc.name] = mkBinaryTypeFunction(&concatFunc); + + scope->exportedTypeBindings[ltFunc.name] = mkBinaryTypeFunction(<Func); + scope->exportedTypeBindings[leFunc.name] = mkBinaryTypeFunction(&leFunc); + scope->exportedTypeBindings[eqFunc.name] = mkBinaryTypeFunction(&eqFunc); + + scope->exportedTypeBindings[keyofFunc.name] = mkUnaryTypeFunction(&keyofFunc); + scope->exportedTypeBindings[rawkeyofFunc.name] = mkUnaryTypeFunction(&rawkeyofFunc); + + scope->exportedTypeBindings[indexFunc.name] = mkBinaryTypeFunction(&indexFunc); + scope->exportedTypeBindings[rawgetFunc.name] = mkBinaryTypeFunction(&rawgetFunc); +} + +const BuiltinTypeFunctions& builtinTypeFunctions() +{ + static std::unique_ptr result = std::make_unique(); + + return *result; +} + +} // namespace Luau diff --git a/third_party/luau/Analysis/src/TypeFunctionReductionGuesser.cpp b/third_party/luau/Analysis/src/TypeFunctionReductionGuesser.cpp new file mode 100644 index 00000000..d4a7c7c0 --- /dev/null +++ b/third_party/luau/Analysis/src/TypeFunctionReductionGuesser.cpp @@ -0,0 +1,454 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/TypeFunctionReductionGuesser.h" + +#include "Luau/DenseHash.h" +#include "Luau/Normalize.h" +#include "Luau/TypeFunction.h" +#include "Luau/Type.h" +#include "Luau/TypePack.h" +#include "Luau/TypeUtils.h" +#include "Luau/VecDeque.h" +#include "Luau/VisitType.h" + +#include +#include +#include + +namespace Luau +{ +struct InstanceCollector2 : TypeOnceVisitor +{ + VecDeque tys; + VecDeque tps; + DenseHashSet cyclicInstance{nullptr}; + DenseHashSet instanceArguments{nullptr}; + + bool visit(TypeId ty, const TypeFunctionInstanceType& it) override + { + // TypeOnceVisitor performs a depth-first traversal in the absence of + // cycles. This means that by pushing to the front of the queue, we will + // try to reduce deeper instances first if we start with the first thing + // in the queue. Consider Add, number>, number>: + // we want to reduce the innermost Add instantiation + // first. + tys.push_front(ty); + for (auto t : it.typeArguments) + instanceArguments.insert(follow(t)); + return true; + } + + void cycle(TypeId ty) override + { + /// Detected cyclic type pack + TypeId t = follow(ty); + if (get(t)) + cyclicInstance.insert(t); + } + + bool visit(TypeId ty, const ClassType&) override + { + return false; + } + + bool visit(TypePackId tp, const TypeFunctionInstanceTypePack&) override + { + // TypeOnceVisitor performs a depth-first traversal in the absence of + // cycles. This means that by pushing to the front of the queue, we will + // try to reduce deeper instances first if we start with the first thing + // in the queue. Consider Add, number>, number>: + // we want to reduce the innermost Add instantiation + // first. + tps.push_front(tp); + return true; + } +}; + + + +TypeFunctionReductionGuesser::TypeFunctionReductionGuesser(NotNull arena, NotNull builtins, NotNull normalizer) + : arena(arena) + , builtins(builtins) + , normalizer(normalizer) +{ +} + +bool TypeFunctionReductionGuesser::isFunctionGenericsSaturated(const FunctionType& ftv, DenseHashSet& argsUsed) +{ + bool sameSize = ftv.generics.size() == argsUsed.size(); + bool allGenericsAppear = true; + for (auto gt : ftv.generics) + allGenericsAppear = allGenericsAppear || argsUsed.contains(gt); + return sameSize && allGenericsAppear; +} + +void TypeFunctionReductionGuesser::dumpGuesses() +{ + for (auto [tf, t] : functionReducesTo) + printf("Type family %s ~~> %s\n", toString(tf).c_str(), toString(t).c_str()); + for (auto [t, t_] : substitutable) + printf("Substitute %s for %s\n", toString(t).c_str(), toString(t_).c_str()); +} + +std::optional TypeFunctionReductionGuesser::guess(TypeId typ) +{ + std::optional guessedType = guessType(typ); + + if (!guessedType.has_value()) + return {}; + + TypeId guess = follow(*guessedType); + if (get(guess)) + return {}; + + return guess; +} + +std::optional TypeFunctionReductionGuesser::guess(TypePackId tp) +{ + auto [head, tail] = flatten(tp); + + std::vector guessedHead; + guessedHead.reserve(head.size()); + + for (auto typ : head) + { + std::optional guessedType = guessType(typ); + + if (!guessedType.has_value()) + return {}; + + TypeId guess = follow(*guessedType); + if (get(guess)) + return {}; + + guessedHead.push_back(*guessedType); + } + + return arena->addTypePack(TypePack{guessedHead, tail}); +} + +TypeFunctionReductionGuessResult TypeFunctionReductionGuesser::guessTypeFunctionReductionForFunctionExpr( + const AstExprFunction& expr, + const FunctionType* ftv, + TypeId retTy +) +{ + InstanceCollector2 collector; + collector.traverse(retTy); + toInfer = std::move(collector.tys); + cyclicInstances = std::move(collector.cyclicInstance); + + if (isFunctionGenericsSaturated(*ftv, collector.instanceArguments)) + return TypeFunctionReductionGuessResult{{}, nullptr, false}; + infer(); + + std::vector> results; + std::vector args; + for (TypeId t : ftv->argTypes) + args.push_back(t); + + // Submit a guess for arg types + for (size_t i = 0; i < expr.args.size; i++) + { + TypeId argTy; + AstLocal* local = expr.args.data[i]; + if (i >= args.size()) + continue; + + argTy = args[i]; + std::optional guessedType = guessType(argTy); + if (!guessedType.has_value()) + continue; + TypeId guess = follow(*guessedType); + if (get(guess)) + continue; + + results.push_back({local->name.value, guess}); + } + + // Submit a guess for return types + TypeId recommendedAnnotation; + std::optional guessedReturnType = guessType(retTy); + if (!guessedReturnType.has_value()) + recommendedAnnotation = builtins->unknownType; + else + recommendedAnnotation = follow(*guessedReturnType); + if (auto t = get(recommendedAnnotation)) + recommendedAnnotation = builtins->unknownType; + + toInfer.clear(); + cyclicInstances.clear(); + functionReducesTo.clear(); + substitutable.clear(); + + return TypeFunctionReductionGuessResult{results, recommendedAnnotation}; +} + +std::optional TypeFunctionReductionGuesser::guessType(TypeId arg) +{ + TypeId t = follow(arg); + if (substitutable.contains(t)) + { + TypeId subst = follow(substitutable[t]); + if (subst == t || substitutable.contains(subst)) + return subst; + else if (!get(subst)) + return subst; + else + return guessType(subst); + } + if (get(t)) + { + if (functionReducesTo.contains(t)) + return functionReducesTo[t]; + } + return {}; +} + +bool TypeFunctionReductionGuesser::isNumericBinopFunction(const TypeFunctionInstanceType& instance) +{ + return instance.function->name == "add" || instance.function->name == "sub" || instance.function->name == "mul" || + instance.function->name == "div" || instance.function->name == "idiv" || instance.function->name == "pow" || + instance.function->name == "mod"; +} + +bool TypeFunctionReductionGuesser::isComparisonFunction(const TypeFunctionInstanceType& instance) +{ + return instance.function->name == "lt" || instance.function->name == "le" || instance.function->name == "eq"; +} + +bool TypeFunctionReductionGuesser::isOrAndFunction(const TypeFunctionInstanceType& instance) +{ + return instance.function->name == "or" || instance.function->name == "and"; +} + +bool TypeFunctionReductionGuesser::isNotFunction(const TypeFunctionInstanceType& instance) +{ + return instance.function->name == "not"; +} + +bool TypeFunctionReductionGuesser::isLenFunction(const TypeFunctionInstanceType& instance) +{ + return instance.function->name == "len"; +} + +bool TypeFunctionReductionGuesser::isUnaryMinus(const TypeFunctionInstanceType& instance) +{ + return instance.function->name == "unm"; +} + +// Operand is assignable if it looks like a cyclic function instance, or a generic type +bool TypeFunctionReductionGuesser::operandIsAssignable(TypeId ty) +{ + if (get(ty)) + return true; + if (get(ty)) + return true; + if (cyclicInstances.contains(ty)) + return true; + return false; +} + +std::shared_ptr TypeFunctionReductionGuesser::normalize(TypeId ty) +{ + return normalizer->normalize(ty); +} + + +std::optional TypeFunctionReductionGuesser::tryAssignOperandType(TypeId ty) +{ + // Because we collect innermost instances first, if we see a type function instance as an operand, + // We try to check if we guessed a type for it + if (auto tfit = get(ty)) + { + if (functionReducesTo.contains(ty)) + return {functionReducesTo[ty]}; + } + + // If ty is a generic, we need to check if we inferred a substitution + if (auto gt = get(ty)) + { + if (substitutable.contains(ty)) + return {substitutable[ty]}; + } + + // If we cannot substitute a type for this value, we return an empty optional + return {}; +} + +void TypeFunctionReductionGuesser::step() +{ + TypeId t = toInfer.front(); + toInfer.pop_front(); + t = follow(t); + if (auto tf = get(t)) + inferTypeFunctionSubstitutions(t, tf); +} + +void TypeFunctionReductionGuesser::infer() +{ + while (!done()) + step(); +} + +bool TypeFunctionReductionGuesser::done() +{ + return toInfer.empty(); +} + +void TypeFunctionReductionGuesser::inferTypeFunctionSubstitutions(TypeId ty, const TypeFunctionInstanceType* instance) +{ + + TypeFunctionInferenceResult result; + LUAU_ASSERT(instance); + // TODO: Make an inexhaustive version of this warn in the compiler? + if (isNumericBinopFunction(*instance)) + result = inferNumericBinopFunction(instance); + else if (isComparisonFunction(*instance)) + result = inferComparisonFunction(instance); + else if (isOrAndFunction(*instance)) + result = inferOrAndFunction(instance); + else if (isNotFunction(*instance)) + result = inferNotFunction(instance); + else if (isLenFunction(*instance)) + result = inferLenFunction(instance); + else if (isUnaryMinus(*instance)) + result = inferUnaryMinusFunction(instance); + else + result = {{}, builtins->unknownType}; + + TypeId resultInference = follow(result.functionResultInference); + if (!functionReducesTo.contains(resultInference)) + functionReducesTo[ty] = resultInference; + + for (size_t i = 0; i < instance->typeArguments.size(); i++) + { + if (i < result.operandInference.size()) + { + TypeId arg = follow(instance->typeArguments[i]); + TypeId inference = follow(result.operandInference[i]); + if (auto tfit = get(arg)) + { + if (!functionReducesTo.contains(arg)) + functionReducesTo.try_insert(arg, inference); + } + else if (auto gt = get(arg)) + substitutable[arg] = inference; + } + } +} + +TypeFunctionInferenceResult TypeFunctionReductionGuesser::inferNumericBinopFunction(const TypeFunctionInstanceType* instance) +{ + LUAU_ASSERT(instance->typeArguments.size() == 2); + TypeFunctionInferenceResult defaultNumericBinopInference{{builtins->numberType, builtins->numberType}, builtins->numberType}; + return defaultNumericBinopInference; +} + +TypeFunctionInferenceResult TypeFunctionReductionGuesser::inferComparisonFunction(const TypeFunctionInstanceType* instance) +{ + LUAU_ASSERT(instance->typeArguments.size() == 2); + // Comparison functions are lt/le/eq. + // Heuristic: these are type functions from t -> t -> bool + + TypeId lhsTy = follow(instance->typeArguments[0]); + TypeId rhsTy = follow(instance->typeArguments[1]); + + auto comparisonInference = [&](TypeId op) -> TypeFunctionInferenceResult + { + return TypeFunctionInferenceResult{{op, op}, builtins->booleanType}; + }; + + if (std::optional ty = tryAssignOperandType(lhsTy)) + lhsTy = follow(*ty); + if (std::optional ty = tryAssignOperandType(rhsTy)) + rhsTy = follow(*ty); + if (operandIsAssignable(lhsTy) && !operandIsAssignable(rhsTy)) + return comparisonInference(rhsTy); + if (operandIsAssignable(rhsTy) && !operandIsAssignable(lhsTy)) + return comparisonInference(lhsTy); + return comparisonInference(builtins->numberType); +} + +TypeFunctionInferenceResult TypeFunctionReductionGuesser::inferOrAndFunction(const TypeFunctionInstanceType* instance) +{ + + LUAU_ASSERT(instance->typeArguments.size() == 2); + + TypeId lhsTy = follow(instance->typeArguments[0]); + TypeId rhsTy = follow(instance->typeArguments[1]); + + if (std::optional ty = tryAssignOperandType(lhsTy)) + lhsTy = follow(*ty); + if (std::optional ty = tryAssignOperandType(rhsTy)) + rhsTy = follow(*ty); + TypeFunctionInferenceResult defaultAndOrInference{{builtins->unknownType, builtins->unknownType}, builtins->booleanType}; + + std::shared_ptr lty = normalize(lhsTy); + std::shared_ptr rty = normalize(lhsTy); + bool lhsTruthy = lty ? lty->isTruthy() : false; + bool rhsTruthy = rty ? rty->isTruthy() : false; + // If at the end, we still don't have good substitutions, return the default type + if (instance->function->name == "or") + { + if (operandIsAssignable(lhsTy) && operandIsAssignable(rhsTy)) + return defaultAndOrInference; + if (operandIsAssignable(lhsTy)) + return TypeFunctionInferenceResult{{builtins->unknownType, rhsTy}, rhsTy}; + if (operandIsAssignable(rhsTy)) + return TypeFunctionInferenceResult{{lhsTy, builtins->unknownType}, lhsTy}; + if (lhsTruthy) + return {{lhsTy, rhsTy}, lhsTy}; + if (rhsTruthy) + return {{builtins->unknownType, rhsTy}, rhsTy}; + } + + if (instance->function->name == "and") + { + + if (operandIsAssignable(lhsTy) && operandIsAssignable(rhsTy)) + return defaultAndOrInference; + if (operandIsAssignable(lhsTy)) + return TypeFunctionInferenceResult{{}, rhsTy}; + if (operandIsAssignable(rhsTy)) + return TypeFunctionInferenceResult{{}, lhsTy}; + if (lhsTruthy) + return {{lhsTy, rhsTy}, rhsTy}; + else + return {{lhsTy, rhsTy}, lhsTy}; + } + + return defaultAndOrInference; +} + +TypeFunctionInferenceResult TypeFunctionReductionGuesser::inferNotFunction(const TypeFunctionInstanceType* instance) +{ + LUAU_ASSERT(instance->typeArguments.size() == 1); + TypeId opTy = follow(instance->typeArguments[0]); + if (std::optional ty = tryAssignOperandType(opTy)) + opTy = follow(*ty); + return {{opTy}, builtins->booleanType}; +} + +TypeFunctionInferenceResult TypeFunctionReductionGuesser::inferLenFunction(const TypeFunctionInstanceType* instance) +{ + LUAU_ASSERT(instance->typeArguments.size() == 1); + TypeId opTy = follow(instance->typeArguments[0]); + if (std::optional ty = tryAssignOperandType(opTy)) + opTy = follow(*ty); + return {{opTy}, builtins->numberType}; +} + +TypeFunctionInferenceResult TypeFunctionReductionGuesser::inferUnaryMinusFunction(const TypeFunctionInstanceType* instance) +{ + LUAU_ASSERT(instance->typeArguments.size() == 1); + TypeId opTy = follow(instance->typeArguments[0]); + if (std::optional ty = tryAssignOperandType(opTy)) + opTy = follow(*ty); + if (isNumber(opTy)) + return {{builtins->numberType}, builtins->numberType}; + return {{builtins->unknownType}, builtins->numberType}; +} + + +} // namespace Luau diff --git a/third_party/luau/Analysis/src/TypeInfer.cpp b/third_party/luau/Analysis/src/TypeInfer.cpp index 8f9e1851..7b7d6fae 100644 --- a/third_party/luau/Analysis/src/TypeInfer.cpp +++ b/third_party/luau/Analysis/src/TypeInfer.cpp @@ -2,12 +2,11 @@ #include "Luau/TypeInfer.h" #include "Luau/ApplyTypeFunction.h" -#include "Luau/Clone.h" +#include "Luau/Cancellation.h" #include "Luau/Common.h" #include "Luau/Instantiation.h" #include "Luau/ModuleResolver.h" #include "Luau/Normalize.h" -#include "Luau/Parser.h" #include "Luau/Quantify.h" #include "Luau/RecursionCounter.h" #include "Luau/Scope.h" @@ -18,7 +17,6 @@ #include "Luau/ToString.h" #include "Luau/Type.h" #include "Luau/TypePack.h" -#include "Luau/TypeReduction.h" #include "Luau/TypeUtils.h" #include "Luau/VisitType.h" @@ -33,15 +31,11 @@ LUAU_FASTINTVARIABLE(LuauCheckRecursionLimit, 300) LUAU_FASTINTVARIABLE(LuauVisitRecursionLimit, 500) LUAU_FASTFLAG(LuauKnowsTheDataModel3) LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false) -LUAU_FASTFLAGVARIABLE(DebugLuauSharedSelf, false) LUAU_FASTFLAG(LuauInstantiateInSubtyping) -LUAU_FASTFLAG(LuauNegatedClassTypes) -LUAU_FASTFLAGVARIABLE(LuauAllowIndexClassParameters, false) -LUAU_FASTFLAG(LuauUninhabitedSubAnything2) -LUAU_FASTFLAG(LuauOccursIsntAlwaysFailure) -LUAU_FASTFLAGVARIABLE(LuauTypecheckTypeguards, false) -LUAU_FASTFLAGVARIABLE(LuauTinyControlFlowAnalysis, false) -LUAU_FASTFLAG(LuauRequirePathTrueModuleName) +LUAU_FASTFLAGVARIABLE(LuauRemoveBadRelationalOperatorWarning, false) +LUAU_FASTFLAGVARIABLE(LuauOkWithIteratingOverTableProperties, false) +LUAU_FASTFLAGVARIABLE(LuauReusableSubstitutions, false) +LUAU_FASTFLAG(LuauDeclarationExtraPropData) namespace Luau { @@ -202,7 +196,8 @@ static bool isMetamethod(const Name& name) { return name == "__index" || name == "__newindex" || name == "__call" || name == "__concat" || name == "__unm" || name == "__add" || name == "__sub" || name == "__mul" || name == "__div" || name == "__mod" || name == "__pow" || name == "__tostring" || - name == "__metatable" || name == "__eq" || name == "__lt" || name == "__le" || name == "__mode" || name == "__iter" || name == "__len"; + name == "__metatable" || name == "__eq" || name == "__lt" || name == "__le" || name == "__mode" || name == "__iter" || name == "__len" || + name == "__idiv"; } size_t HashBoolNamePair::operator()(const std::pair& pair) const @@ -210,21 +205,6 @@ size_t HashBoolNamePair::operator()(const std::pair& pair) const return std::hash()(pair.first) ^ std::hash()(pair.second); } -GlobalTypes::GlobalTypes(NotNull builtinTypes) - : builtinTypes(builtinTypes) -{ - globalScope = std::make_shared(globalTypes.addTypePack(TypePackVar{FreeTypePack{TypeLevel{}}})); - - globalScope->addBuiltinTypeBinding("any", TypeFun{{}, builtinTypes->anyType}); - globalScope->addBuiltinTypeBinding("nil", TypeFun{{}, builtinTypes->nilType}); - globalScope->addBuiltinTypeBinding("number", TypeFun{{}, builtinTypes->numberType}); - globalScope->addBuiltinTypeBinding("string", TypeFun{{}, builtinTypes->stringType}); - globalScope->addBuiltinTypeBinding("boolean", TypeFun{{}, builtinTypes->booleanType}); - globalScope->addBuiltinTypeBinding("thread", TypeFun{{}, builtinTypes->threadType}); - globalScope->addBuiltinTypeBinding("unknown", TypeFun{{}, builtinTypes->unknownType}); - globalScope->addBuiltinTypeBinding("never", TypeFun{{}, builtinTypes->neverType}); -} - TypeChecker::TypeChecker(const ScopePtr& globalScope, ModuleResolver* resolver, NotNull builtinTypes, InternalErrorReporter* iceHandler) : globalScope(globalScope) , resolver(resolver) @@ -232,11 +212,13 @@ TypeChecker::TypeChecker(const ScopePtr& globalScope, ModuleResolver* resolver, , iceHandler(iceHandler) , unifierState(iceHandler) , normalizer(nullptr, builtinTypes, NotNull{&unifierState}) + , reusableInstantiation(TxnLog::empty(), nullptr, builtinTypes, {}, nullptr) , nilType(builtinTypes->nilType) , numberType(builtinTypes->numberType) , stringType(builtinTypes->stringType) , booleanType(builtinTypes->booleanType) , threadType(builtinTypes->threadType) + , bufferType(builtinTypes->bufferType) , anyType(builtinTypes->anyType) , unknownType(builtinTypes->unknownType) , neverType(builtinTypes->neverType) @@ -269,7 +251,8 @@ ModulePtr TypeChecker::checkWithoutRecursionCheck(const SourceModule& module, Mo currentModule.reset(new Module); currentModule->name = module.name; currentModule->humanReadableName = module.humanReadableName; - currentModule->reduction = std::make_unique(NotNull{¤tModule->internalTypes}, builtinTypes, NotNull{iceHandler}); + currentModule->internalTypes.owningModule = currentModule.get(); + currentModule->interfaceTypes.owningModule = currentModule.get(); currentModule->type = module.type; currentModule->allocator = module.allocator; currentModule->names = module.names; @@ -304,12 +287,9 @@ ModulePtr TypeChecker::checkWithoutRecursionCheck(const SourceModule& module, Mo { currentModule->timeout = true; } - - if (FFlag::DebugLuauSharedSelf) + catch (const UserCancelError&) { - for (auto& [ty, scope] : deferredQuantification) - Luau::quantify(ty, scope->level); - deferredQuantification.clear(); + currentModule->cancelled = true; } if (get(follow(moduleScope->returnType))) @@ -347,7 +327,9 @@ ModulePtr TypeChecker::checkWithoutRecursionCheck(const SourceModule& module, Mo ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStat& program) { if (finishTime && TimeTrace::getClock() > *finishTime) - throw TimeLimitError(iceHandler->moduleName); + throwTimeLimitError(); + if (cancellationToken && cancellationToken->requested()) + throwUserCancelError(); if (auto block = program.as()) return check(scope, *block); @@ -357,22 +339,18 @@ ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStat& program) return check(scope, *while_); else if (auto repeat = program.as()) return check(scope, *repeat); - else if (program.is() || program.is()) - { - // Nothing to do - return ControlFlow::None; - } + else if (program.is()) + return ControlFlow::Breaks; + else if (program.is()) + return ControlFlow::Continues; else if (auto return_ = program.as()) return check(scope, *return_); else if (auto expr = program.as()) { checkExprPack(scope, *expr->expr); - if (FFlag::LuauTinyControlFlowAnalysis) - { - if (auto call = expr->expr->as(); call && doesCallError(call)) - return ControlFlow::Throws; - } + if (auto call = expr->expr->as(); call && doesCallError(call)) + return ControlFlow::Throws; return ControlFlow::None; } @@ -392,6 +370,8 @@ ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStat& program) ice("Should not be calling two-argument check() on a function statement", program.location); else if (auto typealias = program.as()) return check(scope, *typealias); + else if (auto typefunction = program.as()) + return check(scope, *typefunction); else if (auto global = program.as()) { TypeId globalType = resolveType(scope, *global->type); @@ -530,7 +510,8 @@ ControlFlow TypeChecker::checkBlockWithoutRecursionCheck(const ScopePtr& scope, std::unordered_map> functionDecls; - auto checkBody = [&](AstStat* stat) { + auto checkBody = [&](AstStat* stat) + { if (auto fun = stat->as()) { LUAU_ASSERT(functionDecls.count(stat)); @@ -594,39 +575,15 @@ ControlFlow TypeChecker::checkBlockWithoutRecursionCheck(const ScopePtr& scope, } else if (auto fun = (*protoIter)->as()) { - std::optional selfType; + std::optional selfType; // TODO clip std::optional expectedType; - if (FFlag::DebugLuauSharedSelf) + if (!fun->func->self) { if (auto name = fun->name->as()) { - TypeId baseTy = checkExpr(scope, *name->expr).type; - tablify(baseTy); - - if (!fun->func->self) - expectedType = getIndexTypeFromType(scope, baseTy, name->index.value, name->indexLocation, /* addErrors= */ false); - else if (auto ttv = getMutableTableType(baseTy)) - { - if (!baseTy->persistent && ttv->state != TableState::Sealed && !ttv->selfTy) - { - ttv->selfTy = anyIfNonstrict(freshType(ttv->level)); - deferredQuantification.push_back({baseTy, scope}); - } - - selfType = ttv->selfTy; - } - } - } - else - { - if (!fun->func->self) - { - if (auto name = fun->name->as()) - { - TypeId exprTy = checkExpr(scope, *name->expr).type; - expectedType = getIndexTypeFromType(scope, exprTy, name->index.value, name->indexLocation, /* addErrors= */ false); - } + TypeId exprTy = checkExpr(scope, *name->expr).type; + expectedType = getIndexTypeFromType(scope, exprTy, name->index.value, name->indexLocation, /* addErrors= */ false); } } @@ -677,7 +634,7 @@ LUAU_NOINLINE void TypeChecker::checkBlockTypeAliases(const ScopePtr& scope, std { if (const auto& typealias = stat->as()) { - if (typealias->name == kParseNameError) + if (typealias->name == kParseNameError || typealias->name == "typeof") continue; auto& bindings = typealias->exported ? scope->exportedTypeBindings : scope->privateTypeBindings; @@ -749,39 +706,25 @@ ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatIf& statement ScopePtr thenScope = childScope(scope, statement.thenbody->location); resolve(result.predicates, thenScope, true); - if (FFlag::LuauTinyControlFlowAnalysis) - { - ScopePtr elseScope = childScope(scope, statement.elsebody ? statement.elsebody->location : statement.location); - resolve(result.predicates, elseScope, false); + ScopePtr elseScope = childScope(scope, statement.elsebody ? statement.elsebody->location : statement.location); + resolve(result.predicates, elseScope, false); - ControlFlow thencf = check(thenScope, *statement.thenbody); - ControlFlow elsecf = ControlFlow::None; - if (statement.elsebody) - elsecf = check(elseScope, *statement.elsebody); + ControlFlow thencf = check(thenScope, *statement.thenbody); + ControlFlow elsecf = ControlFlow::None; + if (statement.elsebody) + elsecf = check(elseScope, *statement.elsebody); - if (matches(thencf, ControlFlow::Returns | ControlFlow::Throws) && elsecf == ControlFlow::None) - scope->inheritRefinements(elseScope); - else if (thencf == ControlFlow::None && matches(elsecf, ControlFlow::Returns | ControlFlow::Throws)) - scope->inheritRefinements(thenScope); + if (thencf != ControlFlow::None && elsecf == ControlFlow::None) + scope->inheritRefinements(elseScope); + else if (thencf == ControlFlow::None && elsecf != ControlFlow::None) + scope->inheritRefinements(thenScope); - if (matches(thencf, ControlFlow::Returns | ControlFlow::Throws) && matches(elsecf, ControlFlow::Returns | ControlFlow::Throws)) - return ControlFlow::Returns; - else - return ControlFlow::None; - } + if (thencf == elsecf) + return thencf; + else if (matches(thencf, ControlFlow::Returns | ControlFlow::Throws) && matches(elsecf, ControlFlow::Returns | ControlFlow::Throws)) + return ControlFlow::Returns; else - { - check(thenScope, *statement.thenbody); - - if (statement.elsebody) - { - ScopePtr elseScope = childScope(scope, statement.elsebody->location); - resolve(result.predicates, elseScope, false); - check(elseScope, *statement.elsebody); - } - return ControlFlow::None; - } } template @@ -842,7 +785,7 @@ struct Demoter : Substitution bool ignoreChildren(TypeId ty) override { - if (FFlag::LuauClassTypeVarsInSubstitution && get(ty)) + if (get(ty)) return true; return false; @@ -913,12 +856,12 @@ ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatReturn& retur if (!errors.empty()) currentModule->getModuleScope()->returnType = addTypePack({anyType}); - return FFlag::LuauTinyControlFlowAnalysis ? ControlFlow::Returns : ControlFlow::None; + return ControlFlow::Returns; } unify(retPack, scope->returnType, scope, return_.location, CountMismatch::Context::Return); - return FFlag::LuauTinyControlFlowAnalysis ? ControlFlow::Returns : ControlFlow::None; + return ControlFlow::Returns; } template @@ -1188,6 +1131,16 @@ ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatLocal& local) { scope->importedTypeBindings[name] = module->exportedTypeBindings; scope->importedModules[name] = moduleInfo->name; + + // Imported types of requires that transitively refer to current module have to be replaced with 'any' + for (const auto& [location, path] : requireCycles) + { + if (!path.empty() && path.front() == moduleInfo->name) + { + for (auto& [name, tf] : scope->importedTypeBindings[name]) + tf = TypeFun{{}, {}, anyType}; + } + } } // In non-strict mode we force the module type on the variable, in strict mode it is already unified @@ -1332,10 +1285,10 @@ ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin) for (size_t i = 2; i < varTypes.size(); ++i) unify(nilType, varTypes[i], scope, forin.location); } - else if (isNonstrictMode()) + else if (isNonstrictMode() || FFlag::LuauOkWithIteratingOverTableProperties) { for (TypeId var : varTypes) - unify(anyType, var, scope, forin.location); + unify(unknownType, var, scope, forin.location); } else { @@ -1533,6 +1486,12 @@ ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& ty if (name == kParseNameError) return ControlFlow::None; + if (name == "typeof") + { + reportError(typealias.location, GenericError{"Type aliases cannot be named typeof"}); + return ControlFlow::None; + } + std::optional binding; if (auto it = scope->exportedTypeBindings.find(name); it != scope->exportedTypeBindings.end()) binding = it->second; @@ -1574,14 +1533,26 @@ ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& ty // Additionally, we can't modify types that come from other modules if (ttv->name || follow(ty)->owningArena != ¤tModule->internalTypes) { - bool sameTys = std::equal(ttv->instantiatedTypeParams.begin(), ttv->instantiatedTypeParams.end(), binding->typeParams.begin(), - binding->typeParams.end(), [](auto&& itp, auto&& tp) { + bool sameTys = std::equal( + ttv->instantiatedTypeParams.begin(), + ttv->instantiatedTypeParams.end(), + binding->typeParams.begin(), + binding->typeParams.end(), + [](auto&& itp, auto&& tp) + { return itp == tp.ty; - }); - bool sameTps = std::equal(ttv->instantiatedTypePackParams.begin(), ttv->instantiatedTypePackParams.end(), binding->typePackParams.begin(), - binding->typePackParams.end(), [](auto&& itpp, auto&& tpp) { + } + ); + bool sameTps = std::equal( + ttv->instantiatedTypePackParams.begin(), + ttv->instantiatedTypePackParams.end(), + binding->typePackParams.begin(), + binding->typePackParams.end(), + [](auto&& itpp, auto&& tpp) + { return itpp == tpp.tp; - }); + } + ); // Copy can be skipped if this is an identical alias if (!ttv->name || ttv->name != name || !sameTys || !sameTps) @@ -1623,13 +1594,6 @@ ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& ty TypeId& bindingType = bindingsMap[name].type; - if (!FFlag::LuauOccursIsntAlwaysFailure) - { - if (unify(ty, bindingType, aliasScope, typealias.location)) - bindingType = ty; - return ControlFlow::None; - } - unify(ty, bindingType, aliasScope, typealias.location); // It is possible for this unification to succeed but for @@ -1648,12 +1612,21 @@ ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& ty return ControlFlow::None; } +ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatTypeFunction& typefunction) +{ + reportError(TypeError{typefunction.location, GenericError{"This syntax is not supported"}}); + + return ControlFlow::None; +} + void TypeChecker::prototype(const ScopePtr& scope, const AstStatTypeAlias& typealias, int subLevel) { Name name = typealias.name.value; // If the alias is missing a name, we can't do anything with it. Ignore it. - if (name == kParseNameError) + // Also, typeof is not a valid type alias name. We will report an error for + // this in check() + if (name == kParseNameError || name == "typeof") return; std::optional binding; @@ -1701,7 +1674,7 @@ void TypeChecker::prototype(const ScopePtr& scope, const AstStatTypeAlias& typea void TypeChecker::prototype(const ScopePtr& scope, const AstStatDeclareClass& declaredClass) { - std::optional superTy = FFlag::LuauNegatedClassTypes ? std::make_optional(builtinTypes->classType) : std::nullopt; + std::optional superTy = std::make_optional(builtinTypes->classType); if (declaredClass.superName) { Name superName = Name(declaredClass.superName->value); @@ -1720,8 +1693,10 @@ void TypeChecker::prototype(const ScopePtr& scope, const AstStatDeclareClass& de if (!get(follow(*superTy))) { - reportError(declaredClass.location, - GenericError{format("Cannot use non-class type '%s' as a superclass of class '%s'", superName.c_str(), declaredClass.name.value)}); + reportError( + declaredClass.location, + GenericError{format("Cannot use non-class type '%s' as a superclass of class '%s'", superName.c_str(), declaredClass.name.value)} + ); incorrectClassDefinitions.insert(&declaredClass); return; } @@ -1729,7 +1704,7 @@ void TypeChecker::prototype(const ScopePtr& scope, const AstStatDeclareClass& de Name className(declaredClass.name.value); - TypeId classTy = addType(ClassType(className, {}, superTy, std::nullopt, {}, {}, currentModule->name)); + TypeId classTy = addType(ClassType(className, {}, superTy, std::nullopt, {}, {}, currentModule->name, declaredClass.location)); ClassType* ctv = getMutable(classTy); TypeId metaTy = addType(TableType{TableState::Sealed, scope->level}); @@ -1759,6 +1734,9 @@ ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatDeclareClass& if (!ctv->metatable) ice("No metatable for declared class"); + if (const auto& indexer = declaredClass.indexer) + ctv->indexer = TableIndexer(resolveType(scope, *indexer->indexType), resolveType(scope, *indexer->resultType)); + TableType* metatable = getMutable(*ctv->metatable); for (const AstDeclaredClassProp& prop : declaredClass.props) { @@ -1777,12 +1755,55 @@ ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatDeclareClass& ftv->argNames.insert(ftv->argNames.begin(), FunctionArgument{"self", {}}); ftv->argTypes = addTypePack(TypePack{{classTy}, ftv->argTypes}); ftv->hasSelf = true; + + if (FFlag::LuauDeclarationExtraPropData) + { + FunctionDefinition defn; + + defn.definitionModuleName = currentModule->name; + defn.definitionLocation = prop.location; + // No data is preserved for varargLocation + defn.originalNameLocation = prop.nameLocation; + + ftv->definition = defn; + } } } if (assignTo.count(propName) == 0) { - assignTo[propName] = {propTy}; + if (FFlag::LuauDeclarationExtraPropData) + assignTo[propName] = {propTy, /*deprecated*/ false, /*deprecatedSuggestion*/ "", prop.location}; + else + assignTo[propName] = {propTy}; + } + else if (FFlag::LuauDeclarationExtraPropData) + { + Luau::Property& prop = assignTo[propName]; + TypeId currentTy = prop.type(); + + // We special-case this logic to keep the intersection flat; otherwise we + // would create a ton of nested intersection types. + if (const IntersectionType* itv = get(currentTy)) + { + std::vector options = itv->parts; + options.push_back(propTy); + TypeId newItv = addType(IntersectionType{std::move(options)}); + + prop.readTy = newItv; + prop.writeTy = newItv; + } + else if (get(currentTy)) + { + TypeId intersection = addType(IntersectionType{{currentTy, propTy}}); + + prop.readTy = intersection; + prop.writeTy = intersection; + } + else + { + reportError(declaredClass.location, GenericError{format("Cannot overload non-function class member '%s'", propName.c_str())}); + } } else { @@ -1822,19 +1843,42 @@ ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatDeclareFuncti std::vector genericTys; genericTys.reserve(generics.size()); - std::transform(generics.begin(), generics.end(), std::back_inserter(genericTys), [](auto&& el) { - return el.ty; - }); + std::transform( + generics.begin(), + generics.end(), + std::back_inserter(genericTys), + [](auto&& el) + { + return el.ty; + } + ); std::vector genericTps; genericTps.reserve(genericPacks.size()); - std::transform(genericPacks.begin(), genericPacks.end(), std::back_inserter(genericTps), [](auto&& el) { - return el.tp; - }); + std::transform( + genericPacks.begin(), + genericPacks.end(), + std::back_inserter(genericTps), + [](auto&& el) + { + return el.tp; + } + ); TypePackId argPack = resolveTypePack(funScope, global.params); TypePackId retPack = resolveTypePack(funScope, global.retTypes); - TypeId fnType = addType(FunctionType{funScope->level, std::move(genericTys), std::move(genericTps), argPack, retPack}); + + FunctionDefinition defn; + + if (FFlag::LuauDeclarationExtraPropData) + { + defn.definitionModuleName = currentModule->name; + defn.definitionLocation = global.location; + defn.varargLocation = global.vararg ? std::make_optional(global.varargLocation) : std::nullopt; + defn.originalNameLocation = global.nameLocation; + } + + TypeId fnType = addType(FunctionType{funScope->level, std::move(genericTys), std::move(genericTps), argPack, retPack, defn}); FunctionType* ftv = getMutable(fnType); ftv->argNames.reserve(global.paramNames.size); @@ -2044,7 +2088,12 @@ std::optional TypeChecker::findMetatableEntry(TypeId type, std::string e } std::optional TypeChecker::getIndexTypeFromType( - const ScopePtr& scope, TypeId type, const Name& name, const Location& location, bool addErrors) + const ScopePtr& scope, + TypeId type, + const Name& name, + const Location& location, + bool addErrors +) { size_t errorCount = currentModule->errors.size(); @@ -2057,7 +2106,12 @@ std::optional TypeChecker::getIndexTypeFromType( } std::optional TypeChecker::getIndexTypeFromTypeImpl( - const ScopePtr& scope, TypeId type, const Name& name, const Location& location, bool addErrors) + const ScopePtr& scope, + TypeId type, + const Name& name, + const Location& location, + bool addErrors +) { type = follow(type); @@ -2105,6 +2159,20 @@ std::optional TypeChecker::getIndexTypeFromTypeImpl( const Property* prop = lookupClassProp(cls, name); if (prop) return prop->type(); + + if (auto indexer = cls->indexer) + { + // TODO: Property lookup should work with string singletons or unions thereof as the indexer key type. + ErrorVec errors = tryUnify(stringType, indexer->indexType, scope, location); + + if (errors.empty()) + return indexer->indexResultType; + + if (addErrors) + reportError(location, UnknownProperty{type, name}); + + return std::nullopt; + } } else if (const UnionType* utv = get(type)) { @@ -2242,7 +2310,11 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp } TypeId TypeChecker::checkExprTable( - const ScopePtr& scope, const AstExprTable& expr, const std::vector>& fieldTypes, std::optional expectedType) + const ScopePtr& scope, + const AstExprTable& expr, + const std::vector>& fieldTypes, + std::optional expectedType +) { TableType::Props props; std::optional indexer; @@ -2471,8 +2543,10 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp return WithPredicate{retType}; } - reportError(expr.location, - GenericError{format("Unary operator '%s' not supported by type '%s'", toString(expr.op).c_str(), toString(operandType).c_str())}); + reportError( + expr.location, + GenericError{format("Unary operator '%s' not supported by type '%s'", toString(expr.op).c_str(), toString(operandType).c_str())} + ); return WithPredicate{errorRecoveryType(scope)}; } @@ -2539,6 +2613,8 @@ std::string opToMetaTableEntry(const AstExprBinary::Op& op) return "__mul"; case AstExprBinary::Div: return "__div"; + case AstExprBinary::FloorDiv: + return "__idiv"; case AstExprBinary::Mod: return "__mod"; case AstExprBinary::Pow: @@ -2617,28 +2693,47 @@ static std::optional areEqComparable(NotNull arena, NotNull(t); }; if (isExempt(a) || isExempt(b)) return true; + NormalizationResult nr; + TypeId c = arena->addType(IntersectionType{{a, b}}); - const NormalizedType* n = normalizer->normalize(c); + std::shared_ptr n = normalizer->normalize(c); if (!n) return std::nullopt; - if (FFlag::LuauUninhabitedSubAnything2) - return normalizer->isInhabited(n); - else - return isInhabited_DEPRECATED(*n); + nr = normalizer->isInhabited(n.get()); + + switch (nr) + { + case NormalizationResult::HitLimits: + return std::nullopt; + case NormalizationResult::False: + return false; + case NormalizationResult::True: + return true; + } + + // n.b. msvc can never figure this stuff out. + LUAU_UNREACHABLE(); } TypeId TypeChecker::checkRelationalOperation( - const ScopePtr& scope, const AstExprBinary& expr, TypeId lhsType, TypeId rhsType, const PredicateVec& predicates) + const ScopePtr& scope, + const AstExprBinary& expr, + TypeId lhsType, + TypeId rhsType, + const PredicateVec& predicates +) { - auto stripNil = [this](TypeId ty, bool isOrOp = false) { + auto stripNil = [this](TypeId ty, bool isOrOp = false) + { ty = follow(ty); if (!isNonstrictMode() && !isOrOp) return ty; @@ -2719,7 +2814,8 @@ TypeId TypeChecker::checkRelationalOperation( if (!*eqTestResult) { reportError( - expr.location, GenericError{format("Type %s cannot be compared with %s", toString(lhsType).c_str(), toString(rhsType).c_str())}); + expr.location, GenericError{format("Type %s cannot be compared with %s", toString(lhsType).c_str(), toString(rhsType).c_str())} + ); return errorRecoveryType(booleanType); } } @@ -2743,10 +2839,34 @@ TypeId TypeChecker::checkRelationalOperation( { reportErrors(state.errors); - if (!isEquality && state.errors.empty() && (get(leftType) || isBoolean(leftType))) + if (FFlag::LuauRemoveBadRelationalOperatorWarning) + { + // The original version of this check also produced this error when we had a union type. + // However, the old solver does not readily have the ability to discern if the union is comparable. + // This is the case when the lhs is e.g. a union of singletons and the rhs is the combined type. + // The new solver has much more powerful logic for resolving relational operators, but for now, + // we need to be conservative in the old solver to deliver a reasonable developer experience. + if (!isEquality && state.errors.empty() && isBoolean(leftType)) + { + reportError( + expr.location, + GenericError{ + format("Type '%s' cannot be compared with relational operator %s", toString(leftType).c_str(), toString(expr.op).c_str()) + } + ); + } + } + else { - reportError(expr.location, GenericError{format("Type '%s' cannot be compared with relational operator %s", toString(leftType).c_str(), - toString(expr.op).c_str())}); + if (!isEquality && state.errors.empty() && (get(leftType) || isBoolean(leftType))) + { + reportError( + expr.location, + GenericError{ + format("Type '%s' cannot be compared with relational operator %s", toString(leftType).c_str(), toString(expr.op).c_str()) + } + ); + } } return booleanType; @@ -2754,8 +2874,9 @@ TypeId TypeChecker::checkRelationalOperation( std::string metamethodName = opToMetaTableEntry(expr.op); - std::optional leftMetatable = isString(lhsType) ? std::nullopt : getMetatable(follow(lhsType), builtinTypes); - std::optional rightMetatable = isString(rhsType) ? std::nullopt : getMetatable(follow(rhsType), builtinTypes); + std::optional stringNoMT = std::nullopt; // works around gcc false positive "maybe uninitialized" warnings + std::optional leftMetatable = isString(lhsType) ? stringNoMT : getMetatable(follow(lhsType), builtinTypes); + std::optional rightMetatable = isString(rhsType) ? stringNoMT : getMetatable(follow(rhsType), builtinTypes); if (leftMetatable != rightMetatable) { @@ -2793,8 +2914,14 @@ TypeId TypeChecker::checkRelationalOperation( if (!matches) { reportError( - expr.location, GenericError{format("Types %s and %s cannot be compared with %s because they do not have the same metatable", - toString(lhsType).c_str(), toString(rhsType).c_str(), toString(expr.op).c_str())}); + expr.location, + GenericError{format( + "Types %s and %s cannot be compared with %s because they do not have the same metatable", + toString(lhsType).c_str(), + toString(rhsType).c_str(), + toString(expr.op).c_str() + )} + ); return errorRecoveryType(booleanType); } } @@ -2825,7 +2952,8 @@ TypeId TypeChecker::checkRelationalOperation( TypeId actualFunctionType = addType(FunctionType(scope->level, addTypePack({lhsType, rhsType}), addTypePack({booleanType}))); state.tryUnify( - instantiate(scope, actualFunctionType, expr.location), instantiate(scope, *metamethod, expr.location), /*isFunctionCall*/ true); + instantiate(scope, actualFunctionType, expr.location), instantiate(scope, *metamethod, expr.location), /*isFunctionCall*/ true + ); state.log.commit(); @@ -2835,7 +2963,8 @@ TypeId TypeChecker::checkRelationalOperation( else if (needsMetamethod) { reportError( - expr.location, GenericError{format("Table %s does not offer metamethod %s", toString(lhsType).c_str(), metamethodName.c_str())}); + expr.location, GenericError{format("Table %s does not offer metamethod %s", toString(lhsType).c_str(), metamethodName.c_str())} + ); return errorRecoveryType(booleanType); } } @@ -2849,8 +2978,12 @@ TypeId TypeChecker::checkRelationalOperation( if (needsMetamethod) { - reportError(expr.location, GenericError{format("Type %s cannot be compared with %s because it has no metatable", - toString(lhsType).c_str(), toString(expr.op).c_str())}); + reportError( + expr.location, + GenericError{ + format("Type %s cannot be compared with %s because it has no metatable", toString(lhsType).c_str(), toString(expr.op).c_str()) + } + ); return errorRecoveryType(booleanType); } @@ -2920,7 +3053,12 @@ TypeId TypeChecker::checkRelationalOperation( } TypeId TypeChecker::checkBinaryOperation( - const ScopePtr& scope, const AstExprBinary& expr, TypeId lhsType, TypeId rhsType, const PredicateVec& predicates) + const ScopePtr& scope, + const AstExprBinary& expr, + TypeId lhsType, + TypeId rhsType, + const PredicateVec& predicates +) { switch (expr.op) { @@ -2971,7 +3109,8 @@ TypeId TypeChecker::checkBinaryOperation( if (typeCouldHaveMetatable(lhsType) || typeCouldHaveMetatable(rhsType)) { - auto checkMetatableCall = [this, &scope, &expr](TypeId fnt, TypeId lhst, TypeId rhst) -> TypeId { + auto checkMetatableCall = [this, &scope, &expr](TypeId fnt, TypeId lhst, TypeId rhst) -> TypeId + { TypeId actualFunctionType = instantiate(scope, fnt, expr.location); TypePackId arguments = addTypePack({lhst, rhst}); TypePackId retTypePack = freshTypePack(scope); @@ -3018,8 +3157,15 @@ TypeId TypeChecker::checkBinaryOperation( return checkMetatableCall(*fnt, rhsType, lhsType); } - reportError(expr.location, GenericError{format("Binary operator '%s' not supported by types '%s' and '%s'", toString(expr.op).c_str(), - toString(lhsType).c_str(), toString(rhsType).c_str())}); + reportError( + expr.location, + GenericError{format( + "Binary operator '%s' not supported by types '%s' and '%s'", + toString(expr.op).c_str(), + toString(lhsType).c_str(), + toString(rhsType).c_str() + )} + ); return errorRecoveryType(scope); } @@ -3034,6 +3180,7 @@ TypeId TypeChecker::checkBinaryOperation( case AstExprBinary::Sub: case AstExprBinary::Mul: case AstExprBinary::Div: + case AstExprBinary::FloorDiv: case AstExprBinary::Mod: case AstExprBinary::Pow: reportErrors(tryUnify(lhsType, numberType, scope, expr.left->location)); @@ -3074,22 +3221,13 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp } else if (expr.op == AstExprBinary::CompareEq || expr.op == AstExprBinary::CompareNe) { - if (!FFlag::LuauTypecheckTypeguards) - { - if (auto predicate = tryGetTypeGuardPredicate(expr)) - return {booleanType, {std::move(*predicate)}}; - } - // For these, passing expectedType is worse than simply forcing them, because their implementation // may inadvertently check if expectedTypes exist first and use it, instead of forceSingleton first. WithPredicate lhs = checkExpr(scope, *expr.left, std::nullopt, /*forceSingleton=*/true); WithPredicate rhs = checkExpr(scope, *expr.right, std::nullopt, /*forceSingleton=*/true); - if (FFlag::LuauTypecheckTypeguards) - { - if (auto predicate = tryGetTypeGuardPredicate(expr)) - return {booleanType, {std::move(*predicate)}}; - } + if (auto predicate = tryGetTypeGuardPredicate(expr)) + return {booleanType, {std::move(*predicate)}}; PredicateVec predicates; @@ -3296,14 +3434,24 @@ TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndex } else if (const ClassType* lhsClass = get(lhs)) { - const Property* prop = lookupClassProp(lhsClass, name); - if (!prop) + if (const Property* prop = lookupClassProp(lhsClass, name)) { - reportError(TypeError{expr.location, UnknownProperty{lhs, name}}); - return errorRecoveryType(scope); + return prop->type(); } - return prop->type(); + if (auto indexer = lhsClass->indexer) + { + Unifier state = mkUnifier(scope, expr.location); + state.tryUnify(stringType, indexer->indexType); + if (state.errors.empty()) + { + state.log.commit(); + return indexer->indexResultType; + } + } + + reportError(TypeError{expr.location, UnknownProperty{lhs, name}}); + return errorRecoveryType(scope); } else if (get(lhs)) { @@ -3345,17 +3493,46 @@ TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndex { if (const ClassType* exprClass = get(exprType)) { - const Property* prop = lookupClassProp(exprClass, value->value.data); - if (!prop) + if (const Property* prop = lookupClassProp(exprClass, value->value.data)) + { + return prop->type(); + } + + if (auto indexer = exprClass->indexer) + { + unify(stringType, indexer->indexType, scope, expr.index->location); + return indexer->indexResultType; + } + + reportError(TypeError{expr.location, UnknownProperty{exprType, value->value.data}}); + return errorRecoveryType(scope); + } + else if (get(exprType)) + { + Name name = std::string(value->value.data, value->value.size); + + if (std::optional ty = getIndexTypeFromType(scope, exprType, name, expr.location, /* addErrors= */ false)) + return *ty; + + // If intersection has a table part, report that it cannot be extended just as a sealed table + if (isTableIntersection(exprType)) { - reportError(TypeError{expr.location, UnknownProperty{exprType, value->value.data}}); + reportError(TypeError{expr.location, CannotExtendTable{exprType, CannotExtendTable::Property, name}}); return errorRecoveryType(scope); } - return prop->type(); } } - else if (FFlag::LuauAllowIndexClassParameters) + else { + if (const ClassType* exprClass = get(exprType)) + { + if (auto indexer = exprClass->indexer) + { + unify(indexType, indexer->indexType, scope, expr.index->location); + return indexer->indexResultType; + } + } + if (const ClassType* exprClass = get(exprType)) { if (isNonstrictMode()) @@ -3420,7 +3597,8 @@ TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndex // Primarily about detecting duplicates. TypeId TypeChecker::checkFunctionName(const ScopePtr& scope, AstExpr& funName, TypeLevel level) { - auto freshTy = [&]() { + auto freshTy = [&]() + { return freshType(level); }; @@ -3493,8 +3671,14 @@ TypeId TypeChecker::checkFunctionName(const ScopePtr& scope, AstExpr& funName, T // `(X) -> Y...`, but after typechecking the body, we cam unify `Y...` with `X` // to get type `(X) -> X`, then we quantify the free types to get the final // generic type `(a) -> a`. -std::pair TypeChecker::checkFunctionSignature(const ScopePtr& scope, int subLevel, const AstExprFunction& expr, - std::optional originalName, std::optional selfType, std::optional expectedType) +std::pair TypeChecker::checkFunctionSignature( + const ScopePtr& scope, + int subLevel, + const AstExprFunction& expr, + std::optional originalName, + std::optional selfType, + std::optional expectedType +) { ScopePtr funScope = childFunctionScope(scope, expr.location, subLevel); @@ -3587,25 +3771,11 @@ std::pair TypeChecker::checkFunctionSignature(const ScopePtr& funScope->returnType = retPack; - if (FFlag::DebugLuauSharedSelf) + if (expr.self) { - if (expr.self) - { - // TODO: generic self types: CLI-39906 - TypeId selfTy = anyIfNonstrict(selfType ? *selfType : freshType(funScope)); - funScope->bindings[expr.self] = {selfTy, expr.self->location}; - argTypes.push_back(selfTy); - } - } - else - { - if (expr.self) - { - // TODO: generic self types: CLI-39906 - TypeId selfType = anyIfNonstrict(freshType(funScope)); - funScope->bindings[expr.self] = {selfType, expr.self->location}; - argTypes.push_back(selfType); - } + TypeId selfType = anyIfNonstrict(freshType(funScope)); + funScope->bindings[expr.self] = {selfType, expr.self->location}; + argTypes.push_back(selfType); } // Prepare expected argument type iterators if we have an expected function type @@ -3794,8 +3964,14 @@ WithPredicate TypeChecker::checkExprPackHelper(const ScopePtr& scope } } -void TypeChecker::checkArgumentList(const ScopePtr& scope, const AstExpr& funName, Unifier& state, TypePackId argPack, TypePackId paramPack, - const std::vector& argLocations) +void TypeChecker::checkArgumentList( + const ScopePtr& scope, + const AstExpr& funName, + Unifier& state, + TypePackId argPack, + TypePackId paramPack, + const std::vector& argLocations +) { /* Important terminology refresher: * A function requires parameters. @@ -3807,7 +3983,8 @@ void TypeChecker::checkArgumentList(const ScopePtr& scope, const AstExpr& funNam size_t paramIndex = 0; - auto reportCountMismatchError = [&state, &argLocations, paramPack, argPack, &funName]() { + auto reportCountMismatchError = [&state, &argLocations, paramPack, argPack, &funName]() + { // For this case, we want the error span to cover every errant extra parameter Location location = state.location; if (!argLocations.empty()) @@ -3819,8 +3996,10 @@ void TypeChecker::checkArgumentList(const ScopePtr& scope, const AstExpr& funNam namePath = *path; auto [minParams, optMaxParams] = getParameterExtents(&state.log, paramPack); - state.reportError(TypeError{location, - CountMismatch{minParams, optMaxParams, std::distance(begin(argPack), end(argPack)), CountMismatch::Context::Arg, false, namePath}}); + state.reportError(TypeError{ + location, + CountMismatch{minParams, optMaxParams, std::distance(begin(argPack), end(argPack)), CountMismatch::Context::Arg, false, namePath} + }); }; while (true) @@ -3927,7 +4106,8 @@ void TypeChecker::checkArgumentList(const ScopePtr& scope, const AstExpr& funNam namePath = *path; state.reportError(TypeError{ - funName.location, CountMismatch{minParams, optMaxParams, paramIndex, CountMismatch::Context::Arg, isVariadic, namePath}}); + funName.location, CountMismatch{minParams, optMaxParams, paramIndex, CountMismatch::Context::Arg, isVariadic, namePath} + }); return; } ++paramIter; @@ -3967,7 +4147,9 @@ void TypeChecker::checkArgumentList(const ScopePtr& scope, const AstExpr& funNam if (argIndex < argLocations.size()) location = argLocations[argIndex]; - unify(*argIter, vtp->ty, scope, location); + state.location = location; + state.tryUnify(*argIter, vtp->ty); + ++argIter; ++argIndex; } @@ -4069,7 +4251,8 @@ WithPredicate TypeChecker::checkExprPackHelper(const ScopePtr& scope // We break this function up into a lambda here to limit our stack footprint. // The vectors used by this function aren't allocated until the lambda is actually called. - auto the_rest = [&]() -> WithPredicate { + auto the_rest = [&]() -> WithPredicate + { // checkExpr will log the pre-instantiated type of the function. // That's not nearly as interesting as the instantiated type, which will include details about how // generic functions are being instantiated for this particular callsite. @@ -4112,7 +4295,8 @@ WithPredicate TypeChecker::checkExprPackHelper(const ScopePtr& scope fn = follow(fn); if (auto ret = checkCallOverload( - scope, expr, fn, retPack, argPack, args, &argLocations, argListResult, overloadsThatMatchArgCount, overloadsThatDont, errors)) + scope, expr, fn, retPack, argPack, args, &argLocations, argListResult, overloadsThatMatchArgCount, overloadsThatDont, errors + )) return *ret; } @@ -4139,7 +4323,8 @@ std::vector> TypeChecker::getExpectedTypesForCall(const st { std::vector> expectedTypes; - auto assignOption = [this, &expectedTypes](size_t index, TypeId ty) { + auto assignOption = [this, &expectedTypes](size_t index, TypeId ty) + { if (index == expectedTypes.size()) { expectedTypes.push_back(ty); @@ -4198,9 +4383,19 @@ std::vector> TypeChecker::getExpectedTypesForCall(const st * If this was an optional, callers would have to pay the stack cost for the result. This is problematic * for functions that need to support recursion up to 600 levels deep. */ -std::unique_ptr> TypeChecker::checkCallOverload(const ScopePtr& scope, const AstExprCall& expr, TypeId fn, - TypePackId retPack, TypePackId argPack, TypePack* args, const std::vector* argLocations, const WithPredicate& argListResult, - std::vector& overloadsThatMatchArgCount, std::vector& overloadsThatDont, std::vector& errors) +std::unique_ptr> TypeChecker::checkCallOverload( + const ScopePtr& scope, + const AstExprCall& expr, + TypeId fn, + TypePackId retPack, + TypePackId argPack, + TypePack* args, + const std::vector* argLocations, + const WithPredicate& argListResult, + std::vector& overloadsThatMatchArgCount, + std::vector& overloadsThatDont, + std::vector& errors +) { LUAU_ASSERT(argLocations); @@ -4314,7 +4509,12 @@ std::unique_ptr> TypeChecker::checkCallOverload(const else overloadsThatDont.push_back(fn); - errors.emplace_back(std::move(state.errors), args->head, ftv); + errors.push_back(OverloadErrorEntry{ + std::move(state.log), + std::move(state.errors), + args->head, + ftv, + }); } else { @@ -4329,12 +4529,17 @@ std::unique_ptr> TypeChecker::checkCallOverload(const return nullptr; } -bool TypeChecker::handleSelfCallMismatch(const ScopePtr& scope, const AstExprCall& expr, TypePack* args, const std::vector& argLocations, - const std::vector& errors) +bool TypeChecker::handleSelfCallMismatch( + const ScopePtr& scope, + const AstExprCall& expr, + TypePack* args, + const std::vector& argLocations, + const std::vector& errors +) { // No overloads succeeded: Scan for one that would have worked had the user // used a.b() rather than a:b() or vice versa. - for (const auto& [_, argVec, ftv] : errors) + for (const auto& e : errors) { // Did you write foo:bar() when you should have written foo.bar()? if (expr.self) @@ -4345,7 +4550,7 @@ bool TypeChecker::handleSelfCallMismatch(const ScopePtr& scope, const AstExprCal TypePackId editedArgPack = addTypePack(TypePack{editedParamList}); Unifier editedState = mkUnifier(scope, expr.location); - checkArgumentList(scope, *expr.func, editedState, editedArgPack, ftv->argTypes, editedArgLocations); + checkArgumentList(scope, *expr.func, editedState, editedArgPack, e.fnTy->argTypes, editedArgLocations); if (editedState.errors.empty()) { @@ -4360,7 +4565,7 @@ bool TypeChecker::handleSelfCallMismatch(const ScopePtr& scope, const AstExprCal return true; } } - else if (ftv->hasSelf) + else if (e.fnTy->hasSelf) { // Did you write foo.bar() when you should have written foo:bar()? if (AstExprIndexName* indexName = expr.func->as()) @@ -4376,7 +4581,7 @@ bool TypeChecker::handleSelfCallMismatch(const ScopePtr& scope, const AstExprCal Unifier editedState = mkUnifier(scope, expr.location); - checkArgumentList(scope, *expr.func, editedState, editedArgPack, ftv->argTypes, editedArgLocations); + checkArgumentList(scope, *expr.func, editedState, editedArgPack, e.fnTy->argTypes, editedArgLocations); if (editedState.errors.empty()) { @@ -4397,13 +4602,22 @@ bool TypeChecker::handleSelfCallMismatch(const ScopePtr& scope, const AstExprCal return false; } -void TypeChecker::reportOverloadResolutionError(const ScopePtr& scope, const AstExprCall& expr, TypePackId retPack, TypePackId argPack, - const std::vector& argLocations, const std::vector& overloads, const std::vector& overloadsThatMatchArgCount, - const std::vector& errors) +void TypeChecker::reportOverloadResolutionError( + const ScopePtr& scope, + const AstExprCall& expr, + TypePackId retPack, + TypePackId argPack, + const std::vector& argLocations, + const std::vector& overloads, + const std::vector& overloadsThatMatchArgCount, + std::vector& errors +) { if (overloads.size() == 1) { - reportErrors(std::get<0>(errors.front())); + errors.front().log.commit(); + + reportErrors(errors.front().errors); return; } @@ -4424,12 +4638,20 @@ void TypeChecker::reportOverloadResolutionError(const ScopePtr& scope, const Ast const FunctionType* ftv = get(overload); - auto error = std::find_if(errors.begin(), errors.end(), [ftv](const OverloadErrorEntry& e) { - return ftv == std::get<2>(e); - }); + auto error = std::find_if( + errors.begin(), + errors.end(), + [ftv](const OverloadErrorEntry& e) + { + return ftv == e.fnTy; + } + ); LUAU_ASSERT(error != errors.end()); - reportErrors(std::get<0>(*error)); + + error->log.commit(); + + reportErrors(error->errors); // If only one overload matched, we don't need this error because we provided the previous errors. if (overloadsThatMatchArgCount.size() == 1) @@ -4470,14 +4692,21 @@ void TypeChecker::reportOverloadResolutionError(const ScopePtr& scope, const Ast return; } -WithPredicate TypeChecker::checkExprList(const ScopePtr& scope, const Location& location, const AstArray& exprs, - bool substituteFreeForNil, const std::vector& instantiateGenerics, const std::vector>& expectedTypes) +WithPredicate TypeChecker::checkExprList( + const ScopePtr& scope, + const Location& location, + const AstArray& exprs, + bool substituteFreeForNil, + const std::vector& instantiateGenerics, + const std::vector>& expectedTypes +) { bool uninhabitable = false; TypePackId pack = addTypePack(TypePack{}); PredicateVec predicates; // At the moment we will be pushing all predicate sets into this. Do we need some way to split them up? - auto insert = [&predicates](PredicateVec& vec) { + auto insert = [&predicates](PredicateVec& vec) + { for (Predicate& c : vec) predicates.push_back(std::move(c)); }; @@ -4603,7 +4832,7 @@ TypeId TypeChecker::checkRequire(const ScopePtr& scope, const ModuleInfo& module // Types of requires that transitively refer to current module have to be replaced with 'any' for (const auto& [location, path] : requireCycles) { - if (!path.empty() && path.front() == (FFlag::LuauRequirePathTrueModuleName ? moduleInfo.name : resolver->getHumanReadableModuleName(moduleInfo.name))) + if (!path.empty() && path.front() == moduleInfo.name) return anyType; } @@ -4744,20 +4973,10 @@ TypeId TypeChecker::quantify(const ScopePtr& scope, TypeId ty, Location location { ty = follow(ty); - if (FFlag::DebugLuauSharedSelf) - { - if (auto ftv = get(ty)) - Luau::quantify(ty, scope->level); - else if (auto ttv = getTableType(ty); ttv && ttv->selfTy) - Luau::quantify(ty, scope->level); - } - else - { - const FunctionType* ftv = get(ty); + const FunctionType* ftv = get(ty); - if (ftv) - Luau::quantify(ty, scope->level); - } + if (ftv) + Luau::quantify(ty, scope->level); return ty; } @@ -4767,15 +4986,30 @@ TypeId TypeChecker::instantiate(const ScopePtr& scope, TypeId ty, Location locat ty = follow(ty); const FunctionType* ftv = get(ty); - if (ftv && ftv->hasNoGenerics) + if (ftv && ftv->hasNoFreeOrGenericTypes) return ty; - Instantiation instantiation{log, ¤tModule->internalTypes, scope->level, /*scope*/ nullptr}; + std::optional instantiated; + + if (FFlag::LuauReusableSubstitutions) + { + reusableInstantiation.resetState(log, ¤tModule->internalTypes, builtinTypes, scope->level, /*scope*/ nullptr); - if (instantiationChildLimit) - instantiation.childLimit = *instantiationChildLimit; + if (instantiationChildLimit) + reusableInstantiation.childLimit = *instantiationChildLimit; + + instantiated = reusableInstantiation.substitute(ty); + } + else + { + Instantiation instantiation{log, ¤tModule->internalTypes, builtinTypes, scope->level, /*scope*/ nullptr}; + + if (instantiationChildLimit) + instantiation.childLimit = *instantiationChildLimit; + + instantiated = instantiation.substitute(ty); + } - std::optional instantiated = instantiation.substitute(ty); if (instantiated.has_value()) return *instantiated; else @@ -4862,24 +5096,40 @@ void TypeChecker::reportErrors(const ErrorVec& errors) reportError(err); } -void TypeChecker::ice(const std::string& message, const Location& location) +LUAU_NOINLINE void TypeChecker::ice(const std::string& message, const Location& location) { iceHandler->ice(message, location); } -void TypeChecker::ice(const std::string& message) +LUAU_NOINLINE void TypeChecker::ice(const std::string& message) { iceHandler->ice(message); } +LUAU_NOINLINE void TypeChecker::throwTimeLimitError() +{ + throw TimeLimitError(iceHandler->moduleName); +} + +LUAU_NOINLINE void TypeChecker::throwUserCancelError() +{ + throw UserCancelError(iceHandler->moduleName); +} + void TypeChecker::prepareErrorsForDisplay(ErrorVec& errVec) { // Remove errors with names that were generated by recovery from a parse error - errVec.erase(std::remove_if(errVec.begin(), errVec.end(), - [](auto& err) { - return containsParseErrorName(err); - }), - errVec.end()); + errVec.erase( + std::remove_if( + errVec.begin(), + errVec.end(), + [](auto& err) + { + return containsParseErrorName(err); + } + ), + errVec.end() + ); for (auto& err : errVec) { @@ -4893,7 +5143,8 @@ void TypeChecker::diagnoseMissingTableKey(UnknownProperty* utk, TypeErrorData& d std::string_view sv(utk->key); std::set candidates; - auto accumulate = [&](const TableType::Props& props) { + auto accumulate = [&](const TableType::Props& props) + { for (const auto& [name, ty] : props) { if (sv != name && equalsLower(sv, name)) @@ -4947,30 +5198,35 @@ ScopePtr TypeChecker::childScope(const ScopePtr& parent, const Location& locatio void TypeChecker::merge(RefinementMap& l, const RefinementMap& r) { - Luau::merge(l, r, [this](TypeId a, TypeId b) { - // TODO: normalize(UnionType{{a, b}}) - std::unordered_set set; + Luau::merge( + l, + r, + [this](TypeId a, TypeId b) + { + // TODO: normalize(UnionType{{a, b}}) + std::unordered_set set; - if (auto utv = get(follow(a))) - set.insert(begin(utv), end(utv)); - else - set.insert(a); + if (auto utv = get(follow(a))) + set.insert(begin(utv), end(utv)); + else + set.insert(a); - if (auto utv = get(follow(b))) - set.insert(begin(utv), end(utv)); - else - set.insert(b); + if (auto utv = get(follow(b))) + set.insert(begin(utv), end(utv)); + else + set.insert(b); - std::vector options(set.begin(), set.end()); - if (set.size() == 1) - return options[0]; - return addType(UnionType{std::move(options)}); - }); + std::vector options(set.begin(), set.end()); + if (set.size() == 1) + return options[0]; + return addType(UnionType{std::move(options)}); + } + ); } Unifier TypeChecker::mkUnifier(const ScopePtr& scope, const Location& location) { - return Unifier{NotNull{&normalizer}, currentModule->mode, NotNull{scope.get()}, location, Variance::Covariant}; + return Unifier{NotNull{&normalizer}, NotNull{scope.get()}, location, Variance::Covariant}; } TypeId TypeChecker::freshType(const ScopePtr& scope) @@ -5016,7 +5272,8 @@ TypePackId TypeChecker::errorRecoveryTypePack(TypePackId guess) TypeIdPredicate TypeChecker::mkTruthyPredicate(bool sense, TypeId emptySetTy) { - return [this, sense, emptySetTy](TypeId ty) -> std::optional { + return [this, sense, emptySetTy](TypeId ty) -> std::optional + { // any/error/free gets a special pass unconditionally because they can't be decided. if (get(ty) || get(ty) || get(ty)) return ty; @@ -5158,12 +5415,22 @@ TypeId TypeChecker::resolveTypeWorker(const ScopePtr& scope, const AstType& anno return tf->type; bool parameterCountErrorReported = false; - bool hasDefaultTypes = std::any_of(tf->typeParams.begin(), tf->typeParams.end(), [](auto&& el) { - return el.defaultValue.has_value(); - }); - bool hasDefaultPacks = std::any_of(tf->typePackParams.begin(), tf->typePackParams.end(), [](auto&& el) { - return el.defaultValue.has_value(); - }); + bool hasDefaultTypes = std::any_of( + tf->typeParams.begin(), + tf->typeParams.end(), + [](auto&& el) + { + return el.defaultValue.has_value(); + } + ); + bool hasDefaultPacks = std::any_of( + tf->typePackParams.begin(), + tf->typePackParams.end(), + [](auto&& el) + { + return el.defaultValue.has_value(); + } + ); if (!lit->hasParameterList) { @@ -5286,7 +5553,8 @@ TypeId TypeChecker::resolveTypeWorker(const ScopePtr& scope, const AstType& anno { if (!parameterCountErrorReported) reportError( - TypeError{annotation.location, IncorrectGenericParameterCount{lit->name.value, *tf, typeParams.size(), typePackParams.size()}}); + TypeError{annotation.location, IncorrectGenericParameterCount{lit->name.value, *tf, typeParams.size(), typePackParams.size()}} + ); // Pad the types out with error recovery types while (typeParams.size() < tf->typeParams.size()) @@ -5295,13 +5563,26 @@ TypeId TypeChecker::resolveTypeWorker(const ScopePtr& scope, const AstType& anno typePackParams.push_back(errorRecoveryTypePack(scope)); } - bool sameTys = std::equal(typeParams.begin(), typeParams.end(), tf->typeParams.begin(), tf->typeParams.end(), [](auto&& itp, auto&& tp) { - return itp == tp.ty; - }); + bool sameTys = std::equal( + typeParams.begin(), + typeParams.end(), + tf->typeParams.begin(), + tf->typeParams.end(), + [](auto&& itp, auto&& tp) + { + return itp == tp.ty; + } + ); bool sameTps = std::equal( - typePackParams.begin(), typePackParams.end(), tf->typePackParams.begin(), tf->typePackParams.end(), [](auto&& itpp, auto&& tpp) { + typePackParams.begin(), + typePackParams.end(), + tf->typePackParams.begin(), + tf->typePackParams.end(), + [](auto&& itpp, auto&& tpp) + { return itpp == tpp.tp; - }); + } + ); // If the generic parameters and the type arguments are the same, we are about to // perform an identity substitution, which we can just short-circuit. @@ -5316,10 +5597,28 @@ TypeId TypeChecker::resolveTypeWorker(const ScopePtr& scope, const AstType& anno std::optional tableIndexer; for (const auto& prop : table->props) - props[prop.name.value] = {resolveType(scope, *prop.type)}; + { + if (prop.access == AstTableAccess::Read) + reportError(prop.accessLocation.value_or(Location{}), GenericError{"read keyword is illegal here"}); + else if (prop.access == AstTableAccess::Write) + reportError(prop.accessLocation.value_or(Location{}), GenericError{"write keyword is illegal here"}); + else if (prop.access == AstTableAccess::ReadWrite) + props[prop.name.value] = {resolveType(scope, *prop.type), /* deprecated: */ false, {}, std::nullopt, {}, std::nullopt, prop.location}; + else + ice("Unexpected property access " + std::to_string(int(prop.access))); + } if (const auto& indexer = table->indexer) - tableIndexer = TableIndexer(resolveType(scope, *indexer->indexType), resolveType(scope, *indexer->resultType)); + { + if (indexer->access == AstTableAccess::Read) + reportError(indexer->accessLocation.value_or(Location{}), GenericError{"read keyword is illegal here"}); + else if (indexer->access == AstTableAccess::Write) + reportError(indexer->accessLocation.value_or(Location{}), GenericError{"write keyword is illegal here"}); + else if (indexer->access == AstTableAccess::ReadWrite) + tableIndexer = TableIndexer(resolveType(scope, *indexer->indexType), resolveType(scope, *indexer->resultType)); + else + ice("Unexpected property access " + std::to_string(int(indexer->access))); + } TableType ttv{props, tableIndexer, scope->level, TableState::Sealed}; ttv.definitionModuleName = currentModule->name; @@ -5338,15 +5637,27 @@ TypeId TypeChecker::resolveTypeWorker(const ScopePtr& scope, const AstType& anno std::vector genericTys; genericTys.reserve(generics.size()); - std::transform(generics.begin(), generics.end(), std::back_inserter(genericTys), [](auto&& el) { - return el.ty; - }); + std::transform( + generics.begin(), + generics.end(), + std::back_inserter(genericTys), + [](auto&& el) + { + return el.ty; + } + ); std::vector genericTps; genericTps.reserve(genericPacks.size()); - std::transform(genericPacks.begin(), genericPacks.end(), std::back_inserter(genericTps), [](auto&& el) { - return el.tp; - }); + std::transform( + genericPacks.begin(), + genericPacks.end(), + std::back_inserter(genericTps), + [](auto&& el) + { + return el.tp; + } + ); TypeId fnType = addType(FunctionType{funcScope->level, std::move(genericTys), std::move(genericTps), argTypes, retTypes}); @@ -5467,8 +5778,13 @@ TypePackId TypeChecker::resolveTypePack(const ScopePtr& scope, const AstTypePack return result; } -TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, const std::vector& typeParams, - const std::vector& typePackParams, const Location& location) +TypeId TypeChecker::instantiateTypeFun( + const ScopePtr& scope, + const TypeFun& tf, + const std::vector& typeParams, + const std::vector& typePackParams, + const Location& location +) { if (tf.typeParams.empty() && tf.typePackParams.empty()) return tf.type; @@ -5496,7 +5812,8 @@ TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, TypeId instantiated = *maybeInstantiated; TypeId target = follow(instantiated); - bool needsClone = follow(tf.type) == target; + const TableType* tfTable = getTableType(tf.type); + bool needsClone = follow(tf.type) == target || (tfTable != nullptr && tfTable == getTableType(target)); bool shouldMutate = getTableType(tf.type); TableType* ttv = getMutableTableType(target); @@ -5531,8 +5848,14 @@ TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, return instantiated; } -GenericTypeDefinitions TypeChecker::createGenericTypes(const ScopePtr& scope, std::optional levelOpt, const AstNode& node, - const AstArray& genericNames, const AstArray& genericPackNames, bool useCache) +GenericTypeDefinitions TypeChecker::createGenericTypes( + const ScopePtr& scope, + std::optional levelOpt, + const AstNode& node, + const AstArray& genericNames, + const AstArray& genericPackNames, + bool useCache +) { LUAU_ASSERT(scope->parent); @@ -5660,7 +5983,8 @@ void TypeChecker::refineLValue(const LValue& lvalue, RefinementMap& refis, const } } - auto intoType = [this](const std::unordered_set& s) -> std::optional { + auto intoType = [this](const std::unordered_set& s) -> std::optional + { if (s.empty()) return std::nullopt; @@ -5847,7 +6171,8 @@ void TypeChecker::resolve(const OrPredicate& orP, RefinementMap& refis, const Sc void TypeChecker::resolve(const IsAPredicate& isaP, RefinementMap& refis, const ScopePtr& scope, bool sense) { - auto predicate = [&](TypeId option) -> std::optional { + auto predicate = [&](TypeId option) -> std::optional + { // This by itself is not truly enough to determine that A is stronger than B or vice versa. bool optionIsSubtype = canUnify(option, isaP.ty, scope, isaP.location).empty(); bool targetIsSubtype = canUnify(isaP.ty, option, scope, isaP.location).empty(); @@ -5910,8 +6235,10 @@ void TypeChecker::resolve(const TypeGuardPredicate& typeguardP, RefinementMap& r return; } - auto refine = [this, &lvalue = typeguardP.lvalue, &refis, &scope, sense](bool(f)(TypeId), std::optional mapsTo = std::nullopt) { - TypeIdPredicate predicate = [f, mapsTo, sense](TypeId ty) -> std::optional { + auto refine = [this, &lvalue = typeguardP.lvalue, &refis, &scope, sense](bool(f)(TypeId), std::optional mapsTo = std::nullopt) + { + TypeIdPredicate predicate = [f, mapsTo, sense](TypeId ty) -> std::optional + { if (sense && get(ty)) return mapsTo.value_or(ty); @@ -5938,24 +6265,35 @@ void TypeChecker::resolve(const TypeGuardPredicate& typeguardP, RefinementMap& r return refine(isBoolean, booleanType); else if (typeguardP.kind == "thread") return refine(isThread, threadType); + else if (typeguardP.kind == "buffer") + return refine(isBuffer, bufferType); else if (typeguardP.kind == "table") { - return refine([](TypeId ty) -> bool { - return isTableIntersection(ty) || get(ty) || get(ty); - }); + return refine( + [](TypeId ty) -> bool + { + return isTableIntersection(ty) || get(ty) || get(ty); + } + ); } else if (typeguardP.kind == "function") { - return refine([](TypeId ty) -> bool { - return isOverloadedFunction(ty) || get(ty); - }); + return refine( + [](TypeId ty) -> bool + { + return isOverloadedFunction(ty) || get(ty); + } + ); } else if (typeguardP.kind == "userdata") { // For now, we don't really care about being accurate with userdata if the typeguard was using typeof. - return refine([](TypeId ty) -> bool { - return get(ty); - }); + return refine( + [](TypeId ty) -> bool + { + return get(ty); + } + ); } if (!typeguardP.isTypeof) @@ -5968,17 +6306,13 @@ void TypeChecker::resolve(const TypeGuardPredicate& typeguardP, RefinementMap& r TypeId type = follow(typeFun->type); // You cannot refine to the top class type. - if (FFlag::LuauNegatedClassTypes) + if (type == builtinTypes->classType) { - if (type == builtinTypes->classType) - { - return addRefinement(refis, typeguardP.lvalue, errorRecoveryType(scope)); - } + return addRefinement(refis, typeguardP.lvalue, errorRecoveryType(scope)); } // We're only interested in the root class of any classes. - if (auto ctv = get(type); - !ctv || (FFlag::LuauNegatedClassTypes ? (ctv->parent != builtinTypes->classType) : (ctv->parent != std::nullopt))) + if (auto ctv = get(type); !ctv || ctv->parent != builtinTypes->classType) return addRefinement(refis, typeguardP.lvalue, errorRecoveryType(scope)); // This probably hints at breaking out type filtering functions from the predicate solver so that typeof is not tightly coupled with IsA. @@ -5989,7 +6323,8 @@ void TypeChecker::resolve(const TypeGuardPredicate& typeguardP, RefinementMap& r void TypeChecker::resolve(const EqPredicate& eqP, RefinementMap& refis, const ScopePtr& scope, bool sense) { // This refinement will require success typing to do everything correctly. For now, we can get most of the way there. - auto options = [](TypeId ty) -> std::vector { + auto options = [](TypeId ty) -> std::vector + { if (auto utv = get(follow(ty))) return std::vector(begin(utv), end(utv)); return {ty}; @@ -6000,7 +6335,8 @@ void TypeChecker::resolve(const EqPredicate& eqP, RefinementMap& refis, const Sc if (sense && std::any_of(rhs.begin(), rhs.end(), isUndecidable)) return; // Optimization: the other side has unknown types, so there's probably an overlap. Refining is no-op here. - auto predicate = [&](TypeId option) -> std::optional { + auto predicate = [&](TypeId option) -> std::optional + { if (!sense && isNil(eqP.type)) return (isUndecidable(option) || !isNil(option)) ? std::optional(option) : std::nullopt; diff --git a/third_party/luau/Analysis/src/TypeOrPack.cpp b/third_party/luau/Analysis/src/TypeOrPack.cpp new file mode 100644 index 00000000..86652141 --- /dev/null +++ b/third_party/luau/Analysis/src/TypeOrPack.cpp @@ -0,0 +1,29 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/TypeOrPack.h" +#include "Luau/Common.h" + +namespace Luau +{ + +const void* ptr(TypeOrPack tyOrTp) +{ + if (auto ty = get(tyOrTp)) + return static_cast(*ty); + else if (auto tp = get(tyOrTp)) + return static_cast(*tp); + else + LUAU_UNREACHABLE(); +} + +TypeOrPack follow(TypeOrPack tyOrTp) +{ + if (auto ty = get(tyOrTp)) + return follow(*ty); + else if (auto tp = get(tyOrTp)) + return follow(*tp); + else + LUAU_UNREACHABLE(); +} + +} // namespace Luau diff --git a/third_party/luau/Analysis/src/TypePack.cpp b/third_party/luau/Analysis/src/TypePack.cpp index 6873820a..9f3924f0 100644 --- a/third_party/luau/Analysis/src/TypePack.cpp +++ b/third_party/luau/Analysis/src/TypePack.cpp @@ -6,6 +6,8 @@ #include +LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); + namespace Luau { @@ -255,16 +257,26 @@ bool areEqual(SeenSet& seen, const TypePackVar& lhs, const TypePackVar& rhs) TypePackId follow(TypePackId tp) { - return follow(tp, [](TypePackId t) { - return t; - }); + return follow( + tp, + nullptr, + [](const void*, TypePackId t) + { + return t; + } + ); } -TypePackId follow(TypePackId tp, std::function mapper) +TypePackId follow(TypePackId tp, const void* context, TypePackId (*mapper)(const void*, TypePackId)) { - auto advance = [&mapper](TypePackId ty) -> std::optional { - if (const Unifiable::Bound* btv = get>(mapper(ty))) + auto advance = [context, mapper](TypePackId ty) -> std::optional + { + TypePackId mapped = mapper(context, ty); + + if (const Unifiable::Bound* btv = get>(mapped)) return btv->boundTo; + else if (const TypePack* tp = get(mapped); tp && tp->head.empty()) + return tp->tail; else return std::nullopt; }; @@ -275,6 +287,9 @@ TypePackId follow(TypePackId tp, std::function mapper) else return tp; + if (!advance(cycleTester)) // Short circuit traversal for the rather common case when advance(advance(t)) == null + return cycleTester; + while (true) { auto a1 = advance(tp); @@ -456,4 +471,11 @@ bool containsNever(TypePackId tp) return false; } +template<> +LUAU_NOINLINE Unifiable::Bound* emplaceTypePack(TypePackVar* ty, TypePackId& tyArg) +{ + LUAU_ASSERT(ty != follow(tyArg)); + return &ty->ty.emplace(tyArg); +} + } // namespace Luau diff --git a/third_party/luau/Analysis/src/TypePath.cpp b/third_party/luau/Analysis/src/TypePath.cpp new file mode 100644 index 00000000..29f5cfb5 --- /dev/null +++ b/third_party/luau/Analysis/src/TypePath.cpp @@ -0,0 +1,711 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/TypePath.h" +#include "Luau/Common.h" +#include "Luau/DenseHash.h" +#include "Luau/Type.h" +#include "Luau/TypeFwd.h" +#include "Luau/TypePack.h" +#include "Luau/TypeOrPack.h" + +#include +#include +#include +#include + +LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); + +// Maximum number of steps to follow when traversing a path. May not always +// equate to the number of components in a path, depending on the traversal +// logic. +LUAU_DYNAMIC_FASTINTVARIABLE(LuauTypePathMaximumTraverseSteps, 100); + +namespace Luau +{ + +namespace TypePath +{ + +Property::Property(std::string name) + : name(std::move(name)) +{ + LUAU_ASSERT(!FFlag::DebugLuauDeferredConstraintResolution); +} + +Property Property::read(std::string name) +{ + return Property(std::move(name), true); +} + +Property Property::write(std::string name) +{ + return Property(std::move(name), false); +} + +bool Property::operator==(const Property& other) const +{ + return name == other.name && isRead == other.isRead; +} + +bool Index::operator==(const Index& other) const +{ + return index == other.index; +} + +bool Reduction::operator==(const Reduction& other) const +{ + return resultType == other.resultType; +} + +Path Path::append(const Path& suffix) const +{ + std::vector joined(components); + joined.reserve(suffix.components.size()); + joined.insert(joined.end(), suffix.components.begin(), suffix.components.end()); + return Path(std::move(joined)); +} + +Path Path::push(Component component) const +{ + std::vector joined(components); + joined.push_back(component); + return Path(std::move(joined)); +} + +Path Path::push_front(Component component) const +{ + std::vector joined{}; + joined.reserve(components.size() + 1); + joined.push_back(std::move(component)); + joined.insert(joined.end(), components.begin(), components.end()); + return Path(std::move(joined)); +} + +Path Path::pop() const +{ + if (empty()) + return kEmpty; + + std::vector popped(components); + popped.pop_back(); + return Path(std::move(popped)); +} + +std::optional Path::last() const +{ + if (empty()) + return std::nullopt; + + return components.back(); +} + +bool Path::empty() const +{ + return components.empty(); +} + +bool Path::operator==(const Path& other) const +{ + return components == other.components; +} + +size_t PathHash::operator()(const Property& prop) const +{ + return std::hash()(prop.name) ^ static_cast(prop.isRead); +} + +size_t PathHash::operator()(const Index& idx) const +{ + return idx.index; +} + +size_t PathHash::operator()(const TypeField& field) const +{ + return static_cast(field); +} + +size_t PathHash::operator()(const PackField& field) const +{ + return static_cast(field); +} + +size_t PathHash::operator()(const Reduction& reduction) const +{ + return std::hash()(reduction.resultType); +} + +size_t PathHash::operator()(const Component& component) const +{ + return visit(*this, component); +} + +size_t PathHash::operator()(const Path& path) const +{ + size_t hash = 0; + + for (const Component& component : path.components) + hash ^= (*this)(component); + + return hash; +} + +Path PathBuilder::build() +{ + return Path(std::move(components)); +} + +PathBuilder& PathBuilder::readProp(std::string name) +{ + LUAU_ASSERT(FFlag::DebugLuauDeferredConstraintResolution); + components.push_back(Property{std::move(name), true}); + return *this; +} + +PathBuilder& PathBuilder::writeProp(std::string name) +{ + LUAU_ASSERT(FFlag::DebugLuauDeferredConstraintResolution); + components.push_back(Property{std::move(name), false}); + return *this; +} + +PathBuilder& PathBuilder::prop(std::string name) +{ + LUAU_ASSERT(!FFlag::DebugLuauDeferredConstraintResolution); + components.push_back(Property{std::move(name)}); + return *this; +} + +PathBuilder& PathBuilder::index(size_t i) +{ + components.push_back(Index{i}); + return *this; +} + +PathBuilder& PathBuilder::mt() +{ + components.push_back(TypeField::Metatable); + return *this; +} + +PathBuilder& PathBuilder::lb() +{ + components.push_back(TypeField::LowerBound); + return *this; +} + +PathBuilder& PathBuilder::ub() +{ + components.push_back(TypeField::UpperBound); + return *this; +} + +PathBuilder& PathBuilder::indexKey() +{ + components.push_back(TypeField::IndexLookup); + return *this; +} + +PathBuilder& PathBuilder::indexValue() +{ + components.push_back(TypeField::IndexResult); + return *this; +} + +PathBuilder& PathBuilder::negated() +{ + components.push_back(TypeField::Negated); + return *this; +} + +PathBuilder& PathBuilder::variadic() +{ + components.push_back(TypeField::Variadic); + return *this; +} + +PathBuilder& PathBuilder::args() +{ + components.push_back(PackField::Arguments); + return *this; +} + +PathBuilder& PathBuilder::rets() +{ + components.push_back(PackField::Returns); + return *this; +} + +PathBuilder& PathBuilder::tail() +{ + components.push_back(PackField::Tail); + return *this; +} + +} // namespace TypePath + +namespace +{ + +struct TraversalState +{ + TraversalState(TypeId root, NotNull builtinTypes) + : current(root) + , builtinTypes(builtinTypes) + { + } + TraversalState(TypePackId root, NotNull builtinTypes) + : current(root) + , builtinTypes(builtinTypes) + { + } + + TypeOrPack current; + NotNull builtinTypes; + int steps = 0; + + void updateCurrent(TypeId ty) + { + LUAU_ASSERT(ty); + current = follow(ty); + } + + void updateCurrent(TypePackId tp) + { + LUAU_ASSERT(tp); + current = follow(tp); + } + + bool tooLong() + { + return ++steps > DFInt::LuauTypePathMaximumTraverseSteps; + } + + bool checkInvariants() + { + return tooLong(); + } + + bool traverse(const TypePath::Property& property) + { + auto currentType = get(current); + if (!currentType) + return false; + + if (checkInvariants()) + return false; + + const Property* prop = nullptr; + + if (auto t = get(*currentType)) + { + auto it = t->props.find(property.name); + if (it != t->props.end()) + { + prop = &it->second; + } + } + else if (auto c = get(*currentType)) + { + prop = lookupClassProp(c, property.name); + } + // For a metatable type, the table takes priority; check that before + // falling through to the metatable entry below. + else if (auto m = get(*currentType)) + { + TypeOrPack pinned = current; + updateCurrent(m->table); + + if (traverse(property)) + return true; + + // Restore the old current type if we didn't traverse the metatable + // successfully; we'll use the next branch to address this. + current = pinned; + } + + if (!prop) + { + if (auto m = getMetatable(*currentType, builtinTypes)) + { + // Weird: rather than use findMetatableEntry, which requires a lot + // of stuff that we don't have and don't want to pull in, we use the + // path traversal logic to grab __index and then re-enter the lookup + // logic there. + updateCurrent(*m); + + if (!traverse(TypePath::Property::read("__index"))) + return false; + + return traverse(property); + } + } + + if (prop) + { + std::optional maybeType; + if (FFlag::DebugLuauDeferredConstraintResolution) + maybeType = property.isRead ? prop->readTy : prop->writeTy; + else + maybeType = prop->type(); + + if (maybeType) + { + updateCurrent(*maybeType); + return true; + } + } + + return false; + } + + bool traverse(const TypePath::Index& index) + { + if (checkInvariants()) + return false; + + if (auto currentType = get(current)) + { + if (auto u = get(*currentType)) + { + auto it = begin(u); + std::advance(it, index.index); + if (it != end(u)) + { + updateCurrent(*it); + return true; + } + } + else if (auto i = get(*currentType)) + { + auto it = begin(i); + std::advance(it, index.index); + if (it != end(i)) + { + updateCurrent(*it); + return true; + } + } + } + else + { + auto currentPack = get(current); + LUAU_ASSERT(currentPack); + if (get(*currentPack)) + { + auto it = begin(*currentPack); + + for (size_t i = 0; i < index.index && it != end(*currentPack); ++i) + ++it; + + if (it != end(*currentPack)) + { + updateCurrent(*it); + return true; + } + } + } + + return false; + } + + bool traverse(TypePath::TypeField field) + { + if (checkInvariants()) + return false; + + switch (field) + { + case TypePath::TypeField::Metatable: + if (auto currentType = get(current)) + { + if (std::optional mt = getMetatable(*currentType, builtinTypes)) + { + updateCurrent(*mt); + return true; + } + } + + return false; + case TypePath::TypeField::LowerBound: + case TypePath::TypeField::UpperBound: + if (auto ft = get(current)) + { + updateCurrent(field == TypePath::TypeField::LowerBound ? ft->lowerBound : ft->upperBound); + return true; + } + + return false; + case TypePath::TypeField::IndexLookup: + case TypePath::TypeField::IndexResult: + { + const TableIndexer* indexer = nullptr; + + if (auto tt = get(current); tt && tt->indexer) + indexer = &(*tt->indexer); + else if (auto mt = get(current)) + { + if (auto mtTab = get(follow(mt->table)); mtTab && mtTab->indexer) + indexer = &(*mtTab->indexer); + else if (auto mtMt = get(follow(mt->metatable)); mtMt && mtMt->indexer) + indexer = &(*mtMt->indexer); + } + // Note: we don't appear to walk the class hierarchy for indexers + else if (auto ct = get(current); ct && ct->indexer) + indexer = &(*ct->indexer); + + if (indexer) + { + updateCurrent(field == TypePath::TypeField::IndexLookup ? indexer->indexType : indexer->indexResultType); + return true; + } + + return false; + } + case TypePath::TypeField::Negated: + if (auto nt = get(current)) + { + updateCurrent(nt->ty); + return true; + } + + return false; + case TypePath::TypeField::Variadic: + if (auto vtp = get(current)) + { + updateCurrent(vtp->ty); + return true; + } + + return false; + } + + return false; + } + + bool traverse(TypePath::Reduction reduction) + { + if (checkInvariants()) + return false; + updateCurrent(reduction.resultType); + return true; + } + + bool traverse(TypePath::PackField field) + { + if (checkInvariants()) + return false; + + switch (field) + { + case TypePath::PackField::Arguments: + case TypePath::PackField::Returns: + if (auto ft = get(current)) + { + updateCurrent(field == TypePath::PackField::Arguments ? ft->argTypes : ft->retTypes); + return true; + } + + return false; + case TypePath::PackField::Tail: + if (auto currentPack = get(current)) + { + auto it = begin(*currentPack); + while (it != end(*currentPack)) + ++it; + + if (auto tail = it.tail()) + { + updateCurrent(*tail); + return true; + } + } + + return false; + } + + return false; + } +}; + +} // namespace + +std::string toString(const TypePath::Path& path, bool prefixDot) +{ + std::stringstream result; + bool first = true; + + auto strComponent = [&](auto&& c) + { + using T = std::decay_t; + if constexpr (std::is_same_v) + { + result << '['; + if (FFlag::DebugLuauDeferredConstraintResolution) + { + if (c.isRead) + result << "read "; + else + result << "write "; + } + + result << '"' << c.name << '"' << ']'; + } + else if constexpr (std::is_same_v) + { + result << '[' << std::to_string(c.index) << ']'; + } + else if constexpr (std::is_same_v) + { + if (!first || prefixDot) + result << '.'; + + switch (c) + { + case TypePath::TypeField::Metatable: + result << "metatable"; + break; + case TypePath::TypeField::LowerBound: + result << "lowerBound"; + break; + case TypePath::TypeField::UpperBound: + result << "upperBound"; + break; + case TypePath::TypeField::IndexLookup: + result << "indexer"; + break; + case TypePath::TypeField::IndexResult: + result << "indexResult"; + break; + case TypePath::TypeField::Negated: + result << "negated"; + break; + case TypePath::TypeField::Variadic: + result << "variadic"; + break; + } + + result << "()"; + } + else if constexpr (std::is_same_v) + { + if (!first || prefixDot) + result << '.'; + + switch (c) + { + case TypePath::PackField::Arguments: + result << "arguments"; + break; + case TypePath::PackField::Returns: + result << "returns"; + break; + case TypePath::PackField::Tail: + result << "tail"; + break; + } + result << "()"; + } + else if constexpr (std::is_same_v) + { + // We need to rework the TypePath system to make subtyping failures easier to understand + // https://roblox.atlassian.net/browse/CLI-104422 + result << "~~>"; + } + else + { + static_assert(always_false_v, "Unhandled Component variant"); + } + + first = false; + }; + + for (const TypePath::Component& component : path.components) + Luau::visit(strComponent, component); + + return result.str(); +} + +static bool traverse(TraversalState& state, const Path& path) +{ + auto step = [&state](auto&& c) + { + return state.traverse(c); + }; + + for (const TypePath::Component& component : path.components) + { + bool stepSuccess = visit(step, component); + if (!stepSuccess) + return false; + } + + return true; +} + +std::optional traverse(TypeId root, const Path& path, NotNull builtinTypes) +{ + TraversalState state(follow(root), builtinTypes); + if (traverse(state, path)) + return state.current; + else + return std::nullopt; +} + +std::optional traverse(TypePackId root, const Path& path, NotNull builtinTypes) +{ + TraversalState state(follow(root), builtinTypes); + if (traverse(state, path)) + return state.current; + else + return std::nullopt; +} + +std::optional traverseForType(TypeId root, const Path& path, NotNull builtinTypes) +{ + TraversalState state(follow(root), builtinTypes); + if (traverse(state, path)) + { + auto ty = get(state.current); + return ty ? std::make_optional(*ty) : std::nullopt; + } + else + return std::nullopt; +} + +std::optional traverseForType(TypePackId root, const Path& path, NotNull builtinTypes) +{ + TraversalState state(follow(root), builtinTypes); + if (traverse(state, path)) + { + auto ty = get(state.current); + return ty ? std::make_optional(*ty) : std::nullopt; + } + else + return std::nullopt; +} + +std::optional traverseForPack(TypeId root, const Path& path, NotNull builtinTypes) +{ + TraversalState state(follow(root), builtinTypes); + if (traverse(state, path)) + { + auto ty = get(state.current); + return ty ? std::make_optional(*ty) : std::nullopt; + } + else + return std::nullopt; +} + +std::optional traverseForPack(TypePackId root, const Path& path, NotNull builtinTypes) +{ + TraversalState state(follow(root), builtinTypes); + if (traverse(state, path)) + { + auto ty = get(state.current); + return ty ? std::make_optional(*ty) : std::nullopt; + } + else + return std::nullopt; +} + +} // namespace Luau diff --git a/third_party/luau/Analysis/src/TypeReduction.cpp b/third_party/luau/Analysis/src/TypeReduction.cpp deleted file mode 100644 index b81cca7b..00000000 --- a/third_party/luau/Analysis/src/TypeReduction.cpp +++ /dev/null @@ -1,1200 +0,0 @@ -// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -#include "Luau/TypeReduction.h" - -#include "Luau/Common.h" -#include "Luau/Error.h" -#include "Luau/RecursionCounter.h" -#include "Luau/VisitType.h" - -#include -#include - -LUAU_FASTINTVARIABLE(LuauTypeReductionCartesianProductLimit, 100'000) -LUAU_FASTINTVARIABLE(LuauTypeReductionRecursionLimit, 300) -LUAU_FASTFLAGVARIABLE(DebugLuauDontReduceTypes, false) - -namespace Luau -{ - -namespace detail -{ -bool TypeReductionMemoization::isIrreducible(TypeId ty) -{ - ty = follow(ty); - - // Only does shallow check, the TypeReducer itself already does deep traversal. - if (auto edge = types.find(ty); edge && edge->irreducible) - return true; - else if (get(ty) || get(ty) || get(ty)) - return false; - else if (auto tt = get(ty); tt && (tt->state == TableState::Free || tt->state == TableState::Unsealed)) - return false; - else - return true; -} - -bool TypeReductionMemoization::isIrreducible(TypePackId tp) -{ - tp = follow(tp); - - // Only does shallow check, the TypeReducer itself already does deep traversal. - if (auto edge = typePacks.find(tp); edge && edge->irreducible) - return true; - else if (get(tp) || get(tp)) - return false; - else if (auto vtp = get(tp)) - return isIrreducible(vtp->ty); - else - return true; -} - -TypeId TypeReductionMemoization::memoize(TypeId ty, TypeId reducedTy) -{ - ty = follow(ty); - reducedTy = follow(reducedTy); - - // The irreducibility of this [`reducedTy`] depends on whether its contents are themselves irreducible. - // We don't need to recurse much further than that, because we already record the irreducibility from - // the bottom up. - bool irreducible = isIrreducible(reducedTy); - if (auto it = get(reducedTy)) - { - for (TypeId part : it) - irreducible &= isIrreducible(part); - } - else if (auto ut = get(reducedTy)) - { - for (TypeId option : ut) - irreducible &= isIrreducible(option); - } - else if (auto tt = get(reducedTy)) - { - for (auto& [k, p] : tt->props) - irreducible &= isIrreducible(p.type()); - - if (tt->indexer) - { - irreducible &= isIrreducible(tt->indexer->indexType); - irreducible &= isIrreducible(tt->indexer->indexResultType); - } - - for (auto ta : tt->instantiatedTypeParams) - irreducible &= isIrreducible(ta); - - for (auto tpa : tt->instantiatedTypePackParams) - irreducible &= isIrreducible(tpa); - } - else if (auto mt = get(reducedTy)) - { - irreducible &= isIrreducible(mt->table); - irreducible &= isIrreducible(mt->metatable); - } - else if (auto ft = get(reducedTy)) - { - irreducible &= isIrreducible(ft->argTypes); - irreducible &= isIrreducible(ft->retTypes); - } - else if (auto nt = get(reducedTy)) - irreducible &= isIrreducible(nt->ty); - - types[ty] = {reducedTy, irreducible}; - types[reducedTy] = {reducedTy, irreducible}; - return reducedTy; -} - -TypePackId TypeReductionMemoization::memoize(TypePackId tp, TypePackId reducedTp) -{ - tp = follow(tp); - reducedTp = follow(reducedTp); - - bool irreducible = isIrreducible(reducedTp); - TypePackIterator it = begin(tp); - while (it != end(tp)) - { - irreducible &= isIrreducible(*it); - ++it; - } - - if (it.tail()) - irreducible &= isIrreducible(*it.tail()); - - typePacks[tp] = {reducedTp, irreducible}; - typePacks[reducedTp] = {reducedTp, irreducible}; - return reducedTp; -} - -std::optional> TypeReductionMemoization::memoizedof(TypeId ty) const -{ - auto fetchContext = [this](TypeId ty) -> std::optional> { - if (auto edge = types.find(ty)) - return *edge; - else - return std::nullopt; - }; - - TypeId currentTy = ty; - std::optional> lastEdge; - while (auto edge = fetchContext(currentTy)) - { - lastEdge = edge; - if (edge->irreducible) - return edge; - else if (edge->type == currentTy) - return edge; - else - currentTy = edge->type; - } - - return lastEdge; -} - -std::optional> TypeReductionMemoization::memoizedof(TypePackId tp) const -{ - auto fetchContext = [this](TypePackId tp) -> std::optional> { - if (auto edge = typePacks.find(tp)) - return *edge; - else - return std::nullopt; - }; - - TypePackId currentTp = tp; - std::optional> lastEdge; - while (auto edge = fetchContext(currentTp)) - { - lastEdge = edge; - if (edge->irreducible) - return edge; - else if (edge->type == currentTp) - return edge; - else - currentTp = edge->type; - } - - return lastEdge; -} -} // namespace detail - -namespace -{ - -template -std::pair get2(const Thing& one, const Thing& two) -{ - const A* a = get(one); - const B* b = get(two); - return a && b ? std::make_pair(a, b) : std::make_pair(nullptr, nullptr); -} - -struct TypeReducer -{ - NotNull arena; - NotNull builtinTypes; - NotNull handle; - NotNull memoization; - DenseHashSet* cyclics; - - int depth = 0; - - TypeId reduce(TypeId ty); - TypePackId reduce(TypePackId tp); - - std::optional intersectionType(TypeId left, TypeId right); - std::optional unionType(TypeId left, TypeId right); - TypeId tableType(TypeId ty); - TypeId functionType(TypeId ty); - TypeId negationType(TypeId ty); - - using BinaryFold = std::optional (TypeReducer::*)(TypeId, TypeId); - using UnaryFold = TypeId (TypeReducer::*)(TypeId); - - template - LUAU_NOINLINE std::pair copy(TypeId ty, const T* t) - { - ty = follow(ty); - - if (auto edge = memoization->memoizedof(ty)) - return {edge->type, getMutable(edge->type)}; - - // We specifically do not want to use [`detail::TypeReductionMemoization::memoize`] because that will - // potentially consider these copiedTy to be reducible, but we need this to resolve cyclic references - // without attempting to recursively reduce it, causing copies of copies of copies of... - TypeId copiedTy = arena->addType(*t); - memoization->types[ty] = {copiedTy, true}; - memoization->types[copiedTy] = {copiedTy, true}; - return {copiedTy, getMutable(copiedTy)}; - } - - template - void foldl_impl(Iter it, Iter endIt, BinaryFold f, std::vector* result, bool* didReduce) - { - RecursionLimiter rl{&depth, FInt::LuauTypeReductionRecursionLimit}; - - while (it != endIt) - { - TypeId right = reduce(*it); - *didReduce |= right != follow(*it); - - // We're hitting a case where the `currentTy` returned a type that's the same as `T`. - // e.g. `(string?) & ~(false | nil)` became `(string?) & (~false & ~nil)` but the current iterator we're consuming doesn't know this. - // We will need to recurse and traverse that first. - if (auto t = get(right)) - { - foldl_impl(begin(t), end(t), f, result, didReduce); - ++it; - continue; - } - - bool replaced = false; - auto resultIt = result->begin(); - while (resultIt != result->end()) - { - TypeId left = *resultIt; - if (left == right) - { - replaced = true; - ++resultIt; - continue; - } - - std::optional reduced = (this->*f)(left, right); - if (reduced) - { - *resultIt = *reduced; - ++resultIt; - replaced = true; - } - else - { - ++resultIt; - continue; - } - } - - if (!replaced) - result->push_back(right); - - *didReduce |= replaced; - ++it; - } - } - - template - TypeId flatten(std::vector&& types) - { - if (types.size() == 1) - return types[0]; - else - return arena->addType(T{std::move(types)}); - } - - template - TypeId foldl(Iter it, Iter endIt, std::optional ty, BinaryFold f) - { - std::vector result; - bool didReduce = false; - foldl_impl(it, endIt, f, &result, &didReduce); - - // If we've done any reduction, then we'll need to reduce it again, e.g. - // `"a" | "b" | string` is reduced into `string | string`, which is then reduced into `string`. - if (!didReduce) - return ty ? *ty : flatten(std::move(result)); - else - return reduce(flatten(std::move(result))); - } - - template - TypeId apply(BinaryFold f, TypeId left, TypeId right) - { - std::vector types{left, right}; - return foldl(begin(types), end(types), std::nullopt, f); - } - - template - TypeId distribute(TypeIterator it, TypeIterator endIt, BinaryFold f, TypeId ty) - { - std::vector result; - while (it != endIt) - { - result.push_back(apply(f, *it, ty)); - ++it; - } - return flatten(std::move(result)); - } -}; - -TypeId TypeReducer::reduce(TypeId ty) -{ - ty = follow(ty); - - if (auto edge = memoization->memoizedof(ty)) - { - if (edge->irreducible) - return edge->type; - else - ty = follow(edge->type); - } - else if (cyclics->contains(ty)) - return ty; - - RecursionLimiter rl{&depth, FInt::LuauTypeReductionRecursionLimit}; - - TypeId result = nullptr; - if (auto i = get(ty)) - result = foldl(begin(i), end(i), ty, &TypeReducer::intersectionType); - else if (auto u = get(ty)) - result = foldl(begin(u), end(u), ty, &TypeReducer::unionType); - else if (get(ty) || get(ty)) - result = tableType(ty); - else if (get(ty)) - result = functionType(ty); - else if (get(ty)) - result = negationType(ty); - else - result = ty; - - return memoization->memoize(ty, result); -} - -TypePackId TypeReducer::reduce(TypePackId tp) -{ - tp = follow(tp); - - if (auto edge = memoization->memoizedof(tp)) - { - if (edge->irreducible) - return edge->type; - else - tp = edge->type; - } - else if (cyclics->contains(tp)) - return tp; - - RecursionLimiter rl{&depth, FInt::LuauTypeReductionRecursionLimit}; - - bool didReduce = false; - TypePackIterator it = begin(tp); - - std::vector head; - while (it != end(tp)) - { - TypeId reducedTy = reduce(*it); - head.push_back(reducedTy); - didReduce |= follow(*it) != follow(reducedTy); - ++it; - } - - std::optional tail = it.tail(); - if (tail) - { - if (auto vtp = get(follow(*it.tail()))) - { - TypeId reducedTy = reduce(vtp->ty); - if (follow(vtp->ty) != follow(reducedTy)) - { - tail = arena->addTypePack(VariadicTypePack{reducedTy, vtp->hidden}); - didReduce = true; - } - } - } - - if (!didReduce) - return memoization->memoize(tp, tp); - else if (head.empty() && tail) - return memoization->memoize(tp, *tail); - else - return memoization->memoize(tp, arena->addTypePack(TypePack{std::move(head), tail})); -} - -std::optional TypeReducer::intersectionType(TypeId left, TypeId right) -{ - if (get(left)) - return left; // never & T ~ never - else if (get(right)) - return right; // T & never ~ never - else if (get(left)) - return right; // unknown & T ~ T - else if (get(right)) - return left; // T & unknown ~ T - else if (get(left)) - return right; // any & T ~ T - else if (get(right)) - return left; // T & any ~ T - else if (get(left)) - return std::nullopt; // 'a & T ~ 'a & T - else if (get(right)) - return std::nullopt; // T & 'a ~ T & 'a - else if (get(left)) - return std::nullopt; // G & T ~ G & T - else if (get(right)) - return std::nullopt; // T & G ~ T & G - else if (get(left)) - return std::nullopt; // error & T ~ error & T - else if (get(right)) - return std::nullopt; // T & error ~ T & error - else if (get(left)) - return std::nullopt; // *blocked* & T ~ *blocked* & T - else if (get(right)) - return std::nullopt; // T & *blocked* ~ T & *blocked* - else if (get(left)) - return std::nullopt; // *pending* & T ~ *pending* & T - else if (get(right)) - return std::nullopt; // T & *pending* ~ T & *pending* - else if (auto [utl, utr] = get2(left, right); utl && utr) - { - std::vector parts; - for (TypeId optionl : utl) - { - for (TypeId optionr : utr) - parts.push_back(apply(&TypeReducer::intersectionType, optionl, optionr)); - } - - return reduce(flatten(std::move(parts))); // (T | U) & (A | B) ~ (T & A) | (T & B) | (U & A) | (U & B) - } - else if (auto ut = get(left)) - return reduce(distribute(begin(ut), end(ut), &TypeReducer::intersectionType, right)); // (A | B) & T ~ (A & T) | (B & T) - else if (get(right)) - return intersectionType(right, left); // T & (A | B) ~ (A | B) & T - else if (auto [p1, p2] = get2(left, right); p1 && p2) - { - if (p1->type == p2->type) - return left; // P1 & P2 ~ P1 iff P1 == P2 - else - return builtinTypes->neverType; // P1 & P2 ~ never iff P1 != P2 - } - else if (auto [p, s] = get2(left, right); p && s) - { - if (p->type == PrimitiveType::String && get(s)) - return right; // string & "A" ~ "A" - else if (p->type == PrimitiveType::Boolean && get(s)) - return right; // boolean & true ~ true - else - return builtinTypes->neverType; // string & true ~ never - } - else if (auto [s, p] = get2(left, right); s && p) - return intersectionType(right, left); // S & P ~ P & S - else if (auto [p, f] = get2(left, right); p && f) - { - if (p->type == PrimitiveType::Function) - return right; // function & () -> () ~ () -> () - else - return builtinTypes->neverType; // string & () -> () ~ never - } - else if (auto [f, p] = get2(left, right); f && p) - return intersectionType(right, left); // () -> () & P ~ P & () -> () - else if (auto [p, t] = get2(left, right); p && t) - { - if (p->type == PrimitiveType::Table) - return right; // table & {} ~ {} - else - return builtinTypes->neverType; // string & {} ~ never - } - else if (auto [p, t] = get2(left, right); p && t) - { - if (p->type == PrimitiveType::Table) - return right; // table & {} ~ {} - else - return builtinTypes->neverType; // string & {} ~ never - } - else if (auto [t, p] = get2(left, right); t && p) - return intersectionType(right, left); // {} & P ~ P & {} - else if (auto [t, p] = get2(left, right); t && p) - return intersectionType(right, left); // M & P ~ P & M - else if (auto [s1, s2] = get2(left, right); s1 && s2) - { - if (*s1 == *s2) - return left; // "a" & "a" ~ "a" - else - return builtinTypes->neverType; // "a" & "b" ~ never - } - else if (auto [c1, c2] = get2(left, right); c1 && c2) - { - if (isSubclass(c1, c2)) - return left; // Derived & Base ~ Derived - else if (isSubclass(c2, c1)) - return right; // Base & Derived ~ Derived - else - return builtinTypes->neverType; // Base & Unrelated ~ never - } - else if (auto [f1, f2] = get2(left, right); f1 && f2) - return std::nullopt; // TODO - else if (auto [t1, t2] = get2(left, right); t1 && t2) - { - if (t1->state == TableState::Free || t2->state == TableState::Free) - return std::nullopt; // '{ x: T } & { x: U } ~ '{ x: T } & { x: U } - else if (t1->state == TableState::Generic || t2->state == TableState::Generic) - return std::nullopt; // '{ x: T } & { x: U } ~ '{ x: T } & { x: U } - - if (cyclics->contains(left)) - return std::nullopt; // (t1 where t1 = { p: t1 }) & {} ~ t1 & {} - else if (cyclics->contains(right)) - return std::nullopt; // {} & (t1 where t1 = { p: t1 }) ~ {} & t1 - - TypeId resultTy = arena->addType(TableType{}); - TableType* table = getMutable(resultTy); - table->state = t1->state == TableState::Sealed || t2->state == TableState::Sealed ? TableState::Sealed : TableState::Unsealed; - - for (const auto& [name, prop] : t1->props) - { - // TODO: when t1 has properties, we should also intersect that with the indexer in t2 if it exists, - // even if we have the corresponding property in the other one. - if (auto other = t2->props.find(name); other != t2->props.end()) - { - TypeId propTy = apply(&TypeReducer::intersectionType, prop.type(), other->second.type()); - if (get(propTy)) - return builtinTypes->neverType; // { p : string } & { p : number } ~ { p : string & number } ~ { p : never } ~ never - else - table->props[name] = {propTy}; // { p : string } & { p : ~"a" } ~ { p : string & ~"a" } - } - else - table->props[name] = prop; // { p : string } & {} ~ { p : string } - } - - for (const auto& [name, prop] : t2->props) - { - // TODO: And vice versa, t2 properties against t1 indexer if it exists, - // even if we have the corresponding property in the other one. - if (!t1->props.count(name)) - table->props[name] = {reduce(prop.type())}; // {} & { p : string & string } ~ { p : string } - } - - if (t1->indexer && t2->indexer) - { - TypeId keyTy = apply(&TypeReducer::intersectionType, t1->indexer->indexType, t2->indexer->indexType); - if (get(keyTy)) - return std::nullopt; // { [string]: _ } & { [number]: _ } ~ { [string]: _ } & { [number]: _ } - - TypeId valueTy = apply(&TypeReducer::intersectionType, t1->indexer->indexResultType, t2->indexer->indexResultType); - table->indexer = TableIndexer{keyTy, valueTy}; // { [string]: number } & { [string]: string } ~ { [string]: never } - } - else if (t1->indexer) - { - TypeId keyTy = reduce(t1->indexer->indexType); - TypeId valueTy = reduce(t1->indexer->indexResultType); - table->indexer = TableIndexer{keyTy, valueTy}; // { [number]: boolean } & { p : string } ~ { p : string, [number]: boolean } - } - else if (t2->indexer) - { - TypeId keyTy = reduce(t2->indexer->indexType); - TypeId valueTy = reduce(t2->indexer->indexResultType); - table->indexer = TableIndexer{keyTy, valueTy}; // { p : string } & { [number]: boolean } ~ { p : string, [number]: boolean } - } - - return resultTy; - } - else if (auto [mt, tt] = get2(left, right); mt && tt) - return std::nullopt; // TODO - else if (auto [tt, mt] = get2(left, right); tt && mt) - return intersectionType(right, left); // T & M ~ M & T - else if (auto [m1, m2] = get2(left, right); m1 && m2) - return std::nullopt; // TODO - else if (auto [nl, nr] = get2(left, right); nl && nr) - { - // These should've been reduced already. - TypeId nlTy = follow(nl->ty); - TypeId nrTy = follow(nr->ty); - LUAU_ASSERT(!get(nlTy) && !get(nrTy)); - LUAU_ASSERT(!get(nlTy) && !get(nrTy)); - LUAU_ASSERT(!get(nlTy) && !get(nrTy)); - LUAU_ASSERT(!get(nlTy) && !get(nrTy)); - LUAU_ASSERT(!get(nlTy) && !get(nrTy)); - - if (auto [npl, npr] = get2(nlTy, nrTy); npl && npr) - { - if (npl->type == npr->type) - return left; // ~P1 & ~P2 ~ ~P1 iff P1 == P2 - else - return std::nullopt; // ~P1 & ~P2 ~ ~P1 & ~P2 iff P1 != P2 - } - else if (auto [nsl, nsr] = get2(nlTy, nrTy); nsl && nsr) - { - if (*nsl == *nsr) - return left; // ~"A" & ~"A" ~ ~"A" - else - return std::nullopt; // ~"A" & ~"B" ~ ~"A" & ~"B" - } - else if (auto [ns, np] = get2(nlTy, nrTy); ns && np) - { - if (get(ns) && np->type == PrimitiveType::String) - return right; // ~"A" & ~string ~ ~string - else if (get(ns) && np->type == PrimitiveType::Boolean) - return right; // ~false & ~boolean ~ ~boolean - else - return std::nullopt; // ~"A" | ~P ~ ~"A" & ~P - } - else if (auto [np, ns] = get2(nlTy, nrTy); np && ns) - return intersectionType(right, left); // ~P & ~S ~ ~S & ~P - else - return std::nullopt; // ~T & ~U ~ ~T & ~U - } - else if (auto nl = get(left)) - { - // These should've been reduced already. - TypeId nlTy = follow(nl->ty); - LUAU_ASSERT(!get(nlTy)); - LUAU_ASSERT(!get(nlTy)); - LUAU_ASSERT(!get(nlTy)); - LUAU_ASSERT(!get(nlTy)); - LUAU_ASSERT(!get(nlTy)); - - if (auto [np, p] = get2(nlTy, right); np && p) - { - if (np->type == p->type) - return builtinTypes->neverType; // ~P1 & P2 ~ never iff P1 == P2 - else - return right; // ~P1 & P2 ~ P2 iff P1 != P2 - } - else if (auto [ns, s] = get2(nlTy, right); ns && s) - { - if (*ns == *s) - return builtinTypes->neverType; // ~"A" & "A" ~ never - else - return right; // ~"A" & "B" ~ "B" - } - else if (auto [ns, p] = get2(nlTy, right); ns && p) - { - if (get(ns) && p->type == PrimitiveType::String) - return std::nullopt; // ~"A" & string ~ ~"A" & string - else if (get(ns) && p->type == PrimitiveType::Boolean) - { - // Because booleans contain a fixed amount of values (2), we can do something cooler with this one. - const BooleanSingleton* b = get(ns); - return arena->addType(SingletonType{BooleanSingleton{!b->value}}); // ~false & boolean ~ true - } - else - return right; // ~"A" & number ~ number - } - else if (auto [np, s] = get2(nlTy, right); np && s) - { - if (np->type == PrimitiveType::String && get(s)) - return builtinTypes->neverType; // ~string & "A" ~ never - else if (np->type == PrimitiveType::Boolean && get(s)) - return builtinTypes->neverType; // ~boolean & true ~ never - else - return right; // ~P & "A" ~ "A" - } - else if (auto [np, f] = get2(nlTy, right); np && f) - { - if (np->type == PrimitiveType::Function) - return builtinTypes->neverType; // ~function & () -> () ~ never - else - return right; // ~string & () -> () ~ () -> () - } - else if (auto [nc, c] = get2(nlTy, right); nc && c) - { - if (isSubclass(c, nc)) - return builtinTypes->neverType; // ~Base & Derived ~ never - else if (isSubclass(nc, c)) - return std::nullopt; // ~Derived & Base ~ ~Derived & Base - else - return right; // ~Base & Unrelated ~ Unrelated - } - else if (auto [np, t] = get2(nlTy, right); np && t) - { - if (np->type == PrimitiveType::Table) - return builtinTypes->neverType; // ~table & {} ~ never - else - return right; // ~string & {} ~ {} - } - else if (auto [np, t] = get2(nlTy, right); np && t) - { - if (np->type == PrimitiveType::Table) - return builtinTypes->neverType; // ~table & {} ~ never - else - return right; // ~string & {} ~ {} - } - else - return right; // ~T & U ~ U - } - else if (get(right)) - return intersectionType(right, left); // T & ~U ~ ~U & T - else - return builtinTypes->neverType; // for all T and U except the ones handled above, T & U ~ never -} - -std::optional TypeReducer::unionType(TypeId left, TypeId right) -{ - LUAU_ASSERT(!get(left)); - LUAU_ASSERT(!get(right)); - - if (get(left)) - return right; // never | T ~ T - else if (get(right)) - return left; // T | never ~ T - else if (get(left)) - return left; // unknown | T ~ unknown - else if (get(right)) - return right; // T | unknown ~ unknown - else if (get(left)) - return left; // any | T ~ any - else if (get(right)) - return right; // T | any ~ any - else if (get(left)) - return std::nullopt; // error | T ~ error | T - else if (get(right)) - return std::nullopt; // T | error ~ T | error - else if (auto [p1, p2] = get2(left, right); p1 && p2) - { - if (p1->type == p2->type) - return left; // P1 | P2 ~ P1 iff P1 == P2 - else - return std::nullopt; // P1 | P2 ~ P1 | P2 iff P1 != P2 - } - else if (auto [p, s] = get2(left, right); p && s) - { - if (p->type == PrimitiveType::String && get(s)) - return left; // string | "A" ~ string - else if (p->type == PrimitiveType::Boolean && get(s)) - return left; // boolean | true ~ boolean - else - return std::nullopt; // string | true ~ string | true - } - else if (auto [s, p] = get2(left, right); s && p) - return unionType(right, left); // S | P ~ P | S - else if (auto [p, f] = get2(left, right); p && f) - { - if (p->type == PrimitiveType::Function) - return left; // function | () -> () ~ function - else - return std::nullopt; // P | () -> () ~ P | () -> () - } - else if (auto [f, p] = get2(left, right); f && p) - return unionType(right, left); // () -> () | P ~ P | () -> () - else if (auto [p, t] = get2(left, right); p && t) - { - if (p->type == PrimitiveType::Table) - return left; // table | {} ~ table - else - return std::nullopt; // P | {} ~ P | {} - } - else if (auto [p, t] = get2(left, right); p && t) - { - if (p->type == PrimitiveType::Table) - return left; // table | {} ~ table - else - return std::nullopt; // P | {} ~ P | {} - } - else if (auto [t, p] = get2(left, right); t && p) - return unionType(right, left); // {} | P ~ P | {} - else if (auto [t, p] = get2(left, right); t && p) - return unionType(right, left); // M | P ~ P | M - else if (auto [s1, s2] = get2(left, right); s1 && s2) - { - if (*s1 == *s2) - return left; // "a" | "a" ~ "a" - else - return std::nullopt; // "a" | "b" ~ "a" | "b" - } - else if (auto [c1, c2] = get2(left, right); c1 && c2) - { - if (isSubclass(c1, c2)) - return right; // Derived | Base ~ Base - else if (isSubclass(c2, c1)) - return left; // Base | Derived ~ Base - else - return std::nullopt; // Base | Unrelated ~ Base | Unrelated - } - else if (auto [nt, it] = get2(left, right); nt && it) - return reduce(distribute(begin(it), end(it), &TypeReducer::unionType, left)); // ~T | (A & B) ~ (~T | A) & (~T | B) - else if (auto [it, nt] = get2(left, right); it && nt) - return unionType(right, left); // (A & B) | ~T ~ ~T | (A & B) - else if (auto it = get(left)) - { - bool didReduce = false; - std::vector parts; - for (TypeId part : it) - { - auto nt = get(part); - if (!nt) - { - parts.push_back(part); - continue; - } - - auto redex = unionType(part, right); - if (redex && get(*redex)) - { - didReduce = true; - continue; - } - - parts.push_back(part); - } - - if (didReduce) - return flatten(std::move(parts)); // (T & ~nil) | nil ~ T - else - return std::nullopt; // (T & ~nil) | U - } - else if (get(right)) - return unionType(right, left); // A | (T & U) ~ (T & U) | A - else if (auto [nl, nr] = get2(left, right); nl && nr) - { - // These should've been reduced already. - TypeId nlTy = follow(nl->ty); - TypeId nrTy = follow(nr->ty); - LUAU_ASSERT(!get(nlTy) && !get(nrTy)); - LUAU_ASSERT(!get(nlTy) && !get(nrTy)); - LUAU_ASSERT(!get(nlTy) && !get(nrTy)); - LUAU_ASSERT(!get(nlTy) && !get(nrTy)); - LUAU_ASSERT(!get(nlTy) && !get(nrTy)); - - if (auto [npl, npr] = get2(nlTy, nrTy); npl && npr) - { - if (npl->type == npr->type) - return left; // ~P1 | ~P2 ~ ~P1 iff P1 == P2 - else - return builtinTypes->unknownType; // ~P1 | ~P2 ~ ~P1 iff P1 != P2 - } - else if (auto [nsl, nsr] = get2(nlTy, nrTy); nsl && nsr) - { - if (*nsl == *nsr) - return left; // ~"A" | ~"A" ~ ~"A" - else - return builtinTypes->unknownType; // ~"A" | ~"B" ~ unknown - } - else if (auto [ns, np] = get2(nlTy, nrTy); ns && np) - { - if (get(ns) && np->type == PrimitiveType::String) - return left; // ~"A" | ~string ~ ~"A" - else if (get(ns) && np->type == PrimitiveType::Boolean) - return left; // ~false | ~boolean ~ ~false - else - return builtinTypes->unknownType; // ~"A" | ~P ~ unknown - } - else if (auto [np, ns] = get2(nlTy, nrTy); np && ns) - return unionType(right, left); // ~P | ~S ~ ~S | ~P - else - return std::nullopt; // TODO! - } - else if (auto nl = get(left)) - { - // These should've been reduced already. - TypeId nlTy = follow(nl->ty); - LUAU_ASSERT(!get(nlTy)); - LUAU_ASSERT(!get(nlTy)); - LUAU_ASSERT(!get(nlTy)); - LUAU_ASSERT(!get(nlTy)); - LUAU_ASSERT(!get(nlTy)); - - if (auto [np, p] = get2(nlTy, right); np && p) - { - if (np->type == p->type) - return builtinTypes->unknownType; // ~P1 | P2 ~ unknown iff P1 == P2 - else - return left; // ~P1 | P2 ~ ~P1 iff P1 != P2 - } - else if (auto [ns, s] = get2(nlTy, right); ns && s) - { - if (*ns == *s) - return builtinTypes->unknownType; // ~"A" | "A" ~ unknown - else - return left; // ~"A" | "B" ~ ~"A" - } - else if (auto [ns, p] = get2(nlTy, right); ns && p) - { - if (get(ns) && p->type == PrimitiveType::String) - return builtinTypes->unknownType; // ~"A" | string ~ unknown - else if (get(ns) && p->type == PrimitiveType::Boolean) - return builtinTypes->unknownType; // ~false | boolean ~ unknown - else - return left; // ~"A" | T ~ ~"A" - } - else if (auto [np, s] = get2(nlTy, right); np && s) - { - if (np->type == PrimitiveType::String && get(s)) - return std::nullopt; // ~string | "A" ~ ~string | "A" - else if (np->type == PrimitiveType::Boolean && get(s)) - { - const BooleanSingleton* b = get(s); - return negationType(arena->addType(SingletonType{BooleanSingleton{!b->value}})); // ~boolean | false ~ ~true - } - else - return left; // ~P | "A" ~ ~P - } - else if (auto [nc, c] = get2(nlTy, right); nc && c) - { - if (isSubclass(c, nc)) - return std::nullopt; // ~Base | Derived ~ ~Base | Derived - else if (isSubclass(nc, c)) - return builtinTypes->unknownType; // ~Derived | Base ~ unknown - else - return left; // ~Base | Unrelated ~ ~Base - } - else if (auto [np, t] = get2(nlTy, right); np && t) - { - if (np->type == PrimitiveType::Table) - return std::nullopt; // ~table | {} ~ ~table | {} - else - return right; // ~P | {} ~ ~P | {} - } - else if (auto [np, t] = get2(nlTy, right); np && t) - { - if (np->type == PrimitiveType::Table) - return std::nullopt; // ~table | {} ~ ~table | {} - else - return right; // ~P | M ~ ~P | M - } - else - return std::nullopt; // TODO - } - else if (get(right)) - return unionType(right, left); // T | ~U ~ ~U | T - else - return std::nullopt; // for all T and U except the ones handled above, T | U ~ T | U -} - -TypeId TypeReducer::tableType(TypeId ty) -{ - if (auto mt = get(ty)) - { - auto [copiedTy, copied] = copy(ty, mt); - copied->table = reduce(mt->table); - copied->metatable = reduce(mt->metatable); - return copiedTy; - } - else if (auto tt = get(ty)) - { - // Because of `typeof()`, we need to preserve pointer identity of free/unsealed tables so that - // all mutations that occurs on this will be applied without leaking the implementation details. - // As a result, we'll just use the type instead of cloning it if it's free/unsealed. - // - // We could choose to do in-place reductions here, but to be on the safer side, I propose that we do not. - if (tt->state == TableState::Free || tt->state == TableState::Unsealed) - return ty; - - auto [copiedTy, copied] = copy(ty, tt); - - for (auto& [name, prop] : copied->props) - { - TypeId propTy = reduce(prop.type()); - if (get(propTy)) - return builtinTypes->neverType; - else - prop.setType(propTy); - } - - if (copied->indexer) - { - TypeId keyTy = reduce(copied->indexer->indexType); - TypeId valueTy = reduce(copied->indexer->indexResultType); - copied->indexer = TableIndexer{keyTy, valueTy}; - } - - for (TypeId& ty : copied->instantiatedTypeParams) - ty = reduce(ty); - - for (TypePackId& tp : copied->instantiatedTypePackParams) - tp = reduce(tp); - - return copiedTy; - } - else - handle->ice("TypeReducer::tableType expects a TableType or MetatableType"); -} - -TypeId TypeReducer::functionType(TypeId ty) -{ - const FunctionType* f = get(ty); - if (!f) - handle->ice("TypeReducer::functionType expects a FunctionType"); - - // TODO: once we have bounded quantification, we need to be able to reduce the generic bounds. - auto [copiedTy, copied] = copy(ty, f); - copied->argTypes = reduce(f->argTypes); - copied->retTypes = reduce(f->retTypes); - return copiedTy; -} - -TypeId TypeReducer::negationType(TypeId ty) -{ - const NegationType* n = get(ty); - if (!n) - return arena->addType(NegationType{ty}); - - TypeId negatedTy = follow(n->ty); - - if (auto nn = get(negatedTy)) - return nn->ty; // ~~T ~ T - else if (get(negatedTy)) - return builtinTypes->unknownType; // ~never ~ unknown - else if (get(negatedTy)) - return builtinTypes->neverType; // ~unknown ~ never - else if (get(negatedTy)) - return builtinTypes->anyType; // ~any ~ any - else if (auto ni = get(negatedTy)) - { - std::vector options; - for (TypeId part : ni) - options.push_back(negationType(arena->addType(NegationType{part}))); - return reduce(flatten(std::move(options))); // ~(T & U) ~ (~T | ~U) - } - else if (auto nu = get(negatedTy)) - { - std::vector parts; - for (TypeId option : nu) - parts.push_back(negationType(arena->addType(NegationType{option}))); - return reduce(flatten(std::move(parts))); // ~(T | U) ~ (~T & ~U) - } - else - return ty; // for all T except the ones handled above, ~T ~ ~T -} - -struct MarkCycles : TypeVisitor -{ - DenseHashSet cyclics{nullptr}; - - void cycle(TypeId ty) override - { - cyclics.insert(follow(ty)); - } - - void cycle(TypePackId tp) override - { - cyclics.insert(follow(tp)); - } - - bool visit(TypeId ty) override - { - return !cyclics.find(follow(ty)); - } - - bool visit(TypePackId tp) override - { - return !cyclics.find(follow(tp)); - } -}; -} // namespace - -TypeReduction::TypeReduction( - NotNull arena, NotNull builtinTypes, NotNull handle, const TypeReductionOptions& opts) - : arena(arena) - , builtinTypes(builtinTypes) - , handle(handle) - , options(opts) -{ -} - -std::optional TypeReduction::reduce(TypeId ty) -{ - ty = follow(ty); - - if (FFlag::DebugLuauDontReduceTypes) - return ty; - else if (!options.allowTypeReductionsFromOtherArenas && ty->owningArena != arena) - return ty; - else if (auto edge = memoization.memoizedof(ty)) - { - if (edge->irreducible) - return edge->type; - else - ty = edge->type; - } - else if (hasExceededCartesianProductLimit(ty)) - return std::nullopt; - - try - { - MarkCycles finder; - finder.traverse(ty); - - TypeReducer reducer{arena, builtinTypes, handle, NotNull{&memoization}, &finder.cyclics}; - return reducer.reduce(ty); - } - catch (const RecursionLimitException&) - { - return std::nullopt; - } -} - -std::optional TypeReduction::reduce(TypePackId tp) -{ - tp = follow(tp); - - if (FFlag::DebugLuauDontReduceTypes) - return tp; - else if (!options.allowTypeReductionsFromOtherArenas && tp->owningArena != arena) - return tp; - else if (auto edge = memoization.memoizedof(tp)) - { - if (edge->irreducible) - return edge->type; - else - tp = edge->type; - } - else if (hasExceededCartesianProductLimit(tp)) - return std::nullopt; - - try - { - MarkCycles finder; - finder.traverse(tp); - - TypeReducer reducer{arena, builtinTypes, handle, NotNull{&memoization}, &finder.cyclics}; - return reducer.reduce(tp); - } - catch (const RecursionLimitException&) - { - return std::nullopt; - } -} - -std::optional TypeReduction::reduce(const TypeFun& fun) -{ - if (FFlag::DebugLuauDontReduceTypes) - return fun; - - // TODO: once we have bounded quantification, we need to be able to reduce the generic bounds. - if (auto reducedTy = reduce(fun.type)) - return TypeFun{fun.typeParams, fun.typePackParams, *reducedTy}; - - return std::nullopt; -} - -size_t TypeReduction::cartesianProductSize(TypeId ty) const -{ - ty = follow(ty); - - auto it = get(follow(ty)); - if (!it) - return 1; - - return std::accumulate(begin(it), end(it), size_t(1), [](size_t acc, TypeId ty) { - if (auto ut = get(ty)) - return acc * std::distance(begin(ut), end(ut)); - else if (get(ty)) - return acc * 0; - else - return acc * 1; - }); -} - -bool TypeReduction::hasExceededCartesianProductLimit(TypeId ty) const -{ - return cartesianProductSize(ty) >= size_t(FInt::LuauTypeReductionCartesianProductLimit); -} - -bool TypeReduction::hasExceededCartesianProductLimit(TypePackId tp) const -{ - TypePackIterator it = begin(tp); - - while (it != end(tp)) - { - if (hasExceededCartesianProductLimit(*it)) - return true; - - ++it; - } - - if (auto tail = it.tail()) - { - if (auto vtp = get(follow(*tail))) - { - if (hasExceededCartesianProductLimit(vtp->ty)) - return true; - } - } - - return false; -} - -} // namespace Luau diff --git a/third_party/luau/Analysis/src/TypeUtils.cpp b/third_party/luau/Analysis/src/TypeUtils.cpp index 9124e2fc..b40805e9 100644 --- a/third_party/luau/Analysis/src/TypeUtils.cpp +++ b/third_party/luau/Analysis/src/TypeUtils.cpp @@ -1,6 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/TypeUtils.h" +#include "Luau/Common.h" #include "Luau/Normalize.h" #include "Luau/Scope.h" #include "Luau/ToString.h" @@ -8,11 +9,96 @@ #include +LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); + namespace Luau { +bool inConditional(const TypeContext& context) +{ + return context == TypeContext::Condition; +} + +bool occursCheck(TypeId needle, TypeId haystack) +{ + LUAU_ASSERT(get(needle) || get(needle)); + haystack = follow(haystack); + + auto checkHaystack = [needle](TypeId haystack) + { + return occursCheck(needle, haystack); + }; + + if (needle == haystack) + return true; + else if (auto ut = get(haystack)) + return std::any_of(begin(ut), end(ut), checkHaystack); + else if (auto it = get(haystack)) + return std::any_of(begin(it), end(it), checkHaystack); + + return false; +} + +// FIXME: Property is quite large. +// +// Returning it on the stack like this isn't great. We'd like to just return a +// const Property*, but we mint a property of type any if the subject type is +// any. +std::optional findTableProperty(NotNull builtinTypes, ErrorVec& errors, TypeId ty, const std::string& name, Location location) +{ + if (get(ty)) + return Property::rw(ty); + + if (const TableType* tableType = getTableType(ty)) + { + const auto& it = tableType->props.find(name); + if (it != tableType->props.end()) + return it->second; + } + + std::optional mtIndex = findMetatableEntry(builtinTypes, errors, ty, "__index", location); + int count = 0; + while (mtIndex) + { + TypeId index = follow(*mtIndex); + + if (count >= 100) + return std::nullopt; + + ++count; + + if (const auto& itt = getTableType(index)) + { + const auto& fit = itt->props.find(name); + if (fit != itt->props.end()) + return fit->second.type(); + } + else if (const auto& itf = get(index)) + { + std::optional r = first(follow(itf->retTypes)); + if (!r) + return builtinTypes->nilType; + else + return *r; + } + else if (get(index)) + return builtinTypes->anyType; + else + errors.push_back(TypeError{location, GenericError{"__index should either be a function or table. Got " + toString(index)}}); + + mtIndex = findMetatableEntry(builtinTypes, errors, *mtIndex, "__index", location); + } + + return std::nullopt; +} + std::optional findMetatableEntry( - NotNull builtinTypes, ErrorVec& errors, TypeId type, const std::string& entry, Location location) + NotNull builtinTypes, + ErrorVec& errors, + TypeId type, + const std::string& entry, + Location location +) { type = follow(type); @@ -40,7 +126,24 @@ std::optional findMetatableEntry( } std::optional findTablePropertyRespectingMeta( - NotNull builtinTypes, ErrorVec& errors, TypeId ty, const std::string& name, Location location) + NotNull builtinTypes, + ErrorVec& errors, + TypeId ty, + const std::string& name, + Location location +) +{ + return findTablePropertyRespectingMeta(builtinTypes, errors, ty, name, ValueContext::RValue, location); +} + +std::optional findTablePropertyRespectingMeta( + NotNull builtinTypes, + ErrorVec& errors, + TypeId ty, + const std::string& name, + ValueContext context, + Location location +) { if (get(ty)) return ty; @@ -49,7 +152,20 @@ std::optional findTablePropertyRespectingMeta( { const auto& it = tableType->props.find(name); if (it != tableType->props.end()) - return it->second.type(); + { + if (FFlag::DebugLuauDeferredConstraintResolution) + { + switch (context) + { + case ValueContext::RValue: + return it->second.readTy; + case ValueContext::LValue: + return it->second.writeTy; + } + } + else + return it->second.type(); + } } std::optional mtIndex = findMetatableEntry(builtinTypes, errors, ty, "__index", location); @@ -118,7 +234,12 @@ std::pair> getParameterExtents(const TxnLog* log, } TypePack extendTypePack( - TypeArena& arena, NotNull builtinTypes, TypePackId pack, size_t length, std::vector> overrides) + TypeArena& arena, + NotNull builtinTypes, + TypePackId pack, + size_t length, + std::vector> overrides +) { TypePack result; @@ -190,7 +311,13 @@ TypePack extendTypePack( } else { - t = arena.freshType(ftp->scope); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + FreeType ft{ftp->scope, builtinTypes->neverType, builtinTypes->unknownType}; + t = arena.addType(ft); + } + else + t = arena.freshType(ftp->scope); } newPack.head.push_back(t); @@ -295,4 +422,59 @@ TypeId stripNil(NotNull builtinTypes, TypeArena& arena, TypeId ty) return follow(ty); } +ErrorSuppression shouldSuppressErrors(NotNull normalizer, TypeId ty) +{ + LUAU_ASSERT(FFlag::DebugLuauDeferredConstraintResolution); + std::shared_ptr normType = normalizer->normalize(ty); + + if (!normType) + return ErrorSuppression::NormalizationFailed; + + return (normType->shouldSuppressErrors()) ? ErrorSuppression::Suppress : ErrorSuppression::DoNotSuppress; +} + +ErrorSuppression shouldSuppressErrors(NotNull normalizer, TypePackId tp) +{ + auto [tys, tail] = flatten(tp); + + // check the head, one type at a time + for (TypeId ty : tys) + { + auto result = shouldSuppressErrors(normalizer, ty); + if (result != ErrorSuppression::DoNotSuppress) + return result; + } + + // check the tail if we have one and it's finite + if (tail && tp != tail && finite(*tail)) + return shouldSuppressErrors(normalizer, *tail); + + return ErrorSuppression::DoNotSuppress; +} + +// This is a useful helper because it is often the case that we are looking at specifically a pair of types that might suppress. +ErrorSuppression shouldSuppressErrors(NotNull normalizer, TypeId ty1, TypeId ty2) +{ + auto result = shouldSuppressErrors(normalizer, ty1); + + // if ty1 is do not suppress, ty2 determines our overall behavior + if (result == ErrorSuppression::DoNotSuppress) + return shouldSuppressErrors(normalizer, ty2); + + // otherwise, ty1 is either suppress or normalization failure which are both the appropriate overarching result + return result; +} + +ErrorSuppression shouldSuppressErrors(NotNull normalizer, TypePackId tp1, TypePackId tp2) +{ + auto result = shouldSuppressErrors(normalizer, tp1); + + // if tp1 is do not suppress, tp2 determines our overall behavior + if (result == ErrorSuppression::DoNotSuppress) + return shouldSuppressErrors(normalizer, tp2); + + // otherwise, tp1 is either suppress or normalization failure which are both the appropriate overarching result + return result; +} + } // namespace Luau diff --git a/third_party/luau/Analysis/src/Unifier.cpp b/third_party/luau/Analysis/src/Unifier.cpp index 3ca93591..3dc66d1d 100644 --- a/third_party/luau/Analysis/src/Unifier.cpp +++ b/third_party/luau/Analysis/src/Unifier.cpp @@ -18,16 +18,11 @@ LUAU_FASTINT(LuauTypeInferTypePackLoopLimit) LUAU_FASTFLAG(LuauErrorRecoveryType) LUAU_FASTFLAGVARIABLE(LuauInstantiateInSubtyping, false) -LUAU_FASTFLAGVARIABLE(LuauUninhabitedSubAnything2, false) -LUAU_FASTFLAGVARIABLE(LuauVariadicAnyCanBeGeneric, false) -LUAU_FASTFLAGVARIABLE(LuauMaintainScopesInUnifier, false) LUAU_FASTFLAGVARIABLE(LuauTransitiveSubtyping, false) -LUAU_FASTFLAGVARIABLE(LuauOccursIsntAlwaysFailure, false) -LUAU_FASTFLAG(LuauClassTypeVarsInSubstitution) LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) -LUAU_FASTFLAG(LuauNormalizeBlockedTypes) -LUAU_FASTFLAG(LuauNegatedClassTypes) -LUAU_FASTFLAG(LuauNegatedTableTypes) +LUAU_FASTFLAGVARIABLE(LuauFixIndexerSubtypingOrdering, false) +LUAU_FASTFLAGVARIABLE(LuauUnifierShouldNotCopyError, false) +LUAU_FASTFLAGVARIABLE(LuauUnifierRecursionOnRestart, false) namespace Luau { @@ -53,7 +48,7 @@ struct PromoteTypeLevels final : TypeOnceVisitor template void promote(TID ty, T* t) { - if (FFlag::DebugLuauDeferredConstraintResolution && !t) + if (useScopes && !t) return; LUAU_ASSERT(t); @@ -315,7 +310,7 @@ TypePackId Widen::clean(TypePackId) bool Widen::ignoreChildren(TypeId ty) { - if (FFlag::LuauClassTypeVarsInSubstitution && get(ty)) + if (get(ty)) return true; return !log->is(ty); @@ -333,7 +328,8 @@ TypePackId Widen::operator()(TypePackId tp) std::optional hasUnificationTooComplex(const ErrorVec& errors) { - auto isUnificationTooComplex = [](const TypeError& te) { + auto isUnificationTooComplex = [](const TypeError& te) + { return nullptr != get(te); }; @@ -344,6 +340,20 @@ std::optional hasUnificationTooComplex(const ErrorVec& errors) return *it; } +std::optional hasCountMismatch(const ErrorVec& errors) +{ + auto isCountMismatch = [](const TypeError& te) + { + return nullptr != get(te); + }; + + auto it = std::find_if(errors.begin(), errors.end(), isCountMismatch); + if (it == errors.end()) + return std::nullopt; + else + return *it; +} + // Used for tagged union matching heuristic, returns first singleton type field static std::optional> getTableMatchTag(TypeId type) { @@ -359,7 +369,6 @@ static std::optional> getTableMatchT return std::nullopt; } -// TODO: Inline and clip with FFlag::DebugLuauDeferredConstraintResolution template static bool subsumes(bool useScopes, TY_A* left, TY_B* right) { @@ -383,11 +392,10 @@ TypeMismatch::Context Unifier::mismatchContext() } } -Unifier::Unifier(NotNull normalizer, Mode mode, NotNull scope, const Location& location, Variance variance, TxnLog* parentLog) +Unifier::Unifier(NotNull normalizer, NotNull scope, const Location& location, Variance variance, TxnLog* parentLog) : types(normalizer->arena) , builtinTypes(normalizer->builtinTypes) , normalizer(normalizer) - , mode(mode) , scope(scope) , log(parentLog) , location(location) @@ -395,13 +403,16 @@ Unifier::Unifier(NotNull normalizer, Mode mode, NotNull scope , sharedState(*normalizer->sharedState) { LUAU_ASSERT(sharedState.iceHandler); + + // Unifier is not usable when this flag is enabled! Please consider using Subtyping instead. + LUAU_ASSERT(!FFlag::DebugLuauDeferredConstraintResolution); } -void Unifier::tryUnify(TypeId subTy, TypeId superTy, bool isFunctionCall, bool isIntersection) +void Unifier::tryUnify(TypeId subTy, TypeId superTy, bool isFunctionCall, bool isIntersection, const LiteralProperties* literalProperties) { sharedState.counters.iterationCount = 0; - tryUnify_(subTy, superTy, isFunctionCall, isIntersection); + tryUnify_(subTy, superTy, isFunctionCall, isIntersection, literalProperties); } static bool isBlocked(const TxnLog& log, TypeId ty) @@ -410,7 +421,13 @@ static bool isBlocked(const TxnLog& log, TypeId ty) return get(ty) || get(ty); } -void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool isIntersection) +static bool isBlocked(const TxnLog& log, TypePackId tp) +{ + tp = log.follow(tp); + return get(tp); +} + +void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool isIntersection, const LiteralProperties* literalProperties) { RecursionLimiter _ra(&sharedState.counters.recursionCount, sharedState.counters.recursionLimit); @@ -428,10 +445,26 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool if (superTy == subTy) return; + if (isBlocked(log, subTy) && isBlocked(log, superTy)) + { + blockedTypes.push_back(subTy); + blockedTypes.push_back(superTy); + } + else if (isBlocked(log, subTy)) + blockedTypes.push_back(subTy); + else if (isBlocked(log, superTy)) + blockedTypes.push_back(superTy); + + if (log.get(superTy)) + ice("Unexpected TypeFunctionInstanceType superTy"); + + if (log.get(subTy)) + ice("Unexpected TypeFunctionInstanceType subTy"); + auto superFree = log.getMutable(superTy); auto subFree = log.getMutable(subTy); - if (superFree && subFree && subsumes(useScopes, superFree, subFree)) + if (superFree && subFree && subsumes(useNewSolver, superFree, subFree)) { if (!occursCheck(subTy, superTy, /* reversed = */ false)) log.replace(subTy, BoundType(superTy)); @@ -442,7 +475,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool { if (!occursCheck(superTy, subTy, /* reversed = */ true)) { - if (subsumes(useScopes, superFree, subFree)) + if (subsumes(useNewSolver, superFree, subFree)) { log.changeLevel(subTy, superFree->level); } @@ -456,7 +489,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool { // Unification can't change the level of a generic. auto subGeneric = log.getMutable(subTy); - if (subGeneric && !subsumes(useScopes, subGeneric, superFree)) + if (subGeneric && !subsumes(useNewSolver, subGeneric, superFree)) { // TODO: a more informative error message? CLI-39912 reportError(location, GenericError{"Generic subtype escaping scope"}); @@ -465,7 +498,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool if (!occursCheck(superTy, subTy, /* reversed = */ true)) { - promoteTypeLevels(log, types, superFree->level, superFree->scope, useScopes, subTy); + promoteTypeLevels(log, types, superFree->level, superFree->scope, useNewSolver, subTy); Widen widen{types, builtinTypes}; log.replace(superTy, BoundType(widen(subTy))); @@ -482,7 +515,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool // Unification can't change the level of a generic. auto superGeneric = log.getMutable(superTy); - if (superGeneric && !subsumes(useScopes, superGeneric, subFree)) + if (superGeneric && !subsumes(useNewSolver, superGeneric, subFree)) { // TODO: a more informative error message? CLI-39912 reportError(location, GenericError{"Generic supertype escaping scope"}); @@ -491,28 +524,69 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool if (!occursCheck(subTy, superTy, /* reversed = */ false)) { - promoteTypeLevels(log, types, subFree->level, subFree->scope, useScopes, superTy); + promoteTypeLevels(log, types, subFree->level, subFree->scope, useNewSolver, superTy); log.replace(subTy, BoundType(superTy)); } return; } - if (log.get(superTy)) - return tryUnifyWithAny(subTy, builtinTypes->anyType); + if (hideousFixMeGenericsAreActuallyFree) + { + auto superGeneric = log.getMutable(superTy); + auto subGeneric = log.getMutable(subTy); - if (!FFlag::LuauTransitiveSubtyping && log.get(superTy)) - return tryUnifyWithAny(subTy, builtinTypes->errorType); + if (superGeneric && subGeneric && subsumes(useNewSolver, superGeneric, subGeneric)) + { + if (!occursCheck(subTy, superTy, /* reversed = */ false)) + log.replace(subTy, BoundType(superTy)); - if (!FFlag::LuauTransitiveSubtyping && log.get(superTy)) - return tryUnifyWithAny(subTy, builtinTypes->unknownType); + return; + } + else if (superGeneric && subGeneric) + { + if (!occursCheck(superTy, subTy, /* reversed = */ true)) + log.replace(superTy, BoundType(subTy)); + + return; + } + else if (superGeneric) + { + if (!occursCheck(superTy, subTy, /* reversed = */ true)) + { + Widen widen{types, builtinTypes}; + log.replace(superTy, BoundType(widen(subTy))); + } + + return; + } + else if (subGeneric) + { + // Normally, if the subtype is free, it should not be bound to any, unknown, or error types. + // But for bug compatibility, we'll only apply this rule to unknown. Doing this will silence cascading type errors. + if (log.get(superTy)) + return; + + if (!occursCheck(subTy, superTy, /* reversed = */ false)) + log.replace(subTy, BoundType(superTy)); + + return; + } + } + + if (log.get(superTy)) + return tryUnifyWithAny(subTy, builtinTypes->anyType); if (log.get(subTy)) { - if (FFlag::LuauTransitiveSubtyping && normalize) + if (normalize) { // TODO: there are probably cheaper ways to check if any <: T. - const NormalizedType* superNorm = normalizer->normalize(superTy); + std::shared_ptr superNorm = normalizer->normalize(superTy); + + if (!superNorm) + return reportError(location, NormalizationTooComplex{}); + if (!log.get(superNorm->tops)) failure = true; } @@ -521,9 +595,6 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool return tryUnifyWithAny(superTy, builtinTypes->anyType); } - if (!FFlag::LuauTransitiveSubtyping && log.get(subTy)) - return tryUnifyWithAny(superTy, builtinTypes->errorType); - if (log.get(subTy)) return tryUnifyWithAny(superTy, builtinTypes->neverType); @@ -555,16 +626,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool size_t errorCount = errors.size(); - if (isBlocked(log, subTy) && isBlocked(log, superTy)) - { - blockedTypes.push_back(subTy); - blockedTypes.push_back(superTy); - } - else if (isBlocked(log, subTy)) - blockedTypes.push_back(subTy); - else if (isBlocked(log, superTy)) - blockedTypes.push_back(superTy); - else if (const UnionType* subUnion = log.getMutable(subTy)) + if (const UnionType* subUnion = log.getMutable(subTy)) { tryUnifyUnionWithType(subTy, subUnion, superTy); } @@ -580,32 +642,32 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool { tryUnifyIntersectionWithType(subTy, uv, superTy, cacheEnabled, isFunctionCall); } - else if (FFlag::LuauTransitiveSubtyping && log.get(subTy)) + else if (log.get(subTy)) { tryUnifyWithAny(superTy, builtinTypes->unknownType); failure = true; } - else if (FFlag::LuauTransitiveSubtyping && log.get(subTy) && log.get(superTy)) + else if (log.get(subTy) && log.get(superTy)) { // error <: error } - else if (FFlag::LuauTransitiveSubtyping && log.get(superTy)) + else if (log.get(superTy)) { tryUnifyWithAny(subTy, builtinTypes->errorType); failure = true; } - else if (FFlag::LuauTransitiveSubtyping && log.get(subTy)) + else if (log.get(subTy)) { tryUnifyWithAny(superTy, builtinTypes->errorType); failure = true; } - else if (FFlag::LuauTransitiveSubtyping && log.get(superTy)) + else if (log.get(superTy)) { // At this point, all the supertypes of `error` have been handled, // and if `error unknownType); } - else if (FFlag::LuauTransitiveSubtyping && log.get(superTy)) + else if (log.get(superTy)) { tryUnifyWithAny(subTy, builtinTypes->unknownType); } @@ -620,7 +682,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool // Ok. Do nothing. forall functions F, F <: function } - else if (FFlag::LuauNegatedTableTypes && isPrim(superTy, PrimitiveType::Table) && (get(subTy) || get(subTy))) + else if (isPrim(superTy, PrimitiveType::Table) && (get(subTy) || get(subTy))) { // Ok, do nothing: forall tables T, T <: table } @@ -635,7 +697,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool else if (log.getMutable(superTy) && log.getMutable(subTy)) { - tryUnifyTables(subTy, superTy, isIntersection); + tryUnifyTables(subTy, superTy, isIntersection, literalProperties); } else if (log.get(superTy) && (log.get(subTy) || log.get(subTy))) { @@ -663,10 +725,10 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool else if (log.get(superTy) || log.get(subTy)) tryUnifyNegations(subTy, superTy); - else if (FFlag::LuauUninhabitedSubAnything2 && checkInhabited && !normalizer->isInhabited(subTy)) + // If the normalizer hits resource limits, we can't show it's uninhabited, so, we should error. + else if (checkInhabited && normalizer->isInhabited(subTy) == NormalizationResult::False) { } - else reportError(location, TypeMismatch{superTy, subTy, mismatchContext()}); @@ -691,15 +753,15 @@ void Unifier::tryUnifyUnionWithType(TypeId subTy, const UnionType* subUnion, Typ Unifier innerState = makeChildUnifier(); innerState.tryUnify_(type, superTy); - if (FFlag::DebugLuauDeferredConstraintResolution) + if (useNewSolver) logs.push_back(std::move(innerState.log)); if (auto e = hasUnificationTooComplex(innerState.errors)) unificationTooComplex = e; - else if (FFlag::LuauTransitiveSubtyping ? innerState.failure : !innerState.errors.empty()) + else if (innerState.failure) { // If errors were suppressed, we store the log up, so we can commit it if no other option succeeds. - if (FFlag::LuauTransitiveSubtyping && innerState.errors.empty()) + if (innerState.errors.empty()) logs.push_back(std::move(innerState.log)); // 'nil' option is skipped from extended report because we present the type in a special way - 'T?' else if (!firstFailedOption && !isNil(type)) @@ -710,47 +772,7 @@ void Unifier::tryUnifyUnionWithType(TypeId subTy, const UnionType* subUnion, Typ } } - if (FFlag::DebugLuauDeferredConstraintResolution) - log.concatAsUnion(combineLogsIntoUnion(std::move(logs)), NotNull{types}); - else - { - // even if A | B <: T fails, we want to bind some options of T with A | B iff A | B was a subtype of that option. - auto tryBind = [this, subTy](TypeId superOption) { - superOption = log.follow(superOption); - - // just skip if the superOption is not free-ish. - auto ttv = log.getMutable(superOption); - if (!log.is(superOption) && (!ttv || ttv->state != TableState::Free)) - return; - - // If superOption is already present in subTy, do nothing. Nothing new has been learned, but the subtype - // test is successful. - if (auto subUnion = get(subTy)) - { - if (end(subUnion) != std::find(begin(subUnion), end(subUnion), superOption)) - return; - } - - // Since we have already checked if S <: T, checking it again will not queue up the type for replacement. - // So we'll have to do it ourselves. We assume they unified cleanly if they are still in the seen set. - if (log.haveSeen(subTy, superOption)) - { - // TODO: would it be nice for TxnLog::replace to do this? - if (log.is(superOption)) - log.bindTable(superOption, subTy); - else - log.replace(superOption, *subTy); - } - }; - - if (auto superUnion = log.getMutable(superTy)) - { - for (TypeId ty : superUnion) - tryBind(ty); - } - else - tryBind(superTy); - } + log.concatAsUnion(combineLogsIntoUnion(std::move(logs)), NotNull{types}); if (unificationTooComplex) reportError(*unificationTooComplex); @@ -758,38 +780,12 @@ void Unifier::tryUnifyUnionWithType(TypeId subTy, const UnionType* subUnion, Typ { if (firstFailedOption) reportError(location, TypeMismatch{superTy, subTy, "Not all union options are compatible.", *firstFailedOption, mismatchContext()}); - else if (!FFlag::LuauTransitiveSubtyping || !errorsSuppressed) + else if (!errorsSuppressed) reportError(location, TypeMismatch{superTy, subTy, mismatchContext()}); failure = true; } } -struct DEPRECATED_BlockedTypeFinder : TypeOnceVisitor -{ - std::unordered_set blockedTypes; - - bool visit(TypeId ty, const BlockedType&) override - { - blockedTypes.insert(ty); - return true; - } -}; - -bool Unifier::DEPRECATED_blockOnBlockedTypes(TypeId subTy, TypeId superTy) -{ - LUAU_ASSERT(!FFlag::LuauNormalizeBlockedTypes); - DEPRECATED_BlockedTypeFinder blockedTypeFinder; - blockedTypeFinder.traverse(subTy); - blockedTypeFinder.traverse(superTy); - if (!blockedTypeFinder.blockedTypes.empty()) - { - blockedTypes.insert(end(blockedTypes), begin(blockedTypeFinder.blockedTypes), end(blockedTypeFinder.blockedTypes)); - return true; - } - - return false; -} - void Unifier::tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionType* uv, bool cacheEnabled, bool isFunctionCall) { // T <: A | B if T <: A or T <: B @@ -831,7 +827,7 @@ void Unifier::tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionTyp } } - if (FFlag::LuauTransitiveSubtyping && !foundHeuristic) + if (!foundHeuristic) { for (size_t i = 0; i < uv->options.size(); ++i) { @@ -871,10 +867,10 @@ void Unifier::tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionTyp innerState.normalize = false; innerState.tryUnify_(subTy, type, isFunctionCall); - if (FFlag::LuauTransitiveSubtyping ? !innerState.failure : innerState.errors.empty()) + if (!innerState.failure) { found = true; - if (FFlag::DebugLuauDeferredConstraintResolution) + if (useNewSolver) logs.push_back(std::move(innerState.log)); else { @@ -882,7 +878,7 @@ void Unifier::tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionTyp break; } } - else if (FFlag::LuauTransitiveSubtyping && innerState.errors.empty()) + else if (innerState.errors.empty()) { errorsSuppressed = true; } @@ -899,28 +895,31 @@ void Unifier::tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionTyp } } - if (FFlag::DebugLuauDeferredConstraintResolution) + if (useNewSolver) log.concatAsUnion(combineLogsIntoUnion(std::move(logs)), NotNull{types}); if (unificationTooComplex) { reportError(*unificationTooComplex); } - else if (FFlag::LuauTransitiveSubtyping && !found && normalize) + else if (!found && normalize) { // It is possible that T <: A | B even though T normalize(subTy); - const NormalizedType* superNorm = normalizer->normalize(superTy); Unifier innerState = makeChildUnifier(); + + std::shared_ptr subNorm = normalizer->normalize(subTy); + std::shared_ptr superNorm = normalizer->normalize(superTy); if (!subNorm || !superNorm) - return reportError(location, UnificationTooComplex{}); + return reportError(location, NormalizationTooComplex{}); else if ((failedOptionCount == 1 || foundHeuristic) && failedOption) innerState.tryUnifyNormalizedTypes( - subTy, superTy, *subNorm, *superNorm, "None of the union options are compatible. For example:", *failedOption); + subTy, superTy, *subNorm, *superNorm, "None of the union options are compatible. For example:", *failedOption + ); else innerState.tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "none of the union options are compatible"); + if (!innerState.failure) log.concat(std::move(innerState.log)); else if (errorsSuppressed || innerState.errors.empty()) @@ -930,18 +929,13 @@ void Unifier::tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionTyp } else if (!found && normalize) { - // We cannot normalize a type that contains blocked types. We have to - // stop for now if we find any. - if (!FFlag::LuauNormalizeBlockedTypes && DEPRECATED_blockOnBlockedTypes(subTy, superTy)) - return; - // It is possible that T <: A | B even though T normalize(subTy); - const NormalizedType* superNorm = normalizer->normalize(superTy); + std::shared_ptr subNorm = normalizer->normalize(subTy); + std::shared_ptr superNorm = normalizer->normalize(superTy); if (!subNorm || !superNorm) - reportError(location, UnificationTooComplex{}); + reportError(location, NormalizationTooComplex{}); else if ((failedOptionCount == 1 || foundHeuristic) && failedOption) tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "None of the union options are compatible. For example:", *failedOption); else @@ -949,11 +943,12 @@ void Unifier::tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionTyp } else if (!found) { - if (FFlag::LuauTransitiveSubtyping && errorsSuppressed) + if (errorsSuppressed) failure = true; else if ((failedOptionCount == 1 || foundHeuristic) && failedOption) reportError( - location, TypeMismatch{superTy, subTy, "None of the union options are compatible. For example:", *failedOption, mismatchContext()}); + location, TypeMismatch{superTy, subTy, "None of the union options are compatible. For example:", *failedOption, mismatchContext()} + ); else reportError(location, TypeMismatch{superTy, subTy, "none of the union options are compatible", mismatchContext()}); } @@ -980,14 +975,14 @@ void Unifier::tryUnifyTypeWithIntersection(TypeId subTy, TypeId superTy, const I firstFailedOption = {innerState.errors.front()}; } - if (FFlag::DebugLuauDeferredConstraintResolution) + if (useNewSolver) logs.push_back(std::move(innerState.log)); else log.concat(std::move(innerState.log)); failure |= innerState.failure; } - if (FFlag::DebugLuauDeferredConstraintResolution) + if (useNewSolver) log.concat(combineLogsIntoIntersection(std::move(logs))); if (unificationTooComplex) @@ -1037,13 +1032,8 @@ void Unifier::tryUnifyIntersectionWithType(TypeId subTy, const IntersectionType* } } - if (FFlag::DebugLuauDeferredConstraintResolution && normalize) + if (useNewSolver && normalize) { - // We cannot normalize a type that contains blocked types. We have to - // stop for now if we find any. - if (!FFlag::LuauNormalizeBlockedTypes && DEPRECATED_blockOnBlockedTypes(subTy, superTy)) - return; - // Sometimes a negation type is inside one of the types, e.g. { p: number } & { p: ~number }. NegationTypeFinder finder; finder.traverse(subTy); @@ -1053,12 +1043,12 @@ void Unifier::tryUnifyIntersectionWithType(TypeId subTy, const IntersectionType* // It is possible that A & B <: T even though A normalize(subTy); - const NormalizedType* superNorm = normalizer->normalize(superTy); + std::shared_ptr subNorm = normalizer->normalize(subTy); + std::shared_ptr superNorm = normalizer->normalize(superTy); if (subNorm && superNorm) tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "none of the intersection parts are compatible"); else - reportError(location, UnificationTooComplex{}); + reportError(location, NormalizationTooComplex{}); return; } @@ -1080,7 +1070,7 @@ void Unifier::tryUnifyIntersectionWithType(TypeId subTy, const IntersectionType* { found = true; errorsSuppressed = innerState.failure; - if (FFlag::DebugLuauDeferredConstraintResolution || (FFlag::LuauTransitiveSubtyping && innerState.failure)) + if (useNewSolver || innerState.failure) logs.push_back(std::move(innerState.log)); else { @@ -1095,29 +1085,25 @@ void Unifier::tryUnifyIntersectionWithType(TypeId subTy, const IntersectionType* } } - if (FFlag::DebugLuauDeferredConstraintResolution) + if (useNewSolver) log.concat(combineLogsIntoIntersection(std::move(logs))); - else if (FFlag::LuauTransitiveSubtyping && errorsSuppressed) + else if (errorsSuppressed) log.concat(std::move(logs.front())); if (unificationTooComplex) reportError(*unificationTooComplex); else if (!found && normalize) { - // We cannot normalize a type that contains blocked types. We have to - // stop for now if we find any. - if (!FFlag::LuauNormalizeBlockedTypes && DEPRECATED_blockOnBlockedTypes(subTy, superTy)) - return; - // It is possible that A & B <: T even though A normalize(subTy); - const NormalizedType* superNorm = normalizer->normalize(superTy); + + std::shared_ptr subNorm = normalizer->normalize(subTy); + std::shared_ptr superNorm = normalizer->normalize(superTy); if (subNorm && superNorm) tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "none of the intersection parts are compatible"); else - reportError(location, UnificationTooComplex{}); + reportError(location, NormalizationTooComplex{}); } else if (!found) { @@ -1128,33 +1114,33 @@ void Unifier::tryUnifyIntersectionWithType(TypeId subTy, const IntersectionType* } void Unifier::tryUnifyNormalizedTypes( - TypeId subTy, TypeId superTy, const NormalizedType& subNorm, const NormalizedType& superNorm, std::string reason, std::optional error) -{ - if (!FFlag::LuauTransitiveSubtyping && get(superNorm.tops)) - return; - else if (get(superNorm.tops)) + TypeId subTy, + TypeId superTy, + const NormalizedType& subNorm, + const NormalizedType& superNorm, + std::string reason, + std::optional error +) +{ + if (get(superNorm.tops)) return; else if (get(subNorm.tops)) { failure = true; return; } - else if (!FFlag::LuauTransitiveSubtyping && get(subNorm.tops)) - return reportError(location, TypeMismatch{superTy, subTy, reason, error, mismatchContext()}); if (get(subNorm.errors)) if (!get(superNorm.errors)) { failure = true; - if (!FFlag::LuauTransitiveSubtyping) - reportError(location, TypeMismatch{superTy, subTy, reason, error, mismatchContext()}); return; } - if (FFlag::LuauTransitiveSubtyping && get(superNorm.tops)) + if (get(superNorm.tops)) return; - if (FFlag::LuauTransitiveSubtyping && get(subNorm.tops)) + if (get(subNorm.tops)) return reportError(location, TypeMismatch{superTy, subTy, reason, error, mismatchContext()}); if (get(subNorm.booleans)) @@ -1183,81 +1169,59 @@ void Unifier::tryUnifyNormalizedTypes( if (!get(superNorm.errors)) return reportError(location, TypeMismatch{superTy, subTy, reason, error, mismatchContext()}); - if (FFlag::LuauNegatedClassTypes) + for (const auto& [subClass, _] : subNorm.classes.classes) { - for (const auto& [subClass, _] : subNorm.classes.classes) + bool found = false; + const ClassType* subCtv = get(subClass); + LUAU_ASSERT(subCtv); + + for (const auto& [superClass, superNegations] : superNorm.classes.classes) { - bool found = false; - const ClassType* subCtv = get(subClass); - LUAU_ASSERT(subCtv); + const ClassType* superCtv = get(superClass); + LUAU_ASSERT(superCtv); - for (const auto& [superClass, superNegations] : superNorm.classes.classes) + if (isSubclass(subCtv, superCtv)) { - const ClassType* superCtv = get(superClass); - LUAU_ASSERT(superCtv); - - if (isSubclass(subCtv, superCtv)) - { - found = true; - - for (TypeId negation : superNegations) - { - const ClassType* negationCtv = get(negation); - LUAU_ASSERT(negationCtv); - - if (isSubclass(subCtv, negationCtv)) - { - found = false; - break; - } - } - - if (found) - break; - } - } + found = true; - if (FFlag::DebugLuauDeferredConstraintResolution) - { - for (TypeId superTable : superNorm.tables) + for (TypeId negation : superNegations) { - Unifier innerState = makeChildUnifier(); - innerState.tryUnify(subClass, superTable); + const ClassType* negationCtv = get(negation); + LUAU_ASSERT(negationCtv); - if (innerState.errors.empty()) + if (isSubclass(subCtv, negationCtv)) { - found = true; - log.concat(std::move(innerState.log)); + found = false; break; } - else if (auto e = hasUnificationTooComplex(innerState.errors)) - return reportError(*e); } - } - if (!found) - { - return reportError(location, TypeMismatch{superTy, subTy, reason, error, mismatchContext()}); + if (found) + break; } } - } - else - { - for (TypeId subClass : subNorm.DEPRECATED_classes) + + if (useNewSolver) { - bool found = false; - const ClassType* subCtv = get(subClass); - for (TypeId superClass : superNorm.DEPRECATED_classes) + for (TypeId superTable : superNorm.tables) { - const ClassType* superCtv = get(superClass); - if (isSubclass(subCtv, superCtv)) + Unifier innerState = makeChildUnifier(); + innerState.tryUnify(subClass, superTable); + + if (innerState.errors.empty()) { found = true; + log.concat(std::move(innerState.log)); break; } + else if (auto e = hasUnificationTooComplex(innerState.errors)) + return reportError(*e); } - if (!found) - return reportError(location, TypeMismatch{superTy, subTy, reason, error, mismatchContext()}); + } + + if (!found) + { + return reportError(location, TypeMismatch{superTy, subTy, reason, error, mismatchContext()}); } } @@ -1266,7 +1230,7 @@ void Unifier::tryUnifyNormalizedTypes( bool found = false; for (TypeId superTable : superNorm.tables) { - if (FFlag::LuauNegatedTableTypes && isPrim(superTable, PrimitiveType::Table)) + if (isPrim(superTable, PrimitiveType::Table)) { found = true; break; @@ -1399,7 +1363,8 @@ bool Unifier::canCacheResult(TypeId subTy, TypeId superTy) if (subTyInfo && *subTyInfo) return false; - auto skipCacheFor = [this](TypeId ty) { + auto skipCacheFor = [this](TypeId ty) + { SkipCacheForType visitor{sharedState.skipCacheForType, types}; visitor.traverse(ty); @@ -1468,6 +1433,15 @@ struct WeirdIter return pack != nullptr && index < pack->head.size(); } + std::optional tail() const + { + if (!pack) + return packId; + + LUAU_ASSERT(index == pack->head.size()); + return pack->tail; + } + bool advance() { if (!pack) @@ -1502,7 +1476,7 @@ struct WeirdIter auto freePack = log.getMutable(packId); level = freePack->level; - if (FFlag::LuauMaintainScopesInUnifier && freePack->scope != nullptr) + if (freePack->scope != nullptr) scope = freePack->scope; log.replace(packId, BoundTypePack(newTail)); packId = newTail; @@ -1529,6 +1503,12 @@ struct WeirdIter } }; +void Unifier::enableNewSolver() +{ + useNewSolver = true; + log.useScopes = true; +} + ErrorVec Unifier::canUnify(TypeId subTy, TypeId superTy) { Unifier s = makeChildUnifier(); @@ -1593,15 +1573,44 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal if (log.haveSeen(superTp, subTp)) return; - if (log.getMutable(superTp)) + if (isBlocked(log, subTp) && isBlocked(log, superTp)) + { + blockedTypePacks.push_back(subTp); + blockedTypePacks.push_back(superTp); + } + else if (isBlocked(log, subTp)) + blockedTypePacks.push_back(subTp); + else if (isBlocked(log, superTp)) + blockedTypePacks.push_back(superTp); + + if (auto superFree = log.getMutable(superTp)) { if (!occursCheck(superTp, subTp, /* reversed = */ true)) { Widen widen{types, builtinTypes}; + if (useNewSolver) + promoteTypeLevels(log, types, superFree->level, superFree->scope, /*useScopes*/ true, subTp); log.replace(superTp, Unifiable::Bound(widen(subTp))); } } - else if (log.getMutable(subTp)) + else if (auto subFree = log.getMutable(subTp)) + { + if (!occursCheck(subTp, superTp, /* reversed = */ false)) + { + if (useNewSolver) + promoteTypeLevels(log, types, subFree->level, subFree->scope, /*useScopes*/ true, superTp); + log.replace(subTp, Unifiable::Bound(superTp)); + } + } + else if (hideousFixMeGenericsAreActuallyFree && log.getMutable(superTp)) + { + if (!occursCheck(superTp, subTp, /* reversed = */ true)) + { + Widen widen{types, builtinTypes}; + log.replace(superTp, Unifiable::Bound(widen(subTp))); + } + } + else if (hideousFixMeGenericsAreActuallyFree && log.getMutable(subTp)) { if (!occursCheck(subTp, superTp, /* reversed = */ false)) { @@ -1632,14 +1641,15 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal auto superIter = WeirdIter(superTp, log); auto subIter = WeirdIter(subTp, log); - if (FFlag::LuauMaintainScopesInUnifier) - { - superIter.scope = scope.get(); - subIter.scope = scope.get(); - } + superIter.scope = scope.get(); + subIter.scope = scope.get(); - auto mkFreshType = [this](Scope* scope, TypeLevel level) { - return types->freshType(scope, level); + auto mkFreshType = [this](Scope* scope, TypeLevel level) + { + if (FFlag::DebugLuauDeferredConstraintResolution) + return freshType(NotNull{types}, builtinTypes, scope); + else + return types->freshType(scope, level); }; const TypePackId emptyTp = types->addTypePack(TypePack{{}, std::nullopt}); @@ -1678,28 +1688,74 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal // If both are at the end, we're done if (!superIter.good() && !subIter.good()) { - const bool lFreeTail = superTpv->tail && log.getMutable(log.follow(*superTpv->tail)) != nullptr; - const bool rFreeTail = subTpv->tail && log.getMutable(log.follow(*subTpv->tail)) != nullptr; - if (lFreeTail && rFreeTail) - { - tryUnify_(*subTpv->tail, *superTpv->tail); - } - else if (lFreeTail) + if (useNewSolver) { - tryUnify_(emptyTp, *superTpv->tail); - } - else if (rFreeTail) - { - tryUnify_(emptyTp, *subTpv->tail); + if (subIter.tail() && superIter.tail()) + tryUnify_(*subIter.tail(), *superIter.tail()); + else if (subIter.tail()) + { + const TypePackId subTail = log.follow(*subIter.tail()); + + if (log.get(subTail)) + tryUnify_(subTail, emptyTp); + else if (log.get(subTail)) + reportError(location, TypePackMismatch{subTail, emptyTp}); + else if (log.get(subTail) || log.get(subTail)) + { + // Nothing. This is ok. + } + else + { + ice("Unexpected subtype tail pack " + toString(subTail), location); + } + } + else if (superIter.tail()) + { + const TypePackId superTail = log.follow(*superIter.tail()); + + if (log.get(superTail)) + tryUnify_(emptyTp, superTail); + else if (log.get(superTail)) + reportError(location, TypePackMismatch{emptyTp, superTail}); + else if (log.get(superTail) || log.get(superTail)) + { + // Nothing. This is ok. + } + else + { + ice("Unexpected supertype tail pack " + toString(superTail), location); + } + } + else + { + // Nothing. This is ok. + } } - else if (subTpv->tail && superTpv->tail) + else { - if (log.getMutable(superIter.packId)) - tryUnifyVariadics(subIter.packId, superIter.packId, false, int(subIter.index)); - else if (log.getMutable(subIter.packId)) - tryUnifyVariadics(superIter.packId, subIter.packId, true, int(superIter.index)); - else + const bool lFreeTail = superTpv->tail && log.getMutable(log.follow(*superTpv->tail)) != nullptr; + const bool rFreeTail = subTpv->tail && log.getMutable(log.follow(*subTpv->tail)) != nullptr; + if (lFreeTail && rFreeTail) + { tryUnify_(*subTpv->tail, *superTpv->tail); + } + else if (lFreeTail) + { + tryUnify_(emptyTp, *superTpv->tail); + } + else if (rFreeTail) + { + tryUnify_(emptyTp, *subTpv->tail); + } + else if (subTpv->tail && superTpv->tail) + { + if (log.getMutable(superIter.packId)) + tryUnifyVariadics(subIter.packId, superIter.packId, false, int(subIter.index)); + else if (log.getMutable(subIter.packId)) + tryUnifyVariadics(superIter.packId, subIter.packId, true, int(superIter.index)); + else + tryUnify_(*subTpv->tail, *superTpv->tail); + } } break; @@ -1829,7 +1885,7 @@ void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCal // generic methods in tables to be marked read-only. if (FFlag::LuauInstantiateInSubtyping && shouldInstantiate) { - Instantiation instantiation{&log, types, scope->level, scope}; + Instantiation instantiation{&log, types, builtinTypes, scope->level, scope}; std::optional instantiated = instantiation.substitute(subTy); if (instantiated.has_value()) @@ -1885,8 +1941,16 @@ void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCal if (auto e = hasUnificationTooComplex(innerState.errors)) reportError(*e); else if (!innerState.errors.empty() && innerState.firstPackErrorPos) - reportError(location, TypeMismatch{superTy, subTy, format("Argument #%d type is not compatible.", *innerState.firstPackErrorPos), - innerState.errors.front(), mismatchContext()}); + reportError( + location, + TypeMismatch{ + superTy, + subTy, + format("Argument #%d type is not compatible.", *innerState.firstPackErrorPos), + innerState.errors.front(), + mismatchContext() + } + ); else if (!innerState.errors.empty()) reportError(location, TypeMismatch{superTy, subTy, "", innerState.errors.front(), mismatchContext()}); @@ -1900,8 +1964,16 @@ void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCal else if (!innerState.errors.empty() && size(superFunction->retTypes) == 1 && finite(superFunction->retTypes)) reportError(location, TypeMismatch{superTy, subTy, "Return type is not compatible.", innerState.errors.front(), mismatchContext()}); else if (!innerState.errors.empty() && innerState.firstPackErrorPos) - reportError(location, TypeMismatch{superTy, subTy, format("Return #%d type is not compatible.", *innerState.firstPackErrorPos), - innerState.errors.front(), mismatchContext()}); + reportError( + location, + TypeMismatch{ + superTy, + subTy, + format("Return #%d type is not compatible.", *innerState.firstPackErrorPos), + innerState.errors.front(), + mismatchContext() + } + ); else if (!innerState.errors.empty()) reportError(location, TypeMismatch{superTy, subTy, "", innerState.errors.front(), mismatchContext()}); } @@ -1956,7 +2028,7 @@ struct Resetter } // namespace -void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) +void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection, const LiteralProperties* literalProperties) { if (isPrim(log.follow(subTy), PrimitiveType::Table)) subTy = builtinTypes->emptyTableType; @@ -1978,7 +2050,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) { if (variance == Covariant && subTable->state == TableState::Generic && superTable->state != TableState::Generic) { - Instantiation instantiation{&log, types, subTable->level, scope}; + Instantiation instantiation{&log, types, builtinTypes, subTable->level, scope}; std::optional instantiated = instantiation.substitute(subTy); if (instantiated.has_value()) @@ -2041,7 +2113,8 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) { // TODO: read-only properties don't need invariance Resetter resetter{&variance}; - variance = Invariant; + if (!literalProperties || !literalProperties->contains(name)) + variance = Invariant; Unifier innerState = makeChildUnifier(); innerState.tryUnify_(r->second.type(), prop.type()); @@ -2057,7 +2130,8 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) // TODO: read-only indexers don't need invariance // TODO: really we should only allow this if prop.type is optional. Resetter resetter{&variance}; - variance = Invariant; + if (!literalProperties || !literalProperties->contains(name)) + variance = Invariant; Unifier innerState = makeChildUnifier(); innerState.tryUnify_(subTable->indexer->indexResultType, prop.type()); @@ -2093,7 +2167,18 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) // If one of the types stopped being a table altogether, we need to restart from the top if ((superTy != superTyNew || activeSubTy != subTyNew) && errors.empty()) - return tryUnify(subTy, superTy, false, isIntersection); + { + if (FFlag::LuauUnifierRecursionOnRestart) + { + RecursionLimiter _ra(&sharedState.counters.recursionCount, sharedState.counters.recursionLimit); + tryUnify(subTy, superTy, false, isIntersection); + return; + } + else + { + return tryUnify(subTy, superTy, false, isIntersection); + } + } // Otherwise, restart only the table unification TableType* newSuperTable = log.getMutable(superTyNew); @@ -2102,9 +2187,12 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) if (superTable != newSuperTable || subTable != newSubTable) { if (errors.empty()) - return tryUnifyTables(subTy, superTy, isIntersection); - else - return; + { + RecursionLimiter _ra(&sharedState.counters.recursionCount, sharedState.counters.recursionLimit); + tryUnifyTables(subTy, superTy, isIntersection); + } + + return; } } @@ -2120,10 +2208,17 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) // TODO: read-only indexers don't need invariance // TODO: really we should only allow this if prop.type is optional. Resetter resetter{&variance}; - variance = Invariant; + if (!literalProperties || !literalProperties->contains(name)) + variance = Invariant; Unifier innerState = makeChildUnifier(); - innerState.tryUnify_(superTable->indexer->indexResultType, prop.type()); + if (useNewSolver || FFlag::LuauFixIndexerSubtypingOrdering) + innerState.tryUnify_(prop.type(), superTable->indexer->indexResultType); + else + { + // Incredibly, the old solver depends on this bug somehow. + innerState.tryUnify_(superTable->indexer->indexResultType, prop.type()); + } checkChildUnifierTypeMismatch(innerState.errors, name, superTy, subTy); @@ -2162,7 +2257,18 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) // If one of the types stopped being a table altogether, we need to restart from the top if ((superTy != superTyNew || activeSubTy != subTyNew) && errors.empty()) - return tryUnify(subTy, superTy, false, isIntersection); + { + if (FFlag::LuauUnifierRecursionOnRestart) + { + RecursionLimiter _ra(&sharedState.counters.recursionCount, sharedState.counters.recursionLimit); + tryUnify(subTy, superTy, false, isIntersection); + return; + } + else + { + return tryUnify(subTy, superTy, false, isIntersection); + } + } // Recursive unification can change the txn log, and invalidate the old // table. If we detect that this has happened, we start over, with the updated @@ -2173,9 +2279,12 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) if (superTable != newSuperTable || subTable != newSubTable) { if (errors.empty()) - return tryUnifyTables(subTy, superTy, isIntersection); - else - return; + { + RecursionLimiter _ra(&sharedState.counters.recursionCount, sharedState.counters.recursionLimit); + tryUnifyTables(subTy, superTy, isIntersection); + } + + return; } } @@ -2268,7 +2377,8 @@ void Unifier::tryUnifyScalarShape(TypeId subTy, TypeId superTy, bool reversed) TypeId osubTy = subTy; TypeId osuperTy = superTy; - if (FFlag::LuauUninhabitedSubAnything2 && checkInhabited && !normalizer->isInhabited(subTy)) + // If the normalizer hits resource limits, we can't show it's uninhabited, so, we should continue. + if (checkInhabited && normalizer->isInhabited(subTy) == NormalizationResult::False) return; if (reversed) @@ -2279,7 +2389,8 @@ void Unifier::tryUnifyScalarShape(TypeId subTy, TypeId superTy, bool reversed) if (!superTable || superTable->state != TableState::Free) return reportError(location, TypeMismatch{osuperTy, osubTy, mismatchContext()}); - auto fail = [&](std::optional e) { + auto fail = [&](std::optional e) + { std::string reason = "The former's metatable does not satisfy the requirements."; if (e) reportError(location, TypeMismatch{osuperTy, osubTy, reason, *e, mismatchContext()}); @@ -2374,7 +2485,8 @@ void Unifier::tryUnifyWithMetatable(TypeId subTy, TypeId superTy, bool reversed) reportError(*e); else if (!innerState.errors.empty()) reportError( - location, TypeMismatch{reversed ? subTy : superTy, reversed ? superTy : subTy, "", innerState.errors.front(), mismatchContext()}); + location, TypeMismatch{reversed ? subTy : superTy, reversed ? superTy : subTy, "", innerState.errors.front(), mismatchContext()} + ); log.concat(std::move(innerState.log)); failure |= innerState.failure; @@ -2385,7 +2497,7 @@ void Unifier::tryUnifyWithMetatable(TypeId subTy, TypeId superTy, bool reversed) { case TableState::Free: { - if (FFlag::DebugLuauDeferredConstraintResolution) + if (useNewSolver) { Unifier innerState = makeChildUnifier(); bool missingProperty = false; @@ -2412,8 +2524,10 @@ void Unifier::tryUnifyWithMetatable(TypeId subTy, TypeId superTy, bool reversed) if (auto e = hasUnificationTooComplex(innerState.errors)) reportError(*e); else if (!innerState.errors.empty()) - reportError(TypeError{location, - TypeMismatch{reversed ? subTy : superTy, reversed ? superTy : subTy, "", innerState.errors.front(), mismatchContext()}}); + reportError(TypeError{ + location, + TypeMismatch{reversed ? subTy : superTy, reversed ? superTy : subTy, "", innerState.errors.front(), mismatchContext()} + }); else if (!missingProperty) { log.concat(std::move(innerState.log)); @@ -2451,7 +2565,8 @@ void Unifier::tryUnifyWithClass(TypeId subTy, TypeId superTy, bool reversed) if (reversed) std::swap(superTy, subTy); - auto fail = [&]() { + auto fail = [&]() + { if (!reversed) reportError(location, TypeMismatch{superTy, subTy, mismatchContext()}); else @@ -2542,15 +2657,10 @@ void Unifier::tryUnifyNegations(TypeId subTy, TypeId superTy) if (!log.get(subTy) && !log.get(superTy)) ice("tryUnifyNegations superTy or subTy must be a negation type"); - // We cannot normalize a type that contains blocked types. We have to - // stop for now if we find any. - if (!FFlag::LuauNormalizeBlockedTypes && DEPRECATED_blockOnBlockedTypes(subTy, superTy)) - return; - - const NormalizedType* subNorm = normalizer->normalize(subTy); - const NormalizedType* superNorm = normalizer->normalize(superTy); + std::shared_ptr subNorm = normalizer->normalize(subTy); + std::shared_ptr superNorm = normalizer->normalize(superTy); if (!subNorm || !superNorm) - return reportError(location, UnificationTooComplex{}); + return reportError(location, NormalizationTooComplex{}); // T maybeTail = subIter.tail()) { TypePackId tail = follow(*maybeTail); - if (get(tail)) + + if (isBlocked(log, tail)) + { + blockedTypePacks.push_back(tail); + } + else if (get(tail)) { log.replace(tail, BoundTypePack(superTp)); } @@ -2622,7 +2737,10 @@ void Unifier::tryUnifyVariadics(TypePackId subTp, TypePackId superTp, bool rever } else if (get(tail)) { - reportError(location, GenericError{"Cannot unify variadic and generic packs"}); + if (!hideousFixMeGenericsAreActuallyFree) + reportError(location, GenericError{"Cannot unify variadic and generic packs"}); + else + log.replace(tail, BoundTypePack{superTp}); } else if (get(tail)) { @@ -2634,7 +2752,7 @@ void Unifier::tryUnifyVariadics(TypePackId subTp, TypePackId superTp, bool rever } } } - else if (FFlag::LuauVariadicAnyCanBeGeneric && get(variadicTy) && log.get(subTp)) + else if (get(variadicTy) && log.get(subTp)) { // Nothing to do. This is ok. } @@ -2644,8 +2762,15 @@ void Unifier::tryUnifyVariadics(TypePackId subTp, TypePackId superTp, bool rever } } -static void tryUnifyWithAny(std::vector& queue, Unifier& state, DenseHashSet& seen, DenseHashSet& seenTypePacks, - const TypeArena* typeArena, TypeId anyType, TypePackId anyTypePack) +static void tryUnifyWithAny( + std::vector& queue, + Unifier& state, + DenseHashSet& seen, + DenseHashSet& seenTypePacks, + const TypeArena* typeArena, + TypeId anyType, + TypePackId anyTypePack +) { while (!queue.empty()) { @@ -2742,8 +2867,8 @@ std::optional Unifier::findTablePropertyRespectingMeta(TypeId lhsType, N TxnLog Unifier::combineLogsIntoIntersection(std::vector logs) { - LUAU_ASSERT(FFlag::DebugLuauDeferredConstraintResolution); - TxnLog result; + LUAU_ASSERT(useNewSolver); + TxnLog result(useNewSolver); for (TxnLog& log : logs) result.concatAsIntersections(std::move(log), NotNull{types}); return result; @@ -2751,8 +2876,7 @@ TxnLog Unifier::combineLogsIntoIntersection(std::vector logs) TxnLog Unifier::combineLogsIntoUnion(std::vector logs) { - LUAU_ASSERT(FFlag::DebugLuauDeferredConstraintResolution); - TxnLog result; + TxnLog result(useNewSolver); for (TxnLog& log : logs) result.concatAsUnion(std::move(log), NotNull{types}); return result; @@ -2764,7 +2888,7 @@ bool Unifier::occursCheck(TypeId needle, TypeId haystack, bool reversed) bool occurs = occursCheck(sharedState.tempSeenTy, needle, haystack); - if (occurs && FFlag::LuauOccursIsntAlwaysFailure) + if (occurs) { Unifier innerState = makeChildUnifier(); if (const UnionType* ut = get(haystack)) @@ -2789,7 +2913,7 @@ bool Unifier::occursCheck(TypeId needle, TypeId haystack, bool reversed) if (innerState.failure) { reportError(location, OccursCheckFailed{}); - log.replace(needle, *builtinTypes->errorRecoveryType()); + log.replace(needle, BoundType{builtinTypes->errorRecoveryType()}); } } @@ -2802,7 +2926,8 @@ bool Unifier::occursCheck(DenseHashSet& seen, TypeId needle, TypeId hays bool occurrence = false; - auto check = [&](TypeId tv) { + auto check = [&](TypeId tv) + { if (occursCheck(seen, needle, tv)) occurrence = true; }; @@ -2818,21 +2943,13 @@ bool Unifier::occursCheck(DenseHashSet& seen, TypeId needle, TypeId hays if (log.getMutable(needle)) return false; - if (!log.getMutable(needle)) + if (!log.getMutable(needle) && !(hideousFixMeGenericsAreActuallyFree && log.is(needle))) ice("Expected needle to be free"); if (needle == haystack) - { - if (!FFlag::LuauOccursIsntAlwaysFailure) - { - reportError(location, OccursCheckFailed{}); - log.replace(needle, *builtinTypes->errorRecoveryType()); - } - return true; - } - if (log.getMutable(haystack)) + if (log.getMutable(haystack) || (hideousFixMeGenericsAreActuallyFree && log.is(haystack))) return false; else if (auto a = log.getMutable(haystack)) { @@ -2854,10 +2971,13 @@ bool Unifier::occursCheck(TypePackId needle, TypePackId haystack, bool reversed) bool occurs = occursCheck(sharedState.tempSeenTp, needle, haystack); - if (occurs && FFlag::LuauOccursIsntAlwaysFailure) + if (occurs) { reportError(location, OccursCheckFailed{}); - log.replace(needle, *builtinTypes->errorRecoveryTypePack()); + if (FFlag::LuauUnifierShouldNotCopyError) + log.replace(needle, BoundTypePack{builtinTypes->errorRecoveryTypePack()}); + else + log.replace(needle, *builtinTypes->errorRecoveryTypePack()); } return occurs; @@ -2876,7 +2996,7 @@ bool Unifier::occursCheck(DenseHashSet& seen, TypePackId needle, Typ if (log.getMutable(needle)) return false; - if (!log.getMutable(needle)) + if (!log.getMutable(needle) && !(hideousFixMeGenericsAreActuallyFree && log.is(needle))) ice("Expected needle pack to be free"); RecursionLimiter _ra(&sharedState.counters.recursionCount, sharedState.counters.recursionLimit); @@ -2884,15 +3004,7 @@ bool Unifier::occursCheck(DenseHashSet& seen, TypePackId needle, Typ while (!log.getMutable(haystack)) { if (needle == haystack) - { - if (!FFlag::LuauOccursIsntAlwaysFailure) - { - reportError(location, OccursCheckFailed{}); - log.replace(needle, *builtinTypes->errorRecoveryTypePack()); - } - return true; - } if (auto a = get(haystack); a && a->tail) { @@ -2908,10 +3020,13 @@ bool Unifier::occursCheck(DenseHashSet& seen, TypePackId needle, Typ Unifier Unifier::makeChildUnifier() { - Unifier u = Unifier{normalizer, mode, scope, location, variance, &log}; + Unifier u = Unifier{normalizer, scope, location, variance, &log}; u.normalize = normalize; u.checkInhabited = checkInhabited; - u.useScopes = useScopes; + + if (useNewSolver) + u.enableNewSolver(); + return u; } @@ -2936,12 +3051,6 @@ void Unifier::reportError(TypeError err) failure = true; } - -bool Unifier::isNonstrictMode() const -{ - return (mode == Mode::Nonstrict) || (mode == Mode::NoCheck); -} - void Unifier::checkChildUnifierTypeMismatch(const ErrorVec& innerErrors, TypeId wantedType, TypeId givenType) { if (auto e = hasUnificationTooComplex(innerErrors)) @@ -2955,8 +3064,10 @@ void Unifier::checkChildUnifierTypeMismatch(const ErrorVec& innerErrors, const s if (auto e = hasUnificationTooComplex(innerErrors)) reportError(*e); else if (!innerErrors.empty()) - reportError(TypeError{location, - TypeMismatch{wantedType, givenType, format("Property '%s' is not compatible.", prop.c_str()), innerErrors.front(), mismatchContext()}}); + reportError(TypeError{ + location, + TypeMismatch{wantedType, givenType, format("Property '%s' is not compatible.", prop.c_str()), innerErrors.front(), mismatchContext()} + }); } void Unifier::ice(const std::string& message, const Location& location) diff --git a/third_party/luau/Analysis/src/Unifier2.cpp b/third_party/luau/Analysis/src/Unifier2.cpp new file mode 100644 index 00000000..5ea11ad0 --- /dev/null +++ b/third_party/luau/Analysis/src/Unifier2.cpp @@ -0,0 +1,928 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/Unifier2.h" + +#include "Luau/Instantiation.h" +#include "Luau/Scope.h" +#include "Luau/Simplify.h" +#include "Luau/Type.h" +#include "Luau/TypeArena.h" +#include "Luau/TypeCheckLimits.h" +#include "Luau/TypeFunction.h" +#include "Luau/TypeFwd.h" +#include "Luau/TypePack.h" +#include "Luau/TypeUtils.h" +#include "Luau/VisitType.h" + +#include +#include + +LUAU_FASTINT(LuauTypeInferRecursionLimit) + +namespace Luau +{ + +static bool areCompatible(TypeId left, TypeId right) +{ + auto p = get2(follow(left), follow(right)); + if (!p) + return true; + + const TableType* leftTable = p.first; + LUAU_ASSERT(leftTable); + const TableType* rightTable = p.second; + LUAU_ASSERT(rightTable); + + const auto missingPropIsCompatible = [](const Property& leftProp, const TableType* rightTable) + { + // Two tables may be compatible even if their shapes aren't exactly the + // same if the extra property is optional, free (and therefore + // potentially optional), or if the right table has an indexer. Or if + // the right table is free (and therefore potentially has an indexer or + // a compatible property) + + LUAU_ASSERT(leftProp.isReadOnly() || leftProp.isShared()); + + const TypeId leftType = follow(leftProp.isReadOnly() ? *leftProp.readTy : leftProp.type()); + + if (isOptional(leftType) || get(leftType) || rightTable->state == TableState::Free || rightTable->indexer.has_value()) + return true; + + return false; + }; + + for (const auto& [name, leftProp] : leftTable->props) + { + auto it = rightTable->props.find(name); + if (it == rightTable->props.end()) + { + if (!missingPropIsCompatible(leftProp, rightTable)) + return false; + } + } + + for (const auto& [name, rightProp] : rightTable->props) + { + auto it = leftTable->props.find(name); + if (it == leftTable->props.end()) + { + if (!missingPropIsCompatible(rightProp, leftTable)) + return false; + } + } + + return true; +} + +// returns `true` if `ty` is irressolvable and should be added to `incompleteSubtypes`. +static bool isIrresolvable(TypeId ty) +{ + return get(ty) || get(ty); +} + +// returns `true` if `tp` is irressolvable and should be added to `incompleteSubtypes`. +static bool isIrresolvable(TypePackId tp) +{ + return get(tp) || get(tp); +} + +Unifier2::Unifier2(NotNull arena, NotNull builtinTypes, NotNull scope, NotNull ice) + : arena(arena) + , builtinTypes(builtinTypes) + , scope(scope) + , ice(ice) + , limits(TypeCheckLimits{}) // TODO: typecheck limits in unifier2 + , recursionLimit(FInt::LuauTypeInferRecursionLimit) + , uninhabitedTypeFunctions(nullptr) +{ +} + +Unifier2::Unifier2( + NotNull arena, + NotNull builtinTypes, + NotNull scope, + NotNull ice, + DenseHashSet* uninhabitedTypeFunctions +) + : arena(arena) + , builtinTypes(builtinTypes) + , scope(scope) + , ice(ice) + , limits(TypeCheckLimits{}) // TODO: typecheck limits in unifier2 + , recursionLimit(FInt::LuauTypeInferRecursionLimit) + , uninhabitedTypeFunctions(uninhabitedTypeFunctions) +{ +} + +bool Unifier2::unify(TypeId subTy, TypeId superTy) +{ + subTy = follow(subTy); + superTy = follow(superTy); + + if (auto subGen = genericSubstitutions.find(subTy)) + return unify(*subGen, superTy); + + if (auto superGen = genericSubstitutions.find(superTy)) + return unify(subTy, *superGen); + + if (seenTypePairings.contains({subTy, superTy})) + return true; + seenTypePairings.insert({subTy, superTy}); + + if (subTy == superTy) + return true; + + // We have potentially done some unifications while dispatching either `SubtypeConstraint` or `PackSubtypeConstraint`, + // so rather than implementing backtracking or traversing the entire type graph multiple times, we could push + // additional constraints as we discover blocked types along with their proper bounds. + // + // But we exclude these two subtyping patterns, they are tautological: + // - never <: *blocked* + // - *blocked* <: unknown + if ((isIrresolvable(subTy) || isIrresolvable(superTy)) && !get(subTy) && !get(superTy)) + { + if (uninhabitedTypeFunctions && (uninhabitedTypeFunctions->contains(subTy) || uninhabitedTypeFunctions->contains(superTy))) + return true; + + incompleteSubtypes.push_back(SubtypeConstraint{subTy, superTy}); + return true; + } + + FreeType* subFree = getMutable(subTy); + FreeType* superFree = getMutable(superTy); + + if (superFree) + { + superFree->lowerBound = mkUnion(superFree->lowerBound, subTy); + } + + if (subFree) + { + return unifyFreeWithType(subTy, superTy); + } + + if (subFree || superFree) + return true; + + auto subFn = get(subTy); + auto superFn = get(superTy); + if (subFn && superFn) + return unify(subTy, superFn); + + auto subUnion = get(subTy); + auto superUnion = get(superTy); + if (subUnion) + return unify(subUnion, superTy); + else if (superUnion) + return unify(subTy, superUnion); + + auto subIntersection = get(subTy); + auto superIntersection = get(superTy); + if (subIntersection) + return unify(subIntersection, superTy); + else if (superIntersection) + return unify(subTy, superIntersection); + + auto subNever = get(subTy); + auto superNever = get(superTy); + if (subNever && superNever) + return true; + else if (subNever && superFn) + { + // If `never` is the subtype, then we can propagate that inward. + bool argResult = unify(superFn->argTypes, builtinTypes->neverTypePack); + bool retResult = unify(builtinTypes->neverTypePack, superFn->retTypes); + return argResult && retResult; + } + else if (subFn && superNever) + { + // If `never` is the supertype, then we can propagate that inward. + bool argResult = unify(builtinTypes->neverTypePack, subFn->argTypes); + bool retResult = unify(subFn->retTypes, builtinTypes->neverTypePack); + return argResult && retResult; + } + + auto subAny = get(subTy); + auto superAny = get(superTy); + + auto subTable = getMutable(subTy); + auto superTable = get(superTy); + + if (subAny && superAny) + return true; + else if (subAny && superFn) + return unify(subAny, superFn); + else if (subFn && superAny) + return unify(subFn, superAny); + else if (subAny && superTable) + return unify(subAny, superTable); + else if (subTable && superAny) + return unify(subTable, superAny); + + if (subTable && superTable) + { + // `boundTo` works like a bound type, and therefore we'd replace it + // with the `boundTo` and try unification again. + // + // However, these pointers should have been chased already by follow(). + LUAU_ASSERT(!subTable->boundTo); + LUAU_ASSERT(!superTable->boundTo); + + return unify(subTable, superTable); + } + + auto subMetatable = get(subTy); + auto superMetatable = get(superTy); + if (subMetatable && superMetatable) + return unify(subMetatable, superMetatable); + else if (subMetatable) // if we only have one metatable, unify with the inner table + return unify(subMetatable->table, superTy); + else if (superMetatable) // if we only have one metatable, unify with the inner table + return unify(subTy, superMetatable->table); + + auto [subNegation, superNegation] = get2(subTy, superTy); + if (subNegation && superNegation) + return unify(subNegation->ty, superNegation->ty); + + // The unification failed, but we're not doing type checking. + return true; +} + +// If superTy is a function and subTy already has a +// potentially-compatible function in its upper bound, we assume that +// the function is not overloaded and attempt to combine superTy into +// subTy's existing function bound. +bool Unifier2::unifyFreeWithType(TypeId subTy, TypeId superTy) +{ + FreeType* subFree = getMutable(subTy); + LUAU_ASSERT(subFree); + + auto doDefault = [&]() + { + subFree->upperBound = mkIntersection(subFree->upperBound, superTy); + expandedFreeTypes[subTy].push_back(superTy); + return true; + }; + + TypeId upperBound = follow(subFree->upperBound); + + if (get(upperBound)) + return unify(subFree->upperBound, superTy); + + const FunctionType* superFunction = get(superTy); + if (!superFunction) + return doDefault(); + + const auto [superArgHead, superArgTail] = flatten(superFunction->argTypes); + if (superArgTail) + return doDefault(); + + const IntersectionType* upperBoundIntersection = get(subFree->upperBound); + if (!upperBoundIntersection) + return doDefault(); + + bool ok = true; + bool foundOne = false; + + for (TypeId part : upperBoundIntersection->parts) + { + const FunctionType* ft = get(follow(part)); + if (!ft) + continue; + + const auto [subArgHead, subArgTail] = flatten(ft->argTypes); + + if (!subArgTail && subArgHead.size() == superArgHead.size()) + { + foundOne = true; + ok &= unify(part, superTy); + } + } + + if (foundOne) + return ok; + else + return doDefault(); +} + +bool Unifier2::unify(TypeId subTy, const FunctionType* superFn) +{ + const FunctionType* subFn = get(subTy); + + bool shouldInstantiate = + (superFn->generics.empty() && !subFn->generics.empty()) || (superFn->genericPacks.empty() && !subFn->genericPacks.empty()); + + if (shouldInstantiate) + { + for (auto generic : subFn->generics) + genericSubstitutions[generic] = freshType(arena, builtinTypes, scope); + + for (auto genericPack : subFn->genericPacks) + genericPackSubstitutions[genericPack] = arena->freshTypePack(scope); + } + + bool argResult = unify(superFn->argTypes, subFn->argTypes); + bool retResult = unify(subFn->retTypes, superFn->retTypes); + return argResult && retResult; +} + +bool Unifier2::unify(const UnionType* subUnion, TypeId superTy) +{ + bool result = true; + + // if the occurs check fails for any option, it fails overall + for (auto subOption : subUnion->options) + { + if (areCompatible(subOption, superTy)) + result &= unify(subOption, superTy); + } + + return result; +} + +bool Unifier2::unify(TypeId subTy, const UnionType* superUnion) +{ + bool result = true; + + // if the occurs check fails for any option, it fails overall + for (auto superOption : superUnion->options) + { + if (areCompatible(subTy, superOption)) + result &= unify(subTy, superOption); + } + + return result; +} + +bool Unifier2::unify(const IntersectionType* subIntersection, TypeId superTy) +{ + bool result = true; + + // if the occurs check fails for any part, it fails overall + for (auto subPart : subIntersection->parts) + result &= unify(subPart, superTy); + + return result; +} + +bool Unifier2::unify(TypeId subTy, const IntersectionType* superIntersection) +{ + bool result = true; + + // if the occurs check fails for any part, it fails overall + for (auto superPart : superIntersection->parts) + result &= unify(subTy, superPart); + + return result; +} + +bool Unifier2::unify(TableType* subTable, const TableType* superTable) +{ + bool result = true; + + // It suffices to only check one direction of properties since we'll only ever have work to do during unification + // if the property is present in both table types. + for (const auto& [propName, subProp] : subTable->props) + { + auto superPropOpt = superTable->props.find(propName); + + if (superPropOpt != superTable->props.end()) + { + const Property& superProp = superPropOpt->second; + + if (subProp.isReadOnly() && superProp.isReadOnly()) + result &= unify(*subProp.readTy, *superPropOpt->second.readTy); + else if (subProp.isReadOnly()) + result &= unify(*subProp.readTy, superProp.type()); + else if (superProp.isReadOnly()) + result &= unify(subProp.type(), *superProp.readTy); + else + { + result &= unify(subProp.type(), superProp.type()); + result &= unify(superProp.type(), subProp.type()); + } + } + } + + auto subTypeParamsIter = subTable->instantiatedTypeParams.begin(); + auto superTypeParamsIter = superTable->instantiatedTypeParams.begin(); + + while (subTypeParamsIter != subTable->instantiatedTypeParams.end() && superTypeParamsIter != superTable->instantiatedTypeParams.end()) + { + result &= unify(*subTypeParamsIter, *superTypeParamsIter); + + subTypeParamsIter++; + superTypeParamsIter++; + } + + auto subTypePackParamsIter = subTable->instantiatedTypePackParams.begin(); + auto superTypePackParamsIter = superTable->instantiatedTypePackParams.begin(); + + while (subTypePackParamsIter != subTable->instantiatedTypePackParams.end() && + superTypePackParamsIter != superTable->instantiatedTypePackParams.end()) + { + result &= unify(*subTypePackParamsIter, *superTypePackParamsIter); + + subTypePackParamsIter++; + superTypePackParamsIter++; + } + + if (subTable->selfTy && superTable->selfTy) + result &= unify(*subTable->selfTy, *superTable->selfTy); + + if (subTable->indexer && superTable->indexer) + { + result &= unify(subTable->indexer->indexType, superTable->indexer->indexType); + result &= unify(subTable->indexer->indexResultType, superTable->indexer->indexResultType); + } + + if (!subTable->indexer && subTable->state == TableState::Unsealed && superTable->indexer) + { + /* + * Unsealed tables are always created from literal table expressions. We + * can't be completely certain whether such a table has an indexer just + * by the content of the expression itself, so we need to be a bit more + * flexible here. + * + * If we are trying to reconcile an unsealed table with a table that has + * an indexer, we therefore conclude that the unsealed table has the + * same indexer. + */ + + TypeId indexType = superTable->indexer->indexType; + if (TypeId* subst = genericSubstitutions.find(indexType)) + indexType = *subst; + + TypeId indexResultType = superTable->indexer->indexResultType; + if (TypeId* subst = genericSubstitutions.find(indexResultType)) + indexResultType = *subst; + + subTable->indexer = TableIndexer{indexType, indexResultType}; + } + + return result; +} + +bool Unifier2::unify(const MetatableType* subMetatable, const MetatableType* superMetatable) +{ + return unify(subMetatable->metatable, superMetatable->metatable) && unify(subMetatable->table, superMetatable->table); +} + +bool Unifier2::unify(const AnyType* subAny, const FunctionType* superFn) +{ + // If `any` is the subtype, then we can propagate that inward. + bool argResult = unify(superFn->argTypes, builtinTypes->anyTypePack); + bool retResult = unify(builtinTypes->anyTypePack, superFn->retTypes); + return argResult && retResult; +} + +bool Unifier2::unify(const FunctionType* subFn, const AnyType* superAny) +{ + // If `any` is the supertype, then we can propagate that inward. + bool argResult = unify(builtinTypes->anyTypePack, subFn->argTypes); + bool retResult = unify(subFn->retTypes, builtinTypes->anyTypePack); + return argResult && retResult; +} + +bool Unifier2::unify(const AnyType* subAny, const TableType* superTable) +{ + for (const auto& [propName, prop] : superTable->props) + { + if (prop.readTy) + unify(builtinTypes->anyType, *prop.readTy); + + if (prop.writeTy) + unify(*prop.writeTy, builtinTypes->anyType); + } + + if (superTable->indexer) + { + unify(builtinTypes->anyType, superTable->indexer->indexType); + unify(builtinTypes->anyType, superTable->indexer->indexResultType); + } + + return true; +} + +bool Unifier2::unify(const TableType* subTable, const AnyType* superAny) +{ + for (const auto& [propName, prop] : subTable->props) + { + if (prop.readTy) + unify(*prop.readTy, builtinTypes->anyType); + + if (prop.writeTy) + unify(builtinTypes->anyType, *prop.writeTy); + } + + if (subTable->indexer) + { + unify(subTable->indexer->indexType, builtinTypes->anyType); + unify(subTable->indexer->indexResultType, builtinTypes->anyType); + } + + return true; +} + +// FIXME? This should probably return an ErrorVec or an optional +// rather than a boolean to signal an occurs check failure. +bool Unifier2::unify(TypePackId subTp, TypePackId superTp) +{ + subTp = follow(subTp); + superTp = follow(superTp); + + if (auto subGen = genericPackSubstitutions.find(subTp)) + return unify(*subGen, superTp); + + if (auto superGen = genericPackSubstitutions.find(superTp)) + return unify(subTp, *superGen); + + if (seenTypePackPairings.contains({subTp, superTp})) + return true; + seenTypePackPairings.insert({subTp, superTp}); + + if (subTp == superTp) + return true; + + if (isIrresolvable(subTp) || isIrresolvable(superTp)) + { + if (uninhabitedTypeFunctions && (uninhabitedTypeFunctions->contains(subTp) || uninhabitedTypeFunctions->contains(superTp))) + return true; + + incompleteSubtypes.push_back(PackSubtypeConstraint{subTp, superTp}); + return true; + } + + const FreeTypePack* subFree = get(subTp); + const FreeTypePack* superFree = get(superTp); + + if (subFree) + { + DenseHashSet seen{nullptr}; + if (OccursCheckResult::Fail == occursCheck(seen, subTp, superTp)) + { + emplaceTypePack(asMutable(subTp), builtinTypes->errorTypePack); + return false; + } + + emplaceTypePack(asMutable(subTp), superTp); + return true; + } + + if (superFree) + { + DenseHashSet seen{nullptr}; + if (OccursCheckResult::Fail == occursCheck(seen, superTp, subTp)) + { + emplaceTypePack(asMutable(superTp), builtinTypes->errorTypePack); + return false; + } + + emplaceTypePack(asMutable(superTp), subTp); + return true; + } + + size_t maxLength = std::max(flatten(subTp).first.size(), flatten(superTp).first.size()); + + auto [subTypes, subTail] = extendTypePack(*arena, builtinTypes, subTp, maxLength); + auto [superTypes, superTail] = extendTypePack(*arena, builtinTypes, superTp, maxLength); + + // right-pad the subpack with nils if `superPack` is larger since that's what a function call does + if (subTypes.size() < maxLength) + { + for (size_t i = 0; i <= maxLength - subTypes.size(); i++) + subTypes.push_back(builtinTypes->nilType); + } + + if (subTypes.size() < maxLength || superTypes.size() < maxLength) + return true; + + for (size_t i = 0; i < maxLength; ++i) + unify(subTypes[i], superTypes[i]); + + if (subTail && superTail) + { + TypePackId followedSubTail = follow(*subTail); + TypePackId followedSuperTail = follow(*superTail); + + if (get(followedSubTail) || get(followedSuperTail)) + return unify(followedSubTail, followedSuperTail); + } + else if (subTail) + { + TypePackId followedSubTail = follow(*subTail); + if (get(followedSubTail)) + emplaceTypePack(asMutable(followedSubTail), builtinTypes->emptyTypePack); + } + else if (superTail) + { + TypePackId followedSuperTail = follow(*superTail); + if (get(followedSuperTail)) + emplaceTypePack(asMutable(followedSuperTail), builtinTypes->emptyTypePack); + } + + return true; +} + +struct FreeTypeSearcher : TypeVisitor +{ + NotNull scope; + + explicit FreeTypeSearcher(NotNull scope) + : TypeVisitor(/*skipBoundTypes*/ true) + , scope(scope) + { + } + + enum Polarity + { + Positive, + Negative, + Both, + }; + + Polarity polarity = Positive; + + void flip() + { + switch (polarity) + { + case Positive: + polarity = Negative; + break; + case Negative: + polarity = Positive; + break; + case Both: + break; + } + } + + DenseHashSet seenPositive{nullptr}; + DenseHashSet seenNegative{nullptr}; + + bool seenWithPolarity(const void* ty) + { + switch (polarity) + { + case Positive: + { + if (seenPositive.contains(ty)) + return true; + + seenPositive.insert(ty); + return false; + } + case Negative: + { + if (seenNegative.contains(ty)) + return true; + + seenNegative.insert(ty); + return false; + } + case Both: + { + if (seenPositive.contains(ty) && seenNegative.contains(ty)) + return true; + + seenPositive.insert(ty); + seenNegative.insert(ty); + return false; + } + } + + return false; + } + + // The keys in these maps are either TypeIds or TypePackIds. It's safe to + // mix them because we only use these pointers as unique keys. We never + // indirect them. + DenseHashMap negativeTypes{0}; + DenseHashMap positiveTypes{0}; + + bool visit(TypeId ty) override + { + if (seenWithPolarity(ty)) + return false; + + LUAU_ASSERT(ty); + return true; + } + + bool visit(TypeId ty, const FreeType& ft) override + { + if (seenWithPolarity(ty)) + return false; + + if (!subsumes(scope, ft.scope)) + return true; + + switch (polarity) + { + case Positive: + positiveTypes[ty]++; + break; + case Negative: + negativeTypes[ty]++; + break; + case Both: + positiveTypes[ty]++; + negativeTypes[ty]++; + break; + } + + return true; + } + + bool visit(TypeId ty, const TableType& tt) override + { + if (seenWithPolarity(ty)) + return false; + + if ((tt.state == TableState::Free || tt.state == TableState::Unsealed) && subsumes(scope, tt.scope)) + { + switch (polarity) + { + case Positive: + positiveTypes[ty]++; + break; + case Negative: + negativeTypes[ty]++; + break; + case Both: + positiveTypes[ty]++; + negativeTypes[ty]++; + break; + } + } + + for (const auto& [_name, prop] : tt.props) + { + if (prop.isReadOnly()) + traverse(*prop.readTy); + else + { + LUAU_ASSERT(prop.isShared()); + + Polarity p = polarity; + polarity = Both; + traverse(prop.type()); + polarity = p; + } + } + + if (tt.indexer) + { + traverse(tt.indexer->indexType); + traverse(tt.indexer->indexResultType); + } + + return false; + } + + bool visit(TypeId ty, const FunctionType& ft) override + { + if (seenWithPolarity(ty)) + return false; + + flip(); + traverse(ft.argTypes); + flip(); + + traverse(ft.retTypes); + + return false; + } + + bool visit(TypeId, const ClassType&) override + { + return false; + } + + bool visit(TypePackId tp, const FreeTypePack& ftp) override + { + if (seenWithPolarity(tp)) + return false; + + if (!subsumes(scope, ftp.scope)) + return true; + + switch (polarity) + { + case Positive: + positiveTypes[tp]++; + break; + case Negative: + negativeTypes[tp]++; + break; + case Both: + positiveTypes[tp]++; + negativeTypes[tp]++; + break; + } + + return true; + } +}; + +TypeId Unifier2::mkUnion(TypeId left, TypeId right) +{ + left = follow(left); + right = follow(right); + + return simplifyUnion(builtinTypes, arena, left, right).result; +} + +TypeId Unifier2::mkIntersection(TypeId left, TypeId right) +{ + left = follow(left); + right = follow(right); + + return simplifyIntersection(builtinTypes, arena, left, right).result; +} + +OccursCheckResult Unifier2::occursCheck(DenseHashSet& seen, TypeId needle, TypeId haystack) +{ + RecursionLimiter _ra(&recursionCount, recursionLimit); + + OccursCheckResult occurrence = OccursCheckResult::Pass; + + auto check = [&](TypeId ty) + { + if (occursCheck(seen, needle, ty) == OccursCheckResult::Fail) + occurrence = OccursCheckResult::Fail; + }; + + needle = follow(needle); + haystack = follow(haystack); + + if (seen.find(haystack)) + return OccursCheckResult::Pass; + + seen.insert(haystack); + + if (get(needle)) + return OccursCheckResult::Pass; + + if (!get(needle)) + ice->ice("Expected needle to be free"); + + if (needle == haystack) + return OccursCheckResult::Fail; + + if (auto haystackFree = get(haystack)) + { + check(haystackFree->lowerBound); + check(haystackFree->upperBound); + } + else if (auto ut = get(haystack)) + { + for (TypeId ty : ut->options) + check(ty); + } + else if (auto it = get(haystack)) + { + for (TypeId ty : it->parts) + check(ty); + } + + return occurrence; +} + +OccursCheckResult Unifier2::occursCheck(DenseHashSet& seen, TypePackId needle, TypePackId haystack) +{ + needle = follow(needle); + haystack = follow(haystack); + + if (seen.find(haystack)) + return OccursCheckResult::Pass; + + seen.insert(haystack); + + if (getMutable(needle)) + return OccursCheckResult::Pass; + + if (!getMutable(needle)) + ice->ice("Expected needle pack to be free"); + + RecursionLimiter _ra(&recursionCount, recursionLimit); + + while (!getMutable(haystack)) + { + if (needle == haystack) + return OccursCheckResult::Fail; + + if (auto a = get(haystack); a && a->tail) + { + haystack = follow(*a->tail); + continue; + } + + break; + } + + return OccursCheckResult::Pass; +} + +} // namespace Luau diff --git a/third_party/luau/Ast/include/Luau/Ast.h b/third_party/luau/Ast/include/Luau/Ast.h index 8309e8f3..099ece2b 100644 --- a/third_party/luau/Ast/include/Luau/Ast.h +++ b/third_party/luau/Ast/include/Luau/Ast.h @@ -1,14 +1,15 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once -#include #include "Luau/Location.h" +#include #include #include #include #include +#include namespace Luau { @@ -59,6 +60,8 @@ class AstStat; class AstStatBlock; class AstExpr; class AstTypePack; +class AstAttr; +class AstExprTable; struct AstLocal { @@ -91,10 +94,21 @@ struct AstArray { return data; } + const T* end() const { return data + size; } + + std::reverse_iterator rbegin() const + { + return std::make_reverse_iterator(end()); + } + + std::reverse_iterator rend() const + { + return std::make_reverse_iterator(begin()); + } }; struct AstTypeList @@ -160,6 +174,10 @@ class AstNode { return nullptr; } + virtual AstAttr* asAttr() + { + return nullptr; + } template bool is() const @@ -181,6 +199,29 @@ class AstNode Location location; }; +class AstAttr : public AstNode +{ +public: + LUAU_RTTI(AstAttr) + + enum Type + { + Checked, + Native, + }; + + AstAttr(const Location& location, Type type); + + AstAttr* asAttr() override + { + return this; + } + + void visit(AstVisitor* visitor) override; + + Type type; +}; + class AstExpr : public AstNode { public: @@ -249,6 +290,7 @@ class AstExprConstantBool : public AstExpr enum class ConstantNumberParseResult { Ok, + Imprecise, Malformed, BinOverflow, HexOverflow, @@ -272,11 +314,18 @@ class AstExprConstantString : public AstExpr public: LUAU_RTTI(AstExprConstantString) - AstExprConstantString(const Location& location, const AstArray& value); + enum QuoteStyle + { + Quoted, + Unquoted + }; + + AstExprConstantString(const Location& location, const AstArray& value, QuoteStyle quoteStyle = Quoted); void visit(AstVisitor* visitor) override; AstArray value; + QuoteStyle quoteStyle = Quoted; }; class AstExprLocal : public AstExpr @@ -335,7 +384,13 @@ class AstExprIndexName : public AstExpr LUAU_RTTI(AstExprIndexName) AstExprIndexName( - const Location& location, AstExpr* expr, const AstName& index, const Location& indexLocation, const Position& opPosition, char op); + const Location& location, + AstExpr* expr, + const AstName& index, + const Location& indexLocation, + const Position& opPosition, + char op + ); void visit(AstVisitor* visitor) override; @@ -364,13 +419,28 @@ class AstExprFunction : public AstExpr public: LUAU_RTTI(AstExprFunction) - AstExprFunction(const Location& location, const AstArray& generics, const AstArray& genericPacks, - AstLocal* self, const AstArray& args, bool vararg, const Location& varargLocation, AstStatBlock* body, size_t functionDepth, - const AstName& debugname, const std::optional& returnAnnotation = {}, AstTypePack* varargAnnotation = nullptr, - bool hasEnd = false, const std::optional& argLocation = std::nullopt); + AstExprFunction( + const Location& location, + const AstArray& attributes, + const AstArray& generics, + const AstArray& genericPacks, + AstLocal* self, + const AstArray& args, + bool vararg, + const Location& varargLocation, + AstStatBlock* body, + size_t functionDepth, + const AstName& debugname, + const std::optional& returnAnnotation = {}, + AstTypePack* varargAnnotation = nullptr, + const std::optional& argLocation = std::nullopt + ); void visit(AstVisitor* visitor) override; + bool hasNativeAttribute() const; + + AstArray attributes; AstArray generics; AstArray genericPacks; AstLocal* self; @@ -386,7 +456,6 @@ class AstExprFunction : public AstExpr AstName debugname; - bool hasEnd = false; std::optional argLocation; }; @@ -450,6 +519,7 @@ class AstExprBinary : public AstExpr Sub, Mul, Div, + FloorDiv, Mod, Pow, Concat, @@ -460,7 +530,9 @@ class AstExprBinary : public AstExpr CompareGt, CompareGe, And, - Or + Or, + + Op__Count }; AstExprBinary(const Location& location, Op op, AstExpr* left, AstExpr* right); @@ -524,11 +596,23 @@ class AstStatBlock : public AstStat public: LUAU_RTTI(AstStatBlock) - AstStatBlock(const Location& location, const AstArray& body); + AstStatBlock(const Location& location, const AstArray& body, bool hasEnd = true); void visit(AstVisitor* visitor) override; AstArray body; + + /* Indicates whether or not this block has been terminated in a + * syntactically valid way. + * + * This is usually but not always done with the 'end' keyword. AstStatIf + * and AstStatRepeat are the two main exceptions to this. + * + * The 'then' clause of an if statement can properly be closed by the + * keywords 'else' or 'elseif'. A 'repeat' loop's body is closed with the + * 'until' keyword. + */ + bool hasEnd = false; }; class AstStatIf : public AstStat @@ -536,8 +620,14 @@ class AstStatIf : public AstStat public: LUAU_RTTI(AstStatIf) - AstStatIf(const Location& location, AstExpr* condition, AstStatBlock* thenbody, AstStat* elsebody, const std::optional& thenLocation, - const std::optional& elseLocation, bool hasEnd); + AstStatIf( + const Location& location, + AstExpr* condition, + AstStatBlock* thenbody, + AstStat* elsebody, + const std::optional& thenLocation, + const std::optional& elseLocation + ); void visit(AstVisitor* visitor) override; @@ -549,8 +639,6 @@ class AstStatIf : public AstStat // Active for 'elseif' as well std::optional elseLocation; - - bool hasEnd = false; }; class AstStatWhile : public AstStat @@ -558,7 +646,7 @@ class AstStatWhile : public AstStat public: LUAU_RTTI(AstStatWhile) - AstStatWhile(const Location& location, AstExpr* condition, AstStatBlock* body, bool hasDo, const Location& doLocation, bool hasEnd); + AstStatWhile(const Location& location, AstExpr* condition, AstStatBlock* body, bool hasDo, const Location& doLocation); void visit(AstVisitor* visitor) override; @@ -567,8 +655,6 @@ class AstStatWhile : public AstStat bool hasDo = false; Location doLocation; - - bool hasEnd = false; }; class AstStatRepeat : public AstStat @@ -576,14 +662,14 @@ class AstStatRepeat : public AstStat public: LUAU_RTTI(AstStatRepeat) - AstStatRepeat(const Location& location, AstExpr* condition, AstStatBlock* body, bool hasUntil); + AstStatRepeat(const Location& location, AstExpr* condition, AstStatBlock* body, bool DEPRECATED_hasUntil); void visit(AstVisitor* visitor) override; AstExpr* condition; AstStatBlock* body; - bool hasUntil = false; + bool DEPRECATED_hasUntil = false; }; class AstStatBreak : public AstStat @@ -635,8 +721,12 @@ class AstStatLocal : public AstStat public: LUAU_RTTI(AstStatLocal) - AstStatLocal(const Location& location, const AstArray& vars, const AstArray& values, - const std::optional& equalsSignLocation); + AstStatLocal( + const Location& location, + const AstArray& vars, + const AstArray& values, + const std::optional& equalsSignLocation + ); void visit(AstVisitor* visitor) override; @@ -651,8 +741,16 @@ class AstStatFor : public AstStat public: LUAU_RTTI(AstStatFor) - AstStatFor(const Location& location, AstLocal* var, AstExpr* from, AstExpr* to, AstExpr* step, AstStatBlock* body, bool hasDo, - const Location& doLocation, bool hasEnd); + AstStatFor( + const Location& location, + AstLocal* var, + AstExpr* from, + AstExpr* to, + AstExpr* step, + AstStatBlock* body, + bool hasDo, + const Location& doLocation + ); void visit(AstVisitor* visitor) override; @@ -664,8 +762,6 @@ class AstStatFor : public AstStat bool hasDo = false; Location doLocation; - - bool hasEnd = false; }; class AstStatForIn : public AstStat @@ -673,8 +769,16 @@ class AstStatForIn : public AstStat public: LUAU_RTTI(AstStatForIn) - AstStatForIn(const Location& location, const AstArray& vars, const AstArray& values, AstStatBlock* body, bool hasIn, - const Location& inLocation, bool hasDo, const Location& doLocation, bool hasEnd); + AstStatForIn( + const Location& location, + const AstArray& vars, + const AstArray& values, + AstStatBlock* body, + bool hasIn, + const Location& inLocation, + bool hasDo, + const Location& doLocation + ); void visit(AstVisitor* visitor) override; @@ -687,8 +791,6 @@ class AstStatForIn : public AstStat bool hasDo = false; Location doLocation; - - bool hasEnd = false; }; class AstStatAssign : public AstStat @@ -749,8 +851,15 @@ class AstStatTypeAlias : public AstStat public: LUAU_RTTI(AstStatTypeAlias) - AstStatTypeAlias(const Location& location, const AstName& name, const Location& nameLocation, const AstArray& generics, - const AstArray& genericPacks, AstType* type, bool exported); + AstStatTypeAlias( + const Location& location, + const AstName& name, + const Location& nameLocation, + const AstArray& generics, + const AstArray& genericPacks, + AstType* type, + bool exported + ); void visit(AstVisitor* visitor) override; @@ -762,16 +871,31 @@ class AstStatTypeAlias : public AstStat bool exported; }; +class AstStatTypeFunction : public AstStat +{ +public: + LUAU_RTTI(AstStatTypeFunction); + + AstStatTypeFunction(const Location& location, const AstName& name, const Location& nameLocation, AstExprFunction* body); + + void visit(AstVisitor* visitor) override; + + AstName name; + Location nameLocation; + AstExprFunction* body; +}; + class AstStatDeclareGlobal : public AstStat { public: LUAU_RTTI(AstStatDeclareGlobal) - AstStatDeclareGlobal(const Location& location, const AstName& name, AstType* type); + AstStatDeclareGlobal(const Location& location, const AstName& name, const Location& nameLocation, AstType* type); void visit(AstVisitor* visitor) override; AstName name; + Location nameLocation; AstType* type; }; @@ -780,25 +904,74 @@ class AstStatDeclareFunction : public AstStat public: LUAU_RTTI(AstStatDeclareFunction) - AstStatDeclareFunction(const Location& location, const AstName& name, const AstArray& generics, - const AstArray& genericPacks, const AstTypeList& params, const AstArray& paramNames, - const AstTypeList& retTypes); + AstStatDeclareFunction( + const Location& location, + const AstName& name, + const Location& nameLocation, + const AstArray& generics, + const AstArray& genericPacks, + const AstTypeList& params, + const AstArray& paramNames, + bool vararg, + const Location& varargLocation, + const AstTypeList& retTypes + ); + + AstStatDeclareFunction( + const Location& location, + const AstArray& attributes, + const AstName& name, + const Location& nameLocation, + const AstArray& generics, + const AstArray& genericPacks, + const AstTypeList& params, + const AstArray& paramNames, + bool vararg, + const Location& varargLocation, + const AstTypeList& retTypes + ); + void visit(AstVisitor* visitor) override; + bool isCheckedFunction() const; + + AstArray attributes; AstName name; + Location nameLocation; AstArray generics; AstArray genericPacks; AstTypeList params; AstArray paramNames; + bool vararg = false; + Location varargLocation; AstTypeList retTypes; }; struct AstDeclaredClassProp { AstName name; + Location nameLocation; AstType* ty = nullptr; bool isMethod = false; + Location location; +}; + +enum class AstTableAccess +{ + Read = 0b01, + Write = 0b10, + ReadWrite = 0b11, +}; + +struct AstTableIndexer +{ + AstType* indexType; + AstType* resultType; + Location location; + + AstTableAccess access = AstTableAccess::ReadWrite; + std::optional accessLocation; }; class AstStatDeclareClass : public AstStat @@ -806,7 +979,13 @@ class AstStatDeclareClass : public AstStat public: LUAU_RTTI(AstStatDeclareClass) - AstStatDeclareClass(const Location& location, const AstName& name, std::optional superName, const AstArray& props); + AstStatDeclareClass( + const Location& location, + const AstName& name, + std::optional superName, + const AstArray& props, + AstTableIndexer* indexer = nullptr + ); void visit(AstVisitor* visitor) override; @@ -814,6 +993,7 @@ class AstStatDeclareClass : public AstStat std::optional superName; AstArray props; + AstTableIndexer* indexer; }; class AstType : public AstNode @@ -842,8 +1022,15 @@ class AstTypeReference : public AstType public: LUAU_RTTI(AstTypeReference) - AstTypeReference(const Location& location, std::optional prefix, AstName name, std::optional prefixLocation, - const Location& nameLocation, bool hasParameterList = false, const AstArray& parameters = {}); + AstTypeReference( + const Location& location, + std::optional prefix, + AstName name, + std::optional prefixLocation, + const Location& nameLocation, + bool hasParameterList = false, + const AstArray& parameters = {} + ); void visit(AstVisitor* visitor) override; @@ -860,13 +1047,8 @@ struct AstTableProp AstName name; Location location; AstType* type; -}; - -struct AstTableIndexer -{ - AstType* indexType; - AstType* resultType; - Location location; + AstTableAccess access = AstTableAccess::ReadWrite; + std::optional accessLocation; }; class AstTypeTable : public AstType @@ -887,11 +1069,30 @@ class AstTypeFunction : public AstType public: LUAU_RTTI(AstTypeFunction) - AstTypeFunction(const Location& location, const AstArray& generics, const AstArray& genericPacks, - const AstTypeList& argTypes, const AstArray>& argNames, const AstTypeList& returnTypes); + AstTypeFunction( + const Location& location, + const AstArray& generics, + const AstArray& genericPacks, + const AstTypeList& argTypes, + const AstArray>& argNames, + const AstTypeList& returnTypes + ); + + AstTypeFunction( + const Location& location, + const AstArray& attributes, + const AstArray& generics, + const AstArray& genericPacks, + const AstTypeList& argTypes, + const AstArray>& argNames, + const AstTypeList& returnTypes + ); void visit(AstVisitor* visitor) override; + bool isCheckedFunction() const; + + AstArray attributes; AstArray generics; AstArray genericPacks; AstTypeList argTypes; @@ -1055,6 +1256,11 @@ class AstVisitor return true; } + virtual bool visit(class AstAttr* node) + { + return visit(static_cast(node)); + } + virtual bool visit(class AstExpr* node) { return visit(static_cast(node)); @@ -1314,4 +1520,4 @@ struct hash } }; -} // namespace std +} // namespace std \ No newline at end of file diff --git a/third_party/luau/Ast/include/Luau/Lexer.h b/third_party/luau/Ast/include/Luau/Lexer.h index 929402b3..f6ac28ad 100644 --- a/third_party/luau/Ast/include/Luau/Lexer.h +++ b/third_party/luau/Ast/include/Luau/Lexer.h @@ -62,6 +62,7 @@ struct Lexeme Dot3, SkinnyArrow, DoubleColon, + FloorDiv, InterpStringBegin, InterpStringMid, @@ -73,6 +74,7 @@ struct Lexeme SubAssign, MulAssign, DivAssign, + FloorDivAssign, ModAssign, PowAssign, ConcatAssign, @@ -85,11 +87,12 @@ struct Lexeme Comment, BlockComment, + Attribute, + BrokenString, BrokenComment, BrokenUnicode, BrokenInterpDoubleBrace, - Error, Reserved_BEGIN, @@ -119,8 +122,15 @@ struct Lexeme Type type; Location location; + + // Field declared here, before the union, to ensure that Lexeme size is 32 bytes. +private: + // length is used to extract a slice from the input buffer. + // This field is only valid for certain lexeme types which don't duplicate portions of input + // but instead store a pointer to a location in the input buffer and the length of lexeme. unsigned int length; +public: union { const char* data; // String, Number, Comment @@ -133,9 +143,13 @@ struct Lexeme Lexeme(const Location& location, Type type, const char* data, size_t size); Lexeme(const Location& location, Type type, const char* name); + unsigned int getLength() const; + std::string toString() const; }; +static_assert(sizeof(Lexeme) <= 32, "Size of `Lexeme` struct should be up to 32 bytes."); + class AstNameTable { public: @@ -204,7 +218,9 @@ class Lexer Position position() const; + // consume() assumes current character is not a newline for performance; when that is not known, consumeAny() should be used instead. void consume(); + void consumeAny(); Lexeme readCommentBody(); diff --git a/third_party/luau/Ast/include/Luau/Location.h b/third_party/luau/Ast/include/Luau/Location.h index dbe36bec..3fc8921a 100644 --- a/third_party/luau/Ast/include/Luau/Location.h +++ b/third_party/luau/Ast/include/Luau/Location.h @@ -8,7 +8,11 @@ struct Position { unsigned int line, column; - Position(unsigned int line, unsigned int column); + Position(unsigned int line, unsigned int column) + : line(line) + , column(column) + { + } bool operator==(const Position& rhs) const; bool operator!=(const Position& rhs) const; @@ -24,10 +28,29 @@ struct Location { Position begin, end; - Location(); - Location(const Position& begin, const Position& end); - Location(const Position& begin, unsigned int length); - Location(const Location& begin, const Location& end); + Location() + : begin(0, 0) + , end(0, 0) + { + } + + Location(const Position& begin, const Position& end) + : begin(begin) + , end(end) + { + } + + Location(const Position& begin, unsigned int length) + : begin(begin) + , end(begin.line, begin.column + length) + { + } + + Location(const Location& begin, const Location& end) + : begin(begin.begin) + , end(end.end) + { + } bool operator==(const Location& rhs) const; bool operator!=(const Location& rhs) const; diff --git a/third_party/luau/Ast/include/Luau/ParseOptions.h b/third_party/luau/Ast/include/Luau/ParseOptions.h index 89e79528..01f2a74f 100644 --- a/third_party/luau/Ast/include/Luau/ParseOptions.h +++ b/third_party/luau/Ast/include/Luau/ParseOptions.h @@ -14,8 +14,6 @@ enum class Mode struct ParseOptions { - bool allowTypeAnnotations = true; - bool supportContinueStatement = true; bool allowDeclarationSyntax = false; bool captureComments = false; }; diff --git a/third_party/luau/Ast/include/Luau/Parser.h b/third_party/luau/Ast/include/Luau/Parser.h index 0b9d8c46..4e49028a 100644 --- a/third_party/luau/Ast/include/Luau/Parser.h +++ b/third_party/luau/Ast/include/Luau/Parser.h @@ -55,7 +55,12 @@ class Parser { public: static ParseResult parse( - const char* buffer, std::size_t bufferSize, AstNameTable& names, Allocator& allocator, ParseOptions options = ParseOptions()); + const char* buffer, + std::size_t bufferSize, + AstNameTable& names, + Allocator& allocator, + ParseOptions options = ParseOptions() + ); private: struct Name; @@ -82,8 +87,8 @@ class Parser // if exp then block {elseif exp then block} [else block] end | // for Name `=' exp `,' exp [`,' exp] do block end | // for namelist in explist do block end | - // function funcname funcbody | - // local function Name funcbody | + // [attributes] function funcname funcbody | + // [attributes] local function Name funcbody | // local namelist [`=' explist] // laststat ::= return [explist] | break AstStat* parseStat(); @@ -114,11 +119,25 @@ class Parser AstExpr* parseFunctionName(Location start, bool& hasself, AstName& debugname); // function funcname funcbody - AstStat* parseFunctionStat(); + LUAU_FORCEINLINE AstStat* parseFunctionStat(const AstArray& attributes = {nullptr, 0}); + + std::pair validateAttribute(const char* attributeName, const TempVector& attributes); + + // attribute ::= '@' NAME + void parseAttribute(TempVector& attribute); + + // attributes ::= {attribute} + AstArray parseAttributes(); + + // attributes local function Name funcbody + // attributes function funcname funcbody + // attributes `declare function' Name`(' [parlist] `)' [`:` Type] + // declare Name '{' Name ':' attributes `(' [parlist] `)' [`:` Type] '}' + AstStat* parseAttributeStat(); // local function Name funcbody | // local namelist [`=' explist] - AstStat* parseLocal(); + AstStat* parseLocal(const AstArray& attributes); // return [explist] AstStat* parseReturn(); @@ -126,11 +145,14 @@ class Parser // type Name `=' Type AstStat* parseTypeAlias(const Location& start, bool exported); + // type function Name ... end + AstStat* parseTypeFunction(const Location& start); + AstDeclaredClassProp parseDeclaredClassMethod(); // `declare global' Name: Type | // `declare function' Name`(' [parlist] `)' [`:` Type] - AstStat* parseDeclaration(const Location& start); + AstStat* parseDeclaration(const Location& start, const AstArray& attributes); // varlist `=' explist AstStat* parseAssignment(AstExpr* initial); @@ -143,7 +165,12 @@ class Parser // funcbodyhead ::= `(' [namelist [`,' `...'] | `...'] `)' [`:` Type] // funcbody ::= funcbodyhead block end std::pair parseFunctionBody( - bool hasself, const Lexeme& matchFunction, const AstName& debugname, const Name* localName); + bool hasself, + const Lexeme& matchFunction, + const AstName& debugname, + const Name* localName, + const AstArray& attributes + ); // explist ::= {exp `,'} exp void parseExprList(TempVector& result); @@ -174,17 +201,24 @@ class Parser std::optional parseOptionalReturnType(); std::pair parseReturnType(); - AstTableIndexer* parseTableIndexer(); + AstTableIndexer* parseTableIndexer(AstTableAccess access, std::optional accessLocation); - AstTypeOrPack parseFunctionType(bool allowPack); - AstType* parseFunctionTypeTail(const Lexeme& begin, AstArray generics, AstArray genericPacks, - AstArray params, AstArray> paramNames, AstTypePack* varargAnnotation); + AstTypeOrPack parseFunctionType(bool allowPack, const AstArray& attributes); + AstType* parseFunctionTypeTail( + const Lexeme& begin, + const AstArray& attributes, + AstArray generics, + AstArray genericPacks, + AstArray params, + AstArray> paramNames, + AstTypePack* varargAnnotation + ); - AstType* parseTableType(); - AstTypeOrPack parseSimpleType(bool allowPack); + AstType* parseTableType(bool inDeclarationContext = false); + AstTypeOrPack parseSimpleType(bool allowPack, bool inDeclarationContext = false); AstTypeOrPack parseTypeOrPack(); - AstType* parseType(); + AstType* parseType(bool inDeclarationContext = false); AstTypePack* parseTypePack(); AstTypePack* parseVariadicArgumentTypePack(); @@ -219,7 +253,7 @@ class Parser // asexp -> simpleexp [`::' Type] AstExpr* parseAssertionExpr(); - // simpleexp -> NUMBER | STRING | NIL | true | false | ... | constructor | FUNCTION body | primaryexp + // simpleexp -> NUMBER | STRING | NIL | true | false | ... | constructor | [attributes] FUNCTION body | primaryexp AstExpr* parseSimpleExpr(); // args ::= `(' [explist] `)' | tableconstructor | String @@ -300,8 +334,13 @@ class Parser void reportNameError(const char* context); - AstStatError* reportStatError(const Location& location, const AstArray& expressions, const AstArray& statements, - const char* format, ...) LUAU_PRINTF_ATTR(5, 6); + AstStatError* reportStatError( + const Location& location, + const AstArray& expressions, + const AstArray& statements, + const char* format, + ... + ) LUAU_PRINTF_ATTR(5, 6); AstExprError* reportExprError(const Location& location, const AstArray& expressions, const char* format, ...) LUAU_PRINTF_ATTR(4, 5); AstTypeError* reportTypeError(const Location& location, const AstArray& types, const char* format, ...) LUAU_PRINTF_ATTR(4, 5); // `parseErrorLocation` is associated with the parser error @@ -392,6 +431,7 @@ class Parser std::vector matchRecoveryStopOnToken; + std::vector scratchAttr; std::vector scratchStat; std::vector> scratchString; std::vector scratchExpr; @@ -412,4 +452,4 @@ class Parser std::string scratchData; }; -} // namespace Luau +} // namespace Luau \ No newline at end of file diff --git a/third_party/luau/Ast/include/Luau/StringUtils.h b/third_party/luau/Ast/include/Luau/StringUtils.h index fcea1874..6345fde4 100644 --- a/third_party/luau/Ast/include/Luau/StringUtils.h +++ b/third_party/luau/Ast/include/Luau/StringUtils.h @@ -3,7 +3,6 @@ #include "Luau/Common.h" -#include #include #include diff --git a/third_party/luau/Ast/include/Luau/TimeTrace.h b/third_party/luau/Ast/include/Luau/TimeTrace.h index be282827..bd2ca86b 100644 --- a/third_party/luau/Ast/include/Luau/TimeTrace.h +++ b/third_party/luau/Ast/include/Luau/TimeTrace.h @@ -4,6 +4,7 @@ #include "Luau/Common.h" #include +#include #include @@ -54,7 +55,7 @@ struct Event struct GlobalContext; struct ThreadContext; -GlobalContext& getGlobalContext(); +std::shared_ptr getGlobalContext(); uint16_t createToken(GlobalContext& context, const char* name, const char* category); uint32_t createThread(GlobalContext& context, ThreadContext* threadContext); @@ -66,7 +67,7 @@ struct ThreadContext ThreadContext() : globalContext(getGlobalContext()) { - threadId = createThread(globalContext, this); + threadId = createThread(*globalContext, this); } ~ThreadContext() @@ -74,16 +75,16 @@ struct ThreadContext if (!events.empty()) flushEvents(); - releaseThread(globalContext, this); + releaseThread(*globalContext, this); } void flushEvents() { - static uint16_t flushToken = createToken(globalContext, "flushEvents", "TimeTrace"); + static uint16_t flushToken = createToken(*globalContext, "flushEvents", "TimeTrace"); events.push_back({EventType::Enter, flushToken, {getClockMicroseconds()}}); - TimeTrace::flushEvents(globalContext, threadId, events, data); + TimeTrace::flushEvents(*globalContext, threadId, events, data); events.clear(); data.clear(); @@ -125,7 +126,7 @@ struct ThreadContext events.push_back({EventType::ArgValue, 0, {pos}}); } - GlobalContext& globalContext; + std::shared_ptr globalContext; uint32_t threadId; std::vector events; std::vector data; @@ -133,6 +134,14 @@ struct ThreadContext static constexpr size_t kEventFlushLimit = 8192; }; +using ThreadContextProvider = ThreadContext& (*)(); + +inline ThreadContextProvider& threadContextProvider() +{ + static ThreadContextProvider handler = nullptr; + return handler; +} + ThreadContext& getThreadContext(); struct Scope diff --git a/third_party/luau/Ast/src/Ast.cpp b/third_party/luau/Ast/src/Ast.cpp index d2c552a3..ff7c7cc6 100644 --- a/third_party/luau/Ast/src/Ast.cpp +++ b/third_party/luau/Ast/src/Ast.cpp @@ -3,6 +3,8 @@ #include "Luau/Common.h" +LUAU_FASTFLAG(LuauNativeAttribute); + namespace Luau { @@ -15,6 +17,17 @@ static void visitTypeList(AstVisitor* visitor, const AstTypeList& list) list.tailType->visit(visitor); } +AstAttr::AstAttr(const Location& location, Type type) + : AstNode(ClassIndex(), location) + , type(type) +{ +} + +void AstAttr::visit(AstVisitor* visitor) +{ + visitor->visit(this); +} + int gAstRttiIndex = 0; AstExprGroup::AstExprGroup(const Location& location, AstExpr* expr) @@ -62,9 +75,10 @@ void AstExprConstantNumber::visit(AstVisitor* visitor) visitor->visit(this); } -AstExprConstantString::AstExprConstantString(const Location& location, const AstArray& value) +AstExprConstantString::AstExprConstantString(const Location& location, const AstArray& value, QuoteStyle quoteStyle) : AstExpr(ClassIndex(), location) , value(value) + , quoteStyle(quoteStyle) { } @@ -127,7 +141,13 @@ void AstExprCall::visit(AstVisitor* visitor) } AstExprIndexName::AstExprIndexName( - const Location& location, AstExpr* expr, const AstName& index, const Location& indexLocation, const Position& opPosition, char op) + const Location& location, + AstExpr* expr, + const AstName& index, + const Location& indexLocation, + const Position& opPosition, + char op +) : AstExpr(ClassIndex(), location) , expr(expr) , index(index) @@ -159,11 +179,24 @@ void AstExprIndexExpr::visit(AstVisitor* visitor) } } -AstExprFunction::AstExprFunction(const Location& location, const AstArray& generics, const AstArray& genericPacks, - AstLocal* self, const AstArray& args, bool vararg, const Location& varargLocation, AstStatBlock* body, size_t functionDepth, - const AstName& debugname, const std::optional& returnAnnotation, AstTypePack* varargAnnotation, bool hasEnd, - const std::optional& argLocation) +AstExprFunction::AstExprFunction( + const Location& location, + const AstArray& attributes, + const AstArray& generics, + const AstArray& genericPacks, + AstLocal* self, + const AstArray& args, + bool vararg, + const Location& varargLocation, + AstStatBlock* body, + size_t functionDepth, + const AstName& debugname, + const std::optional& returnAnnotation, + AstTypePack* varargAnnotation, + const std::optional& argLocation +) : AstExpr(ClassIndex(), location) + , attributes(attributes) , generics(generics) , genericPacks(genericPacks) , self(self) @@ -175,7 +208,6 @@ AstExprFunction::AstExprFunction(const Location& location, const AstArraytype == AstAttr::Type::Native) + return true; + } + return false; +} + AstExprTable::AstExprTable(const Location& location, const AstArray& items) : AstExpr(ClassIndex(), location) , items(items) @@ -278,6 +322,8 @@ std::string toString(AstExprBinary::Op op) return "*"; case AstExprBinary::Div: return "/"; + case AstExprBinary::FloorDiv: + return "//"; case AstExprBinary::Mod: return "%"; case AstExprBinary::Pow: @@ -374,9 +420,10 @@ void AstExprError::visit(AstVisitor* visitor) } } -AstStatBlock::AstStatBlock(const Location& location, const AstArray& body) +AstStatBlock::AstStatBlock(const Location& location, const AstArray& body, bool hasEnd) : AstStat(ClassIndex(), location) , body(body) + , hasEnd(hasEnd) { } @@ -389,15 +436,20 @@ void AstStatBlock::visit(AstVisitor* visitor) } } -AstStatIf::AstStatIf(const Location& location, AstExpr* condition, AstStatBlock* thenbody, AstStat* elsebody, - const std::optional& thenLocation, const std::optional& elseLocation, bool hasEnd) +AstStatIf::AstStatIf( + const Location& location, + AstExpr* condition, + AstStatBlock* thenbody, + AstStat* elsebody, + const std::optional& thenLocation, + const std::optional& elseLocation +) : AstStat(ClassIndex(), location) , condition(condition) , thenbody(thenbody) , elsebody(elsebody) , thenLocation(thenLocation) , elseLocation(elseLocation) - , hasEnd(hasEnd) { } @@ -413,13 +465,12 @@ void AstStatIf::visit(AstVisitor* visitor) } } -AstStatWhile::AstStatWhile(const Location& location, AstExpr* condition, AstStatBlock* body, bool hasDo, const Location& doLocation, bool hasEnd) +AstStatWhile::AstStatWhile(const Location& location, AstExpr* condition, AstStatBlock* body, bool hasDo, const Location& doLocation) : AstStat(ClassIndex(), location) , condition(condition) , body(body) , hasDo(hasDo) , doLocation(doLocation) - , hasEnd(hasEnd) { } @@ -432,11 +483,11 @@ void AstStatWhile::visit(AstVisitor* visitor) } } -AstStatRepeat::AstStatRepeat(const Location& location, AstExpr* condition, AstStatBlock* body, bool hasUntil) +AstStatRepeat::AstStatRepeat(const Location& location, AstExpr* condition, AstStatBlock* body, bool DEPRECATED_hasUntil) : AstStat(ClassIndex(), location) , condition(condition) , body(body) - , hasUntil(hasUntil) + , DEPRECATED_hasUntil(DEPRECATED_hasUntil) { } @@ -497,7 +548,11 @@ void AstStatExpr::visit(AstVisitor* visitor) } AstStatLocal::AstStatLocal( - const Location& location, const AstArray& vars, const AstArray& values, const std::optional& equalsSignLocation) + const Location& location, + const AstArray& vars, + const AstArray& values, + const std::optional& equalsSignLocation +) : AstStat(ClassIndex(), location) , vars(vars) , values(values) @@ -520,8 +575,16 @@ void AstStatLocal::visit(AstVisitor* visitor) } } -AstStatFor::AstStatFor(const Location& location, AstLocal* var, AstExpr* from, AstExpr* to, AstExpr* step, AstStatBlock* body, bool hasDo, - const Location& doLocation, bool hasEnd) +AstStatFor::AstStatFor( + const Location& location, + AstLocal* var, + AstExpr* from, + AstExpr* to, + AstExpr* step, + AstStatBlock* body, + bool hasDo, + const Location& doLocation +) : AstStat(ClassIndex(), location) , var(var) , from(from) @@ -530,7 +593,6 @@ AstStatFor::AstStatFor(const Location& location, AstLocal* var, AstExpr* from, A , body(body) , hasDo(hasDo) , doLocation(doLocation) - , hasEnd(hasEnd) { } @@ -551,8 +613,16 @@ void AstStatFor::visit(AstVisitor* visitor) } } -AstStatForIn::AstStatForIn(const Location& location, const AstArray& vars, const AstArray& values, AstStatBlock* body, - bool hasIn, const Location& inLocation, bool hasDo, const Location& doLocation, bool hasEnd) +AstStatForIn::AstStatForIn( + const Location& location, + const AstArray& vars, + const AstArray& values, + AstStatBlock* body, + bool hasIn, + const Location& inLocation, + bool hasDo, + const Location& doLocation +) : AstStat(ClassIndex(), location) , vars(vars) , values(values) @@ -561,7 +631,6 @@ AstStatForIn::AstStatForIn(const Location& location, const AstArray& , inLocation(inLocation) , hasDo(hasDo) , doLocation(doLocation) - , hasEnd(hasEnd) { } @@ -647,8 +716,15 @@ void AstStatLocalFunction::visit(AstVisitor* visitor) func->visit(visitor); } -AstStatTypeAlias::AstStatTypeAlias(const Location& location, const AstName& name, const Location& nameLocation, - const AstArray& generics, const AstArray& genericPacks, AstType* type, bool exported) +AstStatTypeAlias::AstStatTypeAlias( + const Location& location, + const AstName& name, + const Location& nameLocation, + const AstArray& generics, + const AstArray& genericPacks, + AstType* type, + bool exported +) : AstStat(ClassIndex(), location) , name(name) , nameLocation(nameLocation) @@ -679,9 +755,24 @@ void AstStatTypeAlias::visit(AstVisitor* visitor) } } -AstStatDeclareGlobal::AstStatDeclareGlobal(const Location& location, const AstName& name, AstType* type) +AstStatTypeFunction::AstStatTypeFunction(const Location& location, const AstName& name, const Location& nameLocation, AstExprFunction* body) + : AstStat(ClassIndex(), location) + , name(name) + , nameLocation(nameLocation) + , body(body) +{ +} + +void AstStatTypeFunction::visit(AstVisitor* visitor) +{ + if (visitor->visit(this)) + body->visit(visitor); +} + +AstStatDeclareGlobal::AstStatDeclareGlobal(const Location& location, const AstName& name, const Location& nameLocation, AstType* type) : AstStat(ClassIndex(), location) , name(name) + , nameLocation(nameLocation) , type(type) { } @@ -692,15 +783,55 @@ void AstStatDeclareGlobal::visit(AstVisitor* visitor) type->visit(visitor); } -AstStatDeclareFunction::AstStatDeclareFunction(const Location& location, const AstName& name, const AstArray& generics, - const AstArray& genericPacks, const AstTypeList& params, const AstArray& paramNames, - const AstTypeList& retTypes) +AstStatDeclareFunction::AstStatDeclareFunction( + const Location& location, + const AstName& name, + const Location& nameLocation, + const AstArray& generics, + const AstArray& genericPacks, + const AstTypeList& params, + const AstArray& paramNames, + bool vararg, + const Location& varargLocation, + const AstTypeList& retTypes +) : AstStat(ClassIndex(), location) + , attributes() , name(name) + , nameLocation(nameLocation) , generics(generics) , genericPacks(genericPacks) , params(params) , paramNames(paramNames) + , vararg(vararg) + , varargLocation(varargLocation) + , retTypes(retTypes) +{ +} + +AstStatDeclareFunction::AstStatDeclareFunction( + const Location& location, + const AstArray& attributes, + const AstName& name, + const Location& nameLocation, + const AstArray& generics, + const AstArray& genericPacks, + const AstTypeList& params, + const AstArray& paramNames, + bool vararg, + const Location& varargLocation, + const AstTypeList& retTypes +) + : AstStat(ClassIndex(), location) + , attributes(attributes) + , name(name) + , nameLocation(nameLocation) + , generics(generics) + , genericPacks(genericPacks) + , params(params) + , paramNames(paramNames) + , vararg(vararg) + , varargLocation(varargLocation) , retTypes(retTypes) { } @@ -714,12 +845,29 @@ void AstStatDeclareFunction::visit(AstVisitor* visitor) } } +bool AstStatDeclareFunction::isCheckedFunction() const +{ + for (const AstAttr* attr : attributes) + { + if (attr->type == AstAttr::Type::Checked) + return true; + } + + return false; +} + AstStatDeclareClass::AstStatDeclareClass( - const Location& location, const AstName& name, std::optional superName, const AstArray& props) + const Location& location, + const AstName& name, + std::optional superName, + const AstArray& props, + AstTableIndexer* indexer +) : AstStat(ClassIndex(), location) , name(name) , superName(superName) , props(props) + , indexer(indexer) { } @@ -733,7 +881,11 @@ void AstStatDeclareClass::visit(AstVisitor* visitor) } AstStatError::AstStatError( - const Location& location, const AstArray& expressions, const AstArray& statements, unsigned messageIndex) + const Location& location, + const AstArray& expressions, + const AstArray& statements, + unsigned messageIndex +) : AstStat(ClassIndex(), location) , expressions(expressions) , statements(statements) @@ -753,8 +905,15 @@ void AstStatError::visit(AstVisitor* visitor) } } -AstTypeReference::AstTypeReference(const Location& location, std::optional prefix, AstName name, std::optional prefixLocation, - const Location& nameLocation, bool hasParameterList, const AstArray& parameters) +AstTypeReference::AstTypeReference( + const Location& location, + std::optional prefix, + AstName name, + std::optional prefixLocation, + const Location& nameLocation, + bool hasParameterList, + const AstArray& parameters +) : AstType(ClassIndex(), location) , hasParameterList(hasParameterList) , prefix(prefix) @@ -801,9 +960,36 @@ void AstTypeTable::visit(AstVisitor* visitor) } } -AstTypeFunction::AstTypeFunction(const Location& location, const AstArray& generics, const AstArray& genericPacks, - const AstTypeList& argTypes, const AstArray>& argNames, const AstTypeList& returnTypes) +AstTypeFunction::AstTypeFunction( + const Location& location, + const AstArray& generics, + const AstArray& genericPacks, + const AstTypeList& argTypes, + const AstArray>& argNames, + const AstTypeList& returnTypes +) + : AstType(ClassIndex(), location) + , attributes() + , generics(generics) + , genericPacks(genericPacks) + , argTypes(argTypes) + , argNames(argNames) + , returnTypes(returnTypes) +{ + LUAU_ASSERT(argNames.size == 0 || argNames.size == argTypes.types.size); +} + +AstTypeFunction::AstTypeFunction( + const Location& location, + const AstArray& attributes, + const AstArray& generics, + const AstArray& genericPacks, + const AstTypeList& argTypes, + const AstArray>& argNames, + const AstTypeList& returnTypes +) : AstType(ClassIndex(), location) + , attributes(attributes) , generics(generics) , genericPacks(genericPacks) , argTypes(argTypes) @@ -822,6 +1008,17 @@ void AstTypeFunction::visit(AstVisitor* visitor) } } +bool AstTypeFunction::isCheckedFunction() const +{ + for (const AstAttr* attr : attributes) + { + if (attr->type == AstAttr::Type::Checked) + return true; + } + + return false; +} + AstTypeTypeof::AstTypeTypeof(const Location& location, AstExpr* expr) : AstType(ClassIndex(), location) , expr(expr) @@ -968,4 +1165,4 @@ Location getLocation(const AstTypeList& typeList) return result; } -} // namespace Luau +} // namespace Luau \ No newline at end of file diff --git a/third_party/luau/Ast/src/Confusables.cpp b/third_party/luau/Ast/src/Confusables.cpp index 1c792156..8f7fb56c 100644 --- a/third_party/luau/Ast/src/Confusables.cpp +++ b/third_party/luau/Ast/src/Confusables.cpp @@ -1808,9 +1808,15 @@ static const Confusable kConfusables[] = const char* findConfusable(uint32_t codepoint) { - auto it = std::lower_bound(std::begin(kConfusables), std::end(kConfusables), codepoint, [](const Confusable& lhs, uint32_t rhs) { - return lhs.codepoint < rhs; - }); + auto it = std::lower_bound( + std::begin(kConfusables), + std::end(kConfusables), + codepoint, + [](const Confusable& lhs, uint32_t rhs) + { + return lhs.codepoint < rhs; + } + ); return (it != std::end(kConfusables) && it->codepoint == codepoint) ? it->text : nullptr; } diff --git a/third_party/luau/Ast/src/Lexer.cpp b/third_party/luau/Ast/src/Lexer.cpp index 75b4fe30..a5e1d40e 100644 --- a/third_party/luau/Ast/src/Lexer.cpp +++ b/third_party/luau/Ast/src/Lexer.cpp @@ -1,11 +1,14 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Lexer.h" +#include "Luau/Common.h" #include "Luau/Confusables.h" #include "Luau/StringUtils.h" #include +LUAU_FASTFLAGVARIABLE(LuauLexerLookaheadRemembersBraceType, false) + namespace Luau { @@ -89,8 +92,10 @@ Lexeme::Lexeme(const Location& location, Type type, const char* data, size_t siz , length(unsigned(size)) , data(data) { - LUAU_ASSERT(type == RawString || type == QuotedString || type == InterpStringBegin || type == InterpStringMid || type == InterpStringEnd || - type == InterpStringSimple || type == BrokenInterpDoubleBrace || type == Number || type == Comment || type == BlockComment); + LUAU_ASSERT( + type == RawString || type == QuotedString || type == InterpStringBegin || type == InterpStringMid || type == InterpStringEnd || + type == InterpStringSimple || type == BrokenInterpDoubleBrace || type == Number || type == Comment || type == BlockComment + ); } Lexeme::Lexeme(const Location& location, Type type, const char* name) @@ -99,11 +104,21 @@ Lexeme::Lexeme(const Location& location, Type type, const char* name) , length(0) , name(name) { - LUAU_ASSERT(type == Name || (type >= Reserved_BEGIN && type < Lexeme::Reserved_END)); + LUAU_ASSERT(type == Name || type == Attribute || (type >= Reserved_BEGIN && type < Lexeme::Reserved_END)); +} + +unsigned int Lexeme::getLength() const +{ + LUAU_ASSERT( + type == RawString || type == QuotedString || type == InterpStringBegin || type == InterpStringMid || type == InterpStringEnd || + type == InterpStringSimple || type == BrokenInterpDoubleBrace || type == Number || type == Comment || type == BlockComment + ); + + return length; } -static const char* kReserved[] = {"and", "break", "do", "else", "elseif", "end", "false", "for", "function", "if", "in", "local", "nil", "not", "or", - "repeat", "return", "then", "true", "until", "while"}; +static const char* kReserved[] = {"and", "break", "do", "else", "elseif", "end", "false", "for", "function", "if", "in", + "local", "nil", "not", "or", "repeat", "return", "then", "true", "until", "while"}; std::string Lexeme::toString() const { @@ -136,6 +151,9 @@ std::string Lexeme::toString() const case DoubleColon: return "'::'"; + case FloorDiv: + return "'//'"; + case AddAssign: return "'+='"; @@ -148,6 +166,9 @@ std::string Lexeme::toString() const case DivAssign: return "'/='"; + case FloorDivAssign: + return "'//='"; + case ModAssign: return "'%='"; @@ -182,6 +203,9 @@ std::string Lexeme::toString() const case Comment: return "comment"; + case Attribute: + return name ? format("'%s'", name) : "attribute"; + case BrokenString: return "malformed string"; @@ -269,7 +293,7 @@ std::pair AstNameTable::getOrAddWithType(const char* name nameData[length] = 0; const_cast(entry).value = AstName(nameData); - const_cast(entry).type = Lexeme::Name; + const_cast(entry).type = (name[0] == '@' ? Lexeme::Attribute : Lexeme::Name); return std::make_pair(entry.value, entry.type); } @@ -373,7 +397,7 @@ const Lexeme& Lexer::next(bool skipComments, bool updatePrevLocation) { // consume whitespace before the token while (isSpace(peekch())) - consume(); + consumeAny(); if (updatePrevLocation) prevLocation = lexeme.location; @@ -400,6 +424,8 @@ Lexeme Lexer::lookahead() unsigned int currentLineOffset = lineOffset; Lexeme currentLexeme = lexeme; Location currentPrevLocation = prevLocation; + size_t currentBraceStackSize = braceStack.size(); + BraceType currentBraceType = braceStack.empty() ? BraceType::Normal : braceStack.back(); Lexeme result = next(); @@ -408,6 +434,13 @@ Lexeme Lexer::lookahead() lineOffset = currentLineOffset; lexeme = currentLexeme; prevLocation = currentPrevLocation; + if (FFlag::LuauLexerLookaheadRemembersBraceType) + { + if (braceStack.size() < currentBraceStackSize) + braceStack.push_back(currentBraceType); + else if (braceStack.size() > currentBraceStackSize) + braceStack.pop_back(); + } return result; } @@ -438,7 +471,17 @@ Position Lexer::position() const return Position(line, offset - lineOffset); } +LUAU_FORCEINLINE void Lexer::consume() +{ + // consume() assumes current character is known to not be a newline; use consumeAny if this is not guaranteed + LUAU_ASSERT(!isNewline(buffer[offset])); + + offset++; +} + +LUAU_FORCEINLINE +void Lexer::consumeAny() { if (isNewline(buffer[offset])) { @@ -524,7 +567,7 @@ Lexeme Lexer::readLongString(const Position& start, int sep, Lexeme::Type ok, Le } else { - consume(); + consumeAny(); } } @@ -540,7 +583,7 @@ void Lexer::readBackslashInString() case '\r': consume(); if (peekch() == '\n') - consume(); + consumeAny(); break; case 0: @@ -549,11 +592,11 @@ void Lexer::readBackslashInString() case 'z': consume(); while (isSpace(peekch())) - consume(); + consumeAny(); break; default: - consume(); + consumeAny(); } } @@ -681,7 +724,7 @@ Lexeme Lexer::readNumber(const Position& start, unsigned int startOffset) std::pair Lexer::readName() { - LUAU_ASSERT(isAlpha(peekch()) || peekch() == '_'); + LUAU_ASSERT(isAlpha(peekch()) || peekch() == '_' || peekch() == '@'); unsigned int startOffset = offset; @@ -878,15 +921,31 @@ Lexeme Lexer::readNext() return Lexeme(Location(start, 1), '+'); case '/': + { consume(); - if (peekch() == '=') + char ch = peekch(); + + if (ch == '=') { consume(); return Lexeme(Location(start, 2), Lexeme::DivAssign); } + else if (ch == '/') + { + consume(); + + if (peekch() == '=') + { + consume(); + return Lexeme(Location(start, 3), Lexeme::FloorDivAssign); + } + else + return Lexeme(Location(start, 2), Lexeme::FloorDiv); + } else return Lexeme(Location(start, 1), '/'); + } case '*': consume(); @@ -939,13 +998,20 @@ Lexeme Lexer::readNext() case ';': case ',': case '#': + case '?': + case '&': + case '|': { char ch = peekch(); consume(); return Lexeme(Location(start, 1), ch); } - + case '@': + { + std::pair attribute = readName(); + return Lexeme(Location(start, position()), Lexeme::Attribute, attribute.first.value); + } default: if (isDigit(peekch())) { diff --git a/third_party/luau/Ast/src/Location.cpp b/third_party/luau/Ast/src/Location.cpp index d01d8a18..c2c66d9f 100644 --- a/third_party/luau/Ast/src/Location.cpp +++ b/third_party/luau/Ast/src/Location.cpp @@ -4,12 +4,6 @@ namespace Luau { -Position::Position(unsigned int line, unsigned int column) - : line(line) - , column(column) -{ -} - bool Position::operator==(const Position& rhs) const { return this->column == rhs.column && this->line == rhs.line; @@ -60,30 +54,6 @@ void Position::shift(const Position& start, const Position& oldEnd, const Positi } } -Location::Location() - : begin(0, 0) - , end(0, 0) -{ -} - -Location::Location(const Position& begin, const Position& end) - : begin(begin) - , end(end) -{ -} - -Location::Location(const Position& begin, unsigned int length) - : begin(begin) - , end(begin.line, begin.column + length) -{ -} - -Location::Location(const Location& begin, const Location& end) - : begin(begin.begin) - , end(end.end) -{ -} - bool Location::operator==(const Location& rhs) const { return this->begin == rhs.begin && this->end == rhs.end; diff --git a/third_party/luau/Ast/src/Parser.cpp b/third_party/luau/Ast/src/Parser.cpp index 6a76eda2..d58fe1e8 100644 --- a/third_party/luau/Ast/src/Parser.cpp +++ b/third_party/luau/Ast/src/Parser.cpp @@ -1,6 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Parser.h" +#include "Luau/Common.h" #include "Luau/TimeTrace.h" #include @@ -8,19 +9,30 @@ #include #include -// Warning: If you are introducing new syntax, ensure that it is behind a separate -// flag so that we don't break production games by reverting syntax changes. -// See docs/SyntaxChanges.md for an explanation. LUAU_FASTINTVARIABLE(LuauRecursionLimit, 1000) +LUAU_FASTINTVARIABLE(LuauTypeLengthLimit, 1000) LUAU_FASTINTVARIABLE(LuauParseErrorLimit, 100) -LUAU_FASTFLAGVARIABLE(LuauParserErrorsOnMissingDefaultTypePackArgument, false) - -#define ERROR_INVALID_INTERP_DOUBLE_BRACE "Double braces are not permitted within interpolated strings. Did you mean '\\{'?" +// Warning: If you are introducing new syntax, ensure that it is behind a separate +// flag so that we don't break production games by reverting syntax changes. +// See docs/SyntaxChanges.md for an explanation. +LUAU_FASTFLAGVARIABLE(DebugLuauDeferredConstraintResolution, false) +LUAU_FASTFLAGVARIABLE(LuauNativeAttribute, false) +LUAU_FASTFLAGVARIABLE(LuauAttributeSyntaxFunExpr, false) +LUAU_FASTFLAGVARIABLE(LuauDeclarationExtraPropData, false) +LUAU_FASTFLAGVARIABLE(LuauUserDefinedTypeFunctions, false) namespace Luau { +struct AttributeEntry +{ + const char* name; + AstAttr::Type type; +}; + +AttributeEntry kAttributeEntries[] = {{"@checked", AstAttr::Type::Checked}, {"@native", AstAttr::Type::Native}, {nullptr, AstAttr::Type::Checked}}; + ParseError::ParseError(const Location& location, const std::string& message) : location(location) , message(message) @@ -244,13 +256,13 @@ AstStatBlock* Parser::parseBlockNoScope() while (!blockFollow(lexer.current())) { - unsigned int recursionCounterOld = recursionCounter; + unsigned int oldRecursionCount = recursionCounter; incrementRecursionCounter("block"); AstStat* stat = parseStat(); - recursionCounter = recursionCounterOld; + recursionCounter = oldRecursionCount; if (lexer.current().type == ';') { @@ -279,7 +291,9 @@ AstStatBlock* Parser::parseBlockNoScope() // for binding `=' exp `,' exp [`,' exp] do block end | // for namelist in explist do block end | // function funcname funcbody | +// attributes function funcname funcbody | // local function Name funcbody | +// local attributes function Name funcbody | // local namelist [`=' explist] // laststat ::= return [explist] | break AstStat* Parser::parseStat() @@ -298,13 +312,15 @@ AstStat* Parser::parseStat() case Lexeme::ReservedRepeat: return parseRepeat(); case Lexeme::ReservedFunction: - return parseFunctionStat(); + return parseFunctionStat(AstArray({nullptr, 0})); case Lexeme::ReservedLocal: - return parseLocal(); + return parseLocal(AstArray({nullptr, 0})); case Lexeme::ReservedReturn: return parseReturn(); case Lexeme::ReservedBreak: return parseBreak(); + case Lexeme::Attribute: + return parseAttributeStat(); default:; } @@ -327,25 +343,22 @@ AstStat* Parser::parseStat() // we know this isn't a call or an assignment; therefore it must be a context-sensitive keyword such as `type` or `continue` AstName ident = getIdentifier(expr); - if (options.allowTypeAnnotations) - { - if (ident == "type") - return parseTypeAlias(expr->location, /* exported= */ false); + if (ident == "type") + return parseTypeAlias(expr->location, /* exported= */ false); - if (ident == "export" && lexer.current().type == Lexeme::Name && AstName(lexer.current().name) == "type") - { - nextLexeme(); - return parseTypeAlias(expr->location, /* exported= */ true); - } + if (ident == "export" && lexer.current().type == Lexeme::Name && AstName(lexer.current().name) == "type") + { + nextLexeme(); + return parseTypeAlias(expr->location, /* exported= */ true); } - if (options.supportContinueStatement && ident == "continue") + if (ident == "continue") return parseContinue(expr->location); - if (options.allowTypeAnnotations && options.allowDeclarationSyntax) + if (options.allowDeclarationSyntax) { if (ident == "declare") - return parseDeclaration(expr->location); + return parseDeclaration(expr->location, AstArray({nullptr, 0})); } // skip unexpected symbol if lexer couldn't advance at all (statements are parsed in a loop) @@ -374,17 +387,16 @@ AstStat* Parser::parseIf() AstStat* elsebody = nullptr; Location end = start; std::optional elseLocation; - bool hasEnd = false; if (lexer.current().type == Lexeme::ReservedElseif) { - unsigned int recursionCounterOld = recursionCounter; + thenbody->hasEnd = true; + unsigned int oldRecursionCount = recursionCounter; incrementRecursionCounter("elseif"); elseLocation = lexer.current().location; elsebody = parseIf(); end = elsebody->location; - hasEnd = elsebody->as()->hasEnd; - recursionCounter = recursionCounterOld; + recursionCounter = oldRecursionCount; } else { @@ -392,6 +404,7 @@ AstStat* Parser::parseIf() if (lexer.current().type == Lexeme::ReservedElse) { + thenbody->hasEnd = true; elseLocation = lexer.current().location; matchThenElse = lexer.current(); nextLexeme(); @@ -402,10 +415,18 @@ AstStat* Parser::parseIf() end = lexer.current().location; - hasEnd = expectMatchEndAndConsume(Lexeme::ReservedEnd, matchThenElse); + bool hasEnd = expectMatchEndAndConsume(Lexeme::ReservedEnd, matchThenElse); + + if (elsebody) + { + if (AstStatBlock* elseBlock = elsebody->as()) + elseBlock->hasEnd = hasEnd; + } + else + thenbody->hasEnd = hasEnd; } - return allocator.alloc(Location(start, end), cond, thenbody, elsebody, thenLocation, elseLocation, hasEnd); + return allocator.alloc(Location(start, end), cond, thenbody, elsebody, thenLocation, elseLocation); } // while exp do block end @@ -429,8 +450,9 @@ AstStat* Parser::parseWhile() Location end = lexer.current().location; bool hasEnd = expectMatchEndAndConsume(Lexeme::ReservedEnd, matchDo); + body->hasEnd = hasEnd; - return allocator.alloc(Location(start, end), cond, body, hasDo, matchDo.location, hasEnd); + return allocator.alloc(Location(start, end), cond, body, hasDo, matchDo.location); } // repeat block until exp @@ -450,6 +472,7 @@ AstStat* Parser::parseRepeat() functionStack.back().loopDepth--; bool hasUntil = expectMatchEndAndConsume(Lexeme::ReservedUntil, matchRepeat); + body->hasEnd = hasUntil; AstExpr* cond = parseExpr(); @@ -466,11 +489,11 @@ AstStat* Parser::parseDo() Lexeme matchDo = lexer.current(); nextLexeme(); // do - AstStat* body = parseBlock(); + AstStatBlock* body = parseBlock(); body->location.begin = start.begin; - expectMatchEndAndConsume(Lexeme::ReservedEnd, matchDo); + body->hasEnd = expectMatchEndAndConsume(Lexeme::ReservedEnd, matchDo); return body; } @@ -546,8 +569,9 @@ AstStat* Parser::parseFor() Location end = lexer.current().location; bool hasEnd = expectMatchEndAndConsume(Lexeme::ReservedEnd, matchDo); + body->hasEnd = hasEnd; - return allocator.alloc(Location(start, end), var, from, to, step, body, hasDo, matchDo.location, hasEnd); + return allocator.alloc(Location(start, end), var, from, to, step, body, hasDo, matchDo.location); } else { @@ -588,9 +612,9 @@ AstStat* Parser::parseFor() Location end = lexer.current().location; bool hasEnd = expectMatchEndAndConsume(Lexeme::ReservedEnd, matchDo); + body->hasEnd = hasEnd; - return allocator.alloc( - Location(start, end), copy(vars), copy(values), body, hasIn, inLocation, hasDo, matchDo.location, hasEnd); + return allocator.alloc(Location(start, end), copy(vars), copy(values), body, hasIn, inLocation, hasDo, matchDo.location); } } @@ -603,7 +627,7 @@ AstExpr* Parser::parseFunctionName(Location start, bool& hasself, AstName& debug // parse funcname into a chain of indexing operators AstExpr* expr = parseNameExpr("function name"); - unsigned int recursionCounterOld = recursionCounter; + unsigned int oldRecursionCount = recursionCounter; while (lexer.current().type == '.') { @@ -621,7 +645,7 @@ AstExpr* Parser::parseFunctionName(Location start, bool& hasself, AstName& debug incrementRecursionCounter("function name"); } - recursionCounter = recursionCounterOld; + recursionCounter = oldRecursionCount; // finish with : if (lexer.current().type == ':') @@ -643,7 +667,7 @@ AstExpr* Parser::parseFunctionName(Location start, bool& hasself, AstName& debug } // function funcname funcbody -AstStat* Parser::parseFunctionStat() +AstStat* Parser::parseFunctionStat(const AstArray& attributes) { Location start = lexer.current().location; @@ -656,16 +680,125 @@ AstStat* Parser::parseFunctionStat() matchRecoveryStopOnToken[Lexeme::ReservedEnd]++; - AstExprFunction* body = parseFunctionBody(hasself, matchFunction, debugname, nullptr).first; + AstExprFunction* body = parseFunctionBody(hasself, matchFunction, debugname, nullptr, attributes).first; matchRecoveryStopOnToken[Lexeme::ReservedEnd]--; return allocator.alloc(Location(start, body->location), expr, body); } + +std::pair Parser::validateAttribute(const char* attributeName, const TempVector& attributes) +{ + AstAttr::Type type; + + // check if the attribute name is valid + + bool found = false; + + for (int i = 0; kAttributeEntries[i].name; ++i) + { + found = !strcmp(attributeName, kAttributeEntries[i].name); + if (found) + { + type = kAttributeEntries[i].type; + + if (!FFlag::LuauNativeAttribute && type == AstAttr::Type::Native) + found = false; + + break; + } + } + + if (!found) + { + if (strlen(attributeName) == 1) + report(lexer.current().location, "Attribute name is missing"); + else + report(lexer.current().location, "Invalid attribute '%s'", attributeName); + } + else + { + // check that attribute is not duplicated + for (const AstAttr* attr : attributes) + { + if (attr->type == type) + { + report(lexer.current().location, "Cannot duplicate attribute '%s'", attributeName); + } + } + } + + return {found, type}; +} + +// attribute ::= '@' NAME +void Parser::parseAttribute(TempVector& attributes) +{ + LUAU_ASSERT(lexer.current().type == Lexeme::Type::Attribute); + + Location loc = lexer.current().location; + + const char* name = lexer.current().name; + const auto [found, type] = validateAttribute(name, attributes); + + nextLexeme(); + + if (found) + attributes.push_back(allocator.alloc(loc, type)); +} + +// attributes ::= {attribute} +AstArray Parser::parseAttributes() +{ + Lexeme::Type type = lexer.current().type; + + LUAU_ASSERT(type == Lexeme::Attribute); + + TempVector attributes(scratchAttr); + + while (lexer.current().type == Lexeme::Attribute) + parseAttribute(attributes); + + return copy(attributes); +} + +// attributes local function Name funcbody +// attributes function funcname funcbody +// attributes `declare function' Name`(' [parlist] `)' [`:` Type] +// declare Name '{' Name ':' attributes `(' [parlist] `)' [`:` Type] '}' +AstStat* Parser::parseAttributeStat() +{ + AstArray attributes = parseAttributes(); + + Lexeme::Type type = lexer.current().type; + + switch (type) + { + case Lexeme::Type::ReservedFunction: + return parseFunctionStat(attributes); + case Lexeme::Type::ReservedLocal: + return parseLocal(attributes); + case Lexeme::Type::Name: + if (options.allowDeclarationSyntax && !strcmp("declare", lexer.current().data)) + { + AstExpr* expr = parsePrimaryExpr(/* asStatement= */ true); + return parseDeclaration(expr->location, attributes); + } + default: + return reportStatError( + lexer.current().location, + {}, + {}, + "Expected 'function', 'local function', 'declare function' or a function type declaration after attribute, but got %s instead", + lexer.current().toString().c_str() + ); + } +} + // local function Name funcbody | // local bindinglist [`=' explist] -AstStat* Parser::parseLocal() +AstStat* Parser::parseLocal(const AstArray& attributes) { Location start = lexer.current().location; @@ -685,7 +818,7 @@ AstStat* Parser::parseLocal() matchRecoveryStopOnToken[Lexeme::ReservedEnd]++; - auto [body, var] = parseFunctionBody(false, matchFunction, name.name, &name); + auto [body, var] = parseFunctionBody(false, matchFunction, name.name, &name, attributes); matchRecoveryStopOnToken[Lexeme::ReservedEnd]--; @@ -695,6 +828,17 @@ AstStat* Parser::parseLocal() } else { + if (attributes.size != 0) + { + return reportStatError( + lexer.current().location, + {}, + {}, + "Expected 'function' after local declaration with attribute, but got %s instead", + lexer.current().toString().c_str() + ); + } + matchRecoveryStopOnToken['=']++; TempVector names(scratchBinding); @@ -746,6 +890,15 @@ AstStat* Parser::parseReturn() // type Name [`<' varlist `>'] `=' Type AstStat* Parser::parseTypeAlias(const Location& start, bool exported) { + // parsing a type function + if (FFlag::LuauUserDefinedTypeFunctions) + { + if (lexer.current().type == Lexeme::ReservedFunction) + return parseTypeFunction(start); + } + + // parsing a type alias + // note: `type` token is already parsed for us, so we just need to parse the rest std::optional name = parseNameOpt("type name"); @@ -763,10 +916,38 @@ AstStat* Parser::parseTypeAlias(const Location& start, bool exported) return allocator.alloc(Location(start, type->location), name->name, name->location, generics, genericPacks, type, exported); } +// type function Name `(' arglist `)' `=' funcbody `end' +AstStat* Parser::parseTypeFunction(const Location& start) +{ + Lexeme matchFn = lexer.current(); + nextLexeme(); + + // parse the name of the type function + std::optional fnName = parseNameOpt("type function name"); + if (!fnName) + fnName = Name(nameError, lexer.current().location); + + matchRecoveryStopOnToken[Lexeme::ReservedEnd]++; + + AstExprFunction* body = parseFunctionBody(/* hasself */ false, matchFn, fnName->name, nullptr, AstArray({nullptr, 0})).first; + + matchRecoveryStopOnToken[Lexeme::ReservedEnd]--; + + return allocator.alloc(Location(start, body->location), fnName->name, fnName->location, body); +} + AstDeclaredClassProp Parser::parseDeclaredClassMethod() { + Location start; + + if (FFlag::LuauDeclarationExtraPropData) + start = lexer.current().location; + nextLexeme(); - Location start = lexer.current().location; + + if (!FFlag::LuauDeclarationExtraPropData) + start = lexer.current().location; + Name fnName = parseName("function name"); // TODO: generic method declarations CLI-39909 @@ -791,7 +972,7 @@ AstDeclaredClassProp Parser::parseDeclaredClassMethod() expectMatchAndConsume(')', matchParen); AstTypeList retTypes = parseOptionalReturnType().value_or(AstTypeList{copy(nullptr, 0), nullptr}); - Location end = lexer.current().location; + Location end = FFlag::LuauDeclarationExtraPropData ? lexer.previousLocation() : lexer.current().location; TempVector vars(scratchType); TempVector> varNames(scratchOptArgName); @@ -799,7 +980,11 @@ AstDeclaredClassProp Parser::parseDeclaredClassMethod() if (args.size() == 0 || args[0].name.name != "self" || args[0].annotation != nullptr) { return AstDeclaredClassProp{ - fnName.name, reportTypeError(Location(start, end), {}, "'self' must be present as the unannotated first parameter"), true}; + fnName.name, + FFlag::LuauDeclarationExtraPropData ? fnName.location : Location{}, + reportTypeError(Location(start, end), {}, "'self' must be present as the unannotated first parameter"), + true + }; } // Skip the first index. @@ -817,19 +1002,36 @@ AstDeclaredClassProp Parser::parseDeclaredClassMethod() report(start, "All declaration parameters aside from 'self' must be annotated"); AstType* fnType = allocator.alloc( - Location(start, end), generics, genericPacks, AstTypeList{copy(vars), varargAnnotation}, copy(varNames), retTypes); - - return AstDeclaredClassProp{fnName.name, fnType, true}; + Location(start, end), generics, genericPacks, AstTypeList{copy(vars), varargAnnotation}, copy(varNames), retTypes + ); + + return AstDeclaredClassProp{ + fnName.name, + FFlag::LuauDeclarationExtraPropData ? fnName.location : Location{}, + fnType, + true, + FFlag::LuauDeclarationExtraPropData ? Location(start, end) : Location{} + }; } -AstStat* Parser::parseDeclaration(const Location& start) +AstStat* Parser::parseDeclaration(const Location& start, const AstArray& attributes) { // `declare` token is already parsed at this point + + if ((attributes.size != 0) && (lexer.current().type != Lexeme::ReservedFunction)) + return reportStatError( + lexer.current().location, + {}, + {}, + "Expected a function type declaration after attribute, but got %s instead", + lexer.current().toString().c_str() + ); + if (lexer.current().type == Lexeme::ReservedFunction) { nextLexeme(); - Name globalName = parseName("global function name"); + Name globalName = parseName("global function name"); auto [generics, genericPacks] = parseGenericTypeList(/* withDefaultValues= */ false); MatchLexeme matchParen = lexer.current(); @@ -865,8 +1067,34 @@ AstStat* Parser::parseDeclaration(const Location& start) if (vararg && !varargAnnotation) return reportStatError(Location(start, end), {}, {}, "All declaration parameters must be annotated"); - return allocator.alloc( - Location(start, end), globalName.name, generics, genericPacks, AstTypeList{copy(vars), varargAnnotation}, copy(varNames), retTypes); + if (FFlag::LuauDeclarationExtraPropData) + return allocator.alloc( + Location(start, end), + attributes, + globalName.name, + globalName.location, + generics, + genericPacks, + AstTypeList{copy(vars), varargAnnotation}, + copy(varNames), + vararg, + varargLocation, + retTypes + ); + else + return allocator.alloc( + Location(start, end), + attributes, + globalName.name, + Location{}, + generics, + genericPacks, + AstTypeList{copy(vars), varargAnnotation}, + copy(varNames), + false, + Location{}, + retTypes + ); } else if (AstName(lexer.current().name) == "class") { @@ -882,6 +1110,7 @@ AstStat* Parser::parseDeclaration(const Location& start) } TempVector props(scratchDeclaredClassProps); + AstTableIndexer* indexer = nullptr; while (lexer.current().type != Lexeme::ReservedEnd) { @@ -890,45 +1119,96 @@ AstStat* Parser::parseDeclaration(const Location& start) { props.push_back(parseDeclaredClassMethod()); } - else if (lexer.current().type == '[') + else if (lexer.current().type == '[' && (lexer.lookahead().type == Lexeme::RawString || lexer.lookahead().type == Lexeme::QuotedString)) { const Lexeme begin = lexer.current(); nextLexeme(); // [ - std::optional> chars = parseCharArray(); + if (FFlag::LuauDeclarationExtraPropData) + { + const Location nameBegin = lexer.current().location; + std::optional> chars = parseCharArray(); + + const Location nameEnd = lexer.previousLocation(); - expectMatchAndConsume(']', begin); - expectAndConsume(':', "property type annotation"); - AstType* type = parseType(); + expectMatchAndConsume(']', begin); + expectAndConsume(':', "property type annotation"); + AstType* type = parseType(); - // TODO: since AstName conains a char*, it can't contain null - bool containsNull = chars && (strnlen(chars->data, chars->size) < chars->size); + // since AstName contains a char*, it can't contain null + bool containsNull = chars && (strnlen(chars->data, chars->size) < chars->size); - if (chars && !containsNull) - props.push_back(AstDeclaredClassProp{AstName(chars->data), type, false}); + if (chars && !containsNull) + props.push_back(AstDeclaredClassProp{ + AstName(chars->data), Location(nameBegin, nameEnd), type, false, Location(begin.location, lexer.previousLocation()) + }); + else + report(begin.location, "String literal contains malformed escape sequence or \\0"); + } + else + { + std::optional> chars = parseCharArray(); + + expectMatchAndConsume(']', begin); + expectAndConsume(':', "property type annotation"); + AstType* type = parseType(); + + // since AstName contains a char*, it can't contain null + bool containsNull = chars && (strnlen(chars->data, chars->size) < chars->size); + + if (chars && !containsNull) + props.push_back(AstDeclaredClassProp{AstName(chars->data), Location{}, type, false}); + else + report(begin.location, "String literal contains malformed escape sequence or \\0"); + } + } + else if (lexer.current().type == '[') + { + if (indexer) + { + // maybe we don't need to parse the entire badIndexer... + // however, we either have { or [ to lint, not the entire table type or the bad indexer. + AstTableIndexer* badIndexer = parseTableIndexer(AstTableAccess::ReadWrite, std::nullopt); + + // we lose all additional indexer expressions from the AST after error recovery here + report(badIndexer->location, "Cannot have more than one class indexer"); + } else - report(begin.location, "String literal contains malformed escape sequence"); + { + indexer = parseTableIndexer(AstTableAccess::ReadWrite, std::nullopt); + } + } + else if (FFlag::LuauDeclarationExtraPropData) + { + Location propStart = lexer.current().location; + Name propName = parseName("property name"); + expectAndConsume(':', "property type annotation"); + AstType* propType = parseType(); + props.push_back(AstDeclaredClassProp{propName.name, propName.location, propType, false, Location(propStart, lexer.previousLocation())} + ); } else { Name propName = parseName("property name"); expectAndConsume(':', "property type annotation"); AstType* propType = parseType(); - props.push_back(AstDeclaredClassProp{propName.name, propType, false}); + props.push_back(AstDeclaredClassProp{propName.name, Location{}, propType, false}); } } Location classEnd = lexer.current().location; nextLexeme(); // skip past `end` - return allocator.alloc(Location(classStart, classEnd), className.name, superName, copy(props)); + return allocator.alloc(Location(classStart, classEnd), className.name, superName, copy(props), indexer); } else if (std::optional globalName = parseNameOpt("global variable name")) { expectAndConsume(':', "global variable declaration"); - AstType* type = parseType(); - return allocator.alloc(Location(start, type->location), globalName->name, type); + AstType* type = parseType(/* in declaration context */ true); + return allocator.alloc( + Location(start, type->location), globalName->name, FFlag::LuauDeclarationExtraPropData ? globalName->location : Location{}, type + ); } else { @@ -1003,7 +1283,12 @@ std::pair> Parser::prepareFunctionArguments(const // funcbody ::= `(' [parlist] `)' [`:' ReturnType] block end // parlist ::= bindinglist [`,' `...'] | `...' std::pair Parser::parseFunctionBody( - bool hasself, const Lexeme& matchFunction, const AstName& debugname, const Name* localName) + bool hasself, + const Lexeme& matchFunction, + const AstName& debugname, + const Name* localName, + const AstArray& attributes +) { Location start = matchFunction.location; @@ -1053,10 +1338,27 @@ std::pair Parser::parseFunctionBody( Location end = lexer.current().location; bool hasEnd = expectMatchEndAndConsume(Lexeme::ReservedEnd, matchFunction); - - return {allocator.alloc(Location(start, end), generics, genericPacks, self, vars, vararg, varargLocation, body, - functionStack.size(), debugname, typelist, varargAnnotation, hasEnd, argLocation), - funLocal}; + body->hasEnd = hasEnd; + + return { + allocator.alloc( + Location(start, end), + attributes, + generics, + genericPacks, + self, + vars, + vararg, + varargLocation, + body, + functionStack.size(), + debugname, + typelist, + varargAnnotation, + argLocation + ), + funLocal + }; } // explist ::= {exp `,'} exp @@ -1123,7 +1425,7 @@ std::tuple Parser::parseBindingList(TempVector& result, TempVector Parser::parseOptionalReturnType() { - if (options.allowTypeAnnotations && (lexer.current().type == ':' || lexer.current().type == Lexeme::SkinnyArrow)) + if (lexer.current().type == ':' || lexer.current().type == Lexeme::SkinnyArrow) { if (lexer.current().type == Lexeme::SkinnyArrow) report(lexer.current().location, "Function return type annotations are written after ':' instead of '->'"); @@ -1263,13 +1565,13 @@ std::pair Parser::parseReturnType() return {location, AstTypeList{copy(result), varargAnnotation}}; } - AstType* tail = parseFunctionTypeTail(begin, {}, {}, copy(result), copy(resultNames), varargAnnotation); + AstType* tail = parseFunctionTypeTail(begin, {nullptr, 0}, {}, {}, copy(result), copy(resultNames), varargAnnotation); return {Location{location, tail->location}, AstTypeList{copy(&tail, 1), varargAnnotation}}; } // TableIndexer ::= `[' Type `]' `:' Type -AstTableIndexer* Parser::parseTableIndexer() +AstTableIndexer* Parser::parseTableIndexer(AstTableAccess access, std::optional accessLocation) { const Lexeme begin = lexer.current(); nextLexeme(); // [ @@ -1282,14 +1584,14 @@ AstTableIndexer* Parser::parseTableIndexer() AstType* result = parseType(); - return allocator.alloc(AstTableIndexer{index, result, Location(begin.location, result->location)}); + return allocator.alloc(AstTableIndexer{index, result, Location(begin.location, result->location), access, accessLocation}); } // TableProp ::= Name `:' Type // TablePropOrIndexer ::= TableProp | TableIndexer // PropList ::= TablePropOrIndexer {fieldsep TablePropOrIndexer} [fieldsep] // TableType ::= `{' PropList `}' -AstType* Parser::parseTableType() +AstType* Parser::parseTableType(bool inDeclarationContext) { incrementRecursionCounter("type annotation"); @@ -1303,6 +1605,25 @@ AstType* Parser::parseTableType() while (lexer.current().type != '}') { + AstTableAccess access = AstTableAccess::ReadWrite; + std::optional accessLocation; + + if (lexer.current().type == Lexeme::Name && lexer.lookahead().type != ':') + { + if (AstName(lexer.current().name) == "read") + { + accessLocation = lexer.current().location; + access = AstTableAccess::Read; + lexer.next(); + } + else if (AstName(lexer.current().name) == "write") + { + accessLocation = lexer.current().location; + access = AstTableAccess::Write; + lexer.next(); + } + } + if (lexer.current().type == '[' && (lexer.lookahead().type == Lexeme::RawString || lexer.lookahead().type == Lexeme::QuotedString)) { const Lexeme begin = lexer.current(); @@ -1314,13 +1635,13 @@ AstType* Parser::parseTableType() AstType* type = parseType(); - // TODO: since AstName conains a char*, it can't contain null + // since AstName contains a char*, it can't contain null bool containsNull = chars && (strnlen(chars->data, chars->size) < chars->size); if (chars && !containsNull) - props.push_back({AstName(chars->data), begin.location, type}); + props.push_back(AstTableProp{AstName(chars->data), begin.location, type, access, accessLocation}); else - report(begin.location, "String literal contains malformed escape sequence"); + report(begin.location, "String literal contains malformed escape sequence or \\0"); } else if (lexer.current().type == '[') { @@ -1328,14 +1649,14 @@ AstType* Parser::parseTableType() { // maybe we don't need to parse the entire badIndexer... // however, we either have { or [ to lint, not the entire table type or the bad indexer. - AstTableIndexer* badIndexer = parseTableIndexer(); + AstTableIndexer* badIndexer = parseTableIndexer(access, accessLocation); // we lose all additional indexer expressions from the AST after error recovery here report(badIndexer->location, "Cannot have more than one table indexer"); } else { - indexer = parseTableIndexer(); + indexer = parseTableIndexer(access, accessLocation); } } else if (props.empty() && !indexer && !(lexer.current().type == Lexeme::Name && lexer.lookahead().type == ':')) @@ -1344,7 +1665,7 @@ AstType* Parser::parseTableType() // array-like table type: {T} desugars into {[number]: T} AstType* index = allocator.alloc(type->location, std::nullopt, nameNumber, std::nullopt, type->location); - indexer = allocator.alloc(AstTableIndexer{index, type, type->location}); + indexer = allocator.alloc(AstTableIndexer{index, type, type->location, access, accessLocation}); break; } @@ -1357,9 +1678,9 @@ AstType* Parser::parseTableType() expectAndConsume(':', "table field"); - AstType* type = parseType(); + AstType* type = parseType(inDeclarationContext); - props.push_back({name->name, name->location, type}); + props.push_back(AstTableProp{name->name, name->location, type, access, accessLocation}); } if (lexer.current().type == ',' || lexer.current().type == ';') @@ -1383,7 +1704,7 @@ AstType* Parser::parseTableType() // ReturnType ::= Type | `(' TypeList `)' // FunctionType ::= [`<' varlist `>'] `(' [TypeList] `)' `->` ReturnType -AstTypeOrPack Parser::parseFunctionType(bool allowPack) +AstTypeOrPack Parser::parseFunctionType(bool allowPack, const AstArray& attributes) { incrementRecursionCounter("type annotation"); @@ -1431,11 +1752,18 @@ AstTypeOrPack Parser::parseFunctionType(bool allowPack) AstArray> paramNames = copy(names); - return {parseFunctionTypeTail(begin, generics, genericPacks, paramTypes, paramNames, varargAnnotation), {}}; + return {parseFunctionTypeTail(begin, attributes, generics, genericPacks, paramTypes, paramNames, varargAnnotation), {}}; } -AstType* Parser::parseFunctionTypeTail(const Lexeme& begin, AstArray generics, AstArray genericPacks, - AstArray params, AstArray> paramNames, AstTypePack* varargAnnotation) +AstType* Parser::parseFunctionTypeTail( + const Lexeme& begin, + const AstArray& attributes, + AstArray generics, + AstArray genericPacks, + AstArray params, + AstArray> paramNames, + AstTypePack* varargAnnotation +) { incrementRecursionCounter("type annotation"); @@ -1459,7 +1787,9 @@ AstType* Parser::parseFunctionTypeTail(const Lexeme& begin, AstArray(Location(begin.location, endLocation), generics, genericPacks, paramTypes, paramNames, returnTypeList); + return allocator.alloc( + Location(begin.location, endLocation), attributes, generics, genericPacks, paramTypes, paramNames, returnTypeList + ); } // Type ::= @@ -1471,12 +1801,15 @@ AstType* Parser::parseFunctionTypeTail(const Lexeme& begin, AstArray parts(scratchType); - parts.push_back(type); + + if (type != nullptr) + parts.push_back(type); incrementRecursionCounter("type annotation"); bool isUnion = false; bool isIntersection = false; + bool hasOptional = false; Location location = begin; @@ -1486,20 +1819,34 @@ AstType* Parser::parseTypeSuffix(AstType* type, const Location& begin) if (c == '|') { nextLexeme(); + + unsigned int oldRecursionCount = recursionCounter; parts.push_back(parseSimpleType(/* allowPack= */ false).type); + recursionCounter = oldRecursionCount; + isUnion = true; } else if (c == '?') { + LUAU_ASSERT(parts.size() >= 1); + Location loc = lexer.current().location; nextLexeme(); - parts.push_back(allocator.alloc(loc, std::nullopt, nameNil, std::nullopt, loc)); + + if (!hasOptional) + parts.push_back(allocator.alloc(loc, std::nullopt, nameNil, std::nullopt, loc)); + isUnion = true; + hasOptional = true; } else if (c == '&') { nextLexeme(); + + unsigned int oldRecursionCount = recursionCounter; parts.push_back(parseSimpleType(/* allowPack= */ false).type); + recursionCounter = oldRecursionCount; + isIntersection = true; } else if (c == Lexeme::Dot3) @@ -1509,15 +1856,21 @@ AstType* Parser::parseTypeSuffix(AstType* type, const Location& begin) } else break; + + if (parts.size() > unsigned(FInt::LuauTypeLengthLimit) + hasOptional) + ParseError::raise(parts.back()->location, "Exceeded allowed type length; simplify your type annotation to make the code compile"); } if (parts.size() == 1) - return type; + return parts[0]; if (isUnion && isIntersection) { - return reportTypeError(Location(begin, parts.back()->location), copy(parts), - "Mixing union and intersection types is not allowed; consider wrapping in parentheses."); + return reportTypeError( + Location(begin, parts.back()->location), + copy(parts), + "Mixing union and intersection types is not allowed; consider wrapping in parentheses." + ); } location.end = parts.back()->location.end; @@ -1535,7 +1888,7 @@ AstType* Parser::parseTypeSuffix(AstType* type, const Location& begin) AstTypeOrPack Parser::parseTypeOrPack() { unsigned int oldRecursionCount = recursionCounter; - incrementRecursionCounter("type annotation"); + // recursion counter is incremented in parseSimpleType Location begin = lexer.current().location; @@ -1552,29 +1905,51 @@ AstTypeOrPack Parser::parseTypeOrPack() return {parseTypeSuffix(type, begin), {}}; } -AstType* Parser::parseType() +AstType* Parser::parseType(bool inDeclarationContext) { unsigned int oldRecursionCount = recursionCounter; - incrementRecursionCounter("type annotation"); + // recursion counter is incremented in parseSimpleType and/or parseTypeSuffix Location begin = lexer.current().location; - AstType* type = parseSimpleType(/* allowPack= */ false).type; + AstType* type = nullptr; + Lexeme::Type c = lexer.current().type; + if (c != '|' && c != '&') + { + type = parseSimpleType(/* allowPack= */ false, /* in declaration context */ inDeclarationContext).type; + recursionCounter = oldRecursionCount; + } + + AstType* typeWithSuffix = parseTypeSuffix(type, begin); recursionCounter = oldRecursionCount; - return parseTypeSuffix(type, begin); + return typeWithSuffix; } // Type ::= nil | Name[`.' Name] [ `<' Type [`,' ...] `>' ] | `typeof' `(' expr `)' | `{' [PropList] `}' // | [`<' varlist `>'] `(' [TypeList] `)' `->` ReturnType -AstTypeOrPack Parser::parseSimpleType(bool allowPack) +AstTypeOrPack Parser::parseSimpleType(bool allowPack, bool inDeclarationContext) { incrementRecursionCounter("type annotation"); Location start = lexer.current().location; - if (lexer.current().type == Lexeme::ReservedNil) + AstArray attributes{nullptr, 0}; + + if (lexer.current().type == Lexeme::Attribute) + { + if (!inDeclarationContext) + { + return {reportTypeError(start, {}, "attributes are not allowed in declaration context")}; + } + else + { + attributes = Parser::parseAttributes(); + return parseFunctionType(allowPack, attributes); + } + } + else if (lexer.current().type == Lexeme::ReservedNil) { nextLexeme(); return {allocator.alloc(start, std::nullopt, nameNil, std::nullopt, start), {}}; @@ -1608,7 +1983,7 @@ AstTypeOrPack Parser::parseSimpleType(bool allowPack) else if (lexer.current().type == Lexeme::BrokenString) { nextLexeme(); - return {reportTypeError(start, {}, "Malformed string")}; + return {reportTypeError(start, {}, "Malformed string; did you forget to finish it?")}; } else if (lexer.current().type == Lexeme::Name) { @@ -1656,24 +2031,30 @@ AstTypeOrPack Parser::parseSimpleType(bool allowPack) Location end = lexer.previousLocation(); return { - allocator.alloc(Location(start, end), prefix, name.name, prefixLocation, name.location, hasParameters, parameters), {}}; + allocator.alloc(Location(start, end), prefix, name.name, prefixLocation, name.location, hasParameters, parameters), {} + }; } else if (lexer.current().type == '{') { - return {parseTableType(), {}}; + return {parseTableType(/* inDeclarationContext */ inDeclarationContext), {}}; } else if (lexer.current().type == '(' || lexer.current().type == '<') { - return parseFunctionType(allowPack); + return parseFunctionType(allowPack, AstArray({nullptr, 0})); } else if (lexer.current().type == Lexeme::ReservedFunction) { nextLexeme(); - return {reportTypeError(start, {}, - "Using 'function' as a type annotation is not supported, consider replacing with a function type annotation e.g. '(...any) -> " - "...any'"), - {}}; + return { + reportTypeError( + start, + {}, + "Using 'function' as a type annotation is not supported, consider replacing with a function type annotation e.g. '(...any) -> " + "...any'" + ), + {} + }; } else { @@ -1727,7 +2108,8 @@ AstTypePack* Parser::parseTypePack() return allocator.alloc(Location(name.location, end), name.name); } - // No type pack annotation exists here. + // TODO: shouldParseTypePack can be removed and parseTypePack can be called unconditionally instead + LUAU_ASSERT(!"parseTypePack can't be called if shouldParseTypePack() returned false"); return nullptr; } @@ -1753,6 +2135,8 @@ std::optional Parser::parseBinaryOp(const Lexeme& l) return AstExprBinary::Mul; else if (l.type == '/') return AstExprBinary::Div; + else if (l.type == Lexeme::FloorDiv) + return AstExprBinary::FloorDiv; else if (l.type == '%') return AstExprBinary::Mod; else if (l.type == '^') @@ -1789,6 +2173,8 @@ std::optional Parser::parseCompoundOp(const Lexeme& l) return AstExprBinary::Mul; else if (l.type == Lexeme::DivAssign) return AstExprBinary::Div; + else if (l.type == Lexeme::FloorDivAssign) + return AstExprBinary::FloorDiv; else if (l.type == Lexeme::ModAssign) return AstExprBinary::Mod; else if (l.type == Lexeme::PowAssign) @@ -1812,7 +2198,7 @@ std::optional Parser::checkUnaryConfusables() if (curr.type == '!') { - report(start, "Unexpected '!', did you mean 'not'?"); + report(start, "Unexpected '!'; did you mean 'not'?"); return AstExprUnary::Not; } @@ -1834,20 +2220,19 @@ std::optional Parser::checkBinaryConfusables(const BinaryOpPr if (curr.type == '&' && next.type == '&' && curr.location.end == next.location.begin && binaryPriority[AstExprBinary::And].left > limit) { nextLexeme(); - report(Location(start, next.location), "Unexpected '&&', did you mean 'and'?"); + report(Location(start, next.location), "Unexpected '&&'; did you mean 'and'?"); return AstExprBinary::And; } else if (curr.type == '|' && next.type == '|' && curr.location.end == next.location.begin && binaryPriority[AstExprBinary::Or].left > limit) { nextLexeme(); - report(Location(start, next.location), "Unexpected '||', did you mean 'or'?"); + report(Location(start, next.location), "Unexpected '||'; did you mean 'or'?"); return AstExprBinary::Or; } - else if (curr.type == '!' && next.type == '=' && curr.location.end == next.location.begin && - binaryPriority[AstExprBinary::CompareNe].left > limit) + else if (curr.type == '!' && next.type == '=' && curr.location.end == next.location.begin && binaryPriority[AstExprBinary::CompareNe].left > limit) { nextLexeme(); - report(Location(start, next.location), "Unexpected '!=', did you mean '~='?"); + report(Location(start, next.location), "Unexpected '!='; did you mean '~='?"); return AstExprBinary::CompareNe; } @@ -1858,15 +2243,19 @@ std::optional Parser::checkBinaryConfusables(const BinaryOpPr // where `binop' is any binary operator with a priority higher than `limit' AstExpr* Parser::parseExpr(unsigned int limit) { + // clang-format off static const BinaryOpPriority binaryPriority[] = { - {6, 6}, {6, 6}, {7, 7}, {7, 7}, {7, 7}, // `+' `-' `*' `/' `%' - {10, 9}, {5, 4}, // power and concat (right associative) - {3, 3}, {3, 3}, // equality and inequality - {3, 3}, {3, 3}, {3, 3}, {3, 3}, // order - {2, 2}, {1, 1} // logical (and/or) + {6, 6}, {6, 6}, {7, 7}, {7, 7}, {7, 7}, {7, 7}, // `+' `-' `*' `/' `//' `%' + {10, 9}, {5, 4}, // power and concat (right associative) + {3, 3}, {3, 3}, // equality and inequality + {3, 3}, {3, 3}, {3, 3}, {3, 3}, // order + {2, 2}, {1, 1} // logical (and/or) }; + // clang-format on - unsigned int recursionCounterOld = recursionCounter; + static_assert(sizeof(binaryPriority) / sizeof(binaryPriority[0]) == size_t(AstExprBinary::Op__Count), "binaryPriority needs an entry per op"); + + unsigned int oldRecursionCount = recursionCounter; // this handles recursive calls to parseSubExpr/parseExpr incrementRecursionCounter("expression"); @@ -1918,7 +2307,7 @@ AstExpr* Parser::parseExpr(unsigned int limit) incrementRecursionCounter("expression"); } - recursionCounter = recursionCounterOld; + recursionCounter = oldRecursionCount; return expr; } @@ -1985,7 +2374,7 @@ AstExpr* Parser::parsePrimaryExpr(bool asStatement) AstExpr* expr = parsePrefixExpr(); - unsigned int recursionCounterOld = recursionCounter; + unsigned int oldRecursionCount = recursionCounter; while (true) { @@ -2045,7 +2434,7 @@ AstExpr* Parser::parsePrimaryExpr(bool asStatement) incrementRecursionCounter("expression"); } - recursionCounter = recursionCounterOld; + recursionCounter = oldRecursionCount; return expr; } @@ -2056,7 +2445,7 @@ AstExpr* Parser::parseAssertionExpr() Location start = lexer.current().location; AstExpr* expr = parseSimpleExpr(); - if (options.allowTypeAnnotations && lexer.current().type == Lexeme::DoubleColon) + if (lexer.current().type == Lexeme::DoubleColon) { nextLexeme(); AstType* annotation = parseType(); @@ -2089,6 +2478,9 @@ static ConstantNumberParseResult parseInteger(double& result, const char* data, return base == 2 ? ConstantNumberParseResult::BinOverflow : ConstantNumberParseResult::HexOverflow; } + if (value >= (1ull << 53) && static_cast(result) != value) + return ConstantNumberParseResult::Imprecise; + return ConstantNumberParseResult::Ok; } @@ -2105,15 +2497,45 @@ static ConstantNumberParseResult parseDouble(double& result, const char* data) char* end = nullptr; double value = strtod(data, &end); + // trailing non-numeric characters + if (*end != 0) + return ConstantNumberParseResult::Malformed; + result = value; - return *end == 0 ? ConstantNumberParseResult::Ok : ConstantNumberParseResult::Malformed; + + // for linting, we detect integer constants that are parsed imprecisely + // since the check is expensive we only perform it when the number is larger than the precise integer range + if (value >= double(1ull << 53) && strspn(data, "0123456789") == strlen(data)) + { + char repr[512]; + snprintf(repr, sizeof(repr), "%.0f", value); + + if (strcmp(repr, data) != 0) + return ConstantNumberParseResult::Imprecise; + } + + return ConstantNumberParseResult::Ok; } -// simpleexp -> NUMBER | STRING | NIL | true | false | ... | constructor | FUNCTION body | primaryexp +// simpleexp -> NUMBER | STRING | NIL | true | false | ... | constructor | [attributes] FUNCTION body | primaryexp AstExpr* Parser::parseSimpleExpr() { Location start = lexer.current().location; + AstArray attributes{nullptr, 0}; + + if (FFlag::LuauAttributeSyntaxFunExpr && lexer.current().type == Lexeme::Attribute) + { + attributes = parseAttributes(); + + if (lexer.current().type != Lexeme::ReservedFunction) + { + return reportExprError( + start, {}, "Expected 'function' declaration after attribute, but got %s instead", lexer.current().toString().c_str() + ); + } + } + if (lexer.current().type == Lexeme::ReservedNil) { nextLexeme(); @@ -2137,14 +2559,13 @@ AstExpr* Parser::parseSimpleExpr() Lexeme matchFunction = lexer.current(); nextLexeme(); - return parseFunctionBody(false, matchFunction, AstName(), nullptr).first; + return parseFunctionBody(false, matchFunction, AstName(), nullptr, attributes).first; } else if (lexer.current().type == Lexeme::Number) { return parseNumber(); } - else if (lexer.current().type == Lexeme::RawString || lexer.current().type == Lexeme::QuotedString || - lexer.current().type == Lexeme::InterpStringSimple) + else if (lexer.current().type == Lexeme::RawString || lexer.current().type == Lexeme::QuotedString || lexer.current().type == Lexeme::InterpStringSimple) { return parseString(); } @@ -2155,12 +2576,12 @@ AstExpr* Parser::parseSimpleExpr() else if (lexer.current().type == Lexeme::BrokenString) { nextLexeme(); - return reportExprError(start, {}, "Malformed string"); + return reportExprError(start, {}, "Malformed string; did you forget to finish it?"); } else if (lexer.current().type == Lexeme::BrokenInterpDoubleBrace) { nextLexeme(); - return reportExprError(start, {}, ERROR_INVALID_INTERP_DOUBLE_BRACE); + return reportExprError(start, {}, "Double braces are not permitted within interpolated strings; did you mean '\\{'?"); } else if (lexer.current().type == Lexeme::Dot3) { @@ -2244,15 +2665,22 @@ LUAU_NOINLINE AstExpr* Parser::reportFunctionArgsError(AstExpr* func, bool self) } else { - return reportExprError(Location(func->location.begin, lexer.current().location.begin), copy({func}), - "Expected '(', '{' or when parsing function call, got %s", lexer.current().toString().c_str()); + return reportExprError( + Location(func->location.begin, lexer.current().location.begin), + copy({func}), + "Expected '(', '{' or when parsing function call, got %s", + lexer.current().toString().c_str() + ); } } LUAU_NOINLINE void Parser::reportAmbiguousCallError() { - report(lexer.current().location, "Ambiguous syntax: this looks like an argument list for a function call, but could also be a start of " - "new statement; use ';' to separate statements"); + report( + lexer.current().location, + "Ambiguous syntax: this looks like an argument list for a function call, but could also be a start of " + "new statement; use ';' to separate statements" + ); } // tableconstructor ::= `{' [fieldlist] `}' @@ -2298,7 +2726,7 @@ AstExpr* Parser::parseTableConstructor() nameString.data = const_cast(name.name.value); nameString.size = strlen(name.name.value); - AstExpr* key = allocator.alloc(name.location, nameString); + AstExpr* key = allocator.alloc(name.location, nameString, AstExprConstantString::Unquoted); AstExpr* value = parseExpr(); if (AstExprFunction* func = value->as()) @@ -2449,24 +2877,13 @@ std::pair, AstArray> Parser::parseG seenDefault = true; nextLexeme(); - Lexeme packBegin = lexer.current(); - if (shouldParseTypePack(lexer)) { AstTypePack* typePack = parseTypePack(); namePacks.push_back({name, nameLocation, typePack}); } - else if (!FFlag::LuauParserErrorsOnMissingDefaultTypePackArgument && lexer.current().type == '(') - { - auto [type, typePack] = parseTypeOrPack(); - - if (type) - report(Location(packBegin.location.begin, lexer.previousLocation().end), "Expected type pack after '=', got type"); - - namePacks.push_back({name, nameLocation, typePack}); - } - else if (FFlag::LuauParserErrorsOnMissingDefaultTypePackArgument) + else { auto [type, typePack] = parseTypeOrPack(); @@ -2575,10 +2992,12 @@ AstArray Parser::parseTypeParams() std::optional> Parser::parseCharArray() { - LUAU_ASSERT(lexer.current().type == Lexeme::QuotedString || lexer.current().type == Lexeme::RawString || - lexer.current().type == Lexeme::InterpStringSimple); + LUAU_ASSERT( + lexer.current().type == Lexeme::QuotedString || lexer.current().type == Lexeme::RawString || + lexer.current().type == Lexeme::InterpStringSimple + ); - scratchData.assign(lexer.current().data, lexer.current().length); + scratchData.assign(lexer.current().data, lexer.current().getLength()); if (lexer.current().type == Lexeme::QuotedString || lexer.current().type == Lexeme::InterpStringSimple) { @@ -2618,12 +3037,14 @@ AstExpr* Parser::parseInterpString() do { Lexeme currentLexeme = lexer.current(); - LUAU_ASSERT(currentLexeme.type == Lexeme::InterpStringBegin || currentLexeme.type == Lexeme::InterpStringMid || - currentLexeme.type == Lexeme::InterpStringEnd || currentLexeme.type == Lexeme::InterpStringSimple); + LUAU_ASSERT( + currentLexeme.type == Lexeme::InterpStringBegin || currentLexeme.type == Lexeme::InterpStringMid || + currentLexeme.type == Lexeme::InterpStringEnd || currentLexeme.type == Lexeme::InterpStringSimple + ); endLocation = currentLexeme.location; - scratchData.assign(currentLexeme.data, currentLexeme.length); + scratchData.assign(currentLexeme.data, currentLexeme.getLength()); if (!Lexer::fixupQuotedString(scratchData)) { @@ -2658,7 +3079,7 @@ AstExpr* Parser::parseInterpString() { errorWhileChecking = true; nextLexeme(); - expressions.push_back(reportExprError(endLocation, {}, "Malformed interpolated string, did you forget to add a '`'?")); + expressions.push_back(reportExprError(endLocation, {}, "Malformed interpolated string; did you forget to add a '`'?")); break; } default: @@ -2678,10 +3099,10 @@ AstExpr* Parser::parseInterpString() break; case Lexeme::BrokenInterpDoubleBrace: nextLexeme(); - return reportExprError(endLocation, {}, ERROR_INVALID_INTERP_DOUBLE_BRACE); + return reportExprError(endLocation, {}, "Double braces are not permitted within interpolated strings; did you mean '\\{'?"); case Lexeme::BrokenString: nextLexeme(); - return reportExprError(endLocation, {}, "Malformed interpolated string, did you forget to add a '}'?"); + return reportExprError(endLocation, {}, "Malformed interpolated string; did you forget to add a '}'?"); default: return reportExprError(endLocation, {}, "Malformed interpolated string, got %s", lexer.current().toString().c_str()); } @@ -2696,7 +3117,7 @@ AstExpr* Parser::parseNumber() { Location start = lexer.current().location; - scratchData.assign(lexer.current().data, lexer.current().length); + scratchData.assign(lexer.current().data, lexer.current().getLength()); // Remove all internal _ - they don't hold any meaning and this allows parsing code to just pass the string pointer to strtod et al if (scratchData.find('_') != std::string::npos) @@ -2720,7 +3141,8 @@ AstLocal* Parser::pushLocal(const Binding& binding) AstLocal*& local = localMap[name.name]; local = allocator.alloc( - name.name, name.location, /* shadow= */ local, functionStack.size() - 1, functionStack.back().loopDepth, binding.annotation); + name.name, name.location, /* shadow= */ local, functionStack.size() - 1, functionStack.back().loopDepth, binding.annotation + ); localStack.push_back(local); @@ -2853,11 +3275,25 @@ LUAU_NOINLINE void Parser::expectMatchAndConsumeFail(Lexeme::Type type, const Ma std::string matchString = Lexeme(Location(Position(0, 0), 0), begin.type).toString(); if (lexer.current().location.begin.line == begin.position.line) - report(lexer.current().location, "Expected %s (to close %s at column %d), got %s%s", typeString.c_str(), matchString.c_str(), - begin.position.column + 1, lexer.current().toString().c_str(), extra ? extra : ""); + report( + lexer.current().location, + "Expected %s (to close %s at column %d), got %s%s", + typeString.c_str(), + matchString.c_str(), + begin.position.column + 1, + lexer.current().toString().c_str(), + extra ? extra : "" + ); else - report(lexer.current().location, "Expected %s (to close %s at line %d), got %s%s", typeString.c_str(), matchString.c_str(), - begin.position.line + 1, lexer.current().toString().c_str(), extra ? extra : ""); + report( + lexer.current().location, + "Expected %s (to close %s at line %d), got %s%s", + typeString.c_str(), + matchString.c_str(), + begin.position.line + 1, + lexer.current().toString().c_str(), + extra ? extra : "" + ); } bool Parser::expectMatchEndAndConsume(Lexeme::Type type, const MatchLexeme& begin) @@ -2994,7 +3430,12 @@ LUAU_NOINLINE void Parser::reportNameError(const char* context) } AstStatError* Parser::reportStatError( - const Location& location, const AstArray& expressions, const AstArray& statements, const char* format, ...) + const Location& location, + const AstArray& expressions, + const AstArray& statements, + const char* format, + ... +) { va_list args; va_start(args, format); @@ -3051,11 +3492,11 @@ void Parser::nextLexeme() return; // Comments starting with ! are called "hot comments" and contain directives for type checking / linting / compiling - if (lexeme.type == Lexeme::Comment && lexeme.length && lexeme.data[0] == '!') + if (lexeme.type == Lexeme::Comment && lexeme.getLength() && lexeme.data[0] == '!') { const char* text = lexeme.data; - unsigned int end = lexeme.length; + unsigned int end = lexeme.getLength(); while (end > 0 && isSpace(text[end - 1])) --end; @@ -3066,4 +3507,4 @@ void Parser::nextLexeme() } } -} // namespace Luau +} // namespace Luau \ No newline at end of file diff --git a/third_party/luau/Ast/src/StringUtils.cpp b/third_party/luau/Ast/src/StringUtils.cpp index 343c553c..d3a4cea2 100644 --- a/third_party/luau/Ast/src/StringUtils.cpp +++ b/third_party/luau/Ast/src/StringUtils.cpp @@ -7,6 +7,7 @@ #include #include #include +#include namespace Luau { @@ -140,7 +141,8 @@ size_t editDistance(std::string_view a, std::string_view b) size_t maxDistance = a.size() + b.size(); std::vector distances((a.size() + 2) * (b.size() + 2), 0); - auto getPos = [b](size_t x, size_t y) -> size_t { + auto getPos = [b](size_t x, size_t y) -> size_t + { return (x * (b.size() + 2)) + y; }; diff --git a/third_party/luau/Ast/src/TimeTrace.cpp b/third_party/luau/Ast/src/TimeTrace.cpp index b9d40e99..4782b25c 100644 --- a/third_party/luau/Ast/src/TimeTrace.cpp +++ b/third_party/luau/Ast/src/TimeTrace.cpp @@ -90,17 +90,8 @@ namespace TimeTrace { struct GlobalContext { - GlobalContext() = default; ~GlobalContext() { - // Ideally we would want all ThreadContext destructors to run - // But in VS, not all thread_local object instances are destroyed - for (ThreadContext* context : threads) - { - if (!context->events.empty()) - context->flushEvents(); - } - if (traceFile) fclose(traceFile); } @@ -110,11 +101,15 @@ struct GlobalContext uint32_t nextThreadId = 0; std::vector tokens; FILE* traceFile = nullptr; + +private: + friend std::shared_ptr getGlobalContext(); + GlobalContext() = default; }; -GlobalContext& getGlobalContext() +std::shared_ptr getGlobalContext() { - static GlobalContext context; + static std::shared_ptr context = std::shared_ptr{new GlobalContext}; return context; } @@ -189,8 +184,14 @@ void flushEvents(GlobalContext& context, uint32_t threadId, const std::vector +#include +#include +#include +#include +#include +#include + #ifdef CALLGRIND #include #endif LUAU_FASTFLAG(DebugLuauTimeTracing) +LUAU_FASTFLAG(DebugLuauLogSolverToJsonFile) enum class ReportFormat { @@ -55,8 +64,13 @@ static void reportError(const Luau::Frontend& frontend, ReportFormat format, con if (const Luau::SyntaxError* syntaxError = Luau::get_if(&error.data)) report(format, humanReadableName.c_str(), error.location, "SyntaxError", syntaxError->message.c_str()); else - report(format, humanReadableName.c_str(), error.location, "TypeError", - Luau::toString(error, Luau::TypeErrorToStringOptions{frontend.fileResolver}).c_str()); + report( + format, + humanReadableName.c_str(), + error.location, + "TypeError", + Luau::toString(error, Luau::TypeErrorToStringOptions{frontend.fileResolver}).c_str() + ); } static void reportWarning(ReportFormat format, const char* name, const Luau::LintWarning& warning) @@ -64,26 +78,29 @@ static void reportWarning(ReportFormat format, const char* name, const Luau::Lin report(format, name, warning.location, Luau::LintWarning::getName(warning.code), warning.text.c_str()); } -static bool analyzeFile(Luau::Frontend& frontend, const char* name, ReportFormat format, bool annotate) +static bool reportModuleResult(Luau::Frontend& frontend, const Luau::ModuleName& name, ReportFormat format, bool annotate) { - Luau::CheckResult cr; + std::optional cr = frontend.getCheckResult(name, false); - if (frontend.isDirty(name)) - cr = frontend.check(name); + if (!cr) + { + fprintf(stderr, "Failed to find result for %s\n", name.c_str()); + return false; + } if (!frontend.getSourceModule(name)) { - fprintf(stderr, "Error opening %s\n", name); + fprintf(stderr, "Error opening %s\n", name.c_str()); return false; } - for (auto& error : cr.errors) + for (auto& error : cr->errors) reportError(frontend, format, error); std::string humanReadableName = frontend.fileResolver->getHumanReadableModuleName(name); - for (auto& error : cr.lintResult.errors) + for (auto& error : cr->lintResult.errors) reportWarning(format, humanReadableName.c_str(), error); - for (auto& warning : cr.lintResult.warnings) + for (auto& warning : cr->lintResult.warnings) reportWarning(format, humanReadableName.c_str(), warning); if (annotate) @@ -98,7 +115,7 @@ static bool analyzeFile(Luau::Frontend& frontend, const char* name, ReportFormat printf("%s", annotated.c_str()); } - return cr.errors.empty() && cr.lintResult.errors.empty(); + return cr->errors.empty() && cr->lintResult.errors.empty(); } static void displayHelp(const char* argv0) @@ -216,6 +233,77 @@ struct CliConfigResolver : Luau::ConfigResolver } }; +struct TaskScheduler +{ + TaskScheduler(unsigned threadCount) + : threadCount(threadCount) + { + for (unsigned i = 0; i < threadCount; i++) + { + workers.emplace_back( + [this] + { + workerFunction(); + } + ); + } + } + + ~TaskScheduler() + { + for (unsigned i = 0; i < threadCount; i++) + push({}); + + for (std::thread& worker : workers) + worker.join(); + } + + std::function pop() + { + std::unique_lock guard(mtx); + + cv.wait( + guard, + [this] + { + return !tasks.empty(); + } + ); + + std::function task = tasks.front(); + tasks.pop(); + return task; + } + + void push(std::function task) + { + { + std::unique_lock guard(mtx); + tasks.push(std::move(task)); + } + + cv.notify_one(); + } + + static unsigned getThreadCount() + { + return std::max(std::thread::hardware_concurrency(), 1u); + } + +private: + void workerFunction() + { + while (std::function task = pop()) + task(); + } + + unsigned threadCount = 1; + std::mutex mtx; + std::condition_variable cv; + std::vector workers; + std::queue> tasks; +}; + int main(int argc, char** argv) { Luau::assertHandler() = assertionHandler; @@ -231,6 +319,8 @@ int main(int argc, char** argv) ReportFormat format = ReportFormat::Default; Luau::Mode mode = Luau::Mode::Nonstrict; bool annotate = false; + int threadCount = 0; + std::string basePath = ""; for (int i = 1; i < argc; ++i) { @@ -249,6 +339,10 @@ int main(int argc, char** argv) FFlag::DebugLuauTimeTracing.value = true; else if (strncmp(argv[i], "--fflags=", 9) == 0) setLuauFlags(argv[i] + 9); + else if (strncmp(argv[i], "-j", 2) == 0) + threadCount = int(strtol(argv[i] + 2, nullptr, 10)); + else if (strncmp(argv[i], "--logbase=", 10) == 0) + basePath = std::string{argv[i] + 10}; } #if !defined(LUAU_ENABLE_TIME_TRACE) @@ -267,6 +361,25 @@ int main(int argc, char** argv) CliConfigResolver configResolver(mode); Luau::Frontend frontend(&fileResolver, &configResolver, frontendOptions); + if (FFlag::DebugLuauLogSolverToJsonFile) + { + frontend.writeJsonLog = [&basePath](const Luau::ModuleName& moduleName, std::string log) + { + std::string path = moduleName + ".log.json"; + size_t pos = moduleName.find_last_of('/'); + if (pos != std::string::npos) + path = moduleName.substr(pos + 1); + + if (!basePath.empty()) + path = joinPaths(basePath, path); + + std::ofstream os(path); + + os << log << std::endl; + printf("Wrote JSON log to %s\n", path.c_str()); + }; + } + Luau::registerBuiltinGlobals(frontend, frontend.globals); Luau::freeze(frontend.globals.globalTypes); @@ -276,10 +389,51 @@ int main(int argc, char** argv) std::vector files = getSourceFiles(argc, argv); + for (const std::string& path : files) + frontend.queueModuleCheck(path); + + std::vector checkedModules; + + // If thread count is not set, try to use HW thread count, but with an upper limit + // When we improve scalability of typechecking, upper limit can be adjusted/removed + if (threadCount <= 0) + threadCount = std::min(TaskScheduler::getThreadCount(), 8u); + + try + { + TaskScheduler scheduler(threadCount); + + checkedModules = frontend.checkQueuedModules( + std::nullopt, + [&](std::function f) + { + scheduler.push(std::move(f)); + } + ); + } + catch (const Luau::InternalCompilerError& ice) + { + Luau::Location location = ice.location ? *ice.location : Luau::Location(); + + std::string moduleName = ice.moduleName ? *ice.moduleName : ""; + std::string humanReadableName = frontend.fileResolver->getHumanReadableModuleName(moduleName); + + Luau::TypeError error(location, moduleName, Luau::InternalError{ice.message}); + + report( + format, + humanReadableName.c_str(), + location, + "InternalCompilerError", + Luau::toString(error, Luau::TypeErrorToStringOptions{frontend.fileResolver}).c_str() + ); + return 1; + } + int failed = 0; - for (const std::string& path : files) - failed += !analyzeFile(frontend, path.c_str(), format, annotate); + for (const Luau::ModuleName& name : checkedModules) + failed += !reportModuleResult(frontend, name, format, annotate); if (!configResolver.configErrors.empty()) { diff --git a/third_party/luau/CLI/Ast.cpp b/third_party/luau/CLI/Ast.cpp index 99c58393..b5a922aa 100644 --- a/third_party/luau/CLI/Ast.cpp +++ b/third_party/luau/CLI/Ast.cpp @@ -64,8 +64,6 @@ int main(int argc, char** argv) Luau::ParseOptions options; options.captureComments = true; - options.supportContinueStatement = true; - options.allowTypeAnnotations = true; options.allowDeclarationSyntax = true; Luau::ParseResult parseResult = Luau::Parser::parse(source.data(), source.size(), names, allocator, options); diff --git a/third_party/luau/CLI/Bytecode.cpp b/third_party/luau/CLI/Bytecode.cpp new file mode 100644 index 00000000..2da9570b --- /dev/null +++ b/third_party/luau/CLI/Bytecode.cpp @@ -0,0 +1,299 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "lua.h" +#include "lualib.h" + +#include "Luau/CodeGen.h" +#include "Luau/Compiler.h" +#include "Luau/BytecodeBuilder.h" +#include "Luau/Parser.h" +#include "Luau/BytecodeSummary.h" +#include "FileUtils.h" +#include "Flags.h" + +#include + +using Luau::CodeGen::FunctionBytecodeSummary; + +struct GlobalOptions +{ + int optimizationLevel = 1; + int debugLevel = 1; +} globalOptions; + +static Luau::CompileOptions copts() +{ + Luau::CompileOptions result = {}; + result.optimizationLevel = globalOptions.optimizationLevel; + result.debugLevel = globalOptions.debugLevel; + result.typeInfoLevel = 1; + + return result; +} + +static void displayHelp(const char* argv0) +{ + printf("Usage: %s [options] [file list]\n", argv0); + printf("\n"); + printf("Available options:\n"); + printf(" -h, --help: Display this usage message.\n"); + printf(" -O: compile with optimization level n (default 1, n should be between 0 and 2).\n"); + printf(" -g: compile with debug level n (default 1, n should be between 0 and 2).\n"); + printf(" --fflags=: flags to be enabled.\n"); + printf(" --summary-file=: file in which bytecode analysis summary will be recorded (default 'bytecode-summary.json').\n"); + + exit(0); +} + +static bool parseArgs(int argc, char** argv, std::string& summaryFile) +{ + for (int i = 1; i < argc; i++) + { + if (strcmp(argv[i], "-h") == 0 || strcmp(argv[i], "--help") == 0) + { + displayHelp(argv[0]); + } + else if (strncmp(argv[i], "-O", 2) == 0) + { + int level = atoi(argv[i] + 2); + if (level < 0 || level > 2) + { + fprintf(stderr, "Error: Optimization level must be between 0 and 2 inclusive.\n"); + return false; + } + globalOptions.optimizationLevel = level; + } + else if (strncmp(argv[i], "-g", 2) == 0) + { + int level = atoi(argv[i] + 2); + if (level < 0 || level > 2) + { + fprintf(stderr, "Error: Debug level must be between 0 and 2 inclusive.\n"); + return false; + } + globalOptions.debugLevel = level; + } + else if (strncmp(argv[i], "--summary-file=", 15) == 0) + { + summaryFile = argv[i] + 15; + + if (summaryFile.size() == 0) + { + fprintf(stderr, "Error: filename missing for '--summary-file'.\n\n"); + return false; + } + } + else if (strncmp(argv[i], "--fflags=", 9) == 0) + { + setLuauFlags(argv[i] + 9); + } + else if (argv[i][0] == '-') + { + fprintf(stderr, "Error: Unrecognized option '%s'.\n\n", argv[i]); + displayHelp(argv[0]); + } + } + + return true; +} + +static void report(const char* name, const Luau::Location& location, const char* type, const char* message) +{ + fprintf(stderr, "%s(%d,%d): %s: %s\n", name, location.begin.line + 1, location.begin.column + 1, type, message); +} + +static void reportError(const char* name, const Luau::ParseError& error) +{ + report(name, error.getLocation(), "SyntaxError", error.what()); +} + +static void reportError(const char* name, const Luau::CompileError& error) +{ + report(name, error.getLocation(), "CompileError", error.what()); +} + +static bool analyzeFile(const char* name, const unsigned nestingLimit, std::vector& summaries) +{ + std::optional source = readFile(name); + + if (!source) + { + fprintf(stderr, "Error opening %s\n", name); + return false; + } + + try + { + Luau::BytecodeBuilder bcb; + + compileOrThrow(bcb, *source, copts()); + + const std::string& bytecode = bcb.getBytecode(); + + std::unique_ptr globalState(luaL_newstate(), lua_close); + lua_State* L = globalState.get(); + + if (luau_load(L, name, bytecode.data(), bytecode.size(), 0) == 0) + { + summaries = Luau::CodeGen::summarizeBytecode(L, -1, nestingLimit); + return true; + } + else + { + fprintf(stderr, "Error loading bytecode %s\n", name); + return false; + } + } + catch (Luau::ParseErrors& e) + { + for (auto& error : e.getErrors()) + reportError(name, error); + return false; + } + catch (Luau::CompileError& e) + { + reportError(name, e); + return false; + } + + return true; +} + +static std::string escapeFilename(const std::string& filename) +{ + std::string escaped; + escaped.reserve(filename.size()); + + for (const char ch : filename) + { + switch (ch) + { + case '\\': + escaped.push_back('/'); + break; + case '"': + escaped.push_back('\\'); + escaped.push_back(ch); + break; + default: + escaped.push_back(ch); + } + } + + return escaped; +} + +static void serializeFunctionSummary(const FunctionBytecodeSummary& summary, FILE* fp) +{ + const unsigned nestingLimit = summary.getNestingLimit(); + const unsigned opLimit = summary.getOpLimit(); + + fprintf(fp, " {\n"); + fprintf(fp, " \"source\": \"%s\",\n", summary.getSource().c_str()); + fprintf(fp, " \"name\": \"%s\",\n", summary.getName().c_str()); + fprintf(fp, " \"line\": %d,\n", summary.getLine()); + fprintf(fp, " \"nestingLimit\": %u,\n", nestingLimit); + fprintf(fp, " \"counts\": ["); + + for (unsigned nesting = 0; nesting <= nestingLimit; ++nesting) + { + fprintf(fp, "\n ["); + + for (unsigned i = 0; i < opLimit; ++i) + { + fprintf(fp, "%d", summary.getCount(nesting, uint8_t(i))); + if (i < opLimit - 1) + fprintf(fp, ", "); + } + + fprintf(fp, "]"); + if (nesting < nestingLimit) + fprintf(fp, ","); + } + + fprintf(fp, "\n ]"); + fprintf(fp, "\n }"); +} + +static void serializeScriptSummary(const std::string& file, const std::vector& scriptSummary, FILE* fp) +{ + std::string escaped(escapeFilename(file)); + const size_t functionCount = scriptSummary.size(); + + fprintf(fp, " \"%s\": [\n", escaped.c_str()); + + for (size_t i = 0; i < functionCount; ++i) + { + serializeFunctionSummary(scriptSummary[i], fp); + fprintf(fp, i == (functionCount - 1) ? "\n" : ",\n"); + } + + fprintf(fp, " ]"); +} + +static bool serializeSummaries( + const std::vector& files, + const std::vector>& scriptSummaries, + const std::string& summaryFile +) +{ + + FILE* fp = fopen(summaryFile.c_str(), "w"); + const size_t fileCount = files.size(); + + if (!fp) + { + fprintf(stderr, "Unable to open '%s'.\n", summaryFile.c_str()); + return false; + } + + fprintf(fp, "{\n"); + + for (size_t i = 0; i < fileCount; ++i) + { + serializeScriptSummary(files[i], scriptSummaries[i], fp); + fprintf(fp, i < (fileCount - 1) ? ",\n" : "\n"); + } + + fprintf(fp, "}"); + fclose(fp); + + return true; +} + +static int assertionHandler(const char* expr, const char* file, int line, const char* function) +{ + printf("%s(%d): ASSERTION FAILED: %s\n", file, line, expr); + return 1; +} + +int main(int argc, char** argv) +{ + Luau::assertHandler() = assertionHandler; + + setLuauFlagsDefault(); + + std::string summaryFile("bytecode-summary.json"); + unsigned nestingLimit = 0; + + if (!parseArgs(argc, argv, summaryFile)) + return 1; + + const std::vector files = getSourceFiles(argc, argv); + size_t fileCount = files.size(); + + std::vector> scriptSummaries; + scriptSummaries.reserve(fileCount); + + for (size_t i = 0; i < fileCount; ++i) + { + if (!analyzeFile(files[i].c_str(), nestingLimit, scriptSummaries[i])) + return 1; + } + + if (!serializeSummaries(files, scriptSummaries, summaryFile)) + return 1; + + fprintf(stdout, "Bytecode summary written to '%s'\n", summaryFile.c_str()); + + return 0; +} diff --git a/third_party/luau/CLI/Compile.cpp b/third_party/luau/CLI/Compile.cpp new file mode 100644 index 00000000..7d95387c --- /dev/null +++ b/third_party/luau/CLI/Compile.cpp @@ -0,0 +1,697 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "lua.h" +#include "lualib.h" + +#include "Luau/CodeGen.h" +#include "Luau/Compiler.h" +#include "Luau/BytecodeBuilder.h" +#include "Luau/Parser.h" +#include "Luau/TimeTrace.h" + +#include "FileUtils.h" +#include "Flags.h" + +#include + +#ifdef _WIN32 +#include +#include +#endif + +LUAU_FASTFLAG(DebugLuauTimeTracing) + +enum class CompileFormat +{ + Text, + Binary, + Remarks, + Codegen, // Prints annotated native code including IR and assembly + CodegenAsm, // Prints annotated native code assembly + CodegenIr, // Prints annotated native code IR + CodegenVerbose, // Prints annotated native code including IR, assembly and outlined code + CodegenNull, + Null +}; + +enum class RecordStats +{ + None, + Total, + File, + Function +}; + +struct GlobalOptions +{ + int optimizationLevel = 1; + int debugLevel = 1; + int typeInfoLevel = 0; + + const char* vectorLib = nullptr; + const char* vectorCtor = nullptr; + const char* vectorType = nullptr; +} globalOptions; + +static Luau::CompileOptions copts() +{ + Luau::CompileOptions result = {}; + result.optimizationLevel = globalOptions.optimizationLevel; + result.debugLevel = globalOptions.debugLevel; + result.typeInfoLevel = globalOptions.typeInfoLevel; + + result.vectorLib = globalOptions.vectorLib; + result.vectorCtor = globalOptions.vectorCtor; + result.vectorType = globalOptions.vectorType; + + return result; +} + +static std::optional getCompileFormat(const char* name) +{ + if (strcmp(name, "text") == 0) + return CompileFormat::Text; + else if (strcmp(name, "binary") == 0) + return CompileFormat::Binary; + else if (strcmp(name, "text") == 0) + return CompileFormat::Text; + else if (strcmp(name, "remarks") == 0) + return CompileFormat::Remarks; + else if (strcmp(name, "codegen") == 0) + return CompileFormat::Codegen; + else if (strcmp(name, "codegenasm") == 0) + return CompileFormat::CodegenAsm; + else if (strcmp(name, "codegenir") == 0) + return CompileFormat::CodegenIr; + else if (strcmp(name, "codegenverbose") == 0) + return CompileFormat::CodegenVerbose; + else if (strcmp(name, "codegennull") == 0) + return CompileFormat::CodegenNull; + else if (strcmp(name, "null") == 0) + return CompileFormat::Null; + else + return std::nullopt; +} + +static void report(const char* name, const Luau::Location& location, const char* type, const char* message) +{ + fprintf(stderr, "%s(%d,%d): %s: %s\n", name, location.begin.line + 1, location.begin.column + 1, type, message); +} + +static void reportError(const char* name, const Luau::ParseError& error) +{ + report(name, error.getLocation(), "SyntaxError", error.what()); +} + +static void reportError(const char* name, const Luau::CompileError& error) +{ + report(name, error.getLocation(), "CompileError", error.what()); +} + +static std::string getCodegenAssembly( + const char* name, + const std::string& bytecode, + Luau::CodeGen::AssemblyOptions options, + Luau::CodeGen::LoweringStats* stats +) +{ + std::unique_ptr globalState(luaL_newstate(), lua_close); + lua_State* L = globalState.get(); + + if (luau_load(L, name, bytecode.data(), bytecode.size(), 0) == 0) + return Luau::CodeGen::getAssembly(L, -1, options, stats); + + fprintf(stderr, "Error loading bytecode %s\n", name); + return ""; +} + +static void annotateInstruction(void* context, std::string& text, int fid, int instpos) +{ + Luau::BytecodeBuilder& bcb = *(Luau::BytecodeBuilder*)context; + + bcb.annotateInstruction(text, fid, instpos); +} + +struct CompileStats +{ + size_t lines; + size_t bytecode; + size_t bytecodeInstructionCount; + size_t codegen; + + double readTime; + double miscTime; + double parseTime; + double compileTime; + double codegenTime; + + Luau::CodeGen::LoweringStats lowerStats; + + CompileStats& operator+=(const CompileStats& that) + { + this->lines += that.lines; + this->bytecode += that.bytecode; + this->bytecodeInstructionCount += that.bytecodeInstructionCount; + this->codegen += that.codegen; + this->readTime += that.readTime; + this->miscTime += that.miscTime; + this->parseTime += that.parseTime; + this->compileTime += that.compileTime; + this->codegenTime += that.codegenTime; + this->lowerStats += that.lowerStats; + + return *this; + } + + CompileStats operator+(const CompileStats& other) const + { + CompileStats result(*this); + result += other; + return result; + } +}; + +#define WRITE_NAME(INDENT, NAME) fprintf(fp, INDENT "\"" #NAME "\": ") +#define WRITE_PAIR(INDENT, NAME, FORMAT) fprintf(fp, INDENT "\"" #NAME "\": " FORMAT, stats.NAME) +#define WRITE_PAIR_STRING(INDENT, NAME, FORMAT) fprintf(fp, INDENT "\"" #NAME "\": " FORMAT, stats.NAME.c_str()) + +void serializeFunctionStats(FILE* fp, const Luau::CodeGen::FunctionStats& stats) +{ + fprintf(fp, " {\n"); + WRITE_PAIR_STRING(" ", name, "\"%s\",\n"); + WRITE_PAIR(" ", line, "%d,\n"); + WRITE_PAIR(" ", bcodeCount, "%u,\n"); + WRITE_PAIR(" ", irCount, "%u,\n"); + WRITE_PAIR(" ", asmCount, "%u,\n"); + WRITE_PAIR(" ", asmSize, "%u,\n"); + + WRITE_NAME(" ", bytecodeSummary); + const size_t nestingLimit = stats.bytecodeSummary.size(); + + if (nestingLimit == 0) + fprintf(fp, "[]"); + else + { + fprintf(fp, "[\n"); + for (size_t i = 0; i < nestingLimit; ++i) + { + const std::vector& counts = stats.bytecodeSummary[i]; + fprintf(fp, " ["); + for (size_t j = 0; j < counts.size(); ++j) + { + fprintf(fp, "%u", counts[j]); + if (j < counts.size() - 1) + fprintf(fp, ", "); + } + fprintf(fp, "]"); + if (i < stats.bytecodeSummary.size() - 1) + fprintf(fp, ",\n"); + } + fprintf(fp, "\n ]"); + } + + fprintf(fp, "\n }"); +} + +void serializeBlockLinearizationStats(FILE* fp, const Luau::CodeGen::BlockLinearizationStats& stats) +{ + fprintf(fp, "{\n"); + + WRITE_PAIR(" ", constPropInstructionCount, "%u,\n"); + WRITE_PAIR(" ", timeSeconds, "%f\n"); + + fprintf(fp, " }"); +} + +void serializeLoweringStats(FILE* fp, const Luau::CodeGen::LoweringStats& stats) +{ + fprintf(fp, "{\n"); + + WRITE_PAIR(" ", totalFunctions, "%u,\n"); + WRITE_PAIR(" ", skippedFunctions, "%u,\n"); + WRITE_PAIR(" ", spillsToSlot, "%d,\n"); + WRITE_PAIR(" ", spillsToRestore, "%d,\n"); + WRITE_PAIR(" ", maxSpillSlotsUsed, "%u,\n"); + WRITE_PAIR(" ", blocksPreOpt, "%u,\n"); + WRITE_PAIR(" ", blocksPostOpt, "%u,\n"); + WRITE_PAIR(" ", maxBlockInstructions, "%u,\n"); + WRITE_PAIR(" ", regAllocErrors, "%d,\n"); + WRITE_PAIR(" ", loweringErrors, "%d,\n"); + + WRITE_NAME(" ", blockLinearizationStats); + serializeBlockLinearizationStats(fp, stats.blockLinearizationStats); + fprintf(fp, ",\n"); + + WRITE_NAME(" ", functions); + const size_t functionCount = stats.functions.size(); + + if (functionCount == 0) + fprintf(fp, "[]"); + else + { + fprintf(fp, "[\n"); + for (size_t i = 0; i < functionCount; ++i) + { + serializeFunctionStats(fp, stats.functions[i]); + if (i < functionCount - 1) + fprintf(fp, ",\n"); + } + fprintf(fp, "\n ]"); + } + + fprintf(fp, "\n }"); +} + +void serializeCompileStats(FILE* fp, const CompileStats& stats) +{ + fprintf(fp, "{\n"); + + WRITE_PAIR(" ", lines, "%zu,\n"); + WRITE_PAIR(" ", bytecode, "%zu,\n"); + WRITE_PAIR(" ", bytecodeInstructionCount, "%zu,\n"); + WRITE_PAIR(" ", codegen, "%zu,\n"); + WRITE_PAIR(" ", readTime, "%f,\n"); + WRITE_PAIR(" ", miscTime, "%f,\n"); + WRITE_PAIR(" ", parseTime, "%f,\n"); + WRITE_PAIR(" ", compileTime, "%f,\n"); + WRITE_PAIR(" ", codegenTime, "%f,\n"); + + WRITE_NAME(" ", lowerStats); + serializeLoweringStats(fp, stats.lowerStats); + + fprintf(fp, "\n }"); +} + +#undef WRITE_NAME +#undef WRITE_PAIR +#undef WRITE_PAIR_STRING + +static double recordDeltaTime(double& timer) +{ + double now = Luau::TimeTrace::getClock(); + double delta = now - timer; + timer = now; + return delta; +} + +static bool compileFile(const char* name, CompileFormat format, Luau::CodeGen::AssemblyOptions::Target assemblyTarget, CompileStats& stats) +{ + double currts = Luau::TimeTrace::getClock(); + + std::optional source = readFile(name); + if (!source) + { + fprintf(stderr, "Error opening %s\n", name); + return false; + } + + stats.readTime += recordDeltaTime(currts); + + // NOTE: Normally, you should use Luau::compile or luau_compile (see lua_require as an example) + // This function is much more complicated because it supports many output human-readable formats through internal interfaces + + try + { + Luau::BytecodeBuilder bcb; + + Luau::CodeGen::AssemblyOptions options; + options.target = assemblyTarget; + options.outputBinary = format == CompileFormat::CodegenNull; + + if (!options.outputBinary) + { + options.includeAssembly = format != CompileFormat::CodegenIr; + options.includeIr = format != CompileFormat::CodegenAsm; + options.includeIrTypes = format != CompileFormat::CodegenAsm; + options.includeOutlinedCode = format == CompileFormat::CodegenVerbose; + } + + options.annotator = annotateInstruction; + options.annotatorContext = &bcb; + + if (format == CompileFormat::Text) + { + bcb.setDumpFlags( + Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Source | Luau::BytecodeBuilder::Dump_Locals | + Luau::BytecodeBuilder::Dump_Remarks | Luau::BytecodeBuilder::Dump_Types + ); + bcb.setDumpSource(*source); + } + else if (format == CompileFormat::Remarks) + { + bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Source | Luau::BytecodeBuilder::Dump_Remarks); + bcb.setDumpSource(*source); + } + else if (format == CompileFormat::Codegen || format == CompileFormat::CodegenAsm || format == CompileFormat::CodegenIr || format == CompileFormat::CodegenVerbose) + { + bcb.setDumpFlags( + Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Source | Luau::BytecodeBuilder::Dump_Locals | + Luau::BytecodeBuilder::Dump_Remarks + ); + bcb.setDumpSource(*source); + } + + stats.miscTime += recordDeltaTime(currts); + + Luau::Allocator allocator; + Luau::AstNameTable names(allocator); + Luau::ParseResult result = Luau::Parser::parse(source->c_str(), source->size(), names, allocator); + + if (!result.errors.empty()) + throw Luau::ParseErrors(result.errors); + + stats.lines += result.lines; + stats.parseTime += recordDeltaTime(currts); + + Luau::compileOrThrow(bcb, result, names, copts()); + stats.bytecode += bcb.getBytecode().size(); + stats.bytecodeInstructionCount = bcb.getTotalInstructionCount(); + stats.compileTime += recordDeltaTime(currts); + + switch (format) + { + case CompileFormat::Text: + printf("%s", bcb.dumpEverything().c_str()); + break; + case CompileFormat::Remarks: + printf("%s", bcb.dumpSourceRemarks().c_str()); + break; + case CompileFormat::Binary: + fwrite(bcb.getBytecode().data(), 1, bcb.getBytecode().size(), stdout); + break; + case CompileFormat::Codegen: + case CompileFormat::CodegenAsm: + case CompileFormat::CodegenIr: + case CompileFormat::CodegenVerbose: + printf("%s", getCodegenAssembly(name, bcb.getBytecode(), options, &stats.lowerStats).c_str()); + break; + case CompileFormat::CodegenNull: + stats.codegen += getCodegenAssembly(name, bcb.getBytecode(), options, &stats.lowerStats).size(); + stats.codegenTime += recordDeltaTime(currts); + break; + case CompileFormat::Null: + break; + } + + return true; + } + catch (Luau::ParseErrors& e) + { + for (auto& error : e.getErrors()) + reportError(name, error); + return false; + } + catch (Luau::CompileError& e) + { + reportError(name, e); + return false; + } +} + +static void displayHelp(const char* argv0) +{ + printf("Usage: %s [--mode] [options] [file list]\n", argv0); + printf("\n"); + printf("Available modes:\n"); + printf(" binary, text, remarks, codegen\n"); + printf("\n"); + printf("Available options:\n"); + printf(" -h, --help: Display this usage message.\n"); + printf(" -O: compile with optimization level n (default 1, n should be between 0 and 2).\n"); + printf(" -g: compile with debug level n (default 1, n should be between 0 and 2).\n"); + printf(" --target=: compile code for specific architecture (a64, x64, a64_nf, x64_ms).\n"); + printf(" --timetrace: record compiler time tracing information into trace.json\n"); + printf(" --record-stats=: granularity of compilation stats (total, file, function).\n"); + printf(" --bytecode-summary: Compute bytecode operation distribution.\n"); + printf(" --stats-file=: file in which compilation stats will be recored (default 'stats.json').\n"); + printf(" --vector-lib=: name of the library providing vector type operations.\n"); + printf(" --vector-ctor=: name of the function constructing a vector value.\n"); + printf(" --vector-type=: name of the vector type.\n"); +} + +static int assertionHandler(const char* expr, const char* file, int line, const char* function) +{ + printf("%s(%d): ASSERTION FAILED: %s\n", file, line, expr); + return 1; +} + +std::string escapeFilename(const std::string& filename) +{ + std::string escaped; + escaped.reserve(filename.size()); + + for (const char ch : filename) + { + switch (ch) + { + case '\\': + escaped.push_back('/'); + break; + case '"': + escaped.push_back('\\'); + escaped.push_back(ch); + break; + default: + escaped.push_back(ch); + } + } + + return escaped; +} + +int main(int argc, char** argv) +{ + Luau::assertHandler() = assertionHandler; + + setLuauFlagsDefault(); + + CompileFormat compileFormat = CompileFormat::Text; + Luau::CodeGen::AssemblyOptions::Target assemblyTarget = Luau::CodeGen::AssemblyOptions::Host; + RecordStats recordStats = RecordStats::None; + std::string statsFile("stats.json"); + bool bytecodeSummary = false; + + for (int i = 1; i < argc; i++) + { + if (strcmp(argv[i], "-h") == 0 || strcmp(argv[i], "--help") == 0) + { + displayHelp(argv[0]); + return 0; + } + else if (strncmp(argv[i], "-O", 2) == 0) + { + int level = atoi(argv[i] + 2); + if (level < 0 || level > 2) + { + fprintf(stderr, "Error: Optimization level must be between 0 and 2 inclusive.\n"); + return 1; + } + globalOptions.optimizationLevel = level; + } + else if (strncmp(argv[i], "-g", 2) == 0) + { + int level = atoi(argv[i] + 2); + if (level < 0 || level > 2) + { + fprintf(stderr, "Error: Debug level must be between 0 and 2 inclusive.\n"); + return 1; + } + globalOptions.debugLevel = level; + } + else if (strncmp(argv[i], "-t", 2) == 0) + { + int level = atoi(argv[i] + 2); + if (level < 0 || level > 1) + { + fprintf(stderr, "Error: Type info level must be between 0 and 1 inclusive.\n"); + return 1; + } + globalOptions.typeInfoLevel = level; + } + else if (strncmp(argv[i], "--target=", 9) == 0) + { + const char* value = argv[i] + 9; + + if (strcmp(value, "a64") == 0) + assemblyTarget = Luau::CodeGen::AssemblyOptions::A64; + else if (strcmp(value, "a64_nf") == 0) + assemblyTarget = Luau::CodeGen::AssemblyOptions::A64_NoFeatures; + else if (strcmp(value, "x64") == 0) + assemblyTarget = Luau::CodeGen::AssemblyOptions::X64_SystemV; + else if (strcmp(value, "x64_ms") == 0) + assemblyTarget = Luau::CodeGen::AssemblyOptions::X64_Windows; + else + { + fprintf(stderr, "Error: unknown target\n"); + return 1; + } + } + else if (strcmp(argv[i], "--timetrace") == 0) + { + FFlag::DebugLuauTimeTracing.value = true; + } + else if (strncmp(argv[i], "--record-stats=", 15) == 0) + { + const char* value = argv[i] + 15; + + if (strcmp(value, "total") == 0) + recordStats = RecordStats::Total; + else if (strcmp(value, "file") == 0) + recordStats = RecordStats::File; + else if (strcmp(value, "function") == 0) + recordStats = RecordStats::Function; + else + { + fprintf(stderr, "Error: unknown 'granularity' for '--record-stats'.\n"); + return 1; + } + } + else if (strncmp(argv[i], "--bytecode-summary", 18) == 0) + { + bytecodeSummary = true; + } + else if (strncmp(argv[i], "--stats-file=", 13) == 0) + { + statsFile = argv[i] + 13; + + if (statsFile.size() == 0) + { + fprintf(stderr, "Error: filename missing for '--stats-file'.\n\n"); + return 1; + } + } + else if (strncmp(argv[i], "--fflags=", 9) == 0) + { + setLuauFlags(argv[i] + 9); + } + else if (strncmp(argv[i], "--vector-lib=", 13) == 0) + { + globalOptions.vectorLib = argv[i] + 13; + } + else if (strncmp(argv[i], "--vector-ctor=", 14) == 0) + { + globalOptions.vectorCtor = argv[i] + 14; + } + else if (strncmp(argv[i], "--vector-type=", 14) == 0) + { + globalOptions.vectorType = argv[i] + 14; + } + else if (argv[i][0] == '-' && argv[i][1] == '-' && getCompileFormat(argv[i] + 2)) + { + compileFormat = *getCompileFormat(argv[i] + 2); + } + else if (argv[i][0] == '-') + { + fprintf(stderr, "Error: Unrecognized option '%s'.\n\n", argv[i]); + displayHelp(argv[0]); + return 1; + } + } + + if (bytecodeSummary && (recordStats != RecordStats::Function)) + { + fprintf(stderr, "'Error: Required '--record-stats=function' for '--bytecode-summary'.\n"); + return 1; + } + +#if !defined(LUAU_ENABLE_TIME_TRACE) + if (FFlag::DebugLuauTimeTracing) + { + fprintf(stderr, "To run with --timetrace, Luau has to be built with LUAU_ENABLE_TIME_TRACE enabled\n"); + return 1; + } +#endif + + const std::vector files = getSourceFiles(argc, argv); + +#ifdef _WIN32 + if (compileFormat == CompileFormat::Binary) + _setmode(_fileno(stdout), _O_BINARY); +#endif + + const size_t fileCount = files.size(); + CompileStats stats = {}; + + std::vector fileStats; + if (recordStats == RecordStats::File || recordStats == RecordStats::Function) + fileStats.reserve(fileCount); + + int failed = 0; + unsigned functionStats = (recordStats == RecordStats::Function ? Luau::CodeGen::FunctionStats_Enable : 0) | + (bytecodeSummary ? Luau::CodeGen::FunctionStats_BytecodeSummary : 0); + for (const std::string& path : files) + { + CompileStats fileStat = {}; + fileStat.lowerStats.functionStatsFlags = functionStats; + failed += !compileFile(path.c_str(), compileFormat, assemblyTarget, fileStat); + stats += fileStat; + if (recordStats == RecordStats::File || recordStats == RecordStats::Function) + fileStats.push_back(fileStat); + } + + if (compileFormat == CompileFormat::Null) + { + printf( + "Compiled %d KLOC into %d KB bytecode (read %.2fs, parse %.2fs, compile %.2fs)\n", + int(stats.lines / 1000), + int(stats.bytecode / 1024), + stats.readTime, + stats.parseTime, + stats.compileTime + ); + } + else if (compileFormat == CompileFormat::CodegenNull) + { + printf( + "Compiled %d KLOC into %d KB bytecode => %d KB native code (%.2fx) (read %.2fs, parse %.2fs, compile %.2fs, codegen %.2fs)\n", + int(stats.lines / 1000), + int(stats.bytecode / 1024), + int(stats.codegen / 1024), + stats.bytecode == 0 ? 0.0 : double(stats.codegen) / double(stats.bytecode), + stats.readTime, + stats.parseTime, + stats.compileTime, + stats.codegenTime + ); + + printf( + "Lowering: regalloc failed: %d, lowering failed %d; spills to stack: %d, spills to restore: %d, max spill slot %u\n", + stats.lowerStats.regAllocErrors, + stats.lowerStats.loweringErrors, + stats.lowerStats.spillsToSlot, + stats.lowerStats.spillsToRestore, + stats.lowerStats.maxSpillSlotsUsed + ); + } + + if (recordStats != RecordStats::None) + { + FILE* fp = fopen(statsFile.c_str(), "w"); + + if (!fp) + { + fprintf(stderr, "Unable to open 'stats.json'\n"); + return 1; + } + + if (recordStats == RecordStats::Total) + { + serializeCompileStats(fp, stats); + } + else if (recordStats == RecordStats::File || recordStats == RecordStats::Function) + { + fprintf(fp, "{\n"); + for (size_t i = 0; i < fileCount; ++i) + { + std::string escaped(escapeFilename(files[i])); + fprintf(fp, " \"%s\": ", escaped.c_str()); + serializeCompileStats(fp, fileStats[i]); + fprintf(fp, i == (fileCount - 1) ? "\n" : ",\n"); + } + fprintf(fp, "}"); + } + + fclose(fp); + } + + return failed ? 1 : 0; +} diff --git a/third_party/luau/CLI/FileUtils.cpp b/third_party/luau/CLI/FileUtils.cpp index 71e536c7..daa7c295 100644 --- a/third_party/luau/CLI/FileUtils.cpp +++ b/third_party/luau/CLI/FileUtils.cpp @@ -10,6 +10,7 @@ #ifndef NOMINMAX #define NOMINMAX #endif +#include #include #else #include @@ -44,6 +45,148 @@ static std::string toUtf8(const std::wstring& path) } #endif +bool isAbsolutePath(std::string_view path) +{ +#ifdef _WIN32 + // Must either begin with "X:/", "X:\", "/", or "\", where X is a drive letter + return (path.size() >= 3 && isalpha(path[0]) && path[1] == ':' && (path[2] == '/' || path[2] == '\\')) || + (path.size() >= 1 && (path[0] == '/' || path[0] == '\\')); +#else + // Must begin with '/' + return path.size() >= 1 && path[0] == '/'; +#endif +} + +bool isExplicitlyRelative(std::string_view path) +{ + return (path == ".") || (path == "..") || (path.size() >= 2 && path[0] == '.' && path[1] == '/') || + (path.size() >= 3 && path[0] == '.' && path[1] == '.' && path[2] == '/'); +} + +std::optional getCurrentWorkingDirectory() +{ + // 2^17 - derived from the Windows path length limit + constexpr size_t maxPathLength = 131072; + constexpr size_t initialPathLength = 260; + + std::string directory(initialPathLength, '\0'); + char* cstr = nullptr; + + while (!cstr && directory.size() <= maxPathLength) + { +#ifdef _WIN32 + cstr = _getcwd(directory.data(), static_cast(directory.size())); +#else + cstr = getcwd(directory.data(), directory.size()); +#endif + if (cstr) + { + directory.resize(strlen(cstr)); + return directory; + } + else if (errno != ERANGE || directory.size() * 2 > maxPathLength) + { + return std::nullopt; + } + else + { + directory.resize(directory.size() * 2); + } + } + return std::nullopt; +} + +// Returns the normal/canonical form of a path (e.g. "../subfolder/../module.luau" -> "../module.luau") +std::string normalizePath(std::string_view path) +{ + return resolvePath(path, ""); +} + +// Takes a path that is relative to the file at baseFilePath and returns the path explicitly rebased onto baseFilePath. +// For absolute paths, baseFilePath will be ignored, and this function will resolve the path to a canonical path: +// (e.g. "/Users/.././Users/johndoe" -> "/Users/johndoe"). +std::string resolvePath(std::string_view path, std::string_view baseFilePath) +{ + std::vector pathComponents; + std::vector baseFilePathComponents; + + // Dependent on whether the final resolved path is absolute or relative + // - if relative (when path and baseFilePath are both relative), resolvedPathPrefix remains empty + // - if absolute (if either path or baseFilePath are absolute), resolvedPathPrefix is "C:\", "/", etc. + std::string resolvedPathPrefix; + + if (isAbsolutePath(path)) + { + // path is absolute, we use path's prefix and ignore baseFilePath + size_t afterPrefix = path.find_first_of("\\/") + 1; + resolvedPathPrefix = path.substr(0, afterPrefix); + pathComponents = splitPath(path.substr(afterPrefix)); + } + else + { + pathComponents = splitPath(path); + if (isAbsolutePath(baseFilePath)) + { + // path is relative and baseFilePath is absolute, we use baseFilePath's prefix + size_t afterPrefix = baseFilePath.find_first_of("\\/") + 1; + resolvedPathPrefix = baseFilePath.substr(0, afterPrefix); + baseFilePathComponents = splitPath(baseFilePath.substr(afterPrefix)); + } + else + { + // path and baseFilePath are both relative, we do not set a prefix (resolved path will be relative) + baseFilePathComponents = splitPath(baseFilePath); + } + } + + // Remove filename from components + if (!baseFilePathComponents.empty()) + baseFilePathComponents.pop_back(); + + // Resolve the path by applying pathComponents to baseFilePathComponents + int numPrependedParents = 0; + for (std::string_view component : pathComponents) + { + if (component == "..") + { + if (baseFilePathComponents.empty()) + { + if (resolvedPathPrefix.empty()) // only when final resolved path will be relative + numPrependedParents++; // "../" will later be added to the beginning of the resolved path + } + else if (baseFilePathComponents.back() != "..") + { + baseFilePathComponents.pop_back(); // Resolve cases like "folder/subfolder/../../file" to "file" + } + } + else if (component != "." && !component.empty()) + { + baseFilePathComponents.push_back(component); + } + } + + // Join baseFilePathComponents to form the resolved path + std::string resolvedPath = resolvedPathPrefix; + // Only when resolvedPath will be relative + for (int i = 0; i < numPrependedParents; i++) + { + resolvedPath += "../"; + } + for (auto iter = baseFilePathComponents.begin(); iter != baseFilePathComponents.end(); ++iter) + { + if (iter != baseFilePathComponents.begin()) + resolvedPath += "/"; + + resolvedPath += *iter; + } + if (resolvedPath.size() > resolvedPathPrefix.size() && resolvedPath.back() == '/') + { + // Remove trailing '/' if present + resolvedPath.pop_back(); + } + return resolvedPath; +} + std::optional readFile(const std::string& name) { #ifdef _WIN32 @@ -165,11 +308,14 @@ static bool traverseDirectoryRec(const std::string& path, const std::function splitPath(std::string_view path) +{ + std::vector components; + + size_t pos = 0; + size_t nextPos = path.find_first_of("\\/", pos); + + while (nextPos != std::string::npos) + { + components.push_back(path.substr(pos, nextPos - pos)); + pos = nextPos + 1; + nextPos = path.find_first_of("\\/", pos); + } + components.push_back(path.substr(pos)); + + return components; +} + std::string joinPaths(const std::string& lhs, const std::string& rhs) { std::string result = lhs; @@ -267,6 +431,10 @@ std::vector getSourceFiles(int argc, char** argv) for (int i = 1; i < argc; ++i) { + // Early out once we reach --program-args,-a since the remaining args are passed to lua + if (strcmp(argv[i], "--program-args") == 0 || strcmp(argv[i], "-a") == 0) + return files; + // Treat '-' as a special file whose source is read from stdin // All other arguments that start with '-' are skipped if (argv[i][0] == '-' && argv[i][1] != '\0') @@ -274,12 +442,16 @@ std::vector getSourceFiles(int argc, char** argv) if (isDirectory(argv[i])) { - traverseDirectory(argv[i], [&](const std::string& name) { - std::string ext = getExtension(name); - - if (ext == ".lua" || ext == ".luau") - files.push_back(name); - }); + traverseDirectory( + argv[i], + [&](const std::string& name) + { + std::string ext = getExtension(name); + + if (ext == ".lua" || ext == ".luau") + files.push_back(name); + } + ); } else { diff --git a/third_party/luau/CLI/FileUtils.h b/third_party/luau/CLI/FileUtils.h index 97471cdc..2004a2eb 100644 --- a/third_party/luau/CLI/FileUtils.h +++ b/third_party/luau/CLI/FileUtils.h @@ -3,15 +3,24 @@ #include #include +#include #include #include +std::optional getCurrentWorkingDirectory(); + +std::string normalizePath(std::string_view path); +std::string resolvePath(std::string_view relativePath, std::string_view baseFilePath); + std::optional readFile(const std::string& name); std::optional readStdin(); +bool isAbsolutePath(std::string_view path); +bool isExplicitlyRelative(std::string_view path); bool isDirectory(const std::string& path); bool traverseDirectory(const std::string& path, const std::function& callback); +std::vector splitPath(std::string_view path); std::string joinPaths(const std::string& lhs, const std::string& rhs); std::optional getParentPath(const std::string& path); diff --git a/third_party/luau/CLI/Flags.cpp b/third_party/luau/CLI/Flags.cpp index 4e261171..ee5918c9 100644 --- a/third_party/luau/CLI/Flags.cpp +++ b/third_party/luau/CLI/Flags.cpp @@ -54,8 +54,9 @@ void setLuauFlags(const char* list) else if (value == "false" || value == "False") setLuauFlag(key, false); else - fprintf(stderr, "Warning: unrecognized value '%.*s' for flag '%.*s'.\n", int(value.length()), value.data(), int(key.length()), - key.data()); + fprintf( + stderr, "Warning: unrecognized value '%.*s' for flag '%.*s'.\n", int(value.length()), value.data(), int(key.length()), key.data() + ); } else { diff --git a/third_party/luau/CLI/Profiler.cpp b/third_party/luau/CLI/Profiler.cpp index d3ad4e99..3cf0aea2 100644 --- a/third_party/luau/CLI/Profiler.cpp +++ b/third_party/luau/CLI/Profiler.cpp @@ -131,8 +131,13 @@ void profilerDump(const char* path) fclose(f); - printf("Profiler dump written to %s (total runtime %.3f seconds, %lld samples, %lld stacks)\n", path, double(total) / 1e6, - static_cast(gProfiler.samples.load()), static_cast(gProfiler.data.size())); + printf( + "Profiler dump written to %s (total runtime %.3f seconds, %lld samples, %lld stacks)\n", + path, + double(total) / 1e6, + static_cast(gProfiler.samples.load()), + static_cast(gProfiler.data.size()) + ); uint64_t totalgc = 0; for (uint64_t p : gProfiler.gc) diff --git a/third_party/luau/CLI/Reduce.cpp b/third_party/luau/CLI/Reduce.cpp index b7c78012..7f8c459c 100644 --- a/third_party/luau/CLI/Reduce.cpp +++ b/third_party/luau/CLI/Reduce.cpp @@ -15,7 +15,7 @@ #define VERBOSE 0 // 1 - print out commandline invocations. 2 - print out stdout -#ifdef _WIN32 +#if defined(_WIN32) && !defined(__MINGW32__) const auto popen = &_popen; const auto pclose = &_pclose; @@ -56,10 +56,9 @@ struct Reducer ParseResult parseResult; AstStatBlock* root; - std::string tempScriptName; + std::string scriptName; - std::string appName; - std::vector appArgs; + std::string command; std::string_view searchText; Reducer() @@ -99,10 +98,10 @@ struct Reducer } while (true); } - FILE* f = fopen(tempScriptName.c_str(), "w"); + FILE* f = fopen(scriptName.c_str(), "w"); if (!f) { - printf("Unable to open temp script to %s\n", tempScriptName.c_str()); + printf("Unable to open temp script to %s\n", scriptName.c_str()); exit(2); } @@ -113,7 +112,7 @@ struct Reducer if (written != source.size()) { printf("??? %zu %zu\n", written, source.size()); - printf("Unable to write to temp script %s\n", tempScriptName.c_str()); + printf("Unable to write to temp script %s\n", scriptName.c_str()); exit(3); } @@ -142,12 +141,18 @@ struct Reducer { writeTempScript(); - std::string command = appName + " " + escape(tempScriptName); - for (const auto& arg : appArgs) - command += " " + escape(arg); + std::string cmd = command; + while (true) + { + auto pos = cmd.find("{}"); + if (std::string::npos == pos) + break; + + cmd = cmd.substr(0, pos) + escape(scriptName) + cmd.substr(pos + 2); + } #if VERBOSE >= 1 - printf("running %s\n", command.c_str()); + printf("running %s\n", cmd.c_str()); #endif TestResult result = TestResult::NoBug; @@ -155,7 +160,7 @@ struct Reducer ++step; printf("Step %4d...\n", step); - FILE* p = popen(command.c_str(), "r"); + FILE* p = popen(cmd.c_str(), "r"); while (!feof(p)) { @@ -179,7 +184,8 @@ struct Reducer { std::vector result; - auto append = [&](AstStatBlock* block) { + auto append = [&](AstStatBlock* block) + { if (block) result.insert(result.end(), block->body.data, block->body.data + block->body.size); }; @@ -245,7 +251,8 @@ struct Reducer std::vector> result; - auto append = [&result](Span a, Span b) { + auto append = [&result](Span a, Span b) + { if (a.first == a.second && b.first == b.second) return; else @@ -424,30 +431,19 @@ struct Reducer } } - void run(const std::string scriptName, const std::string appName, const std::vector& appArgs, std::string_view source, - std::string_view searchText) + void run(const std::string scriptName, const std::string command, std::string_view source, std::string_view searchText) { - tempScriptName = scriptName; - if (tempScriptName.substr(tempScriptName.size() - 4) == ".lua") - { - tempScriptName.erase(tempScriptName.size() - 4); - tempScriptName += "-reduced.lua"; - } - else - { - this->tempScriptName = scriptName + "-reduced"; - } + this->scriptName = scriptName; #if 0 // Handy debugging trick: VS Code will update its view of the file in realtime as it is edited. - std::string wheee = "code " + tempScriptName; + std::string wheee = "code " + scriptName; system(wheee.c_str()); #endif - printf("Temp script: %s\n", tempScriptName.c_str()); + printf("Script: %s\n", scriptName.c_str()); - this->appName = appName; - this->appArgs = appArgs; + this->command = command; this->searchText = searchText; parseResult = Parser::parse(source.data(), source.size(), nameTable, allocator, parseOptions); @@ -470,13 +466,14 @@ struct Reducer writeTempScript(/* minify */ true); - printf("Done! Check %s\n", tempScriptName.c_str()); + printf("Done! Check %s\n", scriptName.c_str()); } }; [[noreturn]] void help(const std::vector& args) { - printf("Syntax: %s script application \"search text\" [arguments]\n", args[0].data()); + printf("Syntax: %s script command \"search text\"\n", args[0].data()); + printf(" Within command, use {} as a stand-in for the script being reduced\n"); exit(1); } @@ -484,7 +481,7 @@ int main(int argc, char** argv) { const std::vector args(argv, argv + argc); - if (args.size() < 4) + if (args.size() != 4) help(args); for (size_t i = 1; i < args.size(); ++i) @@ -496,7 +493,6 @@ int main(int argc, char** argv) const std::string scriptName = argv[1]; const std::string appName = argv[2]; const std::string searchText = argv[3]; - const std::vector appArgs(begin(args) + 4, end(args)); std::optional source = readFile(scriptName); @@ -507,5 +503,5 @@ int main(int argc, char** argv) } Reducer reducer; - reducer.run(scriptName, appName, appArgs, *source, searchText); + reducer.run(scriptName, appName, *source, searchText); } diff --git a/third_party/luau/CLI/Repl.cpp b/third_party/luau/CLI/Repl.cpp index 4303364c..b8e9d814 100644 --- a/third_party/luau/CLI/Repl.cpp +++ b/third_party/luau/CLI/Repl.cpp @@ -1,12 +1,12 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Repl.h" +#include "Luau/Common.h" #include "lua.h" #include "lualib.h" #include "Luau/CodeGen.h" #include "Luau/Compiler.h" -#include "Luau/BytecodeBuilder.h" #include "Luau/Parser.h" #include "Luau/TimeTrace.h" @@ -14,6 +14,7 @@ #include "FileUtils.h" #include "Flags.h" #include "Profiler.h" +#include "Require.h" #include "isocline.h" @@ -27,6 +28,10 @@ #include #endif +#ifdef __linux__ +#include +#endif + #ifdef CALLGRIND #include #endif @@ -36,30 +41,12 @@ LUAU_FASTFLAG(DebugLuauTimeTracing) -enum class CliMode -{ - Unknown, - Repl, - Compile, - RunSourceFiles -}; - -enum class CompileFormat -{ - Text, - Binary, - Remarks, - Codegen, // Prints annotated native code including IR and assembly - CodegenAsm, // Prints annotated native code assembly - CodegenIr, // Prints annotated native code IR - CodegenVerbose, // Prints annotated native code including IR, assembly and outlined code - CodegenNull, - Null -}; constexpr int MaxTraversalLimit = 50; static bool codegen = false; +static int program_argc = 0; +char** program_argv = nullptr; // Ctrl-C handling static void sigintCallback(lua_State* L, int gc) @@ -101,6 +88,7 @@ static Luau::CompileOptions copts() Luau::CompileOptions result = {}; result.optimizationLevel = globalOptions.optimizationLevel; result.debugLevel = globalOptions.debugLevel; + result.typeInfoLevel = 1; result.coverageLevel = coverageActive() ? 2 : 0; return result; @@ -134,27 +122,13 @@ static int finishrequire(lua_State* L) static int lua_require(lua_State* L) { std::string name = luaL_checkstring(L, 1); - std::string chunkname = "=" + name; - luaL_findtable(L, LUA_REGISTRYINDEX, "_MODULES", 1); + RequireResolver::ResolvedRequire resolvedRequire = RequireResolver::resolveRequire(L, std::move(name)); - // return the module from the cache - lua_getfield(L, -1, name.c_str()); - if (!lua_isnil(L, -1)) - { - // L stack: _MODULES result + if (resolvedRequire.status == RequireResolver::ModuleStatus::Cached) return finishrequire(L); - } - - lua_pop(L, 1); - - std::optional source = readFile(name + ".luau"); - if (!source) - { - source = readFile(name + ".lua"); // try .lua if .luau doesn't exist - if (!source) - luaL_argerrorL(L, 1, ("error loading " + name).c_str()); // if neither .luau nor .lua exist, we have an error - } + else if (resolvedRequire.status == RequireResolver::ModuleStatus::NotFound) + luaL_errorL(L, "error requiring module"); // module needs to run in a new thread, isolated from the rest // note: we create ML on main thread so that it doesn't inherit environment of L @@ -166,11 +140,14 @@ static int lua_require(lua_State* L) luaL_sandboxthread(ML); // now we can compile & run module on the new thread - std::string bytecode = Luau::compile(*source, copts()); - if (luau_load(ML, chunkname.c_str(), bytecode.data(), bytecode.size(), 0) == 0) + std::string bytecode = Luau::compile(resolvedRequire.sourceCode, copts()); + if (luau_load(ML, resolvedRequire.chunkName.c_str(), bytecode.data(), bytecode.size(), 0) == 0) { if (codegen) - Luau::CodeGen::compile(ML, -1); + { + Luau::CodeGen::CompilationOptions nativeOptions; + Luau::CodeGen::compile(ML, -1, nativeOptions); + } if (coverageActive()) coverageTrack(ML, -1); @@ -197,7 +174,7 @@ static int lua_require(lua_State* L) // there's now a return value on top of ML; L stack: _MODULES ML lua_xmove(ML, L, 1); lua_pushvalue(L, -1); - lua_setfield(L, -4, name.c_str()); + lua_setfield(L, -4, resolvedRequire.absolutePath.c_str()); // L stack: _MODULES ML result return finishrequire(L); @@ -277,8 +254,18 @@ void setupState(lua_State* L) luaL_sandbox(L); } +void setupArguments(lua_State* L, int argc, char** argv) +{ + lua_checkstack(L, argc); + + for (int i = 0; i < argc; ++i) + lua_pushstring(L, argv[i]); +} + std::string runCode(lua_State* L, const std::string& source) { + lua_checkstack(L, LUA_MINSTACK); + std::string bytecode = Luau::compile(source, copts()); if (luau_load(L, "=stdin", bytecode.data(), bytecode.size(), 0) != 0) @@ -292,9 +279,6 @@ std::string runCode(lua_State* L, const std::string& source) return error; } - if (codegen) - Luau::CodeGen::compile(L, -1); - lua_State* T = lua_newthread(L); lua_pushvalue(L, -2); @@ -404,8 +388,13 @@ static void safeGetTable(lua_State* L, int tableIndex) // completePartialMatches finds keys that match the specified 'prefix' // Note: the table/object to be searched must be on the top of the Lua stack -static void completePartialMatches(lua_State* L, bool completeOnlyFunctions, const std::string& editBuffer, std::string_view prefix, - const AddCompletionCallback& addCompletionCallback) +static void completePartialMatches( + lua_State* L, + bool completeOnlyFunctions, + const std::string& editBuffer, + std::string_view prefix, + const AddCompletionCallback& addCompletionCallback +) { for (int i = 0; i < MaxTraversalLimit && lua_istable(L, -1); i++) { @@ -452,6 +441,8 @@ static void completeIndexer(lua_State* L, const std::string& editBuffer, const A std::string_view lookup = editBuffer; bool completeOnlyFunctions = false; + lua_checkstack(L, LUA_MINSTACK); + // Push the global variable table to begin the search lua_pushvalue(L, LUA_GLOBALSINDEX); @@ -497,9 +488,14 @@ static void icGetCompletions(ic_completion_env_t* cenv, const char* editBuffer) { auto* L = reinterpret_cast(ic_completion_arg(cenv)); - getCompletions(L, std::string(editBuffer), [cenv](const std::string& completion, const std::string& display) { - ic_add_completion_ex(cenv, completion.data(), display.data(), nullptr); - }); + getCompletions( + L, + std::string(editBuffer), + [cenv](const std::string& completion, const std::string& display) + { + ic_add_completion_ex(cenv, completion.data(), display.data(), nullptr); + } + ); } static bool isMethodOrFunctionChar(const char* s, long len) @@ -625,12 +621,16 @@ static bool runFile(const char* name, lua_State* GL, bool repl) if (luau_load(L, chunkname.c_str(), bytecode.data(), bytecode.size(), 0) == 0) { if (codegen) - Luau::CodeGen::compile(L, -1); + { + Luau::CodeGen::CompilationOptions nativeOptions; + Luau::CodeGen::compile(L, -1, nativeOptions); + } if (coverageActive()) coverageTrack(L, -1); - status = lua_resume(L, NULL, 0); + setupArguments(L, program_argc, program_argv); + status = lua_resume(L, NULL, program_argc); } else { @@ -664,178 +664,11 @@ static bool runFile(const char* name, lua_State* GL, bool repl) return status == 0; } -static void report(const char* name, const Luau::Location& location, const char* type, const char* message) -{ - fprintf(stderr, "%s(%d,%d): %s: %s\n", name, location.begin.line + 1, location.begin.column + 1, type, message); -} - -static void reportError(const char* name, const Luau::ParseError& error) -{ - report(name, error.getLocation(), "SyntaxError", error.what()); -} - -static void reportError(const char* name, const Luau::CompileError& error) -{ - report(name, error.getLocation(), "CompileError", error.what()); -} - -static std::string getCodegenAssembly(const char* name, const std::string& bytecode, Luau::CodeGen::AssemblyOptions options) -{ - std::unique_ptr globalState(luaL_newstate(), lua_close); - lua_State* L = globalState.get(); - - if (luau_load(L, name, bytecode.data(), bytecode.size(), 0) == 0) - return Luau::CodeGen::getAssembly(L, -1, options); - - fprintf(stderr, "Error loading bytecode %s\n", name); - return ""; -} - -static void annotateInstruction(void* context, std::string& text, int fid, int instpos) -{ - Luau::BytecodeBuilder& bcb = *(Luau::BytecodeBuilder*)context; - - bcb.annotateInstruction(text, fid, instpos); -} - -struct CompileStats -{ - size_t lines; - size_t bytecode; - size_t codegen; - - double readTime; - double miscTime; - double parseTime; - double compileTime; - double codegenTime; -}; - -static double recordDeltaTime(double& timer) -{ - double now = Luau::TimeTrace::getClock(); - double delta = now - timer; - timer = now; - return delta; -} - -static bool compileFile(const char* name, CompileFormat format, CompileStats& stats) -{ - double currts = Luau::TimeTrace::getClock(); - - std::optional source = readFile(name); - if (!source) - { - fprintf(stderr, "Error opening %s\n", name); - return false; - } - - stats.readTime += recordDeltaTime(currts); - - // NOTE: Normally, you should use Luau::compile or luau_compile (see lua_require as an example) - // This function is much more complicated because it supports many output human-readable formats through internal interfaces - - try - { - Luau::BytecodeBuilder bcb; - - Luau::CodeGen::AssemblyOptions options; - options.outputBinary = format == CompileFormat::CodegenNull; - - if (!options.outputBinary) - { - options.includeAssembly = format != CompileFormat::CodegenIr; - options.includeIr = format != CompileFormat::CodegenAsm; - options.includeOutlinedCode = format == CompileFormat::CodegenVerbose; - } - - options.annotator = annotateInstruction; - options.annotatorContext = &bcb; - - if (format == CompileFormat::Text) - { - bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Source | Luau::BytecodeBuilder::Dump_Locals | - Luau::BytecodeBuilder::Dump_Remarks); - bcb.setDumpSource(*source); - } - else if (format == CompileFormat::Remarks) - { - bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Source | Luau::BytecodeBuilder::Dump_Remarks); - bcb.setDumpSource(*source); - } - else if (format == CompileFormat::Codegen || format == CompileFormat::CodegenAsm || format == CompileFormat::CodegenIr || - format == CompileFormat::CodegenVerbose) - { - bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Source | Luau::BytecodeBuilder::Dump_Locals | - Luau::BytecodeBuilder::Dump_Remarks); - bcb.setDumpSource(*source); - } - - stats.miscTime += recordDeltaTime(currts); - - Luau::Allocator allocator; - Luau::AstNameTable names(allocator); - Luau::ParseResult result = Luau::Parser::parse(source->c_str(), source->size(), names, allocator); - - if (!result.errors.empty()) - throw Luau::ParseErrors(result.errors); - - stats.lines += result.lines; - stats.parseTime += recordDeltaTime(currts); - - Luau::compileOrThrow(bcb, result, names, copts()); - stats.bytecode += bcb.getBytecode().size(); - stats.compileTime += recordDeltaTime(currts); - - switch (format) - { - case CompileFormat::Text: - printf("%s", bcb.dumpEverything().c_str()); - break; - case CompileFormat::Remarks: - printf("%s", bcb.dumpSourceRemarks().c_str()); - break; - case CompileFormat::Binary: - fwrite(bcb.getBytecode().data(), 1, bcb.getBytecode().size(), stdout); - break; - case CompileFormat::Codegen: - case CompileFormat::CodegenAsm: - case CompileFormat::CodegenIr: - case CompileFormat::CodegenVerbose: - printf("%s", getCodegenAssembly(name, bcb.getBytecode(), options).c_str()); - break; - case CompileFormat::CodegenNull: - stats.codegen += getCodegenAssembly(name, bcb.getBytecode(), options).size(); - stats.codegenTime += recordDeltaTime(currts); - break; - case CompileFormat::Null: - break; - } - - return true; - } - catch (Luau::ParseErrors& e) - { - for (auto& error : e.getErrors()) - reportError(name, error); - return false; - } - catch (Luau::CompileError& e) - { - reportError(name, e); - return false; - } -} - static void displayHelp(const char* argv0) { - printf("Usage: %s [--mode] [options] [file list]\n", argv0); - printf("\n"); - printf("When mode and file list are omitted, an interactive REPL is started instead.\n"); + printf("Usage: %s [options] [file list] [-a] [arg list]\n", argv0); printf("\n"); - printf("Available modes:\n"); - printf(" omitted: compile and run input files one by one\n"); - printf(" --compile[=format]: compile input files and output resulting bytecode/assembly (binary, text, remarks, codegen)\n"); + printf("When file list is omitted, an interactive REPL is started instead.\n"); printf("\n"); printf("Available options:\n"); printf(" --coverage: collect code coverage while running the code and output results to coverage.out\n"); @@ -846,6 +679,7 @@ static void displayHelp(const char* argv0) printf(" --profile[=N]: profile the code using N Hz sampling (default 10000) and output results to profile.out\n"); printf(" --timetrace: record compiler time tracing information into trace.json\n"); printf(" --codegen: execute code using native code generation\n"); + printf(" --program-args,-a: declare start of arguments to be passed to the Luau program\n"); } static int assertionHandler(const char* expr, const char* file, int line, const char* function) @@ -860,66 +694,17 @@ int replMain(int argc, char** argv) setLuauFlagsDefault(); - CliMode mode = CliMode::Unknown; - CompileFormat compileFormat{}; +#ifdef _WIN32 + SetConsoleOutputCP(CP_UTF8); +#endif + int profile = 0; bool coverage = false; bool interactive = false; + bool codegenPerf = false; + int program_args = argc; - // Set the mode if the user has explicitly specified one. - int argStart = 1; - if (argc >= 2 && strncmp(argv[1], "--compile", strlen("--compile")) == 0) - { - argStart++; - mode = CliMode::Compile; - if (strcmp(argv[1], "--compile") == 0) - { - compileFormat = CompileFormat::Text; - } - else if (strcmp(argv[1], "--compile=binary") == 0) - { - compileFormat = CompileFormat::Binary; - } - else if (strcmp(argv[1], "--compile=text") == 0) - { - compileFormat = CompileFormat::Text; - } - else if (strcmp(argv[1], "--compile=remarks") == 0) - { - compileFormat = CompileFormat::Remarks; - } - else if (strcmp(argv[1], "--compile=codegen") == 0) - { - compileFormat = CompileFormat::Codegen; - } - else if (strcmp(argv[1], "--compile=codegenasm") == 0) - { - compileFormat = CompileFormat::CodegenAsm; - } - else if (strcmp(argv[1], "--compile=codegenir") == 0) - { - compileFormat = CompileFormat::CodegenIr; - } - else if (strcmp(argv[1], "--compile=codegenverbose") == 0) - { - compileFormat = CompileFormat::CodegenVerbose; - } - else if (strcmp(argv[1], "--compile=codegennull") == 0) - { - compileFormat = CompileFormat::CodegenNull; - } - else if (strcmp(argv[1], "--compile=null") == 0) - { - compileFormat = CompileFormat::Null; - } - else - { - fprintf(stderr, "Error: Unrecognized value for '--compile' specified.\n"); - return 1; - } - } - - for (int i = argStart; i < argc; i++) + for (int i = 1; i < argc; i++) { if (strcmp(argv[i], "-h") == 0 || strcmp(argv[i], "--help") == 0) { @@ -962,6 +747,11 @@ int replMain(int argc, char** argv) { codegen = true; } + else if (strcmp(argv[i], "--codegen-perf") == 0) + { + codegen = true; + codegenPerf = true; + } else if (strcmp(argv[i], "--coverage") == 0) { coverage = true; @@ -974,6 +764,11 @@ int replMain(int argc, char** argv) { setLuauFlags(argv[i] + 9); } + else if (strcmp(argv[i], "--program-args") == 0 || strcmp(argv[i], "-a") == 0) + { + program_args = i + 1; + break; + } else if (argv[i][0] == '-') { fprintf(stderr, "Error: Unrecognized option '%s'.\n\n", argv[i]); @@ -982,6 +777,10 @@ int replMain(int argc, char** argv) } } + program_argc = argc - program_args; + program_argv = &argv[program_args]; + + #if !defined(LUAU_ENABLE_TIME_TRACE) if (FFlag::DebugLuauTimeTracing) { @@ -990,58 +789,39 @@ int replMain(int argc, char** argv) } #endif -#if !LUA_CUSTOM_EXECUTION - if (codegen) + if (codegenPerf) { - fprintf(stderr, "To run with --codegen, Luau has to be built with LUA_CUSTOM_EXECUTION enabled\n"); - return 1; - } -#endif +#if __linux__ + char path[128]; + snprintf(path, sizeof(path), "/tmp/perf-%d.map", getpid()); - const std::vector files = getSourceFiles(argc, argv); - if (mode == CliMode::Unknown) - { - mode = files.empty() ? CliMode::Repl : CliMode::RunSourceFiles; - } + // note, there's no need to close the log explicitly as it will be closed when the process exits + FILE* codegenPerfLog = fopen(path, "w"); - if (mode != CliMode::Compile && codegen && !Luau::CodeGen::isSupported()) - { - fprintf(stderr, "Cannot enable --codegen, native code generation is not supported in current configuration\n"); + Luau::CodeGen::setPerfLog( + codegenPerfLog, + [](void* context, uintptr_t addr, unsigned size, const char* symbol) + { + fprintf(static_cast(context), "%016lx %08x %s\n", long(addr), size, symbol); + } + ); +#else + fprintf(stderr, "--codegen-perf option is only supported on Linux\n"); return 1; - } - - switch (mode) - { - case CliMode::Compile: - { -#ifdef _WIN32 - if (compileFormat == CompileFormat::Binary) - _setmode(_fileno(stdout), _O_BINARY); #endif + } - CompileStats stats = {}; - int failed = 0; - - for (const std::string& path : files) - failed += !compileFile(path.c_str(), compileFormat, stats); + if (codegen && !Luau::CodeGen::isSupported()) + fprintf(stderr, "Warning: Native code generation is not supported in current configuration\n"); - if (compileFormat == CompileFormat::Null) - printf("Compiled %d KLOC into %d KB bytecode (read %.2fs, parse %.2fs, compile %.2fs)\n", int(stats.lines / 1000), - int(stats.bytecode / 1024), stats.readTime, stats.parseTime, stats.compileTime); - else if (compileFormat == CompileFormat::CodegenNull) - printf("Compiled %d KLOC into %d KB bytecode => %d KB native code (%.2fx) (read %.2fs, parse %.2fs, compile %.2fs, codegen %.2fs)\n", - int(stats.lines / 1000), int(stats.bytecode / 1024), int(stats.codegen / 1024), - stats.bytecode == 0 ? 0.0 : double(stats.codegen) / double(stats.bytecode), stats.readTime, stats.parseTime, stats.compileTime, - stats.codegenTime); + const std::vector files = getSourceFiles(argc, argv); - return failed ? 1 : 0; - } - case CliMode::Repl: + if (files.empty()) { runRepl(); return 0; } - case CliMode::RunSourceFiles: + else { std::unique_ptr globalState(luaL_newstate(), lua_close); lua_State* L = globalState.get(); @@ -1073,9 +853,4 @@ int replMain(int argc, char** argv) return failed ? 1 : 0; } - case CliMode::Unknown: - default: - LUAU_ASSERT(!"Unhandled cli mode."); - return 1; - } } diff --git a/third_party/luau/CLI/Require.cpp b/third_party/luau/CLI/Require.cpp new file mode 100644 index 00000000..5de78a4a --- /dev/null +++ b/third_party/luau/CLI/Require.cpp @@ -0,0 +1,295 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Require.h" + +#include "FileUtils.h" +#include "Luau/Common.h" + +#include +#include +#include + +RequireResolver::RequireResolver(lua_State* L, std::string path) + : pathToResolve(std::move(path)) + , L(L) +{ + lua_Debug ar; + lua_getinfo(L, 1, "s", &ar); + sourceChunkname = ar.source; + + if (!isRequireAllowed(sourceChunkname)) + luaL_errorL(L, "require is not supported in this context"); + + if (isAbsolutePath(pathToResolve)) + luaL_argerrorL(L, 1, "cannot require an absolute path"); + + std::replace(pathToResolve.begin(), pathToResolve.end(), '\\', '/'); + + substituteAliasIfPresent(pathToResolve); +} + +[[nodiscard]] RequireResolver::ResolvedRequire RequireResolver::resolveRequire(lua_State* L, std::string path) +{ + RequireResolver resolver(L, std::move(path)); + ModuleStatus status = resolver.findModule(); + if (status != ModuleStatus::FileRead) + return ResolvedRequire{status}; + else + return ResolvedRequire{status, std::move(resolver.chunkname), std::move(resolver.absolutePath), std::move(resolver.sourceCode)}; +} + +RequireResolver::ModuleStatus RequireResolver::findModule() +{ + resolveAndStoreDefaultPaths(); + + // Put _MODULES table on stack for checking and saving to the cache + luaL_findtable(L, LUA_REGISTRYINDEX, "_MODULES", 1); + + RequireResolver::ModuleStatus moduleStatus = findModuleImpl(); + + if (moduleStatus != RequireResolver::ModuleStatus::NotFound) + return moduleStatus; + + if (!shouldSearchPathsArray()) + return moduleStatus; + + if (!isConfigFullyResolved) + parseNextConfig(); + + // Index-based iteration because std::iterator may be invalidated if config.paths is reallocated + for (size_t i = 0; i < config.paths.size(); ++i) + { + // "placeholder" acts as a requiring file in the relevant directory + std::optional absolutePathOpt = resolvePath(pathToResolve, joinPaths(config.paths[i], "placeholder")); + + if (!absolutePathOpt) + luaL_errorL(L, "error requiring module"); + + chunkname = *absolutePathOpt; + absolutePath = *absolutePathOpt; + + moduleStatus = findModuleImpl(); + + if (moduleStatus != RequireResolver::ModuleStatus::NotFound) + return moduleStatus; + + // Before finishing the loop, parse more config files if there are any + if (i == config.paths.size() - 1 && !isConfigFullyResolved) + parseNextConfig(); // could reallocate config.paths when paths are parsed and added + } + + return RequireResolver::ModuleStatus::NotFound; +} + +RequireResolver::ModuleStatus RequireResolver::findModuleImpl() +{ + static const std::array possibleSuffixes = {".luau", ".lua", "/init.luau", "/init.lua"}; + + size_t unsuffixedAbsolutePathSize = absolutePath.size(); + + for (const char* possibleSuffix : possibleSuffixes) + { + absolutePath += possibleSuffix; + + // Check cache for module + lua_getfield(L, -1, absolutePath.c_str()); + if (!lua_isnil(L, -1)) + { + return ModuleStatus::Cached; + } + lua_pop(L, 1); + + // Try to read the matching file + std::optional source = readFile(absolutePath); + if (source) + { + chunkname = "=" + chunkname + possibleSuffix; + sourceCode = *source; + return ModuleStatus::FileRead; + } + + absolutePath.resize(unsuffixedAbsolutePathSize); // truncate to remove suffix + } + + return ModuleStatus::NotFound; +} + +bool RequireResolver::isRequireAllowed(std::string_view sourceChunkname) +{ + LUAU_ASSERT(!sourceChunkname.empty()); + return (sourceChunkname[0] == '=' || sourceChunkname[0] == '@'); +} + +bool RequireResolver::shouldSearchPathsArray() +{ + return !isAbsolutePath(pathToResolve) && !isExplicitlyRelative(pathToResolve); +} + +void RequireResolver::resolveAndStoreDefaultPaths() +{ + if (!isAbsolutePath(pathToResolve)) + { + std::string chunknameContext = getRequiringContextRelative(); + std::optional absolutePathContext = getRequiringContextAbsolute(); + + if (!absolutePathContext) + luaL_errorL(L, "error requiring module"); + + // resolvePath automatically sanitizes/normalizes the paths + std::optional chunknameOpt = resolvePath(pathToResolve, chunknameContext); + std::optional absolutePathOpt = resolvePath(pathToResolve, *absolutePathContext); + + if (!chunknameOpt || !absolutePathOpt) + luaL_errorL(L, "error requiring module"); + + chunkname = std::move(*chunknameOpt); + absolutePath = std::move(*absolutePathOpt); + } + else + { + // Here we must explicitly sanitize, as the path is taken as is + std::optional sanitizedPath = normalizePath(pathToResolve); + if (!sanitizedPath) + luaL_errorL(L, "error requiring module"); + + chunkname = *sanitizedPath; + absolutePath = std::move(*sanitizedPath); + } +} + +std::optional RequireResolver::getRequiringContextAbsolute() +{ + std::string requiringFile; + if (isAbsolutePath(sourceChunkname.substr(1))) + { + // We already have an absolute path for the requiring file + requiringFile = sourceChunkname.substr(1); + } + else + { + // Requiring file's stored path is relative to the CWD, must make absolute + std::optional cwd = getCurrentWorkingDirectory(); + if (!cwd) + return std::nullopt; + + if (sourceChunkname.substr(1) == "stdin") + { + // Require statement is being executed from REPL input prompt + // The requiring context is the pseudo-file "stdin" in the CWD + requiringFile = joinPaths(*cwd, "stdin"); + } + else + { + // Require statement is being executed in a file, must resolve relative to CWD + std::optional requiringFileOpt = resolvePath(sourceChunkname.substr(1), joinPaths(*cwd, "stdin")); + if (!requiringFileOpt) + return std::nullopt; + + requiringFile = *requiringFileOpt; + } + } + std::replace(requiringFile.begin(), requiringFile.end(), '\\', '/'); + return requiringFile; +} + +std::string RequireResolver::getRequiringContextRelative() +{ + std::string baseFilePath; + if (sourceChunkname.substr(1) != "stdin") + baseFilePath = sourceChunkname.substr(1); + + return baseFilePath; +} + +void RequireResolver::substituteAliasIfPresent(std::string& path) +{ + if (path.size() < 1 || path[0] != '@') + return; + std::string potentialAlias = path.substr(1, path.find_first_of("\\/")); + + // Not worth searching when potentialAlias cannot be an alias + if (!Luau::isValidAlias(potentialAlias)) + luaL_errorL(L, "@%s is not a valid alias", potentialAlias.c_str()); + + std::optional alias = getAlias(potentialAlias); + if (alias) + { + path = *alias + path.substr(potentialAlias.size() + 1); + } + else + { + luaL_errorL(L, "@%s is not a valid alias", potentialAlias.c_str()); + } +} + +std::optional RequireResolver::getAlias(std::string alias) +{ + std::transform( + alias.begin(), + alias.end(), + alias.begin(), + [](unsigned char c) + { + return ('A' <= c && c <= 'Z') ? (c + ('a' - 'A')) : c; + } + ); + while (!config.aliases.count(alias) && !isConfigFullyResolved) + { + parseNextConfig(); + } + if (!config.aliases.count(alias) && isConfigFullyResolved) + return std::nullopt; // could not find alias + + return resolvePath(config.aliases[alias], joinPaths(lastSearchedDir, Luau::kConfigName)); +} + +void RequireResolver::parseNextConfig() +{ + if (isConfigFullyResolved) + return; // no config files left to parse + + std::optional directory; + if (lastSearchedDir.empty()) + { + std::optional requiringFile = getRequiringContextAbsolute(); + if (!requiringFile) + luaL_errorL(L, "error requiring module"); + + directory = getParentPath(*requiringFile); + } + else + directory = getParentPath(lastSearchedDir); + + if (directory) + { + lastSearchedDir = *directory; + parseConfigInDirectory(*directory); + } + else + isConfigFullyResolved = true; +} + +void RequireResolver::parseConfigInDirectory(const std::string& directory) +{ + std::string configPath = joinPaths(directory, Luau::kConfigName); + + size_t numPaths = config.paths.size(); + + if (std::optional contents = readFile(configPath)) + { + std::optional error = Luau::parseConfig(*contents, config); + if (error) + luaL_errorL(L, "error parsing %s (%s)", configPath.c_str(), (*error).c_str()); + } + + // Resolve any newly obtained relative paths in "paths" in relation to configPath + for (auto it = config.paths.begin() + numPaths; it != config.paths.end(); ++it) + { + if (!isAbsolutePath(*it)) + { + if (std::optional resolvedPath = resolvePath(*it, configPath)) + *it = std::move(*resolvedPath); + else + luaL_errorL(L, "error requiring module"); + } + } +} diff --git a/third_party/luau/CLI/Require.h b/third_party/luau/CLI/Require.h new file mode 100644 index 00000000..ae96834f --- /dev/null +++ b/third_party/luau/CLI/Require.h @@ -0,0 +1,62 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "lua.h" +#include "lualib.h" + +#include "Luau/Config.h" + +#include +#include + +class RequireResolver +{ +public: + std::string chunkname; + std::string absolutePath; + std::string sourceCode; + + enum class ModuleStatus + { + Cached, + FileRead, + NotFound + }; + + struct ResolvedRequire + { + ModuleStatus status; + std::string chunkName; + std::string absolutePath; + std::string sourceCode; + }; + + [[nodiscard]] ResolvedRequire static resolveRequire(lua_State* L, std::string path); + +private: + std::string pathToResolve; + std::string_view sourceChunkname; + + RequireResolver(lua_State* L, std::string path); + + ModuleStatus findModule(); + lua_State* L; + Luau::Config config; + std::string lastSearchedDir; + bool isConfigFullyResolved = false; + + bool isRequireAllowed(std::string_view sourceChunkname); + bool shouldSearchPathsArray(); + + void resolveAndStoreDefaultPaths(); + ModuleStatus findModuleImpl(); + + std::optional getRequiringContextAbsolute(); + std::string getRequiringContextRelative(); + + void substituteAliasIfPresent(std::string& path); + std::optional getAlias(std::string alias); + + void parseNextConfig(); + void parseConfigInDirectory(const std::string& path); +}; diff --git a/third_party/luau/CodeGen/include/Luau/AddressA64.h b/third_party/luau/CodeGen/include/Luau/AddressA64.h index acb64e39..fbac3ec3 100644 --- a/third_party/luau/CodeGen/include/Luau/AddressA64.h +++ b/third_party/luau/CodeGen/include/Luau/AddressA64.h @@ -14,13 +14,10 @@ namespace A64 enum class AddressKindA64 : uint8_t { - imm, // reg + imm - reg, // reg + reg - - // TODO: - // reg + reg << shift - // reg + sext(reg) << shift - // reg + uext(reg) << shift + reg, // reg + reg + imm, // reg + imm + pre, // reg + imm, reg += imm + post, // reg, reg += imm }; struct AddressA64 @@ -29,13 +26,14 @@ struct AddressA64 // For example, ldr x0, [reg+imm] is limited to 8 KB offsets assuming imm is divisible by 8, but loading into w0 reduces the range to 4 KB static constexpr size_t kMaxOffset = 1023; - constexpr AddressA64(RegisterA64 base, int off = 0) - : kind(AddressKindA64::imm) + constexpr AddressA64(RegisterA64 base, int off = 0, AddressKindA64 kind = AddressKindA64::imm) + : kind(kind) , base(base) , offset(xzr) , data(off) { - LUAU_ASSERT(base.kind == KindA64::x || base == sp); + CODEGEN_ASSERT(base.kind == KindA64::x || base == sp); + CODEGEN_ASSERT(kind != AddressKindA64::reg); } constexpr AddressA64(RegisterA64 base, RegisterA64 offset) @@ -44,8 +42,8 @@ struct AddressA64 , offset(offset) , data(0) { - LUAU_ASSERT(base.kind == KindA64::x); - LUAU_ASSERT(offset.kind == KindA64::x); + CODEGEN_ASSERT(base.kind == KindA64::x); + CODEGEN_ASSERT(offset.kind == KindA64::x); } AddressKindA64 kind; diff --git a/third_party/luau/CodeGen/include/Luau/AssemblyBuilderA64.h b/third_party/luau/CodeGen/include/Luau/AssemblyBuilderA64.h index 26be11c5..a4d857a4 100644 --- a/third_party/luau/CodeGen/include/Luau/AssemblyBuilderA64.h +++ b/third_party/luau/CodeGen/include/Luau/AssemblyBuilderA64.h @@ -56,7 +56,7 @@ class AssemblyBuilderA64 void eor(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, int shift = 0); void bic(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, int shift = 0); void tst(RegisterA64 src1, RegisterA64 src2, int shift = 0); - void mvn(RegisterA64 dst, RegisterA64 src); + void mvn_(RegisterA64 dst, RegisterA64 src); // Bitwise with immediate // Note: immediate must have a single contiguous sequence of 1 bits set of length 1..31 @@ -72,6 +72,7 @@ class AssemblyBuilderA64 void ror(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2); void clz(RegisterA64 dst, RegisterA64 src); void rbit(RegisterA64 dst, RegisterA64 src); + void rev(RegisterA64 dst, RegisterA64 src); // Shifts with immediates // Note: immediate value must be in [0, 31] or [0, 63] range based on register type @@ -80,6 +81,12 @@ class AssemblyBuilderA64 void asr(RegisterA64 dst, RegisterA64 src1, uint8_t src2); void ror(RegisterA64 dst, RegisterA64 src1, uint8_t src2); + // Bitfields + void ubfiz(RegisterA64 dst, RegisterA64 src, uint8_t f, uint8_t w); + void ubfx(RegisterA64 dst, RegisterA64 src, uint8_t f, uint8_t w); + void sbfiz(RegisterA64 dst, RegisterA64 src, uint8_t f, uint8_t w); + void sbfx(RegisterA64 dst, RegisterA64 src, uint8_t f, uint8_t w); + // Load // Note: paired loads are currently omitted for simplicity void ldr(RegisterA64 dst, AddressA64 src); @@ -118,12 +125,12 @@ class AssemblyBuilderA64 // Address of code (label) void adr(RegisterA64 dst, Label& label); - // Floating-point scalar moves + // Floating-point scalar/vector moves // Note: constant must be compatible with immediate floating point moves (see isFmovSupported) void fmov(RegisterA64 dst, RegisterA64 src); void fmov(RegisterA64 dst, double src); - // Floating-point scalar math + // Floating-point scalar/vector math void fabs(RegisterA64 dst, RegisterA64 src); void fadd(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2); void fdiv(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2); @@ -132,6 +139,11 @@ class AssemblyBuilderA64 void fsqrt(RegisterA64 dst, RegisterA64 src); void fsub(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2); + // Vector component manipulation + void ins_4s(RegisterA64 dst, RegisterA64 src, uint8_t index); + void ins_4s(RegisterA64 dst, uint8_t dstIndex, RegisterA64 src, uint8_t srcIndex); + void dup_4s(RegisterA64 dst, RegisterA64 src, uint8_t index); + // Floating-point rounding and conversions void frinta(RegisterA64 dst, RegisterA64 src); void frintm(RegisterA64 dst, RegisterA64 src); @@ -151,6 +163,8 @@ class AssemblyBuilderA64 void fcmpz(RegisterA64 src); void fcsel(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, ConditionA64 cond); + void udf(); + // Run final checks bool finalize(); @@ -163,7 +177,7 @@ class AssemblyBuilderA64 // Extracts code offset (in bytes) from label uint32_t getLabelOffset(const Label& label) { - LUAU_ASSERT(label.location != ~0u); + CODEGEN_ASSERT(label.location != ~0u); return label.location * 4; } @@ -171,6 +185,8 @@ class AssemblyBuilderA64 uint32_t getCodeSize() const; + unsigned getInstructionCount() const; + // Resulting data and code that need to be copied over one after the other // The *end* of 'data' has to be aligned to 16 bytes, this will also align 'code' std::vector data; @@ -199,7 +215,7 @@ class AssemblyBuilderA64 void placeR1(const char* name, RegisterA64 dst, RegisterA64 src, uint32_t op); void placeI12(const char* name, RegisterA64 dst, RegisterA64 src1, int src2, uint8_t op); void placeI16(const char* name, RegisterA64 dst, int src, uint8_t op, int shift = 0); - void placeA(const char* name, RegisterA64 dst, AddressA64 src, uint8_t op, uint8_t size, int sizelog); + void placeA(const char* name, RegisterA64 dst, AddressA64 src, uint16_t opsize, int sizelog); void placeB(const char* name, Label& label, uint8_t op); void placeBC(const char* name, Label& label, uint8_t op, uint8_t cond); void placeBCR(const char* name, Label& label, uint8_t op, RegisterA64 cond); @@ -212,7 +228,9 @@ class AssemblyBuilderA64 void placeFCMP(const char* name, RegisterA64 src1, RegisterA64 src2, uint8_t op, uint8_t opc); void placeFMOV(const char* name, RegisterA64 dst, double src, uint32_t op); void placeBM(const char* name, RegisterA64 dst, RegisterA64 src1, uint32_t src2, uint8_t op); - void placeBFM(const char* name, RegisterA64 dst, RegisterA64 src1, uint8_t src2, uint8_t op, int immr, int imms); + void placeBFM(const char* name, RegisterA64 dst, RegisterA64 src1, int src2, uint8_t op, int immr, int imms); + void placeER(const char* name, RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, uint8_t op, int shift); + void placeVR(const char* name, RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, uint16_t op, uint8_t op2); void place(uint32_t word); diff --git a/third_party/luau/CodeGen/include/Luau/AssemblyBuilderX64.h b/third_party/luau/CodeGen/include/Luau/AssemblyBuilderX64.h index e162cd3e..c52d95c5 100644 --- a/third_party/luau/CodeGen/include/Luau/AssemblyBuilderX64.h +++ b/third_party/luau/CodeGen/include/Luau/AssemblyBuilderX64.h @@ -2,6 +2,7 @@ #pragma once #include "Luau/Common.h" +#include "Luau/DenseHash.h" #include "Luau/Label.h" #include "Luau/ConditionX64.h" #include "Luau/OperandX64.h" @@ -84,6 +85,7 @@ class AssemblyBuilderX64 void test(OperandX64 lhs, OperandX64 rhs); void lea(OperandX64 lhs, OperandX64 rhs); void setcc(ConditionX64 cond, OperandX64 op); + void cmov(ConditionX64 cond, RegisterX64 lhs, OperandX64 rhs); void push(OperandX64 op); void pop(OperandX64 op); @@ -97,10 +99,14 @@ class AssemblyBuilderX64 void call(Label& label); void call(OperandX64 op); + void lea(RegisterX64 lhs, Label& label); + void int3(); + void ud2(); void bsr(RegisterX64 dst, OperandX64 src); void bsf(RegisterX64 dst, OperandX64 src); + void bswap(RegisterX64 dst); // Code alignment void nop(uint32_t length = 1); @@ -113,13 +119,18 @@ class AssemblyBuilderX64 void vaddss(OperandX64 dst, OperandX64 src1, OperandX64 src2); void vsubsd(OperandX64 dst, OperandX64 src1, OperandX64 src2); + void vsubps(OperandX64 dst, OperandX64 src1, OperandX64 src2); void vmulsd(OperandX64 dst, OperandX64 src1, OperandX64 src2); + void vmulps(OperandX64 dst, OperandX64 src1, OperandX64 src2); void vdivsd(OperandX64 dst, OperandX64 src1, OperandX64 src2); + void vdivps(OperandX64 dst, OperandX64 src1, OperandX64 src2); + void vandps(OperandX64 dst, OperandX64 src1, OperandX64 src2); void vandpd(OperandX64 dst, OperandX64 src1, OperandX64 src2); void vandnpd(OperandX64 dst, OperandX64 src1, OperandX64 src2); void vxorpd(OperandX64 dst, OperandX64 src1, OperandX64 src2); + void vorps(OperandX64 dst, OperandX64 src1, OperandX64 src2); void vorpd(OperandX64 dst, OperandX64 src1, OperandX64 src2); void vucomisd(OperandX64 src1, OperandX64 src2); @@ -127,6 +138,7 @@ class AssemblyBuilderX64 void vcvttsd2si(OperandX64 dst, OperandX64 src); void vcvtsi2sd(OperandX64 dst, OperandX64 src1, OperandX64 src2); void vcvtsd2ss(OperandX64 dst, OperandX64 src1, OperandX64 src2); + void vcvtss2sd(OperandX64 dst, OperandX64 src1, OperandX64 src2); void vroundsd(OperandX64 dst, OperandX64 src1, OperandX64 src2, RoundingModeX64 roundingMode); // inexact @@ -152,6 +164,8 @@ class AssemblyBuilderX64 void vblendvpd(RegisterX64 dst, RegisterX64 src1, OperandX64 mask, RegisterX64 src3); + void vpshufps(RegisterX64 dst, RegisterX64 src1, OperandX64 src2, uint8_t shuffle); + void vpinsrd(RegisterX64 dst, RegisterX64 src1, OperandX64 src2, uint8_t offset); // Run final checks bool finalize(); @@ -165,14 +179,16 @@ class AssemblyBuilderX64 // Extracts code offset (in bytes) from label uint32_t getLabelOffset(const Label& label) { - LUAU_ASSERT(label.location != ~0u); + CODEGEN_ASSERT(label.location != ~0u); return label.location; } // Constant allocation (uses rip-relative addressing) + OperandX64 i32(int32_t value); OperandX64 i64(int64_t value); OperandX64 f32(float value); OperandX64 f64(double value); + OperandX64 u32x4(uint32_t x, uint32_t y, uint32_t z, uint32_t w); OperandX64 f32x4(float x, float y, float z, float w); OperandX64 f64x2(double x, double y); OperandX64 bytes(const void* ptr, size_t size, size_t align = 8); @@ -181,6 +197,8 @@ class AssemblyBuilderX64 uint32_t getCodeSize() const; + unsigned getInstructionCount() const; + // Resulting data and code that need to be copied over one after the other // The *end* of 'data' has to be aligned to 16 bytes, this will also align 'code' std::vector data; @@ -194,8 +212,19 @@ class AssemblyBuilderX64 private: // Instruction archetypes - void placeBinary(const char* name, OperandX64 lhs, OperandX64 rhs, uint8_t codeimm8, uint8_t codeimm, uint8_t codeimmImm8, uint8_t code8rev, - uint8_t coderev, uint8_t code8, uint8_t code, uint8_t opreg); + void placeBinary( + const char* name, + OperandX64 lhs, + OperandX64 rhs, + uint8_t codeimm8, + uint8_t codeimm, + uint8_t codeimmImm8, + uint8_t code8rev, + uint8_t coderev, + uint8_t code8, + uint8_t code, + uint8_t opreg + ); void placeBinaryRegMemAndImm(OperandX64 lhs, OperandX64 rhs, uint8_t code8, uint8_t code, uint8_t codeImm8, uint8_t opreg); void placeBinaryRegAndRegMem(OperandX64 lhs, OperandX64 rhs, uint8_t code8, uint8_t code); void placeBinaryRegMemAndReg(OperandX64 lhs, OperandX64 rhs, uint8_t code8, uint8_t code); @@ -210,7 +239,16 @@ class AssemblyBuilderX64 void placeAvx(const char* name, OperandX64 dst, OperandX64 src, uint8_t code, uint8_t coderev, bool setW, uint8_t mode, uint8_t prefix); void placeAvx(const char* name, OperandX64 dst, OperandX64 src1, OperandX64 src2, uint8_t code, bool setW, uint8_t mode, uint8_t prefix); void placeAvx( - const char* name, OperandX64 dst, OperandX64 src1, OperandX64 src2, uint8_t imm8, uint8_t code, bool setW, uint8_t mode, uint8_t prefix); + const char* name, + OperandX64 dst, + OperandX64 src1, + OperandX64 src2, + uint8_t imm8, + uint8_t code, + bool setW, + uint8_t mode, + uint8_t prefix + ); // Instruction components void placeRegAndModRegMem(OperandX64 lhs, OperandX64 rhs, int32_t extraCodeBytes = 0); @@ -222,6 +260,7 @@ class AssemblyBuilderX64 void placeVex(OperandX64 dst, OperandX64 src1, OperandX64 src2, bool setW, uint8_t mode, uint8_t prefix); void placeImm8Or32(int32_t imm); void placeImm8(int32_t imm); + void placeImm16(int16_t imm); void placeImm32(int32_t imm); void placeImm64(int64_t imm); void placeLabel(Label& label); @@ -241,6 +280,7 @@ class AssemblyBuilderX64 LUAU_NOINLINE void log(const char* opcode, OperandX64 op1, OperandX64 op2, OperandX64 op3, OperandX64 op4); LUAU_NOINLINE void log(Label label); LUAU_NOINLINE void log(const char* opcode, Label label); + LUAU_NOINLINE void log(const char* opcode, RegisterX64 reg, Label label); void log(OperandX64 op); const char* getSizeName(SizeX64 size) const; @@ -250,12 +290,17 @@ class AssemblyBuilderX64 std::vector