Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

删除 oracle engine 中 filter_sql 实现 #2797

Merged
merged 4 commits into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 0 additions & 10 deletions sql/engines/oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,16 +643,6 @@ def query_check(self, db_name=None, sql=""):
result["msg"] = keyword_warning
return result

def filter_sql(self, sql="", limit_num=0):
sql_lower = sql.lower()
# 对查询sql增加limit限制
if re.match(r"^select|^with", sql_lower) and not (
re.match(r"^select\s+sql_audit.", sql_lower)
and sql_lower.find(" sql_audit where rownum <= ") != -1
):
sql = f"select sql_audit.* from ({sql.rstrip(';')}) sql_audit where rownum <= {limit_num}"
return sql.strip()

def query(
self,
db_name=None,
Expand Down
36 changes: 0 additions & 36 deletions sql/engines/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1238,42 +1238,6 @@ def test_query_check_IndexError(self):
},
)

def test_filter_sql_with_delimiter(self):
sql = "select * from xx;"
new_engine = OracleEngine(instance=self.ins)
check_result = new_engine.filter_sql(sql=sql, limit_num=100)
self.assertEqual(
check_result,
"select sql_audit.* from (select * from xx) sql_audit where rownum <= 100",
)

def test_filter_sql_with_delimiter_and_where(self):
sql = "select * from xx where id>1;"
new_engine = OracleEngine(instance=self.ins)
check_result = new_engine.filter_sql(sql=sql, limit_num=100)
self.assertEqual(
check_result,
"select sql_audit.* from (select * from xx where id>1) sql_audit where rownum <= 100",
)

def test_filter_sql_without_delimiter(self):
sql = "select * from xx;"
new_engine = OracleEngine(instance=self.ins)
check_result = new_engine.filter_sql(sql=sql, limit_num=100)
self.assertEqual(
check_result,
"select sql_audit.* from (select * from xx) sql_audit where rownum <= 100",
)

def test_filter_sql_with_limit(self):
sql = "select * from xx limit 10;"
new_engine = OracleEngine(instance=self.ins)
check_result = new_engine.filter_sql(sql=sql, limit_num=1)
self.assertEqual(
check_result,
"select sql_audit.* from (select * from xx limit 10) sql_audit where rownum <= 1",
)

def test_query_masking(self):
query_result = ResultSet()
new_engine = OracleEngine(instance=self.ins)
Expand Down
3 changes: 2 additions & 1 deletion sql/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,8 @@ def query(request):
limit_num = 0 if re.match(r"^explain", sql_content.lower()) else limit_num

# 对查询sql增加limit限制或者改写语句
sql_content = query_engine.filter_sql(sql=sql_content, limit_num=limit_num)
if instance.db_type != "oracle":
hhw007 marked this conversation as resolved.
Show resolved Hide resolved
sql_content = query_engine.filter_sql(sql=sql_content, limit_num=limit_num)

# 先获取查询连接,用于后面查询复用连接以及终止会话
query_engine.get_connection(db_name=db_name)
Expand Down
Loading