Skip to content

Commit

Permalink
Add array type checks
Browse files Browse the repository at this point in the history
With various fixes

Closes npgsql#5137
  • Loading branch information
roji committed Sep 14, 2023
1 parent 349123e commit 4e854c0
Show file tree
Hide file tree
Showing 12 changed files with 287 additions and 79 deletions.
10 changes: 8 additions & 2 deletions src/Npgsql/Internal/Resolvers/AdoTypeInfoResolver.cs
Original file line number Diff line number Diff line change
Expand Up @@ -401,15 +401,21 @@ protected static void AddArrayInfos(TypeInfoMappingCollection mappings)
if (Statics.LegacyTimestampBehavior)
{
mappings.AddStructArrayType<DateTime>(DataTypeNames.TimestampTz);
mappings.AddStructArrayType<DateTimeOffset>(DataTypeNames.TimestampTz);
}
else
{
mappings.AddResolverStructArrayType<DateTime>(DataTypeNames.TimestampTz);
mappings.AddResolverStructArrayType<DateTimeOffset>(DataTypeNames.TimestampTz);
}
mappings.AddStructArrayType<DateTimeOffset>(DataTypeNames.TimestampTz);
mappings.AddStructArrayType<long>(DataTypeNames.TimestampTz);

// Date
mappings.AddStructArrayType<DateTime>(DataTypeNames.Date);
mappings.AddStructArrayType<int>(DataTypeNames.Date);
#if NET6_0_OR_GREATER
mappings.AddStructArrayType<DateOnly>(DataTypeNames.Date);
#endif

// Time
mappings.AddStructArrayType<TimeSpan>(DataTypeNames.Time);
mappings.AddStructArrayType<long>(DataTypeNames.Time);
Expand Down
9 changes: 9 additions & 0 deletions src/Npgsql/Internal/Resolvers/ExtraConversionsResolver.cs
Original file line number Diff line number Diff line change
Expand Up @@ -148,18 +148,27 @@ protected static void AddArrayInfos(TypeInfoMappingCollection mappings)
mappings.AddStructArrayType<long>(DataTypeNames.Int2);
mappings.AddStructArrayType<byte>(DataTypeNames.Int2);
mappings.AddStructArrayType<sbyte>(DataTypeNames.Int2);
mappings.AddStructArrayType<float>(DataTypeNames.Int2);
mappings.AddStructArrayType<double>(DataTypeNames.Int2);
mappings.AddStructArrayType<decimal>(DataTypeNames.Int2);

// Int4
mappings.AddStructArrayType<short>(DataTypeNames.Int4);
mappings.AddStructArrayType<long>(DataTypeNames.Int4);
mappings.AddStructArrayType<byte>(DataTypeNames.Int4);
mappings.AddStructArrayType<sbyte>(DataTypeNames.Int4);
mappings.AddStructArrayType<float>(DataTypeNames.Int4);
mappings.AddStructArrayType<double>(DataTypeNames.Int4);
mappings.AddStructArrayType<decimal>(DataTypeNames.Int4);

// Int8
mappings.AddStructArrayType<short>(DataTypeNames.Int8);
mappings.AddStructArrayType<int>(DataTypeNames.Int8);
mappings.AddStructArrayType<byte>(DataTypeNames.Int8);
mappings.AddStructArrayType<sbyte>(DataTypeNames.Int8);
mappings.AddStructArrayType<float>(DataTypeNames.Int8);
mappings.AddStructArrayType<double>(DataTypeNames.Int8);
mappings.AddStructArrayType<decimal>(DataTypeNames.Int8);

