Skip to content

Commit

Permalink
feat: add function to manually check rate limit (#346)
Browse files Browse the repository at this point in the history
  • Loading branch information
Charioteer committed Nov 10, 2024
1 parent 3b8431c commit e22b644
Show file tree
Hide file tree
Showing 2 changed files with 137 additions and 66 deletions.
152 changes: 101 additions & 51 deletions index.js
Original file line number Diff line number Diff line change
Expand Up @@ -124,16 +124,17 @@ async function fastifyRateLimit (fastify, settings) {

fastify.decorateRequest(pluginComponent.rateLimitRan, false)

if (!fastify.hasDecorator('createRateLimit')) {
fastify.decorate('createRateLimit', (options) => {
const args = createLimiterArgs(pluginComponent, globalParams, options)
return (req) => applyRateLimit(...args, req)
})
}

if (!fastify.hasDecorator('rateLimit')) {
fastify.decorate('rateLimit', (options) => {
if (typeof options === 'object') {
const newPluginComponent = Object.create(pluginComponent)
const mergedRateLimitParams = mergeParams(globalParams, options, { routeInfo: {} })
newPluginComponent.store = newPluginComponent.store.child(mergedRateLimitParams)
return rateLimitRequestHandler(newPluginComponent, mergedRateLimitParams)
}

return rateLimitRequestHandler(pluginComponent, globalParams)
const args = createLimiterArgs(pluginComponent, globalParams, options)
return rateLimitRequestHandler(...args)
})
}

Expand Down Expand Up @@ -189,6 +190,17 @@ function mergeParams (...params) {
return result
}

function createLimiterArgs (pluginComponent, globalParams, options) {
if (typeof options === 'object') {
const newPluginComponent = Object.create(pluginComponent)
const mergedRateLimitParams = mergeParams(globalParams, options, { routeInfo: {} })
newPluginComponent.store = newPluginComponent.store.child(mergedRateLimitParams)
return [newPluginComponent, mergedRateLimitParams]
}

return [pluginComponent, globalParams]
}

function addRouteRateHook (pluginComponent, params, routeOptions) {
const hook = params.hook
const hookHandler = rateLimitRequestHandler(pluginComponent, params)
Expand All @@ -201,8 +213,72 @@ function addRouteRateHook (pluginComponent, params, routeOptions) {
}
}

async function applyRateLimit (pluginComponent, params, req) {
const { store } = pluginComponent

// Retrieve the key from the generator (the global one or the one defined in the endpoint)
let key = await params.keyGenerator(req)
const groupId = req.routeOptions.config?.rateLimit?.groupId

if (groupId) {
key += groupId
}

// Don't apply any rate limiting if in the allow list
if (params.allowList) {
if (typeof params.allowList === 'function') {
if (await params.allowList(req, key)) {
return {
isAllowed: true,
key
}
}
} else if (params.allowList.indexOf(key) !== -1) {
return {
isAllowed: true,
key
}
}
}

const max = typeof params.max === 'number' ? params.max : await params.max(req, key)
const timeWindow = typeof params.timeWindow === 'number' ? params.timeWindow : await params.timeWindow(req, key)
let current = 0
let ttl = 0
let ttlInSeconds = 0

// We increment the rate limit for the current request
try {
const res = await new Promise((resolve, reject) => {
store.incr(key, (err, res) => {
err ? reject(err) : resolve(res)
}, timeWindow, max)
})

current = res.current
ttl = res.ttl
ttlInSeconds = Math.ceil(res.ttl / 1000)
} catch (err) {
if (!params.skipOnError) {
throw err
}
}

return {
isAllowed: false,
key,
max,
timeWindow,
remaining: Math.max(0, max - current),
ttl,
ttlInSeconds,
isExceeded: current > max,
isBanned: params.ban !== -1 && current - max > params.ban
}
}

function rateLimitRequestHandler (pluginComponent, params) {
const { rateLimitRan, store } = pluginComponent
const { rateLimitRan } = pluginComponent

let timeWindowString
if (typeof params.timeWindow === 'number') {
Expand All @@ -216,51 +292,25 @@ function rateLimitRequestHandler (pluginComponent, params) {

req[rateLimitRan] = true

// Retrieve the key from the generator (the global one or the one defined in the endpoint)
let key = await params.keyGenerator(req)
const groupId = req.routeOptions.config?.rateLimit?.groupId

if (groupId) {
key += groupId
}

// Don't apply any rate limiting if in the allow list
if (params.allowList) {
if (typeof params.allowList === 'function') {
if (await params.allowList(req, key)) {
return
}
} else if (params.allowList.indexOf(key) !== -1) {
return
}
const rateLimit = await applyRateLimit(pluginComponent, params, req)
if (rateLimit.isAllowed) {
return
}

const max = typeof params.max === 'number' ? params.max : await params.max(req, key)
const timeWindow = typeof params.timeWindow === 'number' ? params.timeWindow : await params.timeWindow(req, key)
let current = 0
let ttl = 0
let ttlInSeconds = 0

// We increment the rate limit for the current request
try {
const res = await new Promise((resolve, reject) => {
store.incr(key, (err, res) => {
err ? reject(err) : resolve(res)
}, timeWindow, max)
})

current = res.current
ttl = res.ttl
ttlInSeconds = Math.ceil(res.ttl / 1000)
} catch (err) {
if (!params.skipOnError) {
throw err
}
}
const {
key,
max,
timeWindow,
remaining,
ttl,
ttlInSeconds,
isExceeded,
isBanned
} = rateLimit

if (current <= max) {
if (!isExceeded) {
if (params.addHeadersOnExceeding[params.labels.rateLimit]) { res.header(params.labels.rateLimit, max) }
if (params.addHeadersOnExceeding[params.labels.rateRemaining]) { res.header(params.labels.rateRemaining, max - current) }
if (params.addHeadersOnExceeding[params.labels.rateRemaining]) { res.header(params.labels.rateRemaining, remaining) }
if (params.addHeadersOnExceeding[params.labels.rateReset]) { res.header(params.labels.rateReset, ttlInSeconds) }

params.onExceeding(req, key)
Expand All @@ -283,7 +333,7 @@ function rateLimitRequestHandler (pluginComponent, params) {
after: timeWindowString ?? ms.format(timeWindow, true)
}

if (params.ban !== -1 && current - max > params.ban) {
if (isBanned) {
respCtx.statusCode = 403
respCtx.ban = true
params.onBanReach(req, key)
Expand Down
51 changes: 36 additions & 15 deletions types/index.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,24 @@ import {

declare module 'fastify' {
interface FastifyInstance<RawServer, RawRequest, RawReply, Logger, TypeProvider> {
createRateLimit(options?: fastifyRateLimit.CreateRateLimitOptions): (req: FastifyRequest) => Promise<
| {
isAllowed: true
key: string
}
| {
isAllowed: false
key: string
max: number
timeWindow: number
remaining: number
ttl: number
ttlInSeconds: number
isExceeded: boolean
isBanned: boolean
}
>

rateLimit<
RouteGeneric extends RouteGenericInterface = RouteGenericInterface,
ContextConfig = ContextConfigDefault,
Expand Down Expand Up @@ -89,13 +107,9 @@ declare namespace fastifyRateLimit {
'ratelimit-reset'?: boolean;
}

export type RateLimitHook =
| 'onRequest'
| 'preParsing'
| 'preValidation'
| 'preHandler'

export interface RateLimitOptions {
export interface CreateRateLimitOptions {
store?: FastifyRateLimitStoreCtor;
skipOnError?: boolean;
max?:
| number
| ((req: FastifyRequest, key: string) => number)
Expand All @@ -105,19 +119,26 @@ declare namespace fastifyRateLimit {
| string
| ((req: FastifyRequest, key: string) => number)
| ((req: FastifyRequest, key: string) => Promise<number>);
hook?: RateLimitHook;
cache?: number;
store?: FastifyRateLimitStoreCtor;
/**
* @deprecated Use `allowList` property
*/
* @deprecated Use `allowList` property
*/
whitelist?: string[] | ((req: FastifyRequest, key: string) => boolean);
allowList?: string[] | ((req: FastifyRequest, key: string) => boolean | Promise<boolean>);
continueExceeding?: boolean;
skipOnError?: boolean;
keyGenerator?: (req: FastifyRequest) => string | number | Promise<string | number>;
ban?: number;
}

export type RateLimitHook =
| 'onRequest'
| 'preParsing'
| 'preValidation'
| 'preHandler'

export interface RateLimitOptions extends CreateRateLimitOptions {
hook?: RateLimitHook;
cache?: number;
continueExceeding?: boolean;
onBanReach?: (req: FastifyRequest, key: string) => void;
keyGenerator?: (req: FastifyRequest) => string | number | Promise<string | number>;
groupId?: string;
errorResponseBuilder?: (
req: FastifyRequest,
Expand Down

0 comments on commit e22b644

Please sign in to comment.