Skip to content

Commit

Permalink
Merge pull request #1548 from xlsynth:cdleary/2024-08-18-enum-alias
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 665048673
  • Loading branch information
copybara-github committed Aug 20, 2024
2 parents 5923743 + 0b1a69a commit 5013b53
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 28 deletions.
5 changes: 5 additions & 0 deletions xls/dslx/tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,11 @@ dslx_lang_test(
dslx_deps = [":constexpr_dslx"],
)

dslx_lang_test(
name = "import_enum_alias",
dslx_deps = [":mod_enum_importer_dslx"],
)

dslx_lang_test(name = "map")

dslx_lang_test(name = "multiplies")
Expand Down
21 changes: 21 additions & 0 deletions xls/dslx/tests/import_enum_alias.x
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// Copyright 2024 The XLS Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

import xls.dslx.tests.mod_enum_importer;

// A module that imports an enum alias and uses it.
fn main() -> mod_enum_importer::MyEnumAlias { mod_enum_importer::MyEnumAlias::FOO }

#[test]
fn test_main() { assert_eq(main(), mod_enum_importer::MyEnumAlias::FOO) }
8 changes: 4 additions & 4 deletions xls/dslx/tests/mod_enum_importer.x
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@

import xls.dslx.tests.mod_imported;

type MyEnum = mod_imported::MyEnum;
pub type MyEnumAlias = mod_imported::MyEnum;

fn main(x: u8) -> MyEnum { x as MyEnum }
fn main(x: u8) -> MyEnumAlias { x as MyEnumAlias }

#[test]
fn main_test() {
assert_eq(main(u8:42), MyEnum::FOO);
assert_eq(main(u8:64), MyEnum::BAR);
assert_eq(main(u8:42), MyEnumAlias::FOO);
assert_eq(main(u8:64), MyEnumAlias::BAR);
}
64 changes: 43 additions & 21 deletions xls/dslx/type_system/deduce_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,22 +58,29 @@ using ColonRefSubjectT =

} // namespace