// Float4
mappings.AddStructArrayType<double>(DataTypeNames.Float4);
Expand Down
6 changes: 3 additions & 3 deletions src/Npgsql/Internal/Resolvers/LTreeTypeInfoResolver.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,13 @@ static void AddInfos(TypeInfoMappingCollection mappings)
{
mappings.AddType<string>("ltree",
static (options, mapping, _) => mapping.CreateInfo(options, new VersionPrefixedTextConverter<string>(LTreeVersion, new StringTextConverter(options.TextEncoding))),
isDefault: true);
MatchRequirement.DataTypeName);
mappings.AddType<string>("lquery",
static (options, mapping, _) => mapping.CreateInfo(options, new VersionPrefixedTextConverter<string>(LTreeVersion, new StringTextConverter(options.TextEncoding))),
isDefault: true);
MatchRequirement.DataTypeName);
mappings.AddType<string>("ltxtquery",
static (options, mapping, _) => mapping.CreateInfo(options, new VersionPrefixedTextConverter<string>(LTreeVersion, new StringTextConverter(options.TextEncoding))),
isDefault: true);
MatchRequirement.DataTypeName);
}

static void AddArrayInfos(TypeInfoMappingCollection mappings)
Expand Down
32 changes: 16 additions & 16 deletions src/Npgsql/Internal/Resolvers/RangeTypeInfoResolver.cs
Original file line number Diff line number Diff line change
Expand Up @@ -257,13 +257,6 @@ protected static void AddArrayInfos(TypeInfoMappingCollection mappings, bool sup
mappings.AddStructArrayType<NpgsqlRange<decimal>>(DataTypeNames.NumRange);
mappings.AddStructArrayType<NpgsqlRange<BigInteger>>(DataTypeNames.NumRange);

// daterange
mappings.AddStructArrayType<NpgsqlRange<DateTime>>(DataTypeNames.DateRange);
mappings.AddStructArrayType<NpgsqlRange<int>>(DataTypeNames.DateRange);
#if NET6_0_OR_GREATER
mappings.AddStructArrayType<NpgsqlRange<DateOnly>>(DataTypeNames.DateRange);
#endif

// tsrange
if (Statics.LegacyTimestampBehavior)
mappings.AddStructArrayType<NpgsqlRange<DateTime>>(DataTypeNames.TsRange);
Expand All @@ -284,6 +277,13 @@ protected static void AddArrayInfos(TypeInfoMappingCollection mappings, bool sup
}
mappings.AddStructArrayType<NpgsqlRange<long>>(DataTypeNames.TsTzRange);

// daterange
mappings.AddStructArrayType<NpgsqlRange<DateTime>>(DataTypeNames.DateRange);
mappings.AddStructArrayType<NpgsqlRange<int>>(DataTypeNames.DateRange);
#if NET6_0_OR_GREATER
mappings.AddStructArrayType<NpgsqlRange<DateOnly>>(DataTypeNames.DateRange);
#endif

if (supportsMultiRange)
{
// int4multirange
Expand All @@ -298,14 +298,6 @@ protected static void AddArrayInfos(TypeInfoMappingCollection mappings, bool sup
mappings.AddArrayType<NpgsqlRange<decimal>[]>(DataTypeNames.NumMultirange);
mappings.AddArrayType<List<NpgsqlRange<decimal>>>(DataTypeNames.NumMultirange);

// datemultirange
mappings.AddArrayType<NpgsqlRange<DateTime>[]>(DataTypeNames.DateMultirange);
mappings.AddArrayType<List<NpgsqlRange<DateTime>>>(DataTypeNames.DateMultirange);
#if NET6_0_OR_GREATER
mappings.AddArrayType<NpgsqlRange<DateOnly>[]>(DataTypeNames.DateMultirange);
mappings.AddArrayType<List<NpgsqlRange<DateOnly>>>(DataTypeNames.DateMultirange);
#endif

// tsmultirange
if (Statics.LegacyTimestampBehavior)
{
Expand All @@ -325,7 +317,7 @@ protected static void AddArrayInfos(TypeInfoMappingCollection mappings, bool sup
{
mappings.AddArrayType<NpgsqlRange<DateTime>[]>(DataTypeNames.TsTzMultirange);
mappings.AddArrayType<List<NpgsqlRange<DateTime>>>(DataTypeNames.TsTzMultirange);
mappings.AddArrayType<NpgsqlRange<DateTime>[]>(DataTypeNames.TsTzMultirange);
mappings.AddArrayType<NpgsqlRange<DateTimeOffset>[]>(DataTypeNames.TsTzMultirange);
mappings.AddArrayType<List<NpgsqlRange<DateTimeOffset>>>(DataTypeNames.TsTzMultirange);
}
else
Expand All @@ -337,6 +329,14 @@ protected static void AddArrayInfos(TypeInfoMappingCollection mappings, bool sup
}
mappings.AddArrayType<NpgsqlRange<long>[]>(DataTypeNames.TsTzMultirange);
mappings.AddArrayType<List<NpgsqlRange<long>>>(DataTypeNames.TsTzMultirange);

// datemultirange
mappings.AddArrayType<NpgsqlRange<DateTime>[]>(DataTypeNames.DateMultirange);
mappings.AddArrayType<List<NpgsqlRange<DateTime>>>(DataTypeNames.DateMultirange);
#if NET6_0_OR_GREATER
mappings.AddArrayType<NpgsqlRange<DateOnly>[]>(DataTypeNames.DateMultirange);
mappings.AddArrayType<List<NpgsqlRange<DateOnly>>>(DataTypeNames.DateMultirange);
#endif
}
}

Expand Down
18 changes: 12 additions & 6 deletions src/Npgsql/NpgsqlTypes/NpgsqlDbType.cs
Original file line number Diff line number Diff line change
Expand Up @@ -813,11 +813,14 @@ public static DbType ToDbType(this NpgsqlDbType npgsqlDbType)
NpgsqlDbType.Geometry => "geometry",
NpgsqlDbType.Geography => "geography",

// Unknown cannot be composed
NpgsqlDbType.Unknown => "unknown",
_ when npgsqlDbType.HasFlag(NpgsqlDbType.Unknown)
&& (npgsqlDbType.HasFlag(NpgsqlDbType.Array) || npgsqlDbType.HasFlag(NpgsqlDbType.Range) ||
npgsqlDbType.HasFlag(NpgsqlDbType.Multirange))

// Unknown cannot be composed
_ when npgsqlDbType.HasFlag(NpgsqlDbType.Array) && (npgsqlDbType & ~NpgsqlDbType.Array) == NpgsqlDbType.Unknown
=> "unknown",
_ when npgsqlDbType.HasFlag(NpgsqlDbType.Range) && (npgsqlDbType & ~NpgsqlDbType.Range) == NpgsqlDbType.Unknown
=> "unknown",
_ when npgsqlDbType.HasFlag(NpgsqlDbType.Multirange) && (npgsqlDbType & ~NpgsqlDbType.Multirange) == NpgsqlDbType.Unknown
=> "unknown",

_ => npgsqlDbType.HasFlag(NpgsqlDbType.Array)
Expand Down Expand Up @@ -913,9 +916,12 @@ internal static string ToUnqualifiedDataTypeNameOrThrow(this NpgsqlDbType npgsql
NpgsqlDbType.Unknown => DataTypeNames.Unknown,

// Unknown cannot be composed
_ when npgsqlDbType.HasFlag(NpgsqlDbType.Unknown)
&& (npgsqlDbType.HasFlag(NpgsqlDbType.Array) || npgsqlDbType.HasFlag(NpgsqlDbType.Range) || npgsqlDbType.HasFlag(NpgsqlDbType.Multirange))
_ when npgsqlDbType.HasFlag(NpgsqlDbType.Array) && (npgsqlDbType & ~NpgsqlDbType.Array) == NpgsqlDbType.Unknown
=> DataTypeNames.Unknown,
_ when npgsqlDbType.HasFlag(NpgsqlDbType.Range) && (npgsqlDbType & ~NpgsqlDbType.Range) == NpgsqlDbType.Unknown
=> DataTypeNames.Unknown,
_ when npgsqlDbType.HasFlag(NpgsqlDbType.Multirange) && (npgsqlDbType & ~NpgsqlDbType.Multirange) == NpgsqlDbType.Unknown
=> DataTypeNames.Unknown,

// If both multirange and array are set we first remove array, so array is added to the outermost datatypename.
_ when npgsqlDbType.HasFlag(NpgsqlDbType.Array)
Expand Down
132 changes: 113 additions & 19 deletions test/Npgsql.Tests/Support/TestBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,13 @@ public async Task<T> AssertType<T>(
bool isDefaultForWriting = true,
bool? isDefault = null,
bool isNpgsqlDbTypeInferredFromClrType = true,
Func<T, T, bool>? comparer = null)
Func<T, T, bool>? comparer = null,
bool skipArrayCheck = false)
{
await using var connection = await OpenConnectionAsync();
return await AssertType(
connection, value, sqlLiteral, pgTypeName, npgsqlDbType, dbType, inferredDbType, isDefaultForReading, isDefaultForWriting,
isDefault, isNpgsqlDbTypeInferredFromClrType, comparer);
isDefault, isNpgsqlDbTypeInferredFromClrType, comparer, skipArrayCheck);
}

public async Task<T> AssertType<T>(
Expand All @@ -60,12 +61,13 @@ public async Task<T> AssertType<T>(
bool isDefaultForWriting = true,
bool? isDefault = null,
bool isNpgsqlDbTypeInferredFromClrType = true,
Func<T, T, bool>? comparer = null)
Func<T, T, bool>? comparer = null,
bool skipArrayCheck = false)
{
await using var connection = await dataSource.OpenConnectionAsync();

return await AssertType(connection, value, sqlLiteral, pgTypeName, npgsqlDbType, dbType, inferredDbType, isDefaultForReading,
isDefaultForWriting, isDefault, isNpgsqlDbTypeInferredFromClrType, comparer);
isDefaultForWriting, isDefault, isNpgsqlDbTypeInferredFromClrType, comparer, skipArrayCheck);
}

public async Task<T> AssertType<T>(
Expand All @@ -80,26 +82,27 @@ public async Task<T> AssertType<T>(
bool isDefaultForWriting = true,
bool? isDefault = null,
bool isNpgsqlDbTypeInferredFromClrType = true,
Func<T, T, bool>? comparer = null)
Func<T, T, bool>? comparer = null,
bool skipArrayCheck = false)
{
if (isDefault is not null)
isDefaultForReading = isDefaultForWriting = isDefault.Value;

await AssertTypeWrite(connection, () => value, sqlLiteral, pgTypeName, npgsqlDbType, dbType, inferredDbType, isDefaultForWriting, isNpgsqlDbTypeInferredFromClrType);
return await AssertTypeRead(connection, sqlLiteral, pgTypeName, value, isDefaultForReading, comparer);
await AssertTypeWrite(connection, () => value, sqlLiteral, pgTypeName, npgsqlDbType, dbType, inferredDbType, isDefaultForWriting, isNpgsqlDbTypeInferredFromClrType, skipArrayCheck);
return await AssertTypeRead(connection, sqlLiteral, pgTypeName, value, isDefaultForReading, comparer, fieldType: null, skipArrayCheck);
}

public async Task<T> AssertTypeRead<T>(string sqlLiteral, string pgTypeName, T expected, bool isDefault = true)
public async Task<T> AssertTypeRead<T>(string sqlLiteral, string pgTypeName, T expected, bool isDefault = true, bool skipArrayCheck = false)
{
await using var connection = await OpenConnectionAsync();
return await AssertTypeRead(connection, sqlLiteral, pgTypeName, expected, isDefault);
return await AssertTypeRead(connection, sqlLiteral, pgTypeName, expected, isDefault, comparer: null, fieldType: null, skipArrayCheck);
}

public async Task<T> AssertTypeRead<T>(NpgsqlDataSource dataSource, string sqlLiteral, string pgTypeName, T expected,
bool isDefault = true, Func<T, T, bool>? comparer = null, Type? fieldType = null)
bool isDefault = true, Func<T, T, bool>? comparer = null, Type? fieldType = null, bool skipArrayCheck = false)
{
await using var connection = await dataSource.OpenConnectionAsync();
return await AssertTypeRead(connection, sqlLiteral, pgTypeName, expected, isDefault, comparer, fieldType);
return await AssertTypeRead(connection, sqlLiteral, pgTypeName, expected, isDefault, comparer, fieldType, skipArrayCheck);
}

