Skip to content

Commit

Permalink
refactor: update chainPaymentState
Browse files Browse the repository at this point in the history
  • Loading branch information
hopeyen committed Oct 25, 2024
1 parent c831374 commit f5835a8
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 25 deletions.
42 changes: 21 additions & 21 deletions core/meterer/meterer.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ type Config struct {
// payments information is valid.
type Meterer struct {
Config
// ChainState reads on-chain payment state periodically and cache it in memory
ChainState OnchainPayment
// ChainPaymentState reads on-chain payment state periodically and cache it in memory
ChainPaymentState OnchainPayment
// OffchainStore uses DynamoDB to track metering and used to validate requests
OffchainStore OffchainStore

Expand All @@ -39,8 +39,8 @@ func NewMeterer(
return &Meterer{
Config: config,

ChainState: paymentChainState,
OffchainStore: offchainStore,
ChainPaymentState: paymentChainState,
OffchainStore: offchainStore,

logger: logger.With("component", "Meterer"),
}
Expand All @@ -53,14 +53,14 @@ func (m *Meterer) Start(ctx context.Context) {
defer ticker.Stop()

// initial tick immediately upto Start
if err := m.ChainState.RefreshOnchainPaymentState(ctx, nil); err != nil {
if err := m.ChainPaymentState.RefreshOnchainPaymentState(ctx, nil); err != nil {
m.logger.Error("Failed to make initial query to the on-chain state", "error", err)
}

for {
select {
case <-ticker.C:
if err := m.ChainState.RefreshOnchainPaymentState(ctx, nil); err != nil {
if err := m.ChainPaymentState.RefreshOnchainPaymentState(ctx, nil); err != nil {
m.logger.Error("Failed to refresh on-chain state", "error", err)
}
case <-ctx.Done():
Expand All @@ -76,15 +76,15 @@ func (m *Meterer) MeterRequest(ctx context.Context, blob core.Blob, header core.
headerQuorums := blob.GetQuorumNumbers()
// Validate against the payment method
if header.CumulativePayment.Sign() == 0 {
reservation, err := m.ChainState.GetActiveReservationByAccount(ctx, header.AccountID)
reservation, err := m.ChainPaymentState.GetActiveReservationByAccount(ctx, header.AccountID)
if err != nil {
return fmt.Errorf("failed to get active reservation by account: %w", err)
}
if err := m.ServeReservationRequest(ctx, header, &reservation, blob.RequestHeader.BlobAuthHeader.Length, headerQuorums); err != nil {
return fmt.Errorf("invalid reservation: %w", err)
}
} else {
onDemandPayment, err := m.ChainState.GetOnDemandPaymentByAccount(ctx, header.AccountID)
onDemandPayment, err := m.ChainPaymentState.GetOnDemandPaymentByAccount(ctx, header.AccountID)
if err != nil {
return fmt.Errorf("failed to get on-demand payment by account: %w", err)
}
Expand Down Expand Up @@ -135,7 +135,7 @@ func (m *Meterer) ValidateQuorum(headerQuorums []uint8, allowedQuorums []uint8)
// ValidateBinIndex checks if the provided bin index is valid
func (m *Meterer) ValidateBinIndex(header core.PaymentMetadata, reservation *core.ActiveReservation) bool {
now := uint64(time.Now().Unix())
reservationWindow := m.ChainState.GetReservationWindow()
reservationWindow := m.ChainPaymentState.GetReservationWindow()
currentBinIndex := GetBinIndex(now, reservationWindow)
// Valid bin indexes are either the current bin or the previous bin
if (header.BinIndex != currentBinIndex && header.BinIndex != (currentBinIndex-1)) || (GetBinIndex(reservation.StartTimestamp, reservationWindow) > header.BinIndex || header.BinIndex > GetBinIndex(reservation.EndTimestamp, reservationWindow)) {
Expand All @@ -160,7 +160,7 @@ func (m *Meterer) IncrementBinUsage(ctx context.Context, header core.PaymentMeta
// metered usage before updating the size already exceeded the limit
return fmt.Errorf("bin has already been filled")
}
if newUsage <= 2*usageLimit && header.BinIndex+2 <= GetBinIndex(reservation.EndTimestamp, m.ChainState.GetReservationWindow()) {
if newUsage <= 2*usageLimit && header.BinIndex+2 <= GetBinIndex(reservation.EndTimestamp, m.ChainPaymentState.GetReservationWindow()) {
_, err := m.OffchainStore.UpdateReservationBin(ctx, header.AccountID, uint64(header.BinIndex+2), newUsage-usageLimit)
if err != nil {
return err
Expand All @@ -180,7 +180,7 @@ func GetBinIndex(timestamp uint64, binInterval uint32) uint32 {
// On-demand requests doesn't have additional quorum settings and should only be
// allowed by ETH and EIGEN quorums
func (m *Meterer) ServeOnDemandRequest(ctx context.Context, header core.PaymentMetadata, onDemandPayment *core.OnDemandPayment, blobLength uint, headerQuorums []uint8) error {
quorumNumbers, err := m.ChainState.GetOnDemandQuorumNumbers(ctx)
quorumNumbers, err := m.ChainPaymentState.GetOnDemandQuorumNumbers(ctx)
if err != nil {
return fmt.Errorf("failed to get on-demand quorum numbers: %w", err)
}
Expand Down Expand Up @@ -244,21 +244,21 @@ func (m *Meterer) ValidatePayment(ctx context.Context, header core.PaymentMetada

// PaymentCharged returns the chargeable price for a given data length
func (m *Meterer) PaymentCharged(dataLength uint) uint64 {
fmt.Println("PaymentCharged", dataLength, m.SymbolsCharged(dataLength), m.ChainState.GetPricePerSymbol())
return uint64(m.SymbolsCharged(dataLength)) * uint64(m.ChainState.GetPricePerSymbol())
fmt.Println("PaymentCharged", dataLength, m.SymbolsCharged(dataLength), m.ChainPaymentState.GetPricePerSymbol())
return uint64(m.SymbolsCharged(dataLength)) * uint64(m.ChainPaymentState.GetPricePerSymbol())
}

// SymbolsCharged returns the number of symbols charged for a given data length
// being at least MinNumSymbols or the nearest rounded-up multiple of MinNumSymbols.
func (m *Meterer) SymbolsCharged(dataLength uint) uint32 {
fmt.Println("SymbolsCharged", dataLength, m.ChainState.GetMinNumSymbols())
if dataLength <= uint(m.ChainState.GetMinNumSymbols()) {
fmt.Println("return ", m.ChainState.GetMinNumSymbols())
return m.ChainState.GetMinNumSymbols()
fmt.Println("SymbolsCharged", dataLength, m.ChainPaymentState.GetMinNumSymbols())
if dataLength <= uint(m.ChainPaymentState.GetMinNumSymbols()) {
fmt.Println("return ", m.ChainPaymentState.GetMinNumSymbols())
return m.ChainPaymentState.GetMinNumSymbols()
}
// Round up to the nearest multiple of MinNumSymbols
fmt.Println("return ", uint32(core.RoundUpDivide(uint(dataLength), uint(m.ChainState.GetMinNumSymbols())))*m.ChainState.GetMinNumSymbols())
return uint32(core.RoundUpDivide(uint(dataLength), uint(m.ChainState.GetMinNumSymbols()))) * m.ChainState.GetMinNumSymbols()
fmt.Println("return ", uint32(core.RoundUpDivide(uint(dataLength), uint(m.ChainPaymentState.GetMinNumSymbols())))*m.ChainPaymentState.GetMinNumSymbols())
return uint32(core.RoundUpDivide(uint(dataLength), uint(m.ChainPaymentState.GetMinNumSymbols()))) * m.ChainPaymentState.GetMinNumSymbols()
}

// ValidateBinIndex checks if the provided bin index is valid
Expand All @@ -280,13 +280,13 @@ func (m *Meterer) IncrementGlobalBinUsage(ctx context.Context, symbolsCharged ui
if err != nil {
return fmt.Errorf("failed to increment global bin usage: %w", err)
}
if newUsage > m.ChainState.GetGlobalSymbolsPerSecond() {
if newUsage > m.ChainPaymentState.GetGlobalSymbolsPerSecond() {
return fmt.Errorf("global bin usage overflows")
}
return nil
}

// GetReservationBinLimit returns the bin limit for a given reservation
func (m *Meterer) GetReservationBinLimit(reservation *core.ActiveReservation) uint64 {
return reservation.SymbolsPerSec * uint64(m.ChainState.GetReservationWindow())
return reservation.SymbolsPerSec * uint64(m.ChainPaymentState.GetReservationWindow())
}
8 changes: 4 additions & 4 deletions core/meterer/meterer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ func TestMetererReservations(t *testing.T) {
paymentChainState.On("GetGlobalSymbolsPerSecond", testifymock.Anything).Return(uint64(1009), nil)
paymentChainState.On("GetMinNumSymbols", testifymock.Anything).Return(uint32(3), nil)

binIndex := meterer.GetBinIndex(uint64(time.Now().Unix()), mt.ChainState.GetReservationWindow())
binIndex := meterer.GetBinIndex(uint64(time.Now().Unix()), mt.ChainPaymentState.GetReservationWindow())
quoromNumbers := []uint8{0, 1}

paymentChainState.On("GetActiveReservationByAccount", testifymock.Anything, testifymock.MatchedBy(func(account string) bool {
Expand Down Expand Up @@ -295,7 +295,7 @@ func TestMetererOnDemand(t *testing.T) {
// test duplicated cumulative payments
dataLength := uint(100)
priceCharged := mt.PaymentCharged(dataLength)
assert.Equal(t, uint64(102*mt.ChainState.GetPricePerSymbol()), priceCharged)
assert.Equal(t, uint64(102*mt.ChainPaymentState.GetPricePerSymbol()), priceCharged)
blob, header = createMetererInput(binIndex, priceCharged, dataLength, quorumNumbers, accountID2)
err = mt.MeterRequest(ctx, *blob, *header)
assert.NoError(t, err)
Expand Down Expand Up @@ -401,7 +401,7 @@ func TestMeterer_paymentCharged(t *testing.T) {
paymentChainState.On("GetMinNumSymbols", testifymock.Anything).Return(uint32(tt.minNumSymbols), nil)
t.Run(tt.name, func(t *testing.T) {
m := &meterer.Meterer{
ChainState: paymentChainState,
ChainPaymentState: paymentChainState,
}
result := m.PaymentCharged(tt.dataLength)
assert.Equal(t, tt.expected, result)
Expand Down Expand Up @@ -453,7 +453,7 @@ func TestMeterer_symbolsCharged(t *testing.T) {
paymentChainState.On("GetMinNumSymbols", testifymock.Anything).Return(uint32(tt.minNumSymbols), nil)
t.Run(tt.name, func(t *testing.T) {
m := &meterer.Meterer{
ChainState: paymentChainState,
ChainPaymentState: paymentChainState,
}
result := m.SymbolsCharged(tt.dataLength)
assert.Equal(t, tt.expected, result)
Expand Down

0 comments on commit f5835a8

Please sign in to comment.