Skip to content

Commit

Permalink
feat: Allow custom client location
Browse files Browse the repository at this point in the history
This breaks the type & runtime dependency on a hardcoded
`@prisma/client` client output location, allowing
custom location (eg when having multiple clients).

This adds a layer of runtime safety to validate the DMMF,
which can now be passed in the configuration,
allowing all sorts of runtime hacks.

See #18 & #19.
  • Loading branch information
franky47 committed May 10, 2022
1 parent 27da08f commit 0860162
Show file tree
Hide file tree
Showing 15 changed files with 150 additions and 84 deletions.
15 changes: 15 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,21 @@ _Tip: the current encryption key is already part of the decryption keys, no need
Key rotation on existing fields (decrypt with old key and re-encrypt with the
new one) is done by [data migrations](#migrations).

## Custom Prisma Client Location

If you are generating your Prisma client to a custom location, you'll need to
tell the middleware where to look for the DMMF _(the internal AST generated by Prisma that we use to read those triple-slash comments)_:

```ts
import { Prisma } from '../my/prisma/client'
prismaClient.$use(
fieldEncryptionMiddleware({
dmmf: Prisma.dmmf
})
)
```

**Roadmap:**

- [x] Provide multiple decryption keys
Expand Down
8 changes: 4 additions & 4 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,14 @@
"postbuild": "chmod +x ./dist/generator/main.js && cd node_modules/.bin && ln -sf ../../dist/generator/main.js ./prisma-field-encryption",
"generate": "run-s generate:*",
"generate:prisma": "prisma generate",
"generate:dmmf": "ts-node ./src/scripts/generateDMMF.ts",
"test": "run-s test:**",
"test:types": "tsc --noEmit",
"test:unit": "jest --config jest.config.unit.json",
"pretest:integration": "cp -f ./prisma/db.test.sqlite ./prisma/db.integration.sqlite",
"test:integration": "jest --config jest.config.integration.json --runInBand",
"test:coverage:merge": "nyc merge ./coverage ./coverage/coverage-final.json",
"test:coverage:report": "nyc report -t ./coverage --r html -r lcov -r clover",
"ci": "run-s generate build test",
"ci": "run-s build test",
"prepare": "husky install",
"premigrate": "run-s build generate",
"migrate": "ts-node ./src/tests/migrate.ts"
Expand All @@ -50,14 +49,15 @@
"@47ng/cloak": "^1.1.0-beta.2",
"@prisma/generator-helper": "^3.13.0",
"immer": "^9.0.12",
"object-path": "^0.11.8"
"object-path": "^0.11.8",
"zod": "^3.15.1"
},
"peerDependencies": {
"@prisma/client": "^3.8.0"
},
"devDependencies": {
"@commitlint/config-conventional": "^16.2.4",
"@prisma/client": "^3.13.0",
"@prisma/client": "3.13.0",
"@prisma/sdk": "^3.13.0",
"@types/jest": "^27.4.1",
"@types/node": "^17.0.29",
Expand Down
1 change: 1 addition & 0 deletions prisma/schema.prisma
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ datasource db {
generator client {
provider = "prisma-client-js"
previewFeatures = ["interactiveTransactions"]
output = "../src/tests/.generated/client"
}

// generator fieldEncryptionMigrations {
Expand Down
12 changes: 3 additions & 9 deletions src/dmmf.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import { Prisma } from '@prisma/client'
import { errors, warnings } from './errors'
import type { DMMF, FieldConfiguration } from './types'
import { DMMFDocument, dmmfDocumentParser, FieldConfiguration } from './types'

export interface ConnectionDescriptor {
modelName: string
Expand All @@ -23,13 +22,8 @@ export type DMMFModels = Record<string, DMMFModelDescriptor> // key: model name

const supportedCursorTypes = ['Int', 'String']

export function analyseDMMF(dmmf: DMMF = Prisma.dmmf): DMMFModels {
// todo: Make it robust against changes in the DMMF structure
// (can happen as it's an undocumented API)
// - Prisma.dmmf does not exist
// - Models are not located there, or empty -> warning
// - Model objects don't conform to what we need (parse with zod)

export function analyseDMMF(input: DMMFDocument): DMMFModels {
const dmmf = dmmfDocumentParser.parse(input)
const allModels = dmmf.datamodel.models

return allModels.reduce<DMMFModels>((output, model) => {
Expand Down
65 changes: 34 additions & 31 deletions src/encryption.ts
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ const writeOperations = [

const whereClauseRegExp = /\.where\./

export function encryptOnWrite(
params: MiddlewareParams,
export function encryptOnWrite<Models extends string, Actions extends string>(
params: MiddlewareParams<Models, Actions>,
keys: KeysConfiguration,
models: DMMFModels,
operation: string
Expand All @@ -71,42 +71,45 @@ export function encryptOnWrite(

const encryptionErrors: string[] = []

const mutatedParams = produce(params, (draft: Draft<MiddlewareParams>) => {
visitInputTargetFields(
draft,
models,
function encryptFieldValue({
fieldConfig,
value: clearText,
path,
model,
field
}) {
if (!fieldConfig.encrypt) {
return
}
if (whereClauseRegExp.test(path)) {
console.warn(warnings.whereClause(operation, path))
}
try {
const cipherText = encryptStringSync(clearText, keys.encryptionKey)
objectPath.set(draft.args, path, cipherText)
} catch (error) {
encryptionErrors.push(
errors.fieldEncryptionError(model, field, path, error)
)
const mutatedParams = produce(
params,
(draft: Draft<MiddlewareParams<Models, Actions>>) => {
visitInputTargetFields(
draft,
models,
function encryptFieldValue({
fieldConfig,
value: clearText,
path,
model,
field
}) {
if (!fieldConfig.encrypt) {
return
}
if (whereClauseRegExp.test(path)) {
console.warn(warnings.whereClause(operation, path))
}
try {
const cipherText = encryptStringSync(clearText, keys.encryptionKey)
objectPath.set(draft.args, path, cipherText)
} catch (error) {
encryptionErrors.push(
errors.fieldEncryptionError(model, field, path, error)
)
}
}
}
)
})
)
}
)
if (encryptionErrors.length > 0) {
throw new Error(errors.encryptionErrorReport(operation, encryptionErrors))
}
return mutatedParams
}

export function decryptOnRead(
params: MiddlewareParams,
export function decryptOnRead<Models extends string, Actions extends string>(
params: MiddlewareParams<Models, Actions>,
result: any,
keys: KeysConfiguration,
models: DMMFModels,
Expand Down
4 changes: 2 additions & 2 deletions src/errors.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import type { Prisma } from '@prisma/client'
import type { DMMFField, DMMFModel } from './types'

const header = '[prisma-field-encryption]'

Expand All @@ -8,7 +8,7 @@ const prefixWarning = (input: string) => `${header} Warning: ${input}`
export const errors = {
// Setup errors
noEncryptionKey: prefixError('no encryption key provided.'),
unsupportedFieldType: (model: Prisma.DMMF.Model, field: Prisma.DMMF.Field) =>
unsupportedFieldType: (model: DMMFModel, field: DMMFField) =>
prefixError(
`encryption enabled for field ${model.name}.${field.name} of unsupported type ${field.type}: only String fields can be encrypted.`
),
Expand Down
11 changes: 5 additions & 6 deletions src/generator/runtime/visitRecords.ts
Original file line number Diff line number Diff line change
@@ -1,26 +1,25 @@
import type { PrismaClient } from '@prisma/client'
import { defaultProgressReport, ProgressReportCallback } from './progressReport'

export type RecordVisitor<Cursor> = (
export type RecordVisitor<PrismaClient, Cursor> = (
client: PrismaClient,
cursor: Cursor | undefined
) => Promise<Cursor | undefined>

export interface VisitRecordsArgs<Cursor> {
export interface VisitRecordsArgs<PrismaClient, Cursor> {
modelName: string
client: PrismaClient
getTotalCount: () => Promise<number>
migrateRecord: RecordVisitor<Cursor>
migrateRecord: RecordVisitor<PrismaClient, Cursor>
reportProgress?: ProgressReportCallback
}

export async function visitRecords<Cursor>({
export async function visitRecords<PrismaClient, Cursor>({
modelName,
client,
getTotalCount,
migrateRecord,
reportProgress = defaultProgressReport
}: VisitRecordsArgs<Cursor>) {
}: VisitRecordsArgs<PrismaClient, Cursor>) {
const totalCount = await getTotalCount()
if (totalCount === 0) {
return 0
Expand Down
15 changes: 9 additions & 6 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,20 @@ import { analyseDMMF } from './dmmf'
import { configureKeys, decryptOnRead, encryptOnWrite } from './encryption'
import type { Configuration, Middleware, MiddlewareParams } from './types'

export function fieldEncryptionMiddleware(
config: Configuration = {}
): Middleware {
export function fieldEncryptionMiddleware<
Models extends string = any,
Actions extends string = any
>(config: Configuration = {}): Middleware<Models, Actions> {
// This will throw if the encryption key is missing
// or if anything is invalid.
const keys = configureKeys(config)
const models = analyseDMMF()
const models = analyseDMMF(
config.dmmf ?? require('@prisma/client').Prisma.dmmf
)

return async function fieldEncryptionMiddleware(
params: MiddlewareParams,
next: (params: MiddlewareParams) => Promise<any>
params: MiddlewareParams<Models, Actions>,
next: (params: MiddlewareParams<Models, Actions>) => Promise<any>
) {
if (!params.model) {
// Unsupported operation
Expand Down
12 changes: 0 additions & 12 deletions src/scripts/generateDMMF.ts

This file was deleted.

1 change: 1 addition & 0 deletions src/tests/.generated/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
client/
5 changes: 3 additions & 2 deletions src/tests/prismaClient.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { PrismaClient } from '@prisma/client'
import { fieldEncryptionMiddleware } from '../index'
import { Prisma, PrismaClient } from './.generated/client'

export const TEST_ENCRYPTION_KEY =
'k1.aesgcm256.OsqVmAOZBB_WW3073q1wU4ag0ap0ETYAYMh041RuxuI='
Expand Down Expand Up @@ -33,7 +33,8 @@ client.$use(async (params, next) => {

client.$use(
fieldEncryptionMiddleware({
encryptionKey: TEST_ENCRYPTION_KEY
encryptionKey: TEST_ENCRYPTION_KEY,
dmmf: Prisma.dmmf
})
)

Expand Down
60 changes: 55 additions & 5 deletions src/types.ts
Original file line number Diff line number Diff line change
@@ -1,16 +1,66 @@
import { Prisma } from '@prisma/client'
/**
* Prisma types --
*
* We're copying just what we need for local type safety
* without importing Prisma-generated types, as the location
* of the generated client can be unknown (when using custom
* or multiple client locations).
*/

// Prisma types --
import { z } from 'zod'

export type MiddlewareParams = Prisma.MiddlewareParams
export type Middleware = Prisma.Middleware
export type DMMF = typeof Prisma.dmmf
/**
* Not ideal to use `any` on model & action, but Prisma's
* strong typing there actually prevents using the correct
* type without excessive generics wizardry.
*/
export type MiddlewareParams<Models extends string, Actions extends string> = {
model?: Models
action: Actions
args: any
dataPath: string[]
runInTransaction: boolean
}

export type Middleware<
Models extends string,
Actions extends string,
Result = any
> = (
params: MiddlewareParams<Models, Actions>,
next: (params: MiddlewareParams<Models, Actions>) => Promise<Result>
) => Promise<Result>

const dmmfFieldParser = z.object({
name: z.string(),
isList: z.boolean(),
isUnique: z.boolean(),
isId: z.boolean(),
type: z.string(),
documentation: z.string().optional()
})

const dmmfModelParser = z.object({
name: z.string(),
fields: z.array(dmmfFieldParser)
})

export const dmmfDocumentParser = z.object({
datamodel: z.object({
models: z.array(dmmfModelParser)
})
})

export type DMMFModel = z.TypeOf<typeof dmmfModelParser>
export type DMMFField = z.TypeOf<typeof dmmfFieldParser>
export type DMMFDocument = z.TypeOf<typeof dmmfDocumentParser>

// Internal types --

export interface Configuration {
encryptionKey?: string
decryptionKeys?: string[]
dmmf?: DMMFDocument
}

export interface FieldConfiguration {
Expand Down
4 changes: 2 additions & 2 deletions src/visitor.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ describe('visitor', () => {

test('visitInputTargetFields - simple example', async () => {
const models = analyseDMMF(await dmmf)
const params: MiddlewareParams = {
const params: MiddlewareParams<any, any> = {
action: 'create',
model: 'User',
args: {
Expand All @@ -63,7 +63,7 @@ describe('visitor', () => {

test('visitInputTargetFields - nested create', async () => {
const models = analyseDMMF(await dmmf)
const params: MiddlewareParams = {
const params: MiddlewareParams<any, any> = {
action: 'create',
model: 'User',
args: {
Expand Down
14 changes: 10 additions & 4 deletions src/visitor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,11 @@ const makeVisitor = (models: DMMFModels, visitor: TargetFieldVisitorFn) =>
return state
}

export function visitInputTargetFields(
params: MiddlewareParams,
export function visitInputTargetFields<
Models extends string,
Actions extends string
>(
params: MiddlewareParams<Models, Actions>,
models: DMMFModels,
visitor: TargetFieldVisitorFn
) {
Expand All @@ -70,8 +73,11 @@ export function visitInputTargetFields(
})
}

export function visitOutputTargetFields(
params: MiddlewareParams,
export function visitOutputTargetFields<
Models extends string,
Actions extends string
>(
params: MiddlewareParams<Models, Actions>,
result: any,
models: DMMFModels,
visitor: TargetFieldVisitorFn
Expand Down
Loading

0 comments on commit 0860162

Please sign in to comment.