public async Task AssertTypeWrite<T>(
Expand All @@ -111,12 +114,13 @@ public async Task AssertTypeWrite<T>(
DbType? dbType = null,
DbType? inferredDbType = null,
bool isDefault = true,
bool isNpgsqlDbTypeInferredFromClrType = true)
bool isNpgsqlDbTypeInferredFromClrType = true,
bool skipArrayCheck = false)
{
await using var connection = await dataSource.OpenConnectionAsync();

await AssertTypeWrite(connection, () => value, expectedSqlLiteral, pgTypeName, npgsqlDbType, dbType, inferredDbType, isDefault,
isNpgsqlDbTypeInferredFromClrType);
isNpgsqlDbTypeInferredFromClrType, skipArrayCheck);
}

public Task AssertTypeWrite<T>(
Expand All @@ -127,9 +131,10 @@ public Task AssertTypeWrite<T>(
DbType? dbType = null,
DbType? inferredDbType = null,
bool isDefault = true,
bool isNpgsqlDbTypeInferredFromClrType = true)
bool isNpgsqlDbTypeInferredFromClrType = true,
bool skipArrayCheck = false)
=> AssertTypeWrite(() => value, expectedSqlLiteral, pgTypeName, npgsqlDbType, dbType, inferredDbType, isDefault,
isNpgsqlDbTypeInferredFromClrType);
isNpgsqlDbTypeInferredFromClrType, skipArrayCheck);

