diff --git a/client/client.go b/client/client.go index 2bf582d81..9118cc021 100644 --- a/client/client.go +++ b/client/client.go @@ -7,6 +7,7 @@ import ( "io" "net" "net/http" + "path/filepath" "regexp" "strconv" "strings" @@ -735,6 +736,7 @@ func (client *Client) initDialers(proxies map[string]*commonconfig.ProxyConfig) ) } }, + BanditDir: filepath.Join(configDir, "bandit"), }) return dialers, dialer, nil } diff --git a/dialer/bandit.go b/dialer/bandit.go index 5f826c963..2381bbec8 100644 --- a/dialer/bandit.go +++ b/dialer/bandit.go @@ -2,8 +2,15 @@ package dialer import ( "context" + "encoding/csv" + "fmt" + "io" "math/rand" "net" + "os" + "path/filepath" + "strconv" + "sync" "sync/atomic" "time" @@ -12,9 +19,16 @@ import ( // BanditDialer is responsible for continually choosing the optimized dialer. type BanditDialer struct { - dialers []ProxyDialer - bandit *bandit.EpsilonGreedy - opts *Options + dialers []ProxyDialer + bandit bandit.Bandit + opts *Options + banditRewardsMutex *sync.Mutex +} + +type banditMetrics struct { + Reward float64 + Count int + UpdatedAt int64 } // NewBandit creates a new bandit given the available dialers and options with @@ -29,21 +43,48 @@ func NewBandit(opts *Options) (Dialer, error) { dialers := opts.Dialers log.Debugf("Creating bandit with %d dialers", len(dialers)) - b, err := bandit.NewEpsilonGreedy(0.1, nil, nil) + + var b bandit.Bandit + var err error + dialer := &BanditDialer{ + dialers: dialers, + opts: opts, + banditRewardsMutex: &sync.Mutex{}, + } + + dialerWeights, err := dialer.loadLastBanditRewards() + if err != nil { + log.Errorf("unable to load bandit weights: %v", err) + } + if dialerWeights != nil { + log.Debugf("Loading bandit weights from %q", opts.BanditDir) + counts := make([]int, len(dialers)) + rewards := make([]float64, len(dialers)) + for arm, dialer := range dialers { + if metrics, ok := dialerWeights[dialer.Name()]; ok { + rewards[arm] = metrics.Reward + counts[arm] = metrics.Count + } + } + b, err = bandit.NewEpsilonGreedy(0.1, counts, rewards) + if err != nil { + log.Errorf("unable to create weighted bandit: %w", err) + return nil, err + } + dialer.bandit = b + return dialer, nil + } + + b, err = bandit.NewEpsilonGreedy(0.1, nil, nil) if err != nil { log.Errorf("unable to create bandit: %v", err) return nil, err } - if err := b.Init(len(dialers)); err != nil { log.Errorf("unable to initialize bandit: %v", err) return nil, err } - dialer := &BanditDialer{ - dialers: dialers, - bandit: b, - opts: opts, - } + dialer.bandit = b return dialer, nil } @@ -70,8 +111,8 @@ func (bd *BanditDialer) DialContext(ctx context.Context, network, addr string) ( if !failedUpstream { log.Errorf("Dialer %v failed in %v seconds: %v", d.Name(), time.Since(start).Seconds(), err) - if err := bd.bandit.Update(chosenArm, 0); err != nil { - log.Errorf("unable to update bandit: %v", err) + if errUpdatingBanditReward := bd.bandit.Update(chosenArm, 0); errUpdatingBanditReward != nil { + log.Errorf("unable to update bandit: %v", errUpdatingBanditReward) } } else { log.Debugf("Dialer %v failed upstream...", d.Name()) @@ -79,8 +120,8 @@ func (bd *BanditDialer) DialContext(ctx context.Context, network, addr string) ( // if the DNS resolves to localhost, for example. It is also possible // that the proxy is blacklisted by upstream sites for some reason, // so we have to choose some reasonable value. - if err := bd.bandit.Update(chosenArm, 0.00005); err != nil { - log.Errorf("unable to update bandit: %v", err) + if errUpdatingBanditReward := bd.bandit.Update(chosenArm, 0.00005); errUpdatingBanditReward != nil { + log.Errorf("unable to update bandit: %v", errUpdatingBanditReward) } } return nil, err @@ -97,15 +138,152 @@ func (bd *BanditDialer) DialContext(ctx context.Context, network, addr string) ( time.AfterFunc(secondsForSample*time.Second, func() { speed := normalizeReceiveSpeed(dataRecv.Load()) //log.Debugf("Dialer %v received %v bytes in %v seconds, normalized speed: %v", d.Name(), dt.dataRecv, secondsForSample, speed) - if err := bd.bandit.Update(chosenArm, speed); err != nil { + if err = bd.bandit.Update(chosenArm, speed); err != nil { log.Errorf("unable to update bandit: %v", err) } }) + time.AfterFunc(30*time.Second, func() { + log.Debugf("saving bandit rewards") + metrics := make(map[string]banditMetrics) + rewards := bd.bandit.GetRewards() + counts := bd.bandit.GetCounts() + for i, d := range bd.dialers { + metrics[d.Name()] = banditMetrics{ + Reward: rewards[i], + Count: counts[i], + UpdatedAt: time.Now().UTC().Unix(), + } + } + + err = bd.updateBanditRewards(metrics) + if err != nil { + log.Errorf("unable to save bandit weights: %v", err) + } + }) + bd.opts.OnSuccess(d) return dt, err } +const ( + dialerNameCSVHeader = iota + rewardCSVHeader + countCSVHeader + updatedAtCSVHeader + + unusedBanditDialerIgnoredAfter = 7 * 24 * time.Hour +) + +// loadLastBanditRewards is a function that returns the last bandit rewards +// for each dialer. If this is set, the bandit will be initialized with the +// last metrics. +func (o *BanditDialer) loadLastBanditRewards() (map[string]banditMetrics, error) { + o.banditRewardsMutex.Lock() + defer o.banditRewardsMutex.Unlock() + if o.opts.BanditDir == "" { + return nil, log.Error("bandit directory is not set") + } + + file := filepath.Join(o.opts.BanditDir, "rewards.csv") + data, err := os.Open(file) + if err != nil { + return nil, err + } + + reader := csv.NewReader(data) + // Skip the header, but read it so the csv reader know the expected number of columns + _, err = reader.Read() + if err != nil { + return nil, log.Errorf("unable to skip headers from bandit rewards csv: %w", err) + } + metrics := make(map[string]banditMetrics) + for { + line, err := reader.Read() + if err == io.EOF { + break + } + if err != nil { + return nil, log.Errorf("unable to read line from bandit rewards csv: %w", err) + } + + // load updatedAt unix time and check if it's older than 7 days + updatedAt, err := strconv.ParseInt(line[updatedAtCSVHeader], 10, 64) + if err != nil { + return nil, log.Errorf("unable to parse updated at from %s: %w", line[0], err) + } + if time.Since(time.Unix(updatedAt, 0)) > unusedBanditDialerIgnoredAfter { + log.Debugf("Ignoring bandit dialer %s as it's older than 7 days", line[0]) + continue + } + reward, err := strconv.ParseFloat(line[rewardCSVHeader], 64) + if err != nil { + return nil, log.Errorf("unable to parse reward from %s: %w", line[0], err) + } + count, err := strconv.Atoi(line[countCSVHeader]) + if err != nil { + return nil, log.Errorf("unable to parse count from %s: %w", line[0], err) + } + + metrics[line[dialerNameCSVHeader]] = banditMetrics{ + Reward: reward, + Count: count, + UpdatedAt: updatedAt, + } + } + return metrics, nil +} + +func (o *BanditDialer) updateBanditRewards(newRewards map[string]banditMetrics) error { + if err := os.MkdirAll(o.opts.BanditDir, 0755); err != nil { + return log.Errorf("unable to create bandit directory: %v", err) + } + + previousRewards, err := o.loadLastBanditRewards() + if err != nil && !os.IsNotExist(err) { + return log.Errorf("couldn't load previous bandit rewards: %w", err) + } + o.banditRewardsMutex.Lock() + defer o.banditRewardsMutex.Unlock() + + // if there's previous rewards, we must overwrite current values + if previousRewards != nil { + for dialer, metrics := range newRewards { + previousRewards[dialer] = metrics + } + } else { + previousRewards = newRewards + } + + if o.opts.BanditDir == "" { + return log.Error("bandit directory is not set") + } + + file := filepath.Join(o.opts.BanditDir, "rewards.csv") + + headers := []string{"dialer", "reward", "count", "updated at"} + f, err := os.OpenFile(file, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644) + if err != nil { + return log.Errorf("unable to open bandit rewards file: %v", err) + } + defer f.Close() + + w := csv.NewWriter(f) + defer w.Flush() + + if err = w.Write(headers); err != nil { + return log.Errorf("unable to write headers to bandit rewards file: %v", err) + } + + for dialerName, metric := range previousRewards { + if err = w.Write([]string{dialerName, fmt.Sprintf("%f", metric.Reward), fmt.Sprintf("%d", metric.Count), fmt.Sprintf("%d", metric.UpdatedAt)}); err != nil { + return log.Errorf("unable to write bandit rewards to file: %v", err) + } + } + + return nil +} + func (o *BanditDialer) chooseDialerForDomain(network, addr string) (ProxyDialer, int) { // Loop through the number of dialers we have and select the one that is best // for the given domain. diff --git a/dialer/bandit_test.go b/dialer/bandit_test.go index 867fb19ce..35c3d7265 100644 --- a/dialer/bandit_test.go +++ b/dialer/bandit_test.go @@ -2,13 +2,19 @@ package dialer import ( "context" + "fmt" "io" "math/rand" "net" + "os" + "path/filepath" "reflect" + "strings" + "sync" "testing" "time" + "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -79,39 +85,78 @@ func TestBanditDialer_chooseDialerForDomain(t *testing.T) { } func TestNewBandit(t *testing.T) { + oldDialer := newTcpConnDialer() + oldDialerMetric := banditMetrics{ + Reward: 0.7, + Count: 10, + UpdatedAt: time.Now().UTC().Unix(), + } tests := []struct { - name string - opts *Options - want *BanditDialer - wantErr bool + name string + opts *Options + assert func(t *testing.T, got Dialer, err error, dir string) + setup func() string }{ { name: "should fail if there are no dialers", opts: &Options{ Dialers: nil, }, - want: nil, - wantErr: true, + assert: func(t *testing.T, got Dialer, err error, _ string) { + assert.Nil(t, got) + assert.Error(t, err) + }, }, { name: "should return a BanditDialer if there's only one dialer", opts: &Options{ Dialers: []ProxyDialer{newTcpConnDialer()}, }, - want: &BanditDialer{}, - wantErr: false, + assert: func(t *testing.T, got Dialer, err error, _ string) { + assert.NotNil(t, got) + assert.NoError(t, err) + assert.IsType(t, &BanditDialer{}, got) + }, + }, + { + name: "should load the last bandit rewards if they exist", + opts: &Options{ + Dialers: []ProxyDialer{oldDialer, newTcpConnDialer()}, + }, + setup: func() string { + tempDir, err := os.MkdirTemp("", "client_test") + require.NoError(t, err) + + // create rewards.csv + err = os.WriteFile(filepath.Join(tempDir, "rewards.csv"), []byte(fmt.Sprintf("dialer,reward,count,updated at\n%s,%f,%d,%d\n", oldDialer.Name(), oldDialerMetric.Reward, oldDialerMetric.Count, oldDialerMetric.UpdatedAt)), 0644) + require.NoError(t, err) + return tempDir + }, + assert: func(t *testing.T, got Dialer, err error, dir string) { + assert.NotNil(t, got) + assert.NoError(t, err) + assert.IsType(t, &BanditDialer{}, got) + rewards := got.(*BanditDialer).bandit.GetRewards() + counts := got.(*BanditDialer).bandit.GetCounts() + // checking if the rewards are loaded correctly + assert.Equal(t, oldDialerMetric.Reward, rewards[0]) + assert.Equal(t, oldDialerMetric.Count, counts[0]) + // since there's no data for the second dialer, it should be 0 + assert.Equal(t, float64(0), rewards[1]) + assert.Equal(t, 0, counts[1]) + }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := NewBandit(tt.opts) - if (err != nil) != tt.wantErr { - t.Errorf("NewBandit() error = %v, wantErr %v", err, tt.wantErr) - return - } - if tt.want != nil && !reflect.TypeOf(got).AssignableTo(reflect.TypeOf(tt.want)) { - t.Errorf("BanditDialer.DialContext() = %v, want %v", got, tt.want) + dir := "" + if tt.setup != nil { + dir = tt.setup() + defer os.RemoveAll(dir) + tt.opts.BanditDir = dir } + got, err := NewBandit(tt.opts) + tt.assert(t, got, err, dir) }) } } @@ -148,6 +193,14 @@ func TestBanditDialer_DialContext(t *testing.T) { want: expectedConn, wantErr: false, }, + { + name: "should return an error if failed upstream", + opts: &Options{ + Dialers: []ProxyDialer{newFailingTcpConnDialer()}, + }, + want: nil, + wantErr: true, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -161,9 +214,8 @@ func TestBanditDialer_DialContext(t *testing.T) { } got, err := o.DialContext(context.Background(), "tcp", "localhost:8080") - if (err != nil) != tt.wantErr { - t.Errorf("BanditDialer.DialContext() error = %v, wantErr %v", err, tt.wantErr) - return + if tt.wantErr { + assert.Error(t, err) } if tt.want == nil && got != nil { t.Errorf("BanditDialer.DialContext() = %v, want %v", got, tt.want) @@ -281,11 +333,112 @@ func Test_differentArm(t *testing.T) { } } +func TestUpdateBanditRewards(t *testing.T) { + var tests = []struct { + name string + given map[string]banditMetrics + assert func(t *testing.T, dir string, err error) + }{ + { + name: "it should update rewards file", + given: map[string]banditMetrics{ + "test-dialer": { + Reward: 1.0, + Count: 1, + }, + }, + assert: func(t *testing.T, dir string, err error) { + assert.NoError(t, err) + f, err := os.Open(filepath.Join(dir, "rewards.csv")) + require.NoError(t, err) + defer f.Close() + b, err := io.ReadAll(f) + require.NoError(t, err) + + lines := strings.Split(string(b), "\n") + // check if headers are there + assert.Equal(t, lines[0], "dialer,reward,count,updated at") + // check if the data is there + cols := strings.Split(lines[1], ",") + assert.Equal(t, cols[dialerNameCSVHeader], "test-dialer") + assert.Equal(t, cols[rewardCSVHeader], "1.000000") + assert.Equal(t, cols[countCSVHeader], "1") + assert.NotEmpty(t, cols[updatedAtCSVHeader]) + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tempDir, err := os.MkdirTemp("", "client_test") + require.NoError(t, err) + defer os.RemoveAll(tempDir) + + banditDialer := &BanditDialer{ + opts: &Options{ + BanditDir: tempDir, + }, + banditRewardsMutex: new(sync.Mutex), + } + err = banditDialer.updateBanditRewards(tt.given) + tt.assert(t, tempDir, err) + }) + } +} + +func TestLoadLastBanditRewards(t *testing.T) { + now := time.Now().UTC().Unix() + var tests = []struct { + name string + given string + assert func(t *testing.T, metrics map[string]banditMetrics, err error) + }{ + { + name: "it should load the rewards", + given: fmt.Sprintf("dialer,reward,count,updated at\ntest-dialer,1.000000,1,%d\n", now), + assert: func(t *testing.T, metrics map[string]banditMetrics, err error) { + assert.NoError(t, err) + assert.Contains(t, metrics, "test-dialer") + assert.Equal(t, 1.0, metrics["test-dialer"].Reward) + assert.Equal(t, 1, metrics["test-dialer"].Count) + assert.Equal(t, now, metrics["test-dialer"].UpdatedAt) + }, + }, + { + name: "it should ignore dialers with updated at greater than 7 days", + given: fmt.Sprintf("dialer,reward,count,updated at\ntest-dialer,1.000000,1,%d\n", now-60*60*24*8), + assert: func(t *testing.T, metrics map[string]banditMetrics, err error) { + assert.NoError(t, err) + assert.Empty(t, metrics) + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tempDir, err := os.MkdirTemp("", "bandit_test") + require.NoError(t, err) + defer os.RemoveAll(tempDir) + + err = os.WriteFile(filepath.Join(tempDir, "rewards.csv"), []byte(tt.given), 0644) + require.NoError(t, err) + + banditDialer := &BanditDialer{ + opts: &Options{ + BanditDir: tempDir, + }, + banditRewardsMutex: new(sync.Mutex), + } + metrics, err := banditDialer.loadLastBanditRewards() + tt.assert(t, metrics, err) + }) + } +} + func newTcpConnDialer() ProxyDialer { client, server := net.Pipe() return &tcpConnDialer{ client: client, server: server, + name: uuid.New().String(), } } @@ -299,6 +452,7 @@ type tcpConnDialer struct { shouldFail bool client net.Conn server net.Conn + name string } func (*tcpConnDialer) Ready() <-chan error { @@ -397,8 +551,8 @@ func (*tcpConnDialer) MarkFailure() { } // Name implements Dialer. -func (*tcpConnDialer) Name() string { - return "tcpConnDialer" +func (t *tcpConnDialer) Name() string { + return t.name } // NumPreconnected implements Dialer. diff --git a/dialer/dialer.go b/dialer/dialer.go index 877d99738..2d4774246 100644 --- a/dialer/dialer.go +++ b/dialer/dialer.go @@ -82,6 +82,9 @@ type Options struct { // OnSuccess is the callback that is called by dialer after a successful dial. OnSuccess func(ProxyDialer) + + // BanditDir is the directory where the bandit will store its data + BanditDir string } // Clone creates a deep copy of the Options object @@ -93,6 +96,7 @@ func (o *Options) Clone() *Options { Dialers: o.Dialers, OnError: o.OnError, OnSuccess: o.OnSuccess, + BanditDir: o.BanditDir, } }