diff --git a/lib/factory_bot/definition.rb b/lib/factory_bot/definition.rb index 0df62d6d..668e637b 100644 --- a/lib/factory_bot/definition.rb +++ b/lib/factory_bot/definition.rb @@ -2,6 +2,7 @@ module FactoryBot # @api private class Definition attr_reader :defined_traits, :declarations, :name, :registered_enums + attr_accessor :klass def initialize(name, base_traits = []) @name = name @@ -52,6 +53,7 @@ def compile(klass = nil) declarations.attributes defined_traits.each do |defined_trait| + defined_trait.klass ||= klass base_traits.each { |bt| bt.define_trait defined_trait } additional_traits.each { |at| at.define_trait defined_trait } end @@ -62,7 +64,7 @@ def compile(klass = nil) name: name, attributes: declarations.attributes, traits: defined_traits, - class: klass + class: klass || self.klass } end end diff --git a/lib/factory_bot/trait.rb b/lib/factory_bot/trait.rb index 710783b3..d45a99c8 100644 --- a/lib/factory_bot/trait.rb +++ b/lib/factory_bot/trait.rb @@ -15,7 +15,7 @@ def initialize(name, &block) end delegate :add_callback, :declare_attribute, :to_create, :define_trait, :constructor, - :callbacks, :attributes, to: :@definition + :callbacks, :attributes, :klass, :klass=, to: :@definition def names [@name] diff --git a/spec/acceptance/activesupport_instrumentation_spec.rb b/spec/acceptance/activesupport_instrumentation_spec.rb index 5bc13edf..39e5dbed 100644 --- a/spec/acceptance/activesupport_instrumentation_spec.rb +++ b/spec/acceptance/activesupport_instrumentation_spec.rb @@ -111,12 +111,17 @@ def subscribed(callback, *args) callback = ->(_name, _start, _finish, _id, payload) { tracked_payloads << payload } ActiveSupport::Notifications.subscribed(callback, "factory_bot.compile_factory") do - FactoryBot.build(:user) + FactoryBot.build(:user, :special) end - payload = tracked_payloads.detect { |payload| payload[:name] == :user } - expect(payload[:class]).to eq(User) - expect(payload[:attributes].map(&:name)).to eq([:email, :name]) - expect(payload[:traits].map(&:name)).to eq(["special"]) + user_payload = tracked_payloads.detect { |payload| payload[:name] == :user } + expect(user_payload[:class]).to eq(User) + expect(user_payload[:attributes].map(&:name)).to eq([:email, :name]) + expect(user_payload[:traits].map(&:name)).to eq(["special"]) + + special_payload = tracked_payloads.detect { |payload| payload[:name] == "special" } + expect(special_payload[:class]).to eq(User) + expect(special_payload[:attributes].map(&:name)).to eq([:name]) + expect(special_payload[:traits].map(&:name)).to eq(["special"]) end end