public async Task AssertTypeWrite<T>(
Func<T> valueFactory,
Expand All @@ -139,10 +144,11 @@ public async Task AssertTypeWrite<T>(
DbType? dbType = null,
DbType? inferredDbType = null,
bool isDefault = true,
bool isNpgsqlDbTypeInferredFromClrType = true)
bool isNpgsqlDbTypeInferredFromClrType = true,
bool skipArrayCheck = false)
{
await using var connection = await OpenConnectionAsync();
await AssertTypeWrite(connection, valueFactory, expectedSqlLiteral, pgTypeName, npgsqlDbType, dbType, inferredDbType, isDefault, isNpgsqlDbTypeInferredFromClrType);
await AssertTypeWrite(connection, valueFactory, expectedSqlLiteral, pgTypeName, npgsqlDbType, dbType, inferredDbType, isDefault, isNpgsqlDbTypeInferredFromClrType, skipArrayCheck);
}

internal static async Task<T> AssertTypeRead<T>(
Expand All @@ -151,7 +157,35 @@ internal static async Task<T> AssertTypeRead<T>(
string pgTypeName,
T expected,
bool isDefault = true,
Func<T, T, bool>? comparer = null, Type? fieldType = null)
Func<T, T, bool>? comparer = null,
Type? fieldType = null,
bool skipArrayCheck = false)
{
var result = await AssertTypeReadCore(connection, sqlLiteral, pgTypeName, expected, isDefault, comparer);

// Check the corresponding array type as well
if (!skipArrayCheck && !pgTypeName.EndsWith("[]", StringComparison.Ordinal))
{
await AssertTypeReadCore(
connection,
ArrayLiteral(sqlLiteral),
pgTypeName + "[]",
new[] { expected, expected },
isDefault,
comparer is null ? null : (array1, array2) => comparer(array1[0], array2[0]) && comparer(array1[1], array2[1]));
}

return result;
}

internal static async Task<T> AssertTypeReadCore<T>(
NpgsqlConnection connection,
string sqlLiteral,
string pgTypeName,
T expected,
bool isDefault = true,
Func<T, T, bool>? comparer = null,
Type? fieldType = null)
{
if (sqlLiteral.Contains('\''))
sqlLiteral = sqlLiteral.Replace("'", "''");
Expand Down Expand Up @@ -186,6 +220,38 @@ internal static async Task<T> AssertTypeRead<T>(
}

internal static async Task AssertTypeWrite<T>(
NpgsqlConnection connection,
Func<T> valueFactory,
string expectedSqlLiteral,
string pgTypeName,
NpgsqlDbType? npgsqlDbType,
DbType? dbType = null,
DbType? inferredDbType = null,
bool isDefault = true,
bool isNpgsqlDbTypeInferredFromClrType = true,
bool skipArrayCheck = false)
{
await AssertTypeWriteCore(
connection, valueFactory, expectedSqlLiteral, pgTypeName, npgsqlDbType, dbType, inferredDbType, isDefault,
isNpgsqlDbTypeInferredFromClrType);

// Check the corresponding array type as well
if (!skipArrayCheck && !pgTypeName.EndsWith("[]", StringComparison.Ordinal))
{
await AssertTypeWriteCore(
connection,
() => new[] { valueFactory(), valueFactory() },
ArrayLiteral(expectedSqlLiteral),
pgTypeName + "[]",
npgsqlDbType | NpgsqlDbType.Array,
dbType: null,
inferredDbType: null,
isDefault,
isNpgsqlDbTypeInferredFromClrType);
}
}

internal static async Task AssertTypeWriteCore<T>(
NpgsqlConnection connection,
Func<T> valueFactory,
string expectedSqlLiteral,
Expand All @@ -205,7 +271,10 @@ internal static async Task AssertTypeWrite<T>(

// Strip any facet information (length/precision/scale)
var parenIndex = pgTypeName.IndexOf('(');
var pgTypeNameWithoutFacets = parenIndex > -1 ? pgTypeName[..parenIndex] : pgTypeName;
// var pgTypeNameWithoutFacets = parenIndex > -1 ? pgTypeName[..parenIndex] : pgTypeName;
var pgTypeNameWithoutFacets = parenIndex > -1
? pgTypeName[..parenIndex] + pgTypeName[(pgTypeName.IndexOf(')') + 1)..]
: pgTypeName;

// We test the following scenarios (between 2 and 5 in total):
// 1. With NpgsqlDbType explicitly set
Expand Down Expand Up @@ -363,6 +432,31 @@ public bool Equals(T? x, T? y)
public int GetHashCode(T obj) => throw new NotSupportedException();
}

// For array quoting rules, see array_out in https://github.com/postgres/postgres/blob/master/src/backend/utils/adt/arrayfuncs.c
static string ArrayLiteral(string elementLiteral)
{
switch (elementLiteral)
{
case "":
elementLiteral = "\"\"";
break;
case "NULL":
elementLiteral = "\"NULL\"";
break;
default:
// Escape quotes and backslashes, quote for special chars
elementLiteral = elementLiteral.Replace("\\", "\\\\").Replace("\"", "\\\"");
if (elementLiteral.Any(c => c is '{' or '}' or ',' or '"' or '\\' || char.IsWhiteSpace(c)))
{
elementLiteral = '"' + elementLiteral + '"';
}

break;
}

return $"{{{elementLiteral},{elementLiteral}}}";
}

#endregion Type testing

#region Utilities for use by tests
Expand Down
Loading

0 comments on commit 4e854c0

Please sign in to comment.