Skip to content

Commit

Permalink
updates
Browse files Browse the repository at this point in the history
  • Loading branch information
qingyang-hu committed Aug 8, 2024
1 parent 9c660ad commit 80e941f
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 58 deletions.
99 changes: 53 additions & 46 deletions mongo/integration/mtest/mongotest.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,11 @@ type WriteConcernErrorData struct {
ErrInfo bson.Raw `bson:"errInfo,omitempty"`
}

type failPoint struct {
name string
client *mongo.Client
}

// T is a wrapper around testing.T.
type T struct {
// connsCheckedOut is the net number of connections checked out during test execution.
Expand All @@ -103,7 +108,7 @@ type T struct {
createdColls []*Collection // collections created in this test
proxyDialer *proxyDialer
dbName, collName string
failPointNames []string
failPoints []failPoint
minServerVersion string
maxServerVersion string
validTopologies []TopologyKind
Expand All @@ -128,15 +133,16 @@ type T struct {
succeeded []*event.CommandSucceededEvent
failed []*event.CommandFailedEvent

Client *mongo.Client
fpClient *mongo.Client
DB *mongo.Database
Coll *mongo.Collection
Client *mongo.Client
fpClients map[*mongo.Client]bool
DB *mongo.Database
Coll *mongo.Collection
}

func newT(wrapped *testing.T, opts ...*Options) *T {
t := &T{
T: wrapped,
T: wrapped,
fpClients: make(map[*mongo.Client]bool),
}
for _, opt := range opts {
for _, optFn := range opt.optFuncs {
Expand Down Expand Up @@ -176,16 +182,10 @@ func New(wrapped *testing.T, opts ...*Options) *T {

t := newT(wrapped, opts...)

fpOpt := t.clientOpts
if fpOpt != nil {
fpOpt.AutoEncryptionOptions = nil
}
t.fpClient = t.createTestClient(fpOpt)

// only create a client if it needs to be shared in sub-tests
// otherwise, a new client will be created for each subtest
if t.shareClient != nil && *t.shareClient {
t.Client = t.createTestClient(t.clientOpts)
t.createTestClient()
}

wrapped.Cleanup(t.cleanup)
Expand All @@ -209,6 +209,12 @@ func (t *T) cleanup() {
// always disconnect the client regardless of clientType because Client.Disconnect will work against
// all deployments
_ = t.Client.Disconnect(context.Background())
for client, v := range t.fpClients {
if v {
client.Disconnect(context.Background())
}
}
t.fpClients = nil
}

// Run creates a new T instance for a sub-test and runs the given callback. It also creates a new collection using the
Expand All @@ -231,21 +237,14 @@ func (t *T) RunOpts(name string, opts *Options, callback func(mt *T)) {
sub.AddMockResponses(sub.mockResponses...)
}

if sub.fpClient == nil {
clientOpts := sub.clientOpts
if clientOpts != nil {
clientOpts.AutoEncryptionOptions = nil
}
sub.fpClient = sub.createTestClient(clientOpts)
}
// for shareClient, inherit the client from the parent
if sub.shareClient != nil && *sub.shareClient && sub.clientType == t.clientType {
sub.Client = t.Client
}
// only create a client if not already set
if sub.Client == nil {
if sub.createClient == nil || *sub.createClient {
sub.Client = sub.createTestClient(sub.clientOpts)
sub.createTestClient()
}
}
// create a collection for this test
Expand All @@ -270,7 +269,9 @@ func (t *T) RunOpts(name string, opts *Options, callback func(mt *T)) {
}
// only disconnect client if it's not being shared
if sub.shareClient == nil || !*sub.shareClient {
_ = sub.Client.Disconnect(context.Background())
if _, ok := sub.fpClients[sub.Client]; !ok {
_ = sub.Client.Disconnect(context.Background())
}
}
assert.Equal(sub, 0, sessions, "%v sessions checked out", sessions)
assert.Equal(sub, 0, conns, "%v connections checked out", conns)
Expand Down Expand Up @@ -419,8 +420,10 @@ func (t *T) ResetClient(opts *options.ClientOptions) {
t.clientOpts = opts
}

_ = t.Client.Disconnect(context.Background())
t.Client = t.createTestClient(t.clientOpts)
if _, ok := t.fpClients[t.Client]; !ok {
_ = t.Client.Disconnect(context.Background())
}
t.createTestClient()
t.DB = t.Client.Database(t.dbName)
t.Coll = t.DB.Collection(t.collName, t.collOpts)

Expand Down Expand Up @@ -576,7 +579,8 @@ func (t *T) SetFailPoint(fp FailPoint) {
if err := SetFailPoint(fp, t.Client); err != nil {
t.Fatal(err)
}
t.failPointNames = append(t.failPointNames, fp.ConfigureFailPoint)
t.fpClients[t.Client] = true
t.failPoints = append(t.failPoints, failPoint{fp.ConfigureFailPoint, t.Client})
}

// SetFailPointFromDocument sets the fail point represented by the given document for the client associated with T. This
Expand All @@ -588,30 +592,35 @@ func (t *T) SetFailPointFromDocument(fp bson.Raw) {
t.Fatal(err)
}

t.fpClients[t.Client] = true
name := fp.Index(0).Value().StringValue()
t.failPointNames = append(t.failPointNames, name)
t.failPoints = append(t.failPoints, failPoint{name, t.Client})
}

// TrackFailPoint adds the given fail point to the list of fail points to be disabled when the current test finishes.
// This function does not create a fail point on the server.
func (t *T) TrackFailPoint(fpName string) {
t.failPointNames = append(t.failPointNames, fpName)
func (t *T) TrackFailPoint(fpName string, client *mongo.Client) {
t.fpClients[client] = true
t.failPoints = append(t.failPoints, failPoint{fpName, client})
}

// ClearFailPoints disables all previously set failpoints for this test.
func (t *T) ClearFailPoints() {
db := t.Client.Database("admin")
for _, fp := range t.failPointNames {
for _, fp := range t.failPoints {
cmd := bson.D{
{"configureFailPoint", fp},
{"configureFailPoint", fp.name},
{"mode", "off"},
}
err := db.RunCommand(context.Background(), cmd).Err()
err := fp.client.Database("admin").RunCommand(context.Background(), cmd).Err()
if err != nil {
t.Fatalf("error clearing fail point %s: %v", fp, err)
t.Fatalf("error clearing fail point %s: %v", fp.name, err)
}
if fp.client != t.Client {
_ = fp.client.Disconnect(context.Background())
t.fpClients[fp.client] = false
}
}
t.failPointNames = t.failPointNames[:0]
t.failPoints = t.failPoints[:0]
}

// CloneDatabase modifies the default database for this test to match the given options.
Expand Down Expand Up @@ -639,7 +648,8 @@ func sanitizeCollectionName(db string, coll string) string {
return coll
}

func (t *T) createTestClient(clientOpts *options.ClientOptions) *mongo.Client {
func (t *T) createTestClient() {
clientOpts := t.clientOpts
if clientOpts == nil {
// default opts
clientOpts = options.Client().SetWriteConcern(MajorityWc).SetReadPreference(PrimaryRp)
Expand Down Expand Up @@ -697,20 +707,17 @@ func (t *T) createTestClient(clientOpts *options.ClientOptions) *mongo.Client {
})
}

var client *mongo.Client
var err error
var uriOpts *options.ClientOptions
switch t.clientType {
case Pinned:
// pin to first mongos
pinnedHostList := []string{testContext.connString.Hosts[0]}
uriOpts := options.Client().ApplyURI(testContext.connString.Original).SetHosts(pinnedHostList)
client, err = mongo.NewClient(uriOpts, clientOpts)
uriOpts = options.Client().ApplyURI(testContext.connString.Original).SetHosts(pinnedHostList)
case Mock:
// clear pool monitor to avoid configuration error
clientOpts.PoolMonitor = nil
t.mockDeployment = newMockDeployment()
clientOpts.Deployment = t.mockDeployment
client, err = mongo.NewClient(clientOpts)
case Proxy:
t.proxyDialer = newProxyDialer()
clientOpts.SetDialer(t.proxyDialer)
Expand All @@ -720,23 +727,23 @@ func (t *T) createTestClient(clientOpts *options.ClientOptions) *mongo.Client {
case Default:
// Use a different set of options to specify the URI because clientOpts may already have a URI or host seedlist
// specified.
var uriOpts *options.ClientOptions
if clientOpts.Deployment == nil {
// Only specify URI if the deployment is not set to avoid setting topology/server options along with the
// deployment.
uriOpts = options.Client().ApplyURI(testContext.connString.Original)
}

// Pass in uriOpts first so clientOpts wins if there are any conflicting settings.
client, err = mongo.NewClient(uriOpts, clientOpts)
}
t.clientOpts = options.MergeClientOptions(uriOpts, clientOpts)

var err error
// Pass in uriOpts first so clientOpts wins if there are any conflicting settings.
t.Client, err = mongo.NewClient(t.clientOpts)
if err != nil {
t.Fatalf("error creating client: %v", err)
}
if err := client.Connect(context.Background()); err != nil {
if err := t.Client.Connect(context.Background()); err != nil {
t.Fatalf("error connecting client: %v", err)
}
return client
}

func (t *T) createTestCollection() {
Expand Down
17 changes: 5 additions & 12 deletions mongo/integration/unified_spec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ const (
gridFSFiles = "fs.files"
gridFSChunks = "fs.chunks"
spec1403SkipReason = "servers less than 4.2 do not have mongocryptd; see SPEC-1403"
godriver2123SkipReason = "failpoints and timeouts together cause failures; see GODRIVER-2123"
godriver2413SkipReason = "encryptedFields argument is not supported on Collection.Drop; see GODRIVER-2413"
)

Expand All @@ -52,15 +51,10 @@ var (
// SPEC-1403: This test checks to see if the correct error is thrown when auto encrypting with a server < 4.2.
// Currently, the test will fail because a server < 4.2 wouldn't have mongocryptd, so Client construction
// would fail with a mongocryptd spawn error.
"operation fails with maxWireVersion < 8": spec1403SkipReason,
// GODRIVER-2123: The two tests below use a failpoint and a socket or server selection timeout.
// The timeout causes the eventual clearing of the failpoint in the test runner to fail with an
// i/o timeout.
"Ignore network timeout error on find": godriver2123SkipReason,
"Network error on minPoolSize background creation": godriver2123SkipReason,
"CreateCollection from encryptedFields.": godriver2413SkipReason,
"DropCollection from encryptedFields": godriver2413SkipReason,
"DropCollection from remote encryptedFields": godriver2413SkipReason,
"operation fails with maxWireVersion < 8": spec1403SkipReason,
"CreateCollection from encryptedFields.": godriver2413SkipReason,
"DropCollection from encryptedFields": godriver2413SkipReason,
"DropCollection from remote encryptedFields": godriver2413SkipReason,
}
)

Expand Down Expand Up @@ -476,12 +470,11 @@ func executeTestRunnerOperation(mt *mtest.T, testCase *testCase, op *operation,
if err != nil {
return fmt.Errorf("Connect error for targeted client: %w", err)
}
defer func() { _ = client.Disconnect(context.Background()) }()

if err = client.Database("admin").RunCommand(context.Background(), fp).Err(); err != nil {
return fmt.Errorf("error setting targeted fail point: %w", err)
}
mt.TrackFailPoint(fp.ConfigureFailPoint)
mt.TrackFailPoint(fp.ConfigureFailPoint, client)
case "configureFailPoint":
fp, err := op.Arguments.LookupErr("failPoint")
if err != nil {
Expand Down

0 comments on commit 80e941f

Please sign in to comment.