// Resolves a `TypeAlias` AST node to a `ColonRef` subject -- this requires us
// to traverse through aliases transitively to find a subject.
//
// Has to be an enum or builtin-type name, given the context we're in: looking
// for _values_ hanging off, e.g. in service of a `::` ref.
//
// Note: the returned AST node may not be from the same module that the
// original `type_alias` was from.
static absl::StatusOr<
std::variant<EnumDef*, BuiltinNameDef*, ArrayTypeAnnotation*>>
ResolveTypeAliasToDirectColonRefSubject(ImportData* import_data,
const TypeInfo* type_info,
TypeAlias* type_def) {
VLOG(5) << "ResolveTypeDefToDirectColonRefSubject; type_def: `"
<< type_def->ToString() << "`";

TypeDefinition td = type_def;
while (std::holds_alternative<TypeAlias*>(td)) {
TypeAlias* next_type_alias = std::get<TypeAlias*>(td);
VLOG(5) << "TypeAlias: `" << next_type_alias->ToString() << "`";
TypeAlias* type_alias) {
VLOG(5) << "ResolveTypeDefToDirectColonRefSubject; type_alias: `"
<< type_alias->ToString() << "`";

// Resolve through all the transitive aliases.
TypeDefinition current_type_definition = type_alias;
while (std::holds_alternative<TypeAlias*>(current_type_definition)) {
TypeAlias* next_type_alias = std::get<TypeAlias*>(current_type_definition);
VLOG(5) << " TypeAlias: `" << next_type_alias->ToString() << "`";
TypeAnnotation& type = next_type_alias->type_annotation();
VLOG(5) << "TypeAnnotation: `" << type.ToString() << "`";
VLOG(5) << " TypeAnnotation: `" << type.ToString() << "`";

if (auto* bti = dynamic_cast<BuiltinTypeAnnotation*>(&type);
bti != nullptr) {
Expand All @@ -89,37 +96,49 @@ ResolveTypeAliasToDirectColonRefSubject(ImportData* import_data,
// support parametric TypeDefs.
XLS_RET_CHECK(type_ref_type != nullptr)
<< type.ToString() << " :: " << type.GetNodeTypeName();
VLOG(5) << "TypeRefTypeAnnotation: `" << type_ref_type->ToString() << "`";
VLOG(5) << " TypeRefTypeAnnotation: `" << type_ref_type->ToString() << "`";

td = type_ref_type->type_ref()->type_definition();
current_type_definition = type_ref_type->type_ref()->type_definition();
}

if (std::holds_alternative<ColonRef*>(td)) {
ColonRef* colon_ref = std::get<ColonRef*>(td);
XLS_ASSIGN_OR_RETURN(auto subject, ResolveColonRefSubjectForTypeChecking(
import_data, type_info, colon_ref));
VLOG(5) << absl::StreamFormat(
"ResolveTypeDefToDirectColonRefSubject; arrived at type definition: `%s`",
ToAstNode(current_type_definition)->ToString());

if (std::holds_alternative<ColonRef*>(current_type_definition)) {
ColonRef* colon_ref = std::get<ColonRef*>(current_type_definition);
type_info = import_data->GetRootTypeInfo(colon_ref->owner()).value();
XLS_ASSIGN_OR_RETURN(ColonRefSubjectT subject,
ResolveColonRefSubjectForTypeChecking(
import_data, type_info, colon_ref));
XLS_RET_CHECK(std::holds_alternative<Module*>(subject));
Module* module = std::get<Module*>(subject);
XLS_ASSIGN_OR_RETURN(td, module->GetTypeDefinition(colon_ref->attr()));

if (std::holds_alternative<TypeAlias*>(td)) {
// Grab the type definition being referred to by the `ColonRef` -- this is
// what we now have to traverse to (or we may have arrived).
XLS_ASSIGN_OR_RETURN(current_type_definition,
module->GetTypeDefinition(colon_ref->attr()));

if (std::holds_alternative<TypeAlias*>(current_type_definition)) {
TypeAlias* new_alias = std::get<TypeAlias*>(current_type_definition);
XLS_RET_CHECK_EQ(new_alias->owner(), module);
// We need to get the right type info for the enum's containing module. We
// can get the top-level module since [currently?] enums can't be
// parameterized.
type_info = import_data->GetRootTypeInfo(module).value();
return ResolveTypeAliasToDirectColonRefSubject(import_data, type_info,
std::get<TypeAlias*>(td));
new_alias);
}
}

if (!std::holds_alternative<EnumDef*>(td)) {
if (!std::holds_alternative<EnumDef*>(current_type_definition)) {
return absl::InternalError(
"ResolveTypeDefToDirectColonRefSubject() can only be called when the "
"TypeAlias "
"directory or indirectly refers to an EnumDef.");
}

return std::get<EnumDef*>(td);
return std::get<EnumDef*>(current_type_definition);
}

absl::Status TryEnsureFitsInType(const Number& number,
Expand Down Expand Up @@ -321,7 +340,10 @@ static absl::StatusOr<ColonRefSubjectT> ResolveColonRefNameRefSubject(
absl::StatusOr<ColonRefSubjectT> ResolveColonRefSubjectForTypeChecking(
ImportData* import_data, const TypeInfo* type_info,
const ColonRef* colon_ref) {
VLOG(5) << "ResolveColonRefSubject for " << colon_ref->ToString();
XLS_RET_CHECK_EQ(colon_ref->owner(), type_info->module());

VLOG(5) << absl::StreamFormat("ResolveColonRefSubject for `%s`",
colon_ref->ToString());

// If the subject is a name reference we use a helper routine.
if (std::holds_alternative<NameRef*>(colon_ref->subject())) {
Expand Down
6 changes: 3 additions & 3 deletions xls/dslx/type_system/type_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -531,9 +531,9 @@ void TypeInfo::AddImport(Import* import, Module* module, TypeInfo* type_info) {
}

std::optional<const ImportedInfo*> TypeInfo::GetImported(Import* import) const {
CHECK_EQ(import->owner(), module_)
<< "Import node from: " << import->owner()->name() << " vs TypeInfo for "
<< module_->name();
CHECK_EQ(import->owner(), module_) << absl::StreamFormat(
"Import node is owned by: `%s` vs this TypeInfo is for `%s`",
import->owner()->name(), module_->name());
auto* self = GetRoot();
auto it = self->imports_.find(import);
if (it == self->imports_.end()) {
Expand Down
3 changes: 3 additions & 0 deletions xls/dslx/type_system/type_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,9 @@ class TypeInfo {
// Note that added type information and such will generally be owned by the
// import cache.
void AddImport(Import* import, Module* module, TypeInfo* type_info);

// Returns information on the imported module (its module AST node and
// top-level type information).
std::optional<const ImportedInfo*> GetImported(Import* import) const;
absl::StatusOr<const ImportedInfo*> GetImportedOrError(Import* import) const;
const absl::flat_hash_map<Import*, ImportedInfo>& imports() const {
Expand Down

0 comments on commit 5013b53

Please sign in to comment.