diff --git a/README.md b/README.md index f2fc37d..85baa72 100644 --- a/README.md +++ b/README.md @@ -247,6 +247,21 @@ _Tip: the current encryption key is already part of the decryption keys, no need Key rotation on existing fields (decrypt with old key and re-encrypt with the new one) is done by [data migrations](#migrations). +## Custom Prisma Client Location + +If you are generating your Prisma client to a custom location, you'll need to +tell the middleware where to look for the DMMF _(the internal AST generated by Prisma that we use to read those triple-slash comments)_: + +```ts +import { Prisma } from '../my/prisma/client' + +prismaClient.$use( + fieldEncryptionMiddleware({ + dmmf: Prisma.dmmf + }) +) +``` + **Roadmap:** - [x] Provide multiple decryption keys diff --git a/package.json b/package.json index f80962b..b1abfda 100644 --- a/package.json +++ b/package.json @@ -33,7 +33,6 @@ "postbuild": "chmod +x ./dist/generator/main.js && cd node_modules/.bin && ln -sf ../../dist/generator/main.js ./prisma-field-encryption", "generate": "run-s generate:*", "generate:prisma": "prisma generate", - "generate:dmmf": "ts-node ./src/scripts/generateDMMF.ts", "test": "run-s test:**", "test:types": "tsc --noEmit", "test:unit": "jest --config jest.config.unit.json", @@ -41,7 +40,7 @@ "test:integration": "jest --config jest.config.integration.json --runInBand", "test:coverage:merge": "nyc merge ./coverage ./coverage/coverage-final.json", "test:coverage:report": "nyc report -t ./coverage --r html -r lcov -r clover", - "ci": "run-s generate build test", + "ci": "run-s build test", "prepare": "husky install", "premigrate": "run-s build generate", "migrate": "ts-node ./src/tests/migrate.ts" @@ -50,14 +49,15 @@ "@47ng/cloak": "^1.1.0-beta.2", "@prisma/generator-helper": "^3.13.0", "immer": "^9.0.12", - "object-path": "^0.11.8" + "object-path": "^0.11.8", + "zod": "^3.15.1" }, "peerDependencies": { "@prisma/client": "^3.8.0" }, "devDependencies": { "@commitlint/config-conventional": "^16.2.4", - "@prisma/client": "^3.13.0", + "@prisma/client": "3.13.0", "@prisma/sdk": "^3.13.0", "@types/jest": "^27.4.1", "@types/node": "^17.0.29", diff --git a/prisma/schema.prisma b/prisma/schema.prisma index 1144f70..3b4689f 100644 --- a/prisma/schema.prisma +++ b/prisma/schema.prisma @@ -6,6 +6,7 @@ datasource db { generator client { provider = "prisma-client-js" previewFeatures = ["interactiveTransactions"] + output = "../src/tests/.generated/client" } // generator fieldEncryptionMigrations { diff --git a/src/dmmf.ts b/src/dmmf.ts index 205ee6f..252a149 100644 --- a/src/dmmf.ts +++ b/src/dmmf.ts @@ -1,6 +1,5 @@ -import { Prisma } from '@prisma/client' import { errors, warnings } from './errors' -import type { DMMF, FieldConfiguration } from './types' +import { DMMFDocument, dmmfDocumentParser, FieldConfiguration } from './types' export interface ConnectionDescriptor { modelName: string @@ -23,13 +22,8 @@ export type DMMFModels = Record // key: model name const supportedCursorTypes = ['Int', 'String'] -export function analyseDMMF(dmmf: DMMF = Prisma.dmmf): DMMFModels { - // todo: Make it robust against changes in the DMMF structure - // (can happen as it's an undocumented API) - // - Prisma.dmmf does not exist - // - Models are not located there, or empty -> warning - // - Model objects don't conform to what we need (parse with zod) - +export function analyseDMMF(input: DMMFDocument): DMMFModels { + const dmmf = dmmfDocumentParser.parse(input) const allModels = dmmf.datamodel.models return allModels.reduce((output, model) => { diff --git a/src/encryption.ts b/src/encryption.ts index dd25205..76dd2ff 100644 --- a/src/encryption.ts +++ b/src/encryption.ts @@ -59,8 +59,8 @@ const writeOperations = [ const whereClauseRegExp = /\.where\./ -export function encryptOnWrite( - params: MiddlewareParams, +export function encryptOnWrite( + params: MiddlewareParams, keys: KeysConfiguration, models: DMMFModels, operation: string @@ -71,42 +71,45 @@ export function encryptOnWrite( const encryptionErrors: string[] = [] - const mutatedParams = produce(params, (draft: Draft) => { - visitInputTargetFields( - draft, - models, - function encryptFieldValue({ - fieldConfig, - value: clearText, - path, - model, - field - }) { - if (!fieldConfig.encrypt) { - return - } - if (whereClauseRegExp.test(path)) { - console.warn(warnings.whereClause(operation, path)) - } - try { - const cipherText = encryptStringSync(clearText, keys.encryptionKey) - objectPath.set(draft.args, path, cipherText) - } catch (error) { - encryptionErrors.push( - errors.fieldEncryptionError(model, field, path, error) - ) + const mutatedParams = produce( + params, + (draft: Draft>) => { + visitInputTargetFields( + draft, + models, + function encryptFieldValue({ + fieldConfig, + value: clearText, + path, + model, + field + }) { + if (!fieldConfig.encrypt) { + return + } + if (whereClauseRegExp.test(path)) { + console.warn(warnings.whereClause(operation, path)) + } + try { + const cipherText = encryptStringSync(clearText, keys.encryptionKey) + objectPath.set(draft.args, path, cipherText) + } catch (error) { + encryptionErrors.push( + errors.fieldEncryptionError(model, field, path, error) + ) + } } - } - ) - }) + ) + } + ) if (encryptionErrors.length > 0) { throw new Error(errors.encryptionErrorReport(operation, encryptionErrors)) } return mutatedParams } -export function decryptOnRead( - params: MiddlewareParams, +export function decryptOnRead( + params: MiddlewareParams, result: any, keys: KeysConfiguration, models: DMMFModels, diff --git a/src/errors.ts b/src/errors.ts index 0d53af0..14791b1 100644 --- a/src/errors.ts +++ b/src/errors.ts @@ -1,4 +1,4 @@ -import type { Prisma } from '@prisma/client' +import type { DMMFField, DMMFModel } from './types' const header = '[prisma-field-encryption]' @@ -8,7 +8,7 @@ const prefixWarning = (input: string) => `${header} Warning: ${input}` export const errors = { // Setup errors noEncryptionKey: prefixError('no encryption key provided.'), - unsupportedFieldType: (model: Prisma.DMMF.Model, field: Prisma.DMMF.Field) => + unsupportedFieldType: (model: DMMFModel, field: DMMFField) => prefixError( `encryption enabled for field ${model.name}.${field.name} of unsupported type ${field.type}: only String fields can be encrypted.` ), diff --git a/src/generator/runtime/visitRecords.ts b/src/generator/runtime/visitRecords.ts index 5c33fcc..1cee881 100644 --- a/src/generator/runtime/visitRecords.ts +++ b/src/generator/runtime/visitRecords.ts @@ -1,26 +1,25 @@ -import type { PrismaClient } from '@prisma/client' import { defaultProgressReport, ProgressReportCallback } from './progressReport' -export type RecordVisitor = ( +export type RecordVisitor = ( client: PrismaClient, cursor: Cursor | undefined ) => Promise -export interface VisitRecordsArgs { +export interface VisitRecordsArgs { modelName: string client: PrismaClient getTotalCount: () => Promise - migrateRecord: RecordVisitor + migrateRecord: RecordVisitor reportProgress?: ProgressReportCallback } -export async function visitRecords({ +export async function visitRecords({ modelName, client, getTotalCount, migrateRecord, reportProgress = defaultProgressReport -}: VisitRecordsArgs) { +}: VisitRecordsArgs) { const totalCount = await getTotalCount() if (totalCount === 0) { return 0 diff --git a/src/index.ts b/src/index.ts index f4bb9e5..445a462 100644 --- a/src/index.ts +++ b/src/index.ts @@ -2,17 +2,20 @@ import { analyseDMMF } from './dmmf' import { configureKeys, decryptOnRead, encryptOnWrite } from './encryption' import type { Configuration, Middleware, MiddlewareParams } from './types' -export function fieldEncryptionMiddleware( - config: Configuration = {} -): Middleware { +export function fieldEncryptionMiddleware< + Models extends string = any, + Actions extends string = any +>(config: Configuration = {}): Middleware { // This will throw if the encryption key is missing // or if anything is invalid. const keys = configureKeys(config) - const models = analyseDMMF() + const models = analyseDMMF( + config.dmmf ?? require('@prisma/client').Prisma.dmmf + ) return async function fieldEncryptionMiddleware( - params: MiddlewareParams, - next: (params: MiddlewareParams) => Promise + params: MiddlewareParams, + next: (params: MiddlewareParams) => Promise ) { if (!params.model) { // Unsupported operation diff --git a/src/scripts/generateDMMF.ts b/src/scripts/generateDMMF.ts deleted file mode 100644 index 1d5d125..0000000 --- a/src/scripts/generateDMMF.ts +++ /dev/null @@ -1,12 +0,0 @@ -import { Prisma } from '.prisma/client' -import fs from 'node:fs/promises' -import path from 'node:path' - -export async function generateDMMF() { - const outputPath = path.resolve(__dirname, '../../prisma/dmmf.json') - await fs.writeFile(outputPath, JSON.stringify(Prisma.dmmf, null, 2)) -} - -if (require.main === module) { - generateDMMF() -} diff --git a/src/tests/.generated/.gitignore b/src/tests/.generated/.gitignore new file mode 100644 index 0000000..684bec4 --- /dev/null +++ b/src/tests/.generated/.gitignore @@ -0,0 +1 @@ +client/ diff --git a/src/tests/prismaClient.ts b/src/tests/prismaClient.ts index d71f862..56e83cd 100644 --- a/src/tests/prismaClient.ts +++ b/src/tests/prismaClient.ts @@ -1,5 +1,5 @@ -import { PrismaClient } from '@prisma/client' import { fieldEncryptionMiddleware } from '../index' +import { Prisma, PrismaClient } from './.generated/client' export const TEST_ENCRYPTION_KEY = 'k1.aesgcm256.OsqVmAOZBB_WW3073q1wU4ag0ap0ETYAYMh041RuxuI=' @@ -33,7 +33,8 @@ client.$use(async (params, next) => { client.$use( fieldEncryptionMiddleware({ - encryptionKey: TEST_ENCRYPTION_KEY + encryptionKey: TEST_ENCRYPTION_KEY, + dmmf: Prisma.dmmf }) ) diff --git a/src/types.ts b/src/types.ts index b6a641d..d83ea2d 100644 --- a/src/types.ts +++ b/src/types.ts @@ -1,16 +1,66 @@ -import { Prisma } from '@prisma/client' +/** + * Prisma types -- + * + * We're copying just what we need for local type safety + * without importing Prisma-generated types, as the location + * of the generated client can be unknown (when using custom + * or multiple client locations). + */ -// Prisma types -- +import { z } from 'zod' -export type MiddlewareParams = Prisma.MiddlewareParams -export type Middleware = Prisma.Middleware -export type DMMF = typeof Prisma.dmmf +/** + * Not ideal to use `any` on model & action, but Prisma's + * strong typing there actually prevents using the correct + * type without excessive generics wizardry. + */ +export type MiddlewareParams = { + model?: Models + action: Actions + args: any + dataPath: string[] + runInTransaction: boolean +} + +export type Middleware< + Models extends string, + Actions extends string, + Result = any +> = ( + params: MiddlewareParams, + next: (params: MiddlewareParams) => Promise +) => Promise + +const dmmfFieldParser = z.object({ + name: z.string(), + isList: z.boolean(), + isUnique: z.boolean(), + isId: z.boolean(), + type: z.string(), + documentation: z.string().optional() +}) + +const dmmfModelParser = z.object({ + name: z.string(), + fields: z.array(dmmfFieldParser) +}) + +export const dmmfDocumentParser = z.object({ + datamodel: z.object({ + models: z.array(dmmfModelParser) + }) +}) + +export type DMMFModel = z.TypeOf +export type DMMFField = z.TypeOf +export type DMMFDocument = z.TypeOf // Internal types -- export interface Configuration { encryptionKey?: string decryptionKeys?: string[] + dmmf?: DMMFDocument } export interface FieldConfiguration { diff --git a/src/visitor.test.ts b/src/visitor.test.ts index 1ebdc25..24faa81 100644 --- a/src/visitor.test.ts +++ b/src/visitor.test.ts @@ -36,7 +36,7 @@ describe('visitor', () => { test('visitInputTargetFields - simple example', async () => { const models = analyseDMMF(await dmmf) - const params: MiddlewareParams = { + const params: MiddlewareParams = { action: 'create', model: 'User', args: { @@ -63,7 +63,7 @@ describe('visitor', () => { test('visitInputTargetFields - nested create', async () => { const models = analyseDMMF(await dmmf) - const params: MiddlewareParams = { + const params: MiddlewareParams = { action: 'create', model: 'User', args: { diff --git a/src/visitor.ts b/src/visitor.ts index 7432b4c..1dd9281 100644 --- a/src/visitor.ts +++ b/src/visitor.ts @@ -60,8 +60,11 @@ const makeVisitor = (models: DMMFModels, visitor: TargetFieldVisitorFn) => return state } -export function visitInputTargetFields( - params: MiddlewareParams, +export function visitInputTargetFields< + Models extends string, + Actions extends string +>( + params: MiddlewareParams, models: DMMFModels, visitor: TargetFieldVisitorFn ) { @@ -70,8 +73,11 @@ export function visitInputTargetFields( }) } -export function visitOutputTargetFields( - params: MiddlewareParams, +export function visitOutputTargetFields< + Models extends string, + Actions extends string +>( + params: MiddlewareParams, result: any, models: DMMFModels, visitor: TargetFieldVisitorFn diff --git a/yarn.lock b/yarn.lock index 6e33364..8172064 100644 --- a/yarn.lock +++ b/yarn.lock @@ -745,7 +745,7 @@ mkdirp "^1.0.4" rimraf "^3.0.2" -"@prisma/client@^3.13.0": +"@prisma/client@3.13.0": version "3.13.0" resolved "https://registry.yarnpkg.com/@prisma/client/-/client-3.13.0.tgz#84511ebdf6ba75f77ca08495b9f73f22c4255654" integrity sha512-lnEA2tTyVbO5mS1ehmHJQKBDiKB8shaR6s3azwj3Azfi5XHIfnqmkolLCvUeFYnkDCNVzGXJpUgKwQt/UOOYVQ== @@ -5406,3 +5406,8 @@ zip-stream@^4.1.0: archiver-utils "^2.1.0" compress-commons "^4.1.0" readable-stream "^3.6.0" + +zod@^3.15.1: + version "3.15.1" + resolved "https://registry.yarnpkg.com/zod/-/zod-3.15.1.tgz#9e404cd8002ccffb03baa94cff2e1638ed49d82f" + integrity sha512-WAdjcoOxa4S9oc/u7fTbC3CC7uVqptLLU0LKqS8RDBOrCXp2t5avM8BUfgNVZJymGWAx6SEUYxWPPoYuQ5rgwQ==