From c5d0f889c00c69f2b02874d4d59d73efa73f2ffb Mon Sep 17 00:00:00 2001 From: Niall Fitzpatrick <18366490+Niallfitzy1@users.noreply.github.com> Date: Wed, 6 Sep 2023 20:00:30 +0100 Subject: [PATCH] chore: rely on record id & lookup queries when missing to more reliably update existing records (#3) --- _test/.env_template => .env_template | 0 .gitignore | 3 +- _test/integration_test.go | 134 ---------------- client.go | 147 +++++++++++++++++ go.mod | 2 + go.sum | 2 + integration_test.go | 232 +++++++++++++++++++++++++++ models.go | 68 ++++++++ provider.go | 225 +++++++------------------- types.go | 73 --------- 10 files changed, 514 insertions(+), 372 deletions(-) rename _test/.env_template => .env_template (100%) delete mode 100644 _test/integration_test.go create mode 100644 client.go create mode 100644 integration_test.go create mode 100644 models.go delete mode 100644 types.go diff --git a/_test/.env_template b/.env_template similarity index 100% rename from _test/.env_template rename to .env_template diff --git a/.gitignore b/.gitignore index 793204d..d98b1d4 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ _gitignore/ .env -.vscode/ \ No newline at end of file +.vscode/ +.idea diff --git a/_test/integration_test.go b/_test/integration_test.go deleted file mode 100644 index c3ba935..0000000 --- a/_test/integration_test.go +++ /dev/null @@ -1,134 +0,0 @@ -package main - -import ( - "context" - "fmt" - "log" - "os" - "time" - - "github.com/joho/godotenv" - "github.com/libdns/libdns" - "github.com/libdns/porkbun" -) - -func main() { - envErr := godotenv.Load() - if envErr != nil { - log.Fatal("Error loading .env file", envErr) - } - - apikey := os.Getenv("PORKBUN_API_KEY") - secretapikey := os.Getenv("PORKBUN_SECRET_API_KEY") - zone := os.Getenv("ZONE") - - if apikey == "" || secretapikey == "" || zone == "" { - fmt.Println("All variables must be set in '.env' file") - return - } - - provider := porkbun.Provider{ - APIKey: apikey, - APISecretKey: secretapikey, - } - - //Check Authorization - _, err := provider.CheckCredentials(context.TODO()) - - if err != nil { - log.Fatalf("Credential check failed: %s\n", err.Error()) - } - - //Get records - initialRecords, err := provider.GetRecords(context.TODO(), zone) - if err != nil { - log.Fatalf("Failed to fetch records: %s\n", err.Error()) - } - - log.Println("Records fetched:") - for _, record := range initialRecords { - fmt.Printf("%s (.%s): %s, %s\n", record.Name, zone, record.Value, record.Type) - } - - testValue := "test-value" - updatedTestValue := "updated-test-value" - ttl := time.Duration(600 * time.Second) - recordType := "TXT" - testFullName := "libdns_test_record." + zone - - //Create record - appendedRecords, err := provider.AppendRecords(context.TODO(), zone, []libdns.Record{ - { - Type: recordType, - Name: testFullName, - TTL: ttl, - Value: testValue, - }, - }) - - if err != nil { - log.Fatalf("ERROR: %s\n", err.Error()) - } - - //Get records - postCreatedRecords, err := provider.GetRecords(context.TODO(), zone) - if err != nil { - log.Fatalf("Failed to fetch records: %s\n", err.Error()) - } - - if len(postCreatedRecords) != len(initialRecords)+1 { - log.Fatalln("Additional record not created") - } - - fmt.Printf("Created record: \n%v\n", appendedRecords[0]) - - // Update record - updatedRecords, err := provider.SetRecords(context.TODO(), zone, []libdns.Record{ - { - Type: recordType, - Name: testFullName, - TTL: ttl, - Value: updatedTestValue, - }, - }) - - if err != nil { - log.Fatalf("ERROR: %s\n", err.Error()) - } - fmt.Printf("Updated record: \n%v\n", updatedRecords[0]) - - //Get records - updatedRecords, err = provider.GetRecords(context.TODO(), zone) - if err != nil { - log.Fatalf("Failed to fetch records: %s\n", err.Error()) - } - - if len(updatedRecords) != len(initialRecords)+1 { - log.Fatalln("Additional record created instead of updating existing") - } - - // Delete record - deleteRecords, err := provider.DeleteRecords(context.TODO(), zone, []libdns.Record{ - { - Type: recordType, - Name: testFullName, - }, - }) - - if err != nil { - log.Fatalln("ERROR: %s\n", err.Error()) - } - - //Get records - updatedRecords, err = provider.GetRecords(context.TODO(), zone) - if err != nil { - log.Fatalf("Failed to fetch records: %s\n", err.Error()) - } - - if len(updatedRecords) != len(initialRecords) { - log.Fatalln("Additional record not cleaned up") - } - - fmt.Printf("Deleted record: \n%v\n", deleteRecords[0]) - -} diff --git a/client.go b/client.go new file mode 100644 index 0000000..939dd05 --- /dev/null +++ b/client.go @@ -0,0 +1,147 @@ +package porkbun + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "github.com/libdns/libdns" + "io" + "log" + "net/http" + "net/url" + "strconv" + "strings" + "time" +) + +const ApiBase = "https://porkbun.com/api/json/v3" + +// LibdnsZoneToPorkbunDomain Strips the trailing dot from a Zone +func LibdnsZoneToPorkbunDomain(zone string) string { + return strings.TrimSuffix(zone, ".") +} + +// CheckCredentials allows verifying credentials work in test scripts +func (p *Provider) CheckCredentials(_ context.Context) (string, error) { + credentialJson, err := json.Marshal(p.getCredentials()) + if err != nil { + return "", err + } + + response, err := MakeApiRequest("/ping", bytes.NewReader(credentialJson), pkbnPingResponse{}) + + if err != nil { + return "", err + } + + if response.Status != "SUCCESS" { + return "", err + } + + return response.YourIP, nil +} + +func (p *Provider) getCredentials() ApiCredentials { + return ApiCredentials{p.APIKey, p.APISecretKey} +} + +func (p *Provider) getMatchingRecord(r libdns.Record, zone string) ([]libdns.Record, error) { + var recs []libdns.Record + trimmedZone := LibdnsZoneToPorkbunDomain(zone) + + credentialJson, err := json.Marshal(p.getCredentials()) + if err != nil { + return recs, err + } + endpoint := fmt.Sprintf("/dns/retrieveByNameType/%s/%s/%s", trimmedZone, r.Type, r.Name) + response, err := MakeApiRequest(endpoint, bytes.NewReader(credentialJson), pkbnRecordsResponse{}) + + if err != nil { + return recs, err + } + + recs = make([]libdns.Record, 0, len(response.Records)) + for _, rec := range response.Records { + recs = append(recs, rec.toLibdnsRecord(zone)) + } + return recs, nil +} + +// UpdateRecords adds records to the zone. It returns the records that were added. +func (p *Provider) updateRecords(_ context.Context, zone string, records []libdns.Record) ([]libdns.Record, error) { + credentials := p.getCredentials() + trimmedZone := LibdnsZoneToPorkbunDomain(zone) + + var createdRecords []libdns.Record + + for _, record := range records { + if record.TTL/time.Second < 600 { + record.TTL = 600 * time.Second + } + ttlInSeconds := int(record.TTL / time.Second) + trimmedName := libdns.RelativeName(record.Name, zone) + + reqBody := pkbnRecordPayload{&credentials, record.Value, trimmedName, strconv.Itoa(ttlInSeconds), record.Type} + reqJson, err := json.Marshal(reqBody) + if err != nil { + return nil, err + } + response, err := MakeApiRequest(fmt.Sprintf("/dns/edit/%s/%s", trimmedZone, record.ID), bytes.NewReader(reqJson), pkbnResponseStatus{}) + if err != nil { + return nil, err + } + + if response.Status != "SUCCESS" { + return nil, err + } + createdRecords = append(createdRecords, record) + } + + return createdRecords, nil +} + +func MakeApiRequest[T any](endpoint string, body io.Reader, responseType T) (T, error) { + client := http.Client{} + + fullUrl := ApiBase + endpoint + u, err := url.Parse(fullUrl) + if err != nil { + return responseType, err + } + + req, err := http.NewRequest("POST", u.String(), body) + if err != nil { + return responseType, err + } + resp, err := client.Do(req) + if err != nil { + return responseType, err + } + defer func(Body io.ReadCloser) { + err := Body.Close() + if err != nil { + log.Fatal("Couldn't close body") + } + }(resp.Body) + + if resp.StatusCode != http.StatusOK { + bodyBytes, _ := io.ReadAll(resp.Body) + err = errors.New("Invalid http response status, " + string(bodyBytes)) + return responseType, err + } + + result, err := io.ReadAll(resp.Body) + if err != nil { + return responseType, err + } + + err = json.Unmarshal(result, &responseType) + + if err != nil { + return responseType, err + } + + return responseType, nil +} diff --git a/go.mod b/go.mod index e04ba12..69b3038 100644 --- a/go.mod +++ b/go.mod @@ -3,3 +3,5 @@ module github.com/libdns/porkbun go 1.20 require github.com/libdns/libdns v0.2.1 + +require github.com/joho/godotenv v1.5.1 diff --git a/go.sum b/go.sum index ba9d0cf..7f1e021 100644 --- a/go.sum +++ b/go.sum @@ -1,2 +1,4 @@ +github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= +github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= github.com/libdns/libdns v0.2.1 h1:Wu59T7wSHRgtA0cfxC+n1c/e+O3upJGWytknkmFEDis= github.com/libdns/libdns v0.2.1/go.mod h1:yQCXzk1lEZmmCPa857bnk4TsOiqYasqpyOEeSObbb40= diff --git a/integration_test.go b/integration_test.go new file mode 100644 index 0000000..db53093 --- /dev/null +++ b/integration_test.go @@ -0,0 +1,232 @@ +package porkbun + +import ( + "context" + "github.com/joho/godotenv" + "github.com/libdns/libdns" + "log" + "os" + "testing" + "time" +) + +var records []libdns.Record +var testRecord libdns.Record + +func getInitialRecords(t *testing.T, provider Provider, zone string) []libdns.Record { + if len(records) == 0 { + fetchedRecords, err := provider.GetRecords(context.TODO(), zone) + if err != nil { + t.Error(err) + } + records = fetchedRecords + } + + return records +} + +func createOrGetTestRecord(t *testing.T, provider Provider, zone string) libdns.Record { + if testRecord.ID == "" { + testValue := "test-value" + ttl := time.Duration(600 * time.Second) + recordType := "TXT" + testFullName := "libdns_test_record" + + //Create record + appendedRecords, err := provider.AppendRecords(context.TODO(), zone, []libdns.Record{ + { + Type: recordType, + Name: testFullName, + TTL: ttl, + Value: testValue, + }, + }) + + if err != nil { + t.Error(err) + } + + if len(appendedRecords) != 1 { + t.Errorf("Incorrect amount of records %d created", len(appendedRecords)) + } + + testRecord = appendedRecords[0] + } + + return testRecord +} + +func getProvider(t *testing.T) (Provider, string) { + envErr := godotenv.Load() + if envErr != nil { + t.Error(envErr) + } + + apikey := os.Getenv("PORKBUN_API_KEY") + secretapikey := os.Getenv("PORKBUN_SECRET_API_KEY") + zone := os.Getenv("ZONE") + + if apikey == "" || secretapikey == "" || zone == "" { + t.Errorf("All variables must be set in '.env' file") + } + + provider := Provider{ + APIKey: apikey, + APISecretKey: secretapikey, + } + return provider, zone +} + +func TestProvider_CheckCredentials(t *testing.T) { + provider, _ := getProvider(t) + + //Check Authorization + _, err := provider.CheckCredentials(context.TODO()) + + if err != nil { + t.Error(err) + } +} + +func TestProvider_GetRecords(t *testing.T) { + provider, zone := getProvider(t) + + //Get records + initialRecords := getInitialRecords(t, provider, zone) + + log.Println("Records fetched:") + for _, record := range initialRecords { + t.Logf("%s %s (.%s): %s, %s\n", record.ID, record.Name, zone, record.Value, record.Type) + } +} + +func TestProvider_AppendRecords(t *testing.T) { + provider, zone := getProvider(t) + + //Get records + initialRecords := getInitialRecords(t, provider, zone) + + createdRecord := createOrGetTestRecord(t, provider, zone) + //Get records + postCreatedRecords, err := provider.GetRecords(context.TODO(), zone) + if err != nil { + t.Error(err) + } + + if len(postCreatedRecords) != len(initialRecords)+1 { + t.Errorf("Additional record not created") + } + + t.Logf("Created record: \n%v\n", createdRecord.ID) +} + +func TestProvider_UpdateRecordsById(t *testing.T) { + provider, zone := getProvider(t) + + //Get records + initialRecords := getInitialRecords(t, provider, zone) + + ttl := time.Duration(600 * time.Second) + recordType := "TXT" + testFullName := "libdns_test_record" + + //Create record + createdRecord := createOrGetTestRecord(t, provider, zone) + + updatedTestValue := "updated-test-value" + // Update record + updatedRecords, err := provider.SetRecords(context.TODO(), zone, []libdns.Record{ + { + ID: createdRecord.ID, + Type: recordType, + Name: testFullName, + TTL: ttl, + Value: updatedTestValue, + }, + }) + + if err != nil { + t.Error(err) + } + + if len(updatedRecords) != 1 { + t.Logf("Incorrect amount of records changed") + } + + t.Logf("Updated record: \n%v\n", updatedRecords[0]) + + //Get records + postUpdatedRecords, err := provider.GetRecords(context.TODO(), zone) + if err != nil { + t.Error(err) + } + + if len(postUpdatedRecords) != len(initialRecords)+1 { + t.Errorf("Additional record created instead of updating existing. Started with: %d, now has: %d", len(initialRecords), len(postUpdatedRecords)) + } +} + +func TestProvider_UpdateRecordsByLookup(t *testing.T) { + provider, zone := getProvider(t) + + //Get records + initialRecords := getInitialRecords(t, provider, zone) + + ttl := time.Duration(600 * time.Second) + recordType := "TXT" + testFullName := "libdns_test_record" + + //Create record + _ = createOrGetTestRecord(t, provider, zone) + + updatedTestValue := "updated-test-value-by-lookup" + // Update record + updatedRecords, err := provider.SetRecords(context.TODO(), zone, []libdns.Record{ + { + Type: recordType, + Name: testFullName, + TTL: ttl, + Value: updatedTestValue, + }, + }) + + if err != nil { + t.Error(err) + } + + if len(updatedRecords) != 1 { + t.Logf("Incorrect amount of records changed") + } + + t.Logf("Updated record: \n%v\n", updatedRecords[0]) + + //Get records + postUpdatedRecords, err := provider.GetRecords(context.TODO(), zone) + if err != nil { + t.Error(err) + } + + if len(postUpdatedRecords) != len(initialRecords)+1 { + t.Errorf("Additional record created instead of updating existing. Started with: %d, now has: %d", len(initialRecords), len(postUpdatedRecords)) + } +} + +func TestProvider_DeleteRecords(t *testing.T) { + provider, zone := getProvider(t) + + //Create record + createdRecord := createOrGetTestRecord(t, provider, zone) + + // Delete record + deleteRecords, err := provider.DeleteRecords(context.TODO(), zone, []libdns.Record{createdRecord}) + + if err != nil { + t.Error(err) + } + + if len(deleteRecords) != 1 { + t.Errorf("Deleted incorrect amount of records %d", len(deleteRecords)) + } + + t.Logf("Deleted record: \n%v\n", deleteRecords[0]) +} diff --git a/models.go b/models.go new file mode 100644 index 0000000..4e400bc --- /dev/null +++ b/models.go @@ -0,0 +1,68 @@ +package porkbun + +import ( + "fmt" + "github.com/libdns/libdns" + "strconv" + "time" +) + +type pkbnRecord struct { + Content string `json:"content"` + ID string `json:"id"` + Name string `json:"name"` + Notes string `json:"notes"` + Prio string `json:"prio"` + TTL string `json:"ttl"` + Type string `json:"type"` +} + +type pkbnRecordsResponse struct { + Records []pkbnRecord `json:"records"` + Status string `json:"status"` +} + +type ApiCredentials struct { + Apikey string `json:"apikey"` + Secretapikey string `json:"secretapikey"` +} + +type pkbnResponseStatus struct { + Status string `json:"status"` + Message string `json:"message,omitempty"` +} +type pkbnPingResponse struct { + pkbnResponseStatus + YourIP string `json:"yourIp"` +} + +type pkbnCreateResponse struct { + pkbnResponseStatus + // TODO contact support endpoint isn't returning the ID despite it being in their docs. + // ID string `json:"id"` +} + +func (record pkbnRecord) toLibdnsRecord(zone string) libdns.Record { + ttl, _ := time.ParseDuration(record.TTL + "s") + priority, _ := strconv.Atoi(record.Prio) + return libdns.Record{ + ID: record.ID, + Name: libdns.RelativeName(record.Name, LibdnsZoneToPorkbunDomain(zone)), + Priority: priority, + TTL: ttl, + Type: record.Type, + Value: record.Content, + } +} + +func (a pkbnResponseStatus) Error() string { + return fmt.Sprintf("%s: %s", a.Status, a.Message) +} + +type pkbnRecordPayload struct { + *ApiCredentials + Content string `json:"content"` + Name string `json:"name"` + TTL string `json:"ttl"` + Type string `json:"type"` +} diff --git a/provider.go b/provider.go index 8098e9d..65feb15 100644 --- a/provider.go +++ b/provider.go @@ -8,12 +8,7 @@ import ( "encoding/json" "errors" "fmt" - "io" - "log" - "net/http" - "net/url" "strconv" - "strings" "time" "github.com/libdns/libdns" @@ -25,51 +20,15 @@ type Provider struct { APISecretKey string `json:"api_secret_key,omitempty"` } -func (p *Provider) getApiHost() string { - return "https://porkbun.com/api/json/v3/" -} - -func (p *Provider) getRecordCoordinates(record libdns.Record) string { - return fmt.Sprintf("%s-%s", record.Name, record.Type) -} - -func (p *Provider) getCredentials() ApiCredentials { - return ApiCredentials{p.APIKey, p.APISecretKey} -} - -// Strips the trailing dot from a Zone -func trimZone(zone string) string { - return strings.TrimSuffix(zone, ".") -} - -func (p *Provider) CheckCredentials(_ context.Context) (string, error) { - credentialJson, err := json.Marshal(p.getCredentials()) - if err != nil { - return "", err - } - - response, err := makeHttpRequest[PingResponse](p, "ping", bytes.NewReader(credentialJson), PingResponse{}) - - if err != nil { - return "", err - } - - if response.Status != "SUCCESS" { - return "", err - } - - return response.YourIP, nil -} - // GetRecords lists all the records in the zone. func (p *Provider) GetRecords(_ context.Context, zone string) ([]libdns.Record, error) { - trimmedZone := trimZone(zone) + trimmedZone := LibdnsZoneToPorkbunDomain(zone) credentialJson, err := json.Marshal(p.getCredentials()) if err != nil { return nil, err } - response, err := makeHttpRequest[ApiRecordsResponse](p, "dns/retrieve/"+trimmedZone, bytes.NewReader(credentialJson), ApiRecordsResponse{}) + response, err := MakeApiRequest("/dns/retrieve/"+trimmedZone, bytes.NewReader(credentialJson), pkbnRecordsResponse{}) if err != nil { return nil, err @@ -79,31 +38,17 @@ func (p *Provider) GetRecords(_ context.Context, zone string) ([]libdns.Record, return nil, errors.New(fmt.Sprintf("Invalid response status %s", response.Status)) } - var records []libdns.Record - for _, record := range response.Records { - ttl, err := time.ParseDuration(record.TTL + "s") - if err != nil { - return nil, err - } - priority, _ := strconv.Atoi(record.Prio) - formatted := libdns.Record{ - ID: record.ID, - Name: record.Name + ".", - Priority: priority, - TTL: ttl, - Type: record.Type, - Value: record.Content, - } - records = append(records, formatted) + recs := make([]libdns.Record, 0, len(response.Records)) + for _, rec := range response.Records { + recs = append(recs, rec.toLibdnsRecord(zone)) } - - return records, nil + return recs, nil } // AppendRecords adds records to the zone. It returns the records that were added. func (p *Provider) AppendRecords(_ context.Context, zone string, records []libdns.Record) ([]libdns.Record, error) { credentials := p.getCredentials() - trimmedZone := trimZone(zone) + trimmedZone := LibdnsZoneToPorkbunDomain(zone) var createdRecords []libdns.Record @@ -114,53 +59,26 @@ func (p *Provider) AppendRecords(_ context.Context, zone string, records []libdn ttlInSeconds := int(record.TTL / time.Second) trimmedName := libdns.RelativeName(record.Name, zone) - reqBody := RecordCreateRequest{&credentials, record.Value, trimmedName, strconv.Itoa(ttlInSeconds), record.Type} + reqBody := pkbnRecordPayload{&credentials, record.Value, trimmedName, strconv.Itoa(ttlInSeconds), record.Type} reqJson, err := json.Marshal(reqBody) if err != nil { - return nil, err + return createdRecords, err } - response, err := makeHttpRequest(p, fmt.Sprintf("dns/create/%s", trimmedZone), bytes.NewReader(reqJson), ResponseStatus{}) + response, err := MakeApiRequest(fmt.Sprintf("/dns/create/%s", trimmedZone), bytes.NewReader(reqJson), pkbnCreateResponse{}) if err != nil { - print(err) - return nil, err + return createdRecords, err } if response.Status != "SUCCESS" { - return nil, errors.New(fmt.Sprintf("Invalid response status %s", response.Status)) - } - createdRecords = append(createdRecords, record) - } - - return createdRecords, nil -} - -// UpdateRecords adds records to the zone. It returns the records that were added. -func (p *Provider) UpdateRecords(_ context.Context, zone string, records []libdns.Record) ([]libdns.Record, error) { - credentials := p.getCredentials() - trimmedZone := trimZone(zone) - - var createdRecords []libdns.Record - - for _, record := range records { - if record.TTL/time.Second < 600 { - record.TTL = 600 * time.Second - } - ttlInSeconds := int(record.TTL / time.Second) - trimmedName := libdns.RelativeName(record.Name, zone) - reqBody := RecordUpdateRequest{&credentials, record.Value, strconv.Itoa(ttlInSeconds)} - reqJson, err := json.Marshal(reqBody) - if err != nil { - return nil, err - } - response, err := makeHttpRequest(p, fmt.Sprintf("dns/editByNameType/%s/%s/%s", trimmedZone, record.Type, trimmedName), bytes.NewReader(reqJson), ResponseStatus{}) - if err != nil { - return nil, err + return createdRecords, errors.New(fmt.Sprintf("Invalid response status %s", response.Status)) } - if response.Status != "SUCCESS" { - return nil, err + // TODO contact support endpoint isn't returning the ID despite it being in their docs. Fetch as a workaround + created, err := p.getMatchingRecord(record, zone) + if err == nil && len(created) == 1 { + record.ID = created[0].ID } createdRecords = append(createdRecords, record) } @@ -171,105 +89,84 @@ func (p *Provider) UpdateRecords(_ context.Context, zone string, records []libdn // SetRecords sets the records in the zone, either by updating existing records or creating new ones. // It returns the updated records. func (p *Provider) SetRecords(ctx context.Context, zone string, records []libdns.Record) ([]libdns.Record, error) { - existingRecords, err := p.GetRecords(ctx, zone) - if err != nil { - return nil, err - } - - existingCoordinates := NewSet() - for _, r := range existingRecords { - existingCoordinates.Add(p.getRecordCoordinates(r)) - } - var updates []libdns.Record var creates []libdns.Record + var results []libdns.Record for _, r := range records { - if existingCoordinates.Contains(p.getRecordCoordinates(r)) { + if r.ID == "" { + // Try fetch record in case we are just missing the ID + matches, err := p.getMatchingRecord(r, zone) + if err != nil { + return nil, err + } + + if len(matches) == 0 { + creates = append(creates, r) + continue + } + + if len(matches) > 1 { + return nil, fmt.Errorf("unexpectedly found more than 1 record for %v", r) + } + + r.ID = matches[0].ID updates = append(updates, r) } else { - creates = append(creates, r) + updates = append(updates, r) } } - _, err = p.AppendRecords(ctx, zone, creates) + created, err := p.AppendRecords(ctx, zone, creates) if err != nil { return nil, err } - _, err = p.UpdateRecords(ctx, zone, updates) + updated, err := p.updateRecords(ctx, zone, updates) if err != nil { return nil, err } - return records, nil + results = append(results, created...) + results = append(results, updated...) + return results, nil } // DeleteRecords deletes the records from the zone. It returns the records that were deleted. func (p *Provider) DeleteRecords(_ context.Context, zone string, records []libdns.Record) ([]libdns.Record, error) { credentials := p.getCredentials() - trimmedZone := trimZone(zone) + trimmedZone := LibdnsZoneToPorkbunDomain(zone) var deletedRecords []libdns.Record for _, record := range records { - reqJson, err := json.Marshal(credentials) - if err != nil { - return nil, err + var queuedDeletes []libdns.Record + if record.ID == "" { + // Try fetch record in case we are just missing the ID + matches, err := p.getMatchingRecord(record, zone) + if err != nil { + return deletedRecords, err + } + for _, rec := range matches { + queuedDeletes = append(queuedDeletes, rec) + } + } else { + queuedDeletes = append(queuedDeletes, record) } - trimmedName := libdns.RelativeName(record.Name, zone) - _, err = makeHttpRequest(p, fmt.Sprintf("dns/deleteByNameType/%s/%s/%s", trimmedZone, record.Type, trimmedName), bytes.NewReader(reqJson), ResponseStatus{}) + reqJson, err := json.Marshal(credentials) if err != nil { return nil, err } - deletedRecords = append(deletedRecords, record) - } - - return deletedRecords, nil -} - -func makeHttpRequest[T any](p *Provider, endpoint string, body io.Reader, responseType T) (T, error) { - client := http.Client{} - fullUrl := p.getApiHost() + endpoint - u, err := url.Parse(fullUrl) - if err != nil { - return responseType, err - } - println(u.String()) - - req, err := http.NewRequest("POST", u.String(), body) - if err != nil { - return responseType, err - } - resp, err := client.Do(req) - if err != nil { - return responseType, err - } - defer func(Body io.ReadCloser) { - err := Body.Close() - if err != nil { - log.Fatal("Couldn't close body") + for _, recordToDelete := range queuedDeletes { + _, err = MakeApiRequest(fmt.Sprintf("/dns/delete/%s/%s", trimmedZone, recordToDelete.ID), bytes.NewReader(reqJson), pkbnResponseStatus{}) + if err != nil { + return deletedRecords, err + } + deletedRecords = append(deletedRecords, recordToDelete) } - }(resp.Body) - - if resp.StatusCode != http.StatusOK { - bodyBytes, _ := io.ReadAll(resp.Body) - err = errors.New("Invalid http response status, " + string(bodyBytes)) - return responseType, err - } - - result, err := io.ReadAll(resp.Body) - if err != nil { - return responseType, err } - err = json.Unmarshal(result, &responseType) - - if err != nil { - return responseType, err - } - - return responseType, nil + return deletedRecords, nil } // Interface guards diff --git a/types.go b/types.go deleted file mode 100644 index 8d5ce59..0000000 --- a/types.go +++ /dev/null @@ -1,73 +0,0 @@ -package porkbun - -import "fmt" - -type PorkbunRecord struct { - Content string `json:"content"` - ID string `json:"id"` - Name string `json:"name"` - Notes string `json:"notes"` - Prio string `json:"prio"` - TTL string `json:"ttl"` - Type string `json:"type"` -} - -type ApiRecordsResponse struct { - Records []PorkbunRecord `json:"records"` - Status string `json:"status"` -} - -type ApiCredentials struct { - Apikey string `json:"apikey"` - Secretapikey string `json:"secretapikey"` -} - -type ResponseStatus struct { - Status string `json:"status"` - Message string `json:"message,omitempty"` -} -type PingResponse struct { - ResponseStatus - YourIP string `json:"yourIp"` -} - -func (a ResponseStatus) Error() string { - return fmt.Sprintf("%s: %s", a.Status, a.Message) -} - -type RecordCreateRequest struct { - *ApiCredentials - Content string `json:"content"` - Name string `json:"name"` - TTL string `json:"ttl"` - Type string `json:"type"` -} - -type RecordUpdateRequest struct { - *ApiCredentials - Content string `json:"content"` - TTL string `json:"ttl"` -} - -type Set struct { - m map[string]bool -} - -func NewSet() *Set { - s := &Set{} - s.m = make(map[string]bool) - return s -} - -func (s *Set) Add(value string) { - s.m[value] = true -} - -func (s *Set) Remove(value string) { - delete(s.m, value) -} - -func (s *Set) Contains(value string) bool { - _, c := s.m[value] - return c -}