Skip to content

Commit

Permalink
修复mongo上线问题 (#2803)
Browse files Browse the repository at this point in the history
* 修复未开启认证的mongo执行上线失败问题

* 优化是否存在账号密码的逻辑,封装buildcmd函数

* _build_cmd逻辑优化,提取公共参数,根据不同情况判断进行其他参数拼接

* reformat mongo.py

* refomart mongo by black

* 添加单元测试
  • Loading branch information
smqyisyy authored Sep 18, 2024
1 parent 54bb3dc commit 5a6871d
Show file tree
Hide file tree
Showing 2 changed files with 150 additions and 37 deletions.
84 changes: 47 additions & 37 deletions sql/engines/mongo.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# -*- coding: UTF-8 -*-
import re, time
import re
import time
import pymongo
import logging
import traceback
Expand Down Expand Up @@ -286,15 +287,14 @@ def test_connection(self):
def exec_cmd(self, sql, db_name=None, slave_ok=""):
"""审核时执行的语句"""

if self.user and self.password and self.port and self.host:
if self.port and self.host:
msg = ""
auth_db = self.instance.db_name or "admin"
sql_len = len(sql)
is_load = False # 默认不使用load方法执行mongodb sql语句
try:
if (
not sql.startswith("var host=") and sql_len > 4000
): # 在master节点执行的情况,如果sql长度大于4000,就采取load js的方法
if not sql.startswith("var host=") and sql_len > 4000:
# 在master节点执行的情况,如果sql长度大于4000,就采取load js的方法
# 因为用mongo load方法执行js脚本,所以需要重新改写一下sql,以便回显js执行结果
sql = "var result = " + sql + "\nprintjson(result);"
# 因为要知道具体的临时文件位置,所以用了NamedTemporaryFile模块
Expand All @@ -303,43 +303,17 @@ def exec_cmd(self, sql, db_name=None, slave_ok=""):
)
fp.write(sql.encode("utf-8"))
fp.seek(0) # 把文件指针指向开始,这样写的sql内容才能落到磁盘文件上
cmd = "{mongo} --quiet -u {uname} -p '{password}' {host}:{port}/{auth_db} <<\\EOF\ndb=db.getSiblingDB(\"{db_name}\");{slave_ok}load('{tempfile_}')\nEOF".format(
mongo=mongo,
uname=self.user,
password=self.password,
host=self.host,
port=self.port,
db_name=db_name,
sql=sql,
auth_db=auth_db,
slave_ok=slave_ok,
tempfile_=fp.name,
cmd = self._build_cmd(
db_name, auth_db, slave_ok, fp.name, is_load=True
)
is_load = True # 标记使用了load方法,用来在finally里面判断是否需要强制删除临时文件
elif (
not sql.startswith("var host=") and sql_len < 4000
): # 在master节点执行的情况, 如果sql长度小于4000,就直接用mongo shell执行,减少磁盘交换,节省性能
cmd = "{mongo} --quiet -u {uname} -p '{password}' {host}:{port}/{auth_db} <<\\EOF\ndb=db.getSiblingDB(\"{db_name}\");{slave_ok}printjson({sql})\nEOF".format(
mongo=mongo,
uname=self.user,
password=self.password,
host=self.host,
port=self.port,
db_name=db_name,
sql=sql,
auth_db=auth_db,
slave_ok=slave_ok,
)
cmd = self._build_cmd(db_name, auth_db, slave_ok, sql=sql)
else:
cmd = "{mongo} --quiet -u {user} -p '{password}' {host}:{port}/{auth_db} <<\\EOF\nrs.slaveOk();{sql}\nEOF".format(
mongo=mongo,
user=self.user,
password=self.password,
host=self.host,
port=self.port,
db_name=db_name,
sql=sql,
auth_db=auth_db,
cmd = self._build_cmd(
db_name, auth_db, sql=sql, slave_ok="rs.slaveOk();"
)
p = subprocess.Popen(
cmd,
Expand Down Expand Up @@ -371,6 +345,42 @@ def exec_cmd(self, sql, db_name=None, slave_ok=""):
fp.close()
return msg

# 用来进行判断是否有用户名与密码以及是否需要临时文件的情况,进而返回要执行的mongo命令
def _build_cmd(
self, db_name, auth_db, slave_ok="", tempfile_=None, sql=None, is_load=False
):
# 提取公共参数
common_params = {
"mongo": "mongo",
"host": self.host,
"port": self.port,
"db_name": db_name,
"auth_db": auth_db,
"slave_ok": slave_ok,
}
if is_load:
cmd_template = (
"{mongo} --quiet {auth_options} {host}:{port}/{auth_db} <<\\EOF\n"
"db=db.getSiblingDB('{db_name}');{slave_ok}load('{tempfile_}')\nEOF"
)
# 长度超限使用loadjs的方式运行,使用临时文件
common_params["tempfile_"] = tempfile_
else:
cmd_template = (
"{mongo} --quiet {auth_options} {host}:{port}/{auth_db} <<\\EOF\n"
"db=db.getSiblingDB('{db_name}');{slave_ok}printjson({sql})\nEOF"
)
# 长度不超限直接mongo shell,无需临时文件
common_params["sql"] = sql
# 如果有账号密码,则添加选项
if self.user and self.password:
common_params["auth_options"] = "-u {uname} -p '{password}'".format(
uname=self.user, password=self.password
)
else:
common_params["auth_options"] = ""
return cmd_template.format(**common_params)

def get_master(self):
"""获得主节点的port和host"""

Expand Down Expand Up @@ -795,7 +805,6 @@ def execute_check(self, db_name=None, sql=""):
def get_connection(self, db_name=None):
self.db_name = db_name or self.instance.db_name or "admin"
auth_db = self.instance.db_name or "admin"

if self.user and self.password:
self.conn = pymongo.MongoClient(
self.host,
Expand All @@ -814,6 +823,7 @@ def get_connection(self, db_name=None):
connect=True,
connectTimeoutMS=10000,
)

return self.conn

def close(self):
Expand Down
103 changes: 103 additions & 0 deletions sql/engines/test_mongo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import pytest
from unittest.mock import patch, MagicMock
from sql.engines.mongo import MongoEngine


@pytest.fixture
def mongo_engine():
engine = MongoEngine()
engine.host = "localhost"
engine.port = 27017
engine.user = "test_user"
engine.password = "test_password"
engine.instance = MagicMock()
engine.instance.db_name = "test_db"
return engine


def test_build_cmd_with_load(mongo_engine):
# Call the method with is_load=True
cmd = mongo_engine._build_cmd(
db_name="test_db",
auth_db="admin",
slave_ok="rs.slaveOk();",
tempfile_="/tmp/test.js",
is_load=True,
)

# Expected command template
expected_cmd = (
"mongo --quiet -u test_user -p 'test_password' localhost:27017/admin <<\\EOF\n"
"db=db.getSiblingDB('test_db');rs.slaveOk();load('/tmp/test.js')\nEOF"
)

# Assertions
assert cmd == expected_cmd


def test_build_cmd_without_load(mongo_engine):
# Call the method with is_load=False
cmd = mongo_engine._build_cmd(
db_name="test_db",
auth_db="admin",
slave_ok="rs.slaveOk();",
sql="db.test_collection.find()",
is_load=False,
)

# Expected command template
expected_cmd = (
"mongo --quiet -u test_user -p 'test_password' localhost:27017/admin <<\\EOF\n"
"db=db.getSiblingDB('test_db');rs.slaveOk();printjson(db.test_collection.find())\nEOF"
)

# Assertions
assert cmd == expected_cmd


def test_build_cmd_without_auth(mongo_engine):
# Set user and password to None
mongo_engine.user = None
mongo_engine.password = None

# Call the method with is_load=False
cmd = mongo_engine._build_cmd(
db_name="test_db",
auth_db="admin",
slave_ok="rs.slaveOk();",
sql="db.test_collection.find()",
is_load=False,
)

# Expected command template
expected_cmd = (
"mongo --quiet localhost:27017/admin <<\\EOF\n"
"db=db.getSiblingDB('test_db');rs.slaveOk();printjson(db.test_collection.find())\nEOF"
)

# Assertions
assert cmd == expected_cmd


def test_build_cmd_with_load_without_auth(mongo_engine):
# Set user and password to None
mongo_engine.user = None
mongo_engine.password = None

# Call the method with is_load=True
cmd = mongo_engine._build_cmd(
db_name="test_db",
auth_db="admin",
slave_ok="rs.slaveOk();",
tempfile_="/tmp/test.js",
is_load=True,
)

# Expected command template
expected_cmd = (
"mongo --quiet localhost:27017/admin <<\\EOF\n"
"db=db.getSiblingDB('test_db');rs.slaveOk();load('/tmp/test.js')\nEOF"
)

# Assertions
assert cmd == expected_cmd

0 comments on commit 5a6871d

Please sign in to comment.