Skip to content

Commit

Permalink
Revert "Revert "MAB persist decisions across runs (#1449)"" (#1472)
Browse files Browse the repository at this point in the history
* Revert "Revert "MAB persist decisions across runs (#1449)""

This reverts commit e74617a.

* fix: renaming err when updating rewards and also add test scenario for making sure it returns an error when fail upstream

* fix: removing assert.NotNil from error and check only if there was an error as expected
  • Loading branch information
WendelHime authored Dec 17, 2024
1 parent 71c7c9e commit 40b474f
Show file tree
Hide file tree
Showing 4 changed files with 373 additions and 35 deletions.
2 changes: 2 additions & 0 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"io"
"net"
"net/http"
"path/filepath"
"regexp"
"strconv"
"strings"
Expand Down Expand Up @@ -735,6 +736,7 @@ func (client *Client) initDialers(proxies map[string]*commonconfig.ProxyConfig)
)
}
},
BanditDir: filepath.Join(configDir, "bandit"),
})
return dialers, dialer, nil
}
Expand Down
208 changes: 193 additions & 15 deletions dialer/bandit.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,15 @@ package dialer

import (
"context"
"encoding/csv"
"fmt"
"io"
"math/rand"
"net"
"os"
"path/filepath"
"strconv"
"sync"
"sync/atomic"
"time"

Expand All @@ -12,9 +19,16 @@ import (

// BanditDialer is responsible for continually choosing the optimized dialer.
type BanditDialer struct {
dialers []ProxyDialer
bandit *bandit.EpsilonGreedy
opts *Options
dialers []ProxyDialer
bandit bandit.Bandit
opts *Options
banditRewardsMutex *sync.Mutex
}

type banditMetrics struct {
Reward float64
Count int
UpdatedAt int64
}

// NewBandit creates a new bandit given the available dialers and options with
Expand All @@ -29,21 +43,48 @@ func NewBandit(opts *Options) (Dialer, error) {

dialers := opts.Dialers
log.Debugf("Creating bandit with %d dialers", len(dialers))
b, err := bandit.NewEpsilonGreedy(0.1, nil, nil)

var b bandit.Bandit
var err error
dialer := &BanditDialer{
dialers: dialers,
opts: opts,
banditRewardsMutex: &sync.Mutex{},
}

dialerWeights, err := dialer.loadLastBanditRewards()
if err != nil {
log.Errorf("unable to load bandit weights: %v", err)
}
if dialerWeights != nil {
log.Debugf("Loading bandit weights from %q", opts.BanditDir)
counts := make([]int, len(dialers))
rewards := make([]float64, len(dialers))
for arm, dialer := range dialers {
if metrics, ok := dialerWeights[dialer.Name()]; ok {
rewards[arm] = metrics.Reward
counts[arm] = metrics.Count
}
}
b, err = bandit.NewEpsilonGreedy(0.1, counts, rewards)
if err != nil {
log.Errorf("unable to create weighted bandit: %w", err)
return nil, err
}
dialer.bandit = b
return dialer, nil
}

b, err = bandit.NewEpsilonGreedy(0.1, nil, nil)
if err != nil {
log.Errorf("unable to create bandit: %v", err)
return nil, err
}

if err := b.Init(len(dialers)); err != nil {
log.Errorf("unable to initialize bandit: %v", err)
return nil, err
}
dialer := &BanditDialer{
dialers: dialers,
bandit: b,
opts: opts,
}
dialer.bandit = b

return dialer, nil
}
Expand All @@ -70,17 +111,17 @@ func (bd *BanditDialer) DialContext(ctx context.Context, network, addr string) (

if !failedUpstream {
log.Errorf("Dialer %v failed in %v seconds: %v", d.Name(), time.Since(start).Seconds(), err)
if err := bd.bandit.Update(chosenArm, 0); err != nil {
log.Errorf("unable to update bandit: %v", err)
if errUpdatingBanditReward := bd.bandit.Update(chosenArm, 0); errUpdatingBanditReward != nil {
log.Errorf("unable to update bandit: %v", errUpdatingBanditReward)
}
} else {
log.Debugf("Dialer %v failed upstream...", d.Name())
// This can happen, for example, if the upstream server is down, or
// if the DNS resolves to localhost, for example. It is also possible
// that the proxy is blacklisted by upstream sites for some reason,
// so we have to choose some reasonable value.
if err := bd.bandit.Update(chosenArm, 0.00005); err != nil {
log.Errorf("unable to update bandit: %v", err)
if errUpdatingBanditReward := bd.bandit.Update(chosenArm, 0.00005); errUpdatingBanditReward != nil {
log.Errorf("unable to update bandit: %v", errUpdatingBanditReward)
}
}
return nil, err
Expand All @@ -97,15 +138,152 @@ func (bd *BanditDialer) DialContext(ctx context.Context, network, addr string) (
time.AfterFunc(secondsForSample*time.Second, func() {
speed := normalizeReceiveSpeed(dataRecv.Load())
//log.Debugf("Dialer %v received %v bytes in %v seconds, normalized speed: %v", d.Name(), dt.dataRecv, secondsForSample, speed)
if err := bd.bandit.Update(chosenArm, speed); err != nil {
if err = bd.bandit.Update(chosenArm, speed); err != nil {
log.Errorf("unable to update bandit: %v", err)
}
})

time.AfterFunc(30*time.Second, func() {
log.Debugf("saving bandit rewards")
metrics := make(map[string]banditMetrics)
rewards := bd.bandit.GetRewards()
counts := bd.bandit.GetCounts()
for i, d := range bd.dialers {
metrics[d.Name()] = banditMetrics{
Reward: rewards[i],
Count: counts[i],
UpdatedAt: time.Now().UTC().Unix(),
}
}

err = bd.updateBanditRewards(metrics)
if err != nil {
log.Errorf("unable to save bandit weights: %v", err)
}
})

bd.opts.OnSuccess(d)
return dt, err
}

const (
dialerNameCSVHeader = iota
rewardCSVHeader
countCSVHeader
updatedAtCSVHeader

unusedBanditDialerIgnoredAfter = 7 * 24 * time.Hour
)

// loadLastBanditRewards is a function that returns the last bandit rewards
// for each dialer. If this is set, the bandit will be initialized with the
// last metrics.
func (o *BanditDialer) loadLastBanditRewards() (map[string]banditMetrics, error) {
o.banditRewardsMutex.Lock()
defer o.banditRewardsMutex.Unlock()
if o.opts.BanditDir == "" {
return nil, log.Error("bandit directory is not set")
}

file := filepath.Join(o.opts.BanditDir, "rewards.csv")
data, err := os.Open(file)
if err != nil {
return nil, err
}

reader := csv.NewReader(data)
// Skip the header, but read it so the csv reader know the expected number of columns
_, err = reader.Read()
if err != nil {
return nil, log.Errorf("unable to skip headers from bandit rewards csv: %w", err)
}
metrics := make(map[string]banditMetrics)
for {
line, err := reader.Read()
if err == io.EOF {
break
}
if err != nil {
return nil, log.Errorf("unable to read line from bandit rewards csv: %w", err)
}

// load updatedAt unix time and check if it's older than 7 days
updatedAt, err := strconv.ParseInt(line[updatedAtCSVHeader], 10, 64)
if err != nil {
return nil, log.Errorf("unable to parse updated at from %s: %w", line[0], err)
}
if time.Since(time.Unix(updatedAt, 0)) > unusedBanditDialerIgnoredAfter {
log.Debugf("Ignoring bandit dialer %s as it's older than 7 days", line[0])
continue
}
reward, err := strconv.ParseFloat(line[rewardCSVHeader], 64)
if err != nil {
return nil, log.Errorf("unable to parse reward from %s: %w", line[0], err)
}
count, err := strconv.Atoi(line[countCSVHeader])
if err != nil {
return nil, log.Errorf("unable to parse count from %s: %w", line[0], err)
}

metrics[line[dialerNameCSVHeader]] = banditMetrics{
Reward: reward,
Count: count,
UpdatedAt: updatedAt,
}
}
return metrics, nil
}

func (o *BanditDialer) updateBanditRewards(newRewards map[string]banditMetrics) error {
if err := os.MkdirAll(o.opts.BanditDir, 0755); err != nil {
return log.Errorf("unable to create bandit directory: %v", err)
}

previousRewards, err := o.loadLastBanditRewards()
if err != nil && !os.IsNotExist(err) {
return log.Errorf("couldn't load previous bandit rewards: %w", err)
}
o.banditRewardsMutex.Lock()
defer o.banditRewardsMutex.Unlock()

// if there's previous rewards, we must overwrite current values
if previousRewards != nil {
for dialer, metrics := range newRewards {
previousRewards[dialer] = metrics
}
} else {
previousRewards = newRewards
}

if o.opts.BanditDir == "" {
return log.Error("bandit directory is not set")
}

file := filepath.Join(o.opts.BanditDir, "rewards.csv")

headers := []string{"dialer", "reward", "count", "updated at"}
f, err := os.OpenFile(file, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644)
if err != nil {
return log.Errorf("unable to open bandit rewards file: %v", err)
}
defer f.Close()

w := csv.NewWriter(f)
defer w.Flush()

if err = w.Write(headers); err != nil {
return log.Errorf("unable to write headers to bandit rewards file: %v", err)
}

for dialerName, metric := range previousRewards {
if err = w.Write([]string{dialerName, fmt.Sprintf("%f", metric.Reward), fmt.Sprintf("%d", metric.Count), fmt.Sprintf("%d", metric.UpdatedAt)}); err != nil {
return log.Errorf("unable to write bandit rewards to file: %v", err)
}
}

return nil
}

func (o *BanditDialer) chooseDialerForDomain(network, addr string) (ProxyDialer, int) {
// Loop through the number of dialers we have and select the one that is best
// for the given domain.
Expand Down
Loading

0 comments on commit 40b474f

Please sign in to comment.