Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix calling procedures with output parameters by their four-part syntax #2349

Merged
merged 2 commits into from
Mar 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ public class SQLServerPreparedStatement extends SQLServerStatement implements IS

private boolean isCallEscapeSyntax;

private boolean isFourPartSyntax;

/** Parameter positions in processed SQL statement text. */
final int[] userSQLParamPositions;

Expand Down Expand Up @@ -144,6 +146,11 @@ private void setPreparedStatementHandle(int handle) {
*/
private static final Pattern execEscapePattern = Pattern.compile("^\\s*(?i)(?:exec|execute)\\b");

/**
* Regex for four part syntax
*/
private static final Pattern fourPartSyntaxPattern = Pattern.compile("(.+)\\.(.+)\\.(.+)\\.(.+)");

/** Returns the prepared statement SQL */
@Override
public String toString() {
Expand Down Expand Up @@ -271,6 +278,7 @@ private boolean resetPrepStmtHandle(boolean discardCurrentCacheItem) {
userSQL = parsedSQL.processedSQL;
isExecEscapeSyntax = isExecEscapeSyntax(sql);
isCallEscapeSyntax = isCallEscapeSyntax(sql);
isFourPartSyntax = isFourPartSyntax(sql);
userSQLParamPositions = parsedSQL.parameterPositions;
initParams(userSQLParamPositions.length);
useBulkCopyForBatchInsert = conn.getUseBulkCopyForBatchInsert();
Expand Down Expand Up @@ -1234,10 +1242,12 @@ boolean callRPCDirectly(Parameter[] params) throws SQLServerException {
// 2. There must be parameters
// 3. Parameters must not be a TVP type
// 4. Compliant CALL escape syntax
// If isExecEscapeSyntax is true, EXEC escape syntax is used then use prior behaviour to
// execute the procedure
// If isExecEscapeSyntax is true, EXEC escape syntax is used then use prior behaviour of
// wrapping call to execute the procedure
// If isFourPartSyntax is true, sproc is being executed against linked server, then
// use prior behaviour of wrapping call to execute procedure
return (null != procedureName && paramCount != 0 && !isTVPType(params) && isCallEscapeSyntax
&& !isExecEscapeSyntax);
&& !isExecEscapeSyntax && !isFourPartSyntax);
}

/**
Expand Down Expand Up @@ -1265,6 +1275,10 @@ private boolean isCallEscapeSyntax(String sql) {
return callEscapePattern.matcher(sql).find();
}

private boolean isFourPartSyntax(String sql) {
return fourPartSyntaxPattern.matcher(sql).find();
}

/**
* Executes sp_prepare to prepare a parameterized statement and sets the prepared statement handle
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1150,6 +1150,66 @@ public void testExecDocumentedSystemStoredProceduresIndexedParameters() throws S
}
}

@Test
@Tag(Constants.reqExternalSetup)
@Tag(Constants.xAzureSQLDB)
@Tag(Constants.xAzureSQLDW)
@Tag(Constants.xAzureSQLMI)
public void testFourPartSyntaxCallEscapeSyntax() throws SQLException {
String table = "serverList";

try (Statement stmt = connection.createStatement()) {
stmt.execute("IF OBJECT_ID(N'" + table + "') IS NOT NULL DROP TABLE " + table);
stmt.execute("CREATE TABLE " + table + " (serverName varchar(100),network varchar(100),serverStatus varchar(4000), id int, collation varchar(100), connectTimeout int, queryTimeout int)");
stmt.execute("INSERT " + table + " EXEC sp_helpserver");

ResultSet rs = stmt.executeQuery("SELECT COUNT(*) FROM " + table + " WHERE serverName = N'" + linkedServer + "'");
rs.next();

if (rs.getInt(1) == 1) {
stmt.execute("EXEC sp_dropserver @server='" + linkedServer + "';");
}

stmt.execute("EXEC sp_addlinkedserver @server='" + linkedServer + "';");
stmt.execute("EXEC sp_addlinkedsrvlogin @rmtsrvname=N'" + linkedServer + "', @rmtuser=N'" + remoteUser + "', @rmtpassword=N'" + remotePassword + "'");
stmt.execute("EXEC sp_serveroption '" + linkedServer + "', 'rpc', true;");
stmt.execute("EXEC sp_serveroption '" + linkedServer + "', 'rpc out', true;");
}
barryw-mssql marked this conversation as resolved.
Show resolved Hide resolved

SQLServerDataSource ds = new SQLServerDataSource();
ds.setServerName(linkedServer);
ds.setUser(remoteUser);
ds.setPassword(remotePassword);
ds.setEncrypt(false);
ds.setTrustServerCertificate(true);

try (Connection linkedServerConnection = ds.getConnection(); Statement stmt = linkedServerConnection.createStatement()) {
stmt.execute("create or alter procedure dbo.TestAdd(@Num1 int, @Num2 int, @Result int output) as begin set @Result = @Num1 + @Num2; end;");
}

try (CallableStatement cstmt = connection.prepareCall("{call [" + linkedServer + "].master.dbo.TestAdd(?,?,?)}")) {
int sum = 11;
int param0 = 1;
int param1 = 10;
cstmt.setInt(1, param0);
cstmt.setInt(2, param1);
cstmt.registerOutParameter(3, Types.INTEGER);
cstmt.execute();
assertEquals(sum, cstmt.getInt(3));
}

try (CallableStatement cstmt = connection.prepareCall("exec [" + linkedServer + "].master.dbo.TestAdd ?,?,?")) {
int sum = 11;
int param0 = 1;
int param1 = 10;
cstmt.setInt(1, param0);
cstmt.setInt(2, param1);
cstmt.registerOutParameter(3, Types.INTEGER);
cstmt.execute();
assertEquals(sum, cstmt.getInt(3));
}
}

/**
* Cleanup after test
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ public abstract class AbstractTest {
protected static String tenantID;
protected static String[] keyIDs = null;

protected static String linkedServer = null;
protected static String remoteUser = null;
protected static String remotePassword = null;
protected static String[] enclaveServer = null;
protected static String[] enclaveAttestationUrl = null;
protected static String[] enclaveAttestationProtocol = null;
Expand Down Expand Up @@ -197,6 +200,10 @@ public static void setup() throws Exception {

clientKeyPassword = getConfiguredProperty("clientKeyPassword", "");

linkedServer = getConfiguredProperty("linkedServer", null);
remoteUser = getConfiguredProperty("remoteUser", null);
remotePassword = getConfiguredProperty("remotePassword", null);

kerberosServer = getConfiguredProperty("kerberosServer", null);
kerberosServerPort = getConfiguredProperty("kerberosServerPort", null);

Expand Down
Loading