Skip to content

Commit

Permalink
feat(taiko-client): improve ProofBuffer (#18627)
Browse files Browse the repository at this point in the history
  • Loading branch information
davidtaikocha authored Dec 23, 2024
1 parent 45603f7 commit c386589
Show file tree
Hide file tree
Showing 8 changed files with 103 additions and 67 deletions.
4 changes: 2 additions & 2 deletions packages/taiko-client/cmd/flags/prover.go
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ var (
Category: proverCategory,
EnvVars: []string{"PROVER_ZKVM_BATCH_SIZE"},
}
ForceProveInterval = &cli.DurationFlag{
ForceBatchProvingInterval = &cli.DurationFlag{
Name: "prover.forceBatchProvingInterval",
Usage: "Time interval to prove blocks even the number of pending proof do not exceed prover.batchSize, " +
"this flag only works post Ontake fork",
Expand Down Expand Up @@ -254,5 +254,5 @@ var ProverFlags = MergeFlags(CommonFlags, []cli.Flag{
RaikoZKVMHostEndpoint,
SGXBatchSize,
ZKVMBatchSize,
ForceProveInterval,
ForceBatchProvingInterval,
}, TxmgrFlags)
8 changes: 4 additions & 4 deletions packages/taiko-client/prover/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ type Config struct {
PrivateTxmgrConfigs *txmgr.CLIConfig
SGXProofBufferSize uint64
ZKVMProofBufferSize uint64
ForceProveInterval time.Duration
ForceBatchProvingInterval time.Duration
}

// NewConfigFromCliContext creates a new config instance from command line flags.
Expand Down Expand Up @@ -186,8 +186,8 @@ func NewConfigFromCliContext(c *cli.Context) (*Config, error) {
l1ProverPrivKey,
c,
),
SGXProofBufferSize: c.Uint64(flags.SGXBatchSize.Name),
ZKVMProofBufferSize: c.Uint64(flags.ZKVMBatchSize.Name),
ForceProveInterval: c.Duration(flags.ForceProveInterval.Name),
SGXProofBufferSize: c.Uint64(flags.SGXBatchSize.Name),
ZKVMProofBufferSize: c.Uint64(flags.ZKVMBatchSize.Name),
ForceBatchProvingInterval: c.Duration(flags.ForceBatchProvingInterval.Name),
}, nil
}
1 change: 1 addition & 0 deletions packages/taiko-client/prover/init.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ func (p *Prover) initProofSubmitters(
p.IsGuardianProver(),
p.cfg.GuardianProofSubmissionDelay,
bufferSize,
p.cfg.ForceBatchProvingInterval,
); err != nil {
return err
}
Expand Down
1 change: 1 addition & 0 deletions packages/taiko-client/prover/proof_submitter/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ type Submitter interface {
Producer() proofProducer.ProofProducer
Tier() uint16
BufferSize() uint64
AggregationEnabled() bool
}

// Contester is the interface for contesting proofs of the L2 blocks.
Expand Down
44 changes: 34 additions & 10 deletions packages/taiko-client/prover/proof_submitter/proof_buffer.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package submitter
import (
"errors"
"sync"
"time"

producer "github.com/taikoxyz/taiko-mono/packages/taiko-client/prover/proof_producer"
)
Expand All @@ -14,16 +15,19 @@ var (

// ProofBuffer caches all single proof with a fixed size.
type ProofBuffer struct {
MaxLength uint64
buffer []*producer.ProofWithHeader
mutex sync.RWMutex
MaxLength uint64
buffer []*producer.ProofWithHeader
lastUpdatedAt time.Time
isAggregating bool
mutex sync.RWMutex
}

// NewProofBuffer creates a new ProofBuffer instance.
func NewProofBuffer(maxLength uint64) *ProofBuffer {
return &ProofBuffer{
buffer: make([]*producer.ProofWithHeader, 0, maxLength),
MaxLength: maxLength,
buffer: make([]*producer.ProofWithHeader, 0, maxLength),
lastUpdatedAt: time.Now(),
MaxLength: maxLength,
}
}

Expand All @@ -37,6 +41,7 @@ func (pb *ProofBuffer) Write(item *producer.ProofWithHeader) (int, error) {
}

pb.buffer = append(pb.buffer, item)
pb.lastUpdatedAt = time.Now()
return len(pb.buffer), nil
}

Expand Down Expand Up @@ -65,11 +70,14 @@ func (pb *ProofBuffer) Len() int {
return len(pb.buffer)
}

// Clear clears all buffer.
func (pb *ProofBuffer) Clear() {
pb.mutex.Lock()
defer pb.mutex.Unlock()
pb.buffer = pb.buffer[:0]
// LastUpdatedAt returns the last updated time of the buffer.
func (pb *ProofBuffer) LastUpdatedAt() time.Time {
return pb.lastUpdatedAt
}

// LastUpdatedAt returns the last updated time of the buffer.
func (pb *ProofBuffer) UpdateLastUpdatedAt() {
pb.lastUpdatedAt = time.Now()
}

// ClearItems clears items that has given block ids in the buffer.
Expand All @@ -94,5 +102,21 @@ func (pb *ProofBuffer) ClearItems(blockIDs ...uint64) int {
}

pb.buffer = newBuffer
pb.isAggregating = false
return clearedCount
}

// MarkAggregating marks the proofs in this buffer are aggregating.
func (pb *ProofBuffer) MarkAggregating() {
pb.isAggregating = true
}

// IsAggregating returns if the proofs in this buffer are aggregating.
func (pb *ProofBuffer) IsAggregating() bool {
return pb.isAggregating
}

// Enabled returns if the buffer is enabled.
func (pb *ProofBuffer) Enabled() bool {
return pb.MaxLength > 1
}
84 changes: 53 additions & 31 deletions packages/taiko-client/prover/proof_submitter/proof_submitter.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ type ProofSubmitter struct {
isGuardian bool
submissionDelay time.Duration
// Batch proof related
proofBuffer *ProofBuffer
proofBuffer *ProofBuffer
forceBatchProvingInterval time.Duration
}

// NewProofSubmitter creates a new ProofSubmitter instance.
Expand All @@ -74,29 +75,31 @@ func NewProofSubmitter(
isGuardian bool,
submissionDelay time.Duration,
proofBufferSize uint64,
forceBatchProvingInterval time.Duration,
) (*ProofSubmitter, error) {
anchorValidator, err := validator.New(taikoL2Address, rpcClient.L2.ChainID, rpcClient)
if err != nil {
return nil, err
}

return &ProofSubmitter{
rpc: rpcClient,
proofProducer: proofProducer,
resultCh: resultCh,
batchResultCh: batchResultCh,
aggregationNotify: aggregationNotify,
anchorValidator: anchorValidator,
txBuilder: builder,
sender: transaction.NewSender(rpcClient, txmgr, privateTxmgr, proverSetAddress, gasLimit),
proverAddress: txmgr.From(),
proverSetAddress: proverSetAddress,
taikoL2Address: taikoL2Address,
graffiti: rpc.StringToBytes32(graffiti),
tiers: tiers,
isGuardian: isGuardian,
submissionDelay: submissionDelay,
proofBuffer: NewProofBuffer(proofBufferSize),
rpc: rpcClient,
proofProducer: proofProducer,
resultCh: resultCh,
batchResultCh: batchResultCh,
aggregationNotify: aggregationNotify,
anchorValidator: anchorValidator,
txBuilder: builder,
sender: transaction.NewSender(rpcClient, txmgr, privateTxmgr, proverSetAddress, gasLimit),
proverAddress: txmgr.From(),
proverSetAddress: proverSetAddress,
taikoL2Address: taikoL2Address,
graffiti: rpc.StringToBytes32(graffiti),
tiers: tiers,
isGuardian: isGuardian,
submissionDelay: submissionDelay,
proofBuffer: NewProofBuffer(proofBufferSize),
forceBatchProvingInterval: forceBatchProvingInterval,
}, nil
}

Expand Down Expand Up @@ -143,7 +146,7 @@ func (s *ProofSubmitter) RequestProof(ctx context.Context, meta metadata.TaikoBl
Graffiti: common.Bytes2Hex(s.graffiti[:]),
GasUsed: header.GasUsed,
ParentGasUsed: parent.GasUsed(),
Compressed: s.proofBuffer.MaxLength > 1,
Compressed: s.proofBuffer.Enabled(),
}

// If the prover set address is provided, we use that address as the prover on chain.
Expand All @@ -159,9 +162,9 @@ func (s *ProofSubmitter) RequestProof(ctx context.Context, meta metadata.TaikoBl
log.Error("Failed to request proof, context is canceled", "blockID", opts.BlockID, "error", ctx.Err())
return nil
}
// Check if the proof buffer is full
if s.proofBuffer.MaxLength > 1 && s.proofBuffer.MaxLength == uint64(s.proofBuffer.Len()) {
log.Debug("Buffer is full now", "blockID", meta.GetBlockID())
// Check if the proof buffer is full.
if s.proofBuffer.Enabled() && uint64(s.proofBuffer.Len()) >= s.proofBuffer.MaxLength {
log.Warn("Proof buffer is full now", "blockID", meta.GetBlockID())
return errBufferOverflow
}
// Check if there is a need to generate proof
Expand Down Expand Up @@ -198,21 +201,30 @@ func (s *ProofSubmitter) RequestProof(ctx context.Context, meta metadata.TaikoBl
}
return fmt.Errorf("failed to request proof (id: %d): %w", meta.GetBlockID(), err)
}
if meta.IsOntakeBlock() && s.proofBuffer.MaxLength > 1 {
if meta.IsOntakeBlock() && s.proofBuffer.Enabled() {
bufferSize, err := s.proofBuffer.Write(result)
if err != nil {
return fmt.Errorf("failed to add proof into buffer (id: %d)(current buffer size: %d): %w",
return fmt.Errorf(
"failed to add proof into buffer (id: %d) (current buffer size: %d): %w",
meta.GetBlockID(),
bufferSize,
err,
)
}
log.Debug("Succeed to generate proof",
log.Info(
"Proof generated",
"blockID", meta.GetBlockID(),
"bufferSize", bufferSize,
"maxBufferSize", s.proofBuffer.MaxLength,
"bufferIsAggregating", s.proofBuffer.IsAggregating(),
"bufferLastUpdatedAt", s.proofBuffer.lastUpdatedAt,
)
if s.proofBuffer.MaxLength == uint64(bufferSize) {
// Check if we need to aggregate proofs.
if !s.proofBuffer.IsAggregating() &&
(uint64(bufferSize) >= s.proofBuffer.MaxLength ||
time.Since(s.proofBuffer.lastUpdatedAt) > s.forceBatchProvingInterval) {
s.aggregationNotify <- s.Tier()
s.proofBuffer.MarkAggregating()
}
} else {
s.resultCh <- result
Expand Down Expand Up @@ -344,7 +356,8 @@ func (s *ProofSubmitter) BatchSubmitProofs(ctx context.Context, batchProof *proo
)
var (
invalidBlockIDs []uint64
latestProvenBlockID = big.NewInt(0)
latestProvenBlockID = common.Big0
uint64BlockIDs []uint64
)
if len(batchProof.Proofs) == 0 {
return proofProducer.ErrInvalidLength
Expand All @@ -369,27 +382,29 @@ func (s *ProofSubmitter) BatchSubmitProofs(ctx context.Context, batchProof *proo
return err
}
for i, proof := range batchProof.Proofs {
uint64BlockIDs = append(uint64BlockIDs, proof.BlockID.Uint64())
// Check if this proof is still needed to be submitted.
ok, err := s.sender.ValidateProof(ctx, proof, new(big.Int).SetUint64(stateVars.B.LastVerifiedBlockId))
if err != nil {
return err
}
if !ok {
log.Error("a valid proof for block is already submitted", "blockId", proof.BlockID)
log.Error("A valid proof for block is already submitted", "blockId", proof.BlockID)
invalidBlockIDs = append(invalidBlockIDs, proof.BlockID.Uint64())
continue
}

if proofStatus[i].IsSubmitted && !proofStatus[i].Invalid {
log.Error("a valid proof for block is already submitted", "blockId", proof.BlockID)
log.Error("A valid proof for block is already submitted", "blockId", proof.BlockID)
invalidBlockIDs = append(invalidBlockIDs, proof.BlockID.Uint64())
continue
}

// Get the corresponding L2 block.
block, err := s.rpc.L2.BlockByHash(ctx, proof.Header.Hash())
if err != nil {
log.Error("failed to get L2 block with given hash",
log.Error(
"Failed to get L2 block with given hash",
"hash", proof.Header.Hash(),
"error", err,
)
Expand All @@ -415,7 +430,7 @@ func (s *ProofSubmitter) BatchSubmitProofs(ctx context.Context, batchProof *proo
}

if len(invalidBlockIDs) > 0 {
log.Warn("Detected invalid proofs", "blockIds", invalidBlockIDs)
log.Warn("Invalid proofs in batch", "blockIds", invalidBlockIDs)
s.proofBuffer.ClearItems(invalidBlockIDs...)
return ErrInvalidProof
}
Expand All @@ -435,7 +450,9 @@ func (s *ProofSubmitter) BatchSubmitProofs(ctx context.Context, batchProof *proo

metrics.ProverSentProofCounter.Add(float64(len(batchProof.BlockIDs)))
metrics.ProverLatestProvenBlockIDGauge.Set(float64(latestProvenBlockID.Uint64()))
s.proofBuffer.Clear()
s.proofBuffer.ClearItems(uint64BlockIDs...)
// Each time we submit a batch proof, we should update the LastUpdatedAt() of the buffer.
s.proofBuffer.UpdateLastUpdatedAt()

return nil
}
Expand Down Expand Up @@ -511,3 +528,8 @@ func (s *ProofSubmitter) Tier() uint16 {
func (s *ProofSubmitter) BufferSize() uint64 {
return s.proofBuffer.MaxLength
}

// AggregationEnabled returns whether the proof submitter's aggregation feature is enabled.
func (s *ProofSubmitter) AggregationEnabled() bool {
return s.proofBuffer.Enabled()
}
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ func (s *ProofSubmitterTestSuite) SetupTest() {
false,
0*time.Second,
0,
30*time.Minute,
)
s.Nil(err)
s.contester = NewProofContester(
Expand Down Expand Up @@ -199,6 +200,7 @@ func (s *ProofSubmitterTestSuite) TestGetRandomBumpedSubmissionDelay() {
false,
time.Duration(0),
0,
30*time.Minute,
)
s.Nil(err)

Expand All @@ -223,6 +225,7 @@ func (s *ProofSubmitterTestSuite) TestGetRandomBumpedSubmissionDelay() {
false,
1*time.Hour,
0,
30*time.Minute,
)
s.Nil(err)
delay, err = submitter2.getRandomBumpedSubmissionDelay(time.Now())
Expand Down
Loading

0 comments on commit c386589

Please sign in to comment.