Skip to content

Commit

Permalink
add cassandra integration test
Browse files Browse the repository at this point in the history
  • Loading branch information
LeoQuote committed Aug 10, 2023
1 parent 0841977 commit 0f387bf
Show file tree
Hide file tree
Showing 3 changed files with 187 additions and 90 deletions.
143 changes: 53 additions & 90 deletions sql/engines/cassandra.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,25 +8,40 @@

import sqlparse

from common.utils.timer import FuncTimer
from . import EngineBase
from .models import ResultSet, ReviewSet, ReviewResult

from sql.models import SqlWorkflow

logger = logging.getLogger("default")


def split_sql(db_name=None, sql=""):
# 切分语句,追加到检测结果中,默认全部检测通过
sql = sql.split("\n")
sql = filter(None, sql)
sql_result = []
if db_name:
sql_result += [f"""USE {db_name}"""]
sql_result += sql
return sql_result


class CassandraEngine(EngineBase):
def get_connection(self, db_name=None):
db_name = db_name or self.db_name

if self.conn:
if db_name:
self.conn.execute(f"use {db_name}")
return self.conn
auth_provider = PlainTextAuthProvider(username=self.user, password=self.password)
cluster = Cluster([self.host], port=self.port, auth_provider=auth_provider)
self.conn = cluster.connect(keyspace=db_name)
return self.conn

def close(self):
if self.conn:
self.conn.close()
self.conn.shutdown()
self.conn = None

@property
Expand All @@ -40,23 +55,33 @@ def info(self):
def test_connection(self):
return self.get_all_databases()

def escape_string(self, value: str) -> str:
return re.sub(r"[; ]", "", value)

def get_all_databases(self, **kwargs):
"""
获取所有的 keyspace/database
:return:
"""
result = ResultSet(full_sql="SELECT keyspace_name FROM system_schema.keyspaces;")
session = self.get_connection()
rows = session.execute(result.full_sql)
result = self.query(sql="SELECT keyspace_name FROM system_schema.keyspaces;")
result.rows = [x[0] for x in result.rows]
return result

db_list = [x.keyspace_name for x in rows]
result.rows = db_list
def get_all_columns_by_tb(self, db_name, tb_name, **kwargs):
"""获取所有列, 返回一个ResultSet"""
sql = "select column_name, type from columns where keyspace_name=%s and table_name=%s"
result = self.query(db_name="system_schema", sql=sql, parameters=(db_name, tb_name))
return result

def describe_table(self, db_name, tb_name, **kwargs):
sql = f"describe {tb_name}"
return self.query(db_name=db_name, sql=sql)

def query_check(self, db_name=None, sql="", limit_num=0):
"""提交查询前的检查"""
# 查询语句的检查、注释去除、切分
result = {"msg": "", "bad_query": False, "filtered_sql": sql, "has_star": False}
result = {"msg": "", "bad_query": False,
"filtered_sql": sql, "has_star": False}
# 删除注释语句,进行语法判断,执行第一条有效sql
try:
sql = sqlparse.format(sql, strip_comments=True)
Expand All @@ -71,60 +96,47 @@ def query_check(self, db_name=None, sql="", limit_num=0):
if "*" in sql:
result["has_star"] = True
result["msg"] = "SQL语句中含有 * "
# 不应该查看mysql.user表
if re.match(
".*(\\s)+(mysql|`mysql`)(\\s)*\\.(\\s)*(user|`user`)((\\s)*|;).*",
sql.lower().replace("\n", ""),
) or (
db_name == "mysql"
and re.match(
".*(\\s)+(user|`user`)((\\s)*|;).*", sql.lower().replace("\n", "")
)
):
result["bad_query"] = True
result["msg"] = "您无权查看该表"

return result

def query(self, db_name=None, sql="", limit_num=0, close_conn=True, **kwargs):
def query(self, db_name=None, sql="", limit_num=0, close_conn=True, parameters=None, **kwargs):
"""返回 ResultSet"""
result_set = ResultSet(full_sql=sql)
try:
conn = self.get_connection(db_name=db_name)
rows = conn.execute(sql)
rows = conn.execute(sql, parameters=parameters)
result_set.column_list = rows.column_names
result_set.rows = rows.all()
result_set.affected_rows = len(result_set.rows)
if limit_num > 0:
result_set.rows = result_set.rows[0:limit_num]
result_set.affected_rows = min(limit_num, result_set.affected_rows)
except Exception as e:
logger.warning(f"cassandra query 错误,语句:{sql}, 错误信息:{traceback.format_exc()}")
logger.warning(f"{self.name} query 错误,语句:{sql}, 错误信息:{traceback.format_exc()}")
result_set.error = str(e)
return result_set

def get_all_tables(self, db_name, **kwargs):
sql = "SELECT table_name FROM system_schema.tables WHERE keyspace_name = %s;"
parameters = [db_name]
result = self.query(db_name=db_name, sql=sql, parameters=parameters)
tb_list = [row[0] for row in result.rows]
result.rows = tb_list
return result

def filter_sql(self, sql="", limit_num=0):
return sql.strip()

def query_masking(self, db_name=None, sql="", resultset=None):
"""不做脱敏"""
return resultset

def split_sql(self, db_name=None, sql=""):
# 切分语句,追加到检测结果中,默认全部检测通过
sql = sql.split("\n")
sql = filter(None, sql)
split_sql = [f"""USE {db_name}"""]
split_sql += sql
return split_sql

def execute_check(self, db_name=None, sql=""):
"""上线单执行前的检查, 返回Review set"""
check_result = ReviewSet(full_sql=sql)
# 切分语句,追加到检测结果中,默认全部检测通过
split_sql = self.split_sql(db_name, sql)
sql_result = split_sql(db_name, sql)
rowid = 1
for statement in split_sql:
for statement in sql_result:
check_result.rows.append(
ReviewResult(
id=rowid,
Expand All @@ -143,11 +155,11 @@ def execute(self, db_name=None, sql="", close_conn=True, parameters=None):
"""执行sql语句 返回 Review set"""
execute_result = ReviewSet(full_sql=sql)
conn = self.get_connection(db_name=db_name)
split_sql = self.split_sql(db_name, sql)
sql_result = split_sql(db_name, sql)
rowid = 1
for statement in split_sql:
for statement in sql_result:
try:
result = conn.execute(statement)
conn.execute(statement)
execute_result.rows.append(
ReviewResult(
id=rowid,
Expand All @@ -160,7 +172,7 @@ def execute(self, db_name=None, sql="", close_conn=True, parameters=None):
)
)
except Exception as e:
logger.warning(f"Mssql命令执行报错,语句:{sql}, 错误信息:{traceback.format_exc()}")
logger.warning(f"{self.name} 命令执行报错,语句:{sql}, 错误信息:{traceback.format_exc()}")
execute_result.error = str(e)
execute_result.rows.append(
ReviewResult(
Expand All @@ -174,10 +186,8 @@ def execute(self, db_name=None, sql="", close_conn=True, parameters=None):
)
)
break

rowid += 1
if execute_result.error:
# 如果失败, 将剩下的部分加入结果集, 并将语句回滚
for statement in split_sql[rowid:]:
execute_result.rows.append(
ReviewResult(
Expand All @@ -191,57 +201,10 @@ def execute(self, db_name=None, sql="", close_conn=True, parameters=None):
)
)
rowid += 1
cursor.rollback()
for row in execute_result.rows:
if row.stagestatus == "Execute Successfully":
row.stagestatus += "\nRollback Successfully"
else:
cursor.commit()
if close_conn:
self.close()
return execute_result


def execute_workflow(self, workflow):
def execute_workflow(self, workflow: SqlWorkflow):
"""执行上线单,返回Review set"""
sql = workflow.sqlworkflowcontent.sql_content
split_sql = [cmd.strip() for cmd in sql.split("\n") if cmd.strip()]
execute_result = ReviewSet(full_sql=sql)
line = 1
cmd = None
try:
conn = self.get_connection(db_name=workflow.db_name)
conn.execute(workflow.sqlworkflowcontent)
except Exception as e:
logger.warning(
f"cassandra 执行错误,语句:{cmd or sql}, 错误信息:{traceback.format_exc()}"
)
# 追加当前报错语句信息到执行结果中
execute_result.error = str(e)
execute_result.rows.append(
ReviewResult(
id=line,
errlevel=2,
stagestatus="Execute Failed",
errormessage=f"异常信息:{e}",
sql=cmd,
affected_rows=0,
execute_time=0,
)
)
line += 1
# 报错语句后面的语句标记为审核通过、未执行,追加到执行结果中
for statement in split_sql[line - 1:]:
execute_result.rows.append(
ReviewResult(
id=line,
errlevel=0,
stagestatus="Audit completed",
errormessage=f"前序语句失败, 未执行",
sql=statement,
affected_rows=0,
execute_time=0,
)
)
line += 1
return execute_result
return self.execute(db_name=workflow.db_name, sql=workflow.sqlworkflowcontent.sql_content)
133 changes: 133 additions & 0 deletions sql/engines/test_cassandra.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import unittest
from unittest.mock import patch, Mock

from django.test import TestCase
from sql.models import Instance
from sql.engines.cassandra import CassandraEngine
from sql.engines.models import ResultSet

# 启用后, 会运行全部测试, 包括一些集成测试
integration_test_enabled = False
integration_test_host = "localhost"


class CassandraEngineTest(TestCase):
def setUp(self) -> None:
self.ins = Instance.objects.create(
instance_name="some_ins",
type="slave",
db_type="cassandra",
host="localhost",
port=9200,
user="cassandra",
password="cassandra",
db_name="some_db",
)
self.engine = CassandraEngine(instance=self.ins)

def tearDown(self) -> None:
self.ins.delete()
if integration_test_enabled:
self.engine.execute(sql="drop keyspace test;")

@patch("sql.engines.cassandra.Cluster.connect")
def test_get_connection(self, mock_connect):
_ = self.engine.get_connection()
mock_connect.assert_called_once()

@patch("sql.engines.cassandra.CassandraEngine.get_connection")
def test_query(self, mock_get_connection):
test_sql = """select 123"""
self.assertIsInstance(self.engine.query("some_db", test_sql), ResultSet)

def test_query_check(self):
test_sql = """select 123; -- this is comment
select 456;"""

result_sql = "select 123;"

check_result = self.engine.query_check(sql=test_sql)

self.assertIsInstance(check_result, dict)
self.assertEqual(False, check_result.get("bad_query"))
self.assertEqual(result_sql, check_result.get("filtered_sql"))

def test_query_check_error(self):
test_sql = """drop table table_a"""

check_result = self.engine.query_check(sql=test_sql)

self.assertIsInstance(check_result, dict)
self.assertEqual(True, check_result.get("bad_query"))

@patch("sql.engines.cassandra.CassandraEngine.query")
def test_get_all_databases(self, mock_query):
mock_query.return_value = ResultSet(rows=[("some_db",)])

result = self.engine.get_all_databases()

self.assertIsInstance(result, ResultSet)
self.assertEqual(result.rows, ["some_db"])

@patch("sql.engines.cassandra.CassandraEngine.query")
def test_get_all_tables(self, mock_query):
# 下面是查表示例返回结果
mock_query.return_value = ResultSet(rows=[("u",), ("v",), ("w",)])

table_list = self.engine.get_all_tables("some_db")

self.assertEqual(table_list.rows, ["u", "v", "w"])

@patch("sql.engines.cassandra.CassandraEngine.query")
def test_describe_table(self, mock_query):
mock_query.return_value = ResultSet()
self.engine.describe_table("some_db", "some_table")
mock_query.assert_called_once_with(db_name="some_db", sql="describe some_table")

@patch("sql.engines.cassandra.CassandraEngine.query")
def test_get_all_columns_by_tb(self, mock_query):
mock_query.return_value = ResultSet(
rows=[("name", "text")],
column_list=["column_name", "type"]
)

result = self.engine.get_all_columns_by_tb("some_db", "some_table")
self.assertEqual(result.rows, [("name", "text")])
self.assertEqual(
result.column_list, ["column_name", "type"]
)


@unittest.skipIf(not integration_test_enabled, "cassandra integration test is not enabled")
class CassandraIntegrationTest(TestCase):
def setUp(self):
self.instance = Instance.objects.create(
instance_name="int_ins",
type="slave",
db_type="cassandra",
host=integration_test_host,
port=9042,
user="cassandra",
password="cassandra",
db_name="",
)
self.engine = CassandraEngine(instance=self.instance)

self.keyspace = "test"
self.table = "test_table"
# 新建 keyspace
self.engine.execute(sql=f"create keyspace {self.keyspace} with replication = "
"{'class': 'org.apache.cassandra.locator.SimpleStrategy', "
"'replication_factor': '1'};")
# 建表
self.engine.execute(db_name=self.keyspace, sql=f"""create table if not exists {self.table}( name text primary key );""")

def tearDown(self):
self.engine.execute(sql="drop keyspace test;")

def test_integrate_query(self):
self.engine.execute(db_name=self.keyspace, sql=f"insert into {self.table} (name) values ('test')")

result = self.engine.query(db_name=self.keyspace, sql=f"select * from {self.table}")

self.assertEqual(result.rows[0][0], "test")
1 change: 1 addition & 0 deletions sql/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ class Meta:
("odps", "ODPS"),
("clickhouse", "ClickHouse"),
("goinception", "goInception"),
("cassandra", "Cassandra"),
)


Expand Down

0 comments on commit 0f387bf

Please sign in to comment.