diff --git a/lib/src/openApiToZod.ts b/lib/src/openApiToZod.ts index ae50e0d..5e00538 100644 --- a/lib/src/openApiToZod.ts +++ b/lib/src/openApiToZod.ts @@ -3,10 +3,10 @@ import { match } from "ts-pattern"; import type { CodeMetaData, ConversionTypeContext } from "./CodeMeta"; import { CodeMeta } from "./CodeMeta"; +import { inferRequiredSchema } from "./inferRequiredOnly"; import { isReferenceObject } from "./isReferenceObject"; import type { TemplateContext } from "./template-context"; import { escapeControlCharacters, isPrimitiveType, wrapWithQuotesIfNeeded } from "./utils"; -import { inferRequiredSchema } from "./inferRequiredOnly"; type ConversionArgs = { schema: SchemaObject | ReferenceObject; @@ -19,6 +19,7 @@ type ConversionArgs = { * @see https://github.com/OAI/OpenAPI-Specification/blob/main/versions/3.0.3.md#schemaObject * @see https://github.com/colinhacks/zod */ +let circularRef = ""; // eslint-disable-next-line sonarjs/cognitive-complexity export function getZodSchema({ schema: $schema, ctx, meta: inheritedMeta, options }: ConversionArgs): CodeMeta { if (!$schema) { @@ -43,6 +44,8 @@ export function getZodSchema({ schema: $schema, ctx, meta: inheritedMeta, option // circular(=recursive) reference if (refsPath.length > 1 && refsPath.includes(schemaName)) { + circularRef = code.ref!; + return code.assign(ctx.zodSchemaByName[code.ref!]!); } @@ -93,10 +96,17 @@ export function getZodSchema({ schema: $schema, ctx, meta: inheritedMeta, option if (schema.discriminator && !hasMultipleAllOf) { const propertyName = schema.discriminator.propertyName; + const discriminatedUnionOptions = schema.oneOf.map((prop) => + getZodSchema({ schema: prop, ctx, meta, options }) + ); + const isCircularDiscriminatedUnion = discriminatedUnionOptions.some((t) => t.ref === circularRef); + + if (isCircularDiscriminatedUnion) { + return code.assign(`z.union([${discriminatedUnionOptions.join(", ")}])`); + } + return code.assign(` - z.discriminatedUnion("${propertyName}", [${schema.oneOf - .map((prop) => getZodSchema({ schema: prop, ctx, meta, options })) - .join(", ")}]) + z.discriminatedUnion("${propertyName}", [${discriminatedUnionOptions.join(", ")}]) `); } @@ -138,6 +148,7 @@ export function getZodSchema({ schema: $schema, ctx, meta: inheritedMeta, option const type = getZodSchema({ schema: schema.allOf[0]!, ctx, meta, options }); return code.assign(type.toString()); } + const { patchRequiredSchemaInLoop, noRequiredOnlyAllof, composedRequiredSchema } = inferRequiredSchema(schema); const types = noRequiredOnlyAllof.map((prop) => { @@ -213,15 +224,11 @@ export function getZodSchema({ schema: $schema, ctx, meta: inheritedMeta, option if (schemaType === "array") { if (schema.items) { return code.assign( - `z.array(${ - getZodSchema({ schema: schema.items, ctx, meta, options }).toString() - }${ - getZodChain({ - schema: schema.items as SchemaObject, - meta: { ...meta, isRequired: true }, - options, - }) - })${readonly}` + `z.array(${getZodSchema({ schema: schema.items, ctx, meta, options }).toString()}${getZodChain({ + schema: schema.items as SchemaObject, + meta: { ...meta, isRequired: true }, + options, + })})${readonly}` ); } diff --git a/lib/src/template-context.ts b/lib/src/template-context.ts index 48a61ea..95f2d52 100644 --- a/lib/src/template-context.ts +++ b/lib/src/template-context.ts @@ -3,6 +3,7 @@ import { sortBy, sortListFromRefArray, sortObjKeysFromArray } from "pastable/ser import { ts } from "tanu"; import { match } from "ts-pattern"; +import type { CodeMetaData } from "./CodeMeta"; import { getOpenApiDependencyGraph } from "./getOpenApiDependencyGraph"; import type { EndpointDefinitionWithRefs } from "./getZodiosEndpointDefinitionList"; import { getZodiosEndpointDefinitionList } from "./getZodiosEndpointDefinitionList"; @@ -11,7 +12,6 @@ import { getTypescriptFromOpenApi } from "./openApiToTypescript"; import { getZodSchema } from "./openApiToZod"; import { topologicalSort } from "./topologicalSort"; import { asComponentSchema, normalizeString } from "./utils"; -import type { CodeMetaData } from "./CodeMeta"; const file = ts.createSourceFile("", "", ts.ScriptTarget.ESNext, true); const printer = ts.createPrinter({ newLine: ts.NewLineKind.LineFeed });