Skip to content

Commit

Permalink
support more match fields in rule
Browse files Browse the repository at this point in the history
  • Loading branch information
IrineSistiana committed Oct 20, 2024
1 parent f2e13c6 commit c04e06b
Show file tree
Hide file tree
Showing 4 changed files with 138 additions and 26 deletions.
16 changes: 10 additions & 6 deletions app/router/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,14 @@ type DomainSetConfig struct {
}

type RuleConfig struct {
Reverse bool `yaml:"reverse"`
Domain string `yaml:"domain"`
Reject uint16 `yaml:"reject"`
Forward string `yaml:"forward"`
Reverse bool `yaml:"reverse"`
Domain string `yaml:"domain"`
Server string `yaml:"server"`
ServerName string `yaml:"server_name"`
Path string `yaml:"path"`
ClientIP []string `yaml:"client_ip"`
Reject uint16 `yaml:"reject"`
Forward string `yaml:"forward"`
}

type AddonsConfig struct{}
Expand All @@ -135,8 +139,8 @@ type CacheConfig struct {
}

type ECSConfig struct {
Enabled bool `yaml:"enabled"`
Forward bool `yaml:"forward"`
Enabled bool `yaml:"enabled"`
Forward bool `yaml:"forward"`
IpZone []string `yaml:"ip_zone"`
ZoneOverwrite []string `yaml:"zone_overwrite"`
}
Expand Down
15 changes: 4 additions & 11 deletions app/router/router_handle.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,24 +70,17 @@ func (r *Router) BuiltInHandler(ctx context.Context, q *QueryCtx) {
// Match rules
var matchedRule *rule
for _, rule := range r.rules {
if rule.matcher != nil {
matched := rule.matcher.Match(q.Question.Name)
if rule.reverse {
matched = !matched
}
if !matched {
continue
}
if rule.match(q) {
matchedRule = rule
break
}
matchedRule = rule
break
}

if matchedRule == nil {
SetEmptyRespMQ(q, dnsmsg.RCodeRefused)
return
}
if rejectRCode := matchedRule.reject; rejectRCode > 0 {
if rejectRCode := matchedRule.cfg.Reject; rejectRCode > 0 {
SetEmptyRespMQ(q, dnsmsg.RCode(rejectRCode))
return
}
Expand Down
129 changes: 120 additions & 9 deletions app/router/rule.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,31 @@ package router

import (
"fmt"
"net/netip"
"strings"

"github.com/IrineSistiana/mosproxy/internal/netlist"
)

type rule struct {
reverse bool
matcher *DomainSet
reject uint16
upstream Upstream // maybe nil
cfg RuleConfig
domainSet *DomainSet // maybe nil
clientIp *netlist.List[struct{}] // maybe nil
upstream Upstream // maybe nil
}

func (r *Router) loadRule(cfg RuleConfig) (*rule, error) {
ru := new(rule)
ru := &rule{
cfg: cfg,
}
if len(cfg.Domain) > 0 {
m := r.domainSets[cfg.Domain]
if m == nil {
return nil, fmt.Errorf("cannot find domain set tag [%s]", cfg.Domain)
}
ru.matcher = m
ru.reverse = cfg.Reverse
ru.domainSet = m
}

ru.reject = cfg.Reject

if len(cfg.Forward) > 0 {
uw := r.upstreams[cfg.Forward]
if uw != nil {
Expand All @@ -37,5 +40,113 @@ func (r *Router) loadRule(cfg RuleConfig) (*rule, error) {
}
}
}

// Copied from https://github.com/go4org/netipx/blob/fdeea329fbbac19fb83c9cfda32c4fcac39bbaab/netipx.go#L188
lastIp := func(p netip.Prefix) netip.Addr {
if !p.IsValid() {
return netip.Addr{}
}
a16 := p.Addr().As16()
var off uint8
var bits uint8 = 128
if p.Addr().Is4() {
off = 12
bits = 32
}
for b := uint8(p.Bits()); b < bits; b++ {
byteNum, bitInByte := b/8, 7-(b%8)
a16[off+byteNum] |= 1 << uint(bitInByte)
}
if p.Addr().Is4() {
return netip.AddrFrom16(a16).Unmap()
} else {
return netip.AddrFrom16(a16) // doesn't unmap
}
}

if len(cfg.ClientIP) > 0 {
lb := netlist.NewBuilder[struct{}](0)
for _, s := range cfg.ClientIP {
var (
start, end netip.Addr
err error
)
s1, s2, ok := strings.Cut(s, "-")
if ok { // range s1-s2
start, err = netip.ParseAddr(s1)
if err != nil {
return nil, fmt.Errorf("invalid start addr [%s], %w", s1, err)
}
end, err = netip.ParseAddr(s2)
if err != nil {
return nil, fmt.Errorf("invalid end addr [%s], %w", s2, err)
}
} else if strings.ContainsRune(s, '/') { //cidr
p, err := netip.ParsePrefix(s)
if err != nil {
return nil, fmt.Errorf("invalid cidr addr [%s], %w", s, err)
}
start = p.Masked().Addr()
end = lastIp(p)
} else { // single ip
addr, err := netip.ParseAddr(s)
if err != nil {
return nil, fmt.Errorf("invalid addr [%s], %w", s, err)
}
start = addr
end = addr
}

ok = lb.Add(start, end, struct{}{})
if !ok {
return nil, fmt.Errorf("invalid addr range [%s-%s]", start, end)
}
}

l, err := lb.Build()
if err != nil {
return nil, fmt.Errorf("failed to build client addr index, %w", err)
}
ru.clientIp = l
}
return ru, nil
}

func (ru *rule) match(q *QueryCtx) bool {
ok := ru._match(q)
if ru.cfg.Reverse {
return !ok
}
return ok
}

func (ru *rule) _match(q *QueryCtx) bool {
if ru.domainSet != nil {
ok := ru.domainSet.Match(q.Question.Name)
if !ok {
return false
}
}

if len(ru.cfg.Server) > 0 && ru.cfg.Server != q.ServerTag {
return false
}

if len(ru.cfg.ServerName) > 0 && ru.cfg.ServerName != string(q.ServerName) {
return false
}

if p := ru.cfg.Path; len(p) > 0 {
if strings.HasSuffix(p, "/") { // match url prefix
ok := len(q.Path) >= len(q.Path) && string(q.Path[0:len(p)]) == p
if !ok {
return false
}
} else {
if p != string(q.Path) {
return false
}
}
}
return true
}
4 changes: 4 additions & 0 deletions internal/netlist/netlist_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ func Test_Netlist(t *testing.T) {
add("0.0.0.0", "0.0.0.4", 1, b)
add("0.0.0.5", "0.0.0.6", 2, b)
add("0.0.0.7", "0.0.0.8", 3, b)
add("0.0.0.255", "0.0.0.255", 4, b)
l, err = b.Build()
r.NoError(err)
r.NotNil(l)
Expand All @@ -52,6 +53,9 @@ func Test_Netlist(t *testing.T) {
v, ok = l.Lookup(ipf("0.0.0.8")) // matched
r.True(ok)
r.Equal(3, v)
v, ok = l.Lookup(ipf("0.0.0.255")) // matched
r.True(ok)
r.Equal(4, v)
v, ok = l.Lookup(ipf("0.0.0.10")) // not matched
r.Equal(0, v)
r.False(ok)
Expand Down

0 comments on commit c04e06b

Please sign in to comment.