Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/shm #120

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/vul-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,13 @@ jobs:

- name: Run Vul test
run: |
curl --fail --retry-delay 10 --retry 30 --retry-connrefused http://127.0.0.1:8003/api/django/demo/get_open?name=Data
cd ${{ github.workspace }}/DockerVulspace
curl --fail --retry-delay 10 --retry 30 --retry-connrefused http://127.0.0.1:8003/api/django/demo/get_open?name=Data
docker-compose logs djangoweb flaskweb
docker-compose exec -T djangoweb python -V
docker-compose exec -T djangoweb pip list
docker-compose exec -T flaskweb python -V
docker-compose exec -T flaskweb pip list
docker-compose logs djangoweb flaskweb
bash ${{ github.workspace }}/DongTai-agent-python/dongtai_agent_python/tests/vul-test.sh \
django http://127.0.0.1:8003/api/django ${{ github.run_id }}
bash ${{ github.workspace }}/DongTai-agent-python/dongtai_agent_python/tests/vul-test.sh \
Expand Down
2 changes: 2 additions & 0 deletions dongtai_agent_python/api/openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,8 @@ def agent_register(self):
logger.error("register get agent id empty")
return resp

self.setting.set_shm('agent-' + str(self.agent_id))

if resp.get('data', {}).get('coreAutoStart', 0) != 1:
logger.info("agent is waiting for auditing")
self.setting.dt_manual_pause = True
Expand Down
56 changes: 51 additions & 5 deletions dongtai_agent_python/setting/setting.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
import os
from multiprocessing import Lock

from dongtai_agent_python import version
from .config import Config
from dongtai_agent_python.utils import Singleton
from dongtai_agent_python.utils.shm import SharedMemoryDict
from dongtai_agent_python.utils.lock import lock

_lock = Lock()


class Setting(Singleton):
Expand All @@ -13,10 +18,8 @@ def init(self):
return

self.version = version.__version__
self.paused = False
self.manual_paused = False
self.agent_id = 0
self.request_seq = 0
self.shm = None

self.auto_create_project = 0
self.use_local_policy = False
Expand All @@ -38,6 +41,12 @@ def init(self):
self.init_os_environ()
Setting.loaded = True

def __del__(self):
if self.shm is None:
return False
self.shm.close()
self.shm.unlink()

def set_container(self, container):
if container and isinstance(container, dict):
self.container = container
Expand Down Expand Up @@ -76,8 +85,45 @@ def init_os_environ(self):
for key in os_env.keys():
self.os_env_list.append(key + '=' + str(os_env[key]))

def set_shm(self, name):
if self.shm is None:
self.shm = SharedMemoryDict('dongtai-shm-python-' + name)

@property
def paused(self):
if self.shm is None:
return False
return self.shm.get('paused')

@paused.setter
def paused(self, status):
if self.shm is None:
return
self.shm['paused'] = status

@property
def manual_paused(self):
if self.shm is None:
return False
return self.shm.get('manual_paused')

@manual_paused.setter
def manual_paused(self, status):
if self.shm is None:
return
self.shm['manual_paused'] = status

def is_agent_paused(self):
return self.paused and self.manual_paused
return self.paused or self.manual_paused

@property
def request_seq(self):
if self.shm is None:
return 0
return self.shm.get('request_seq', 0)

@lock(_lock)
def incr_request_seq(self):
self.request_seq = self.request_seq + 1
if self.shm is None:
return
self.shm['request_seq'] = self.request_seq + 1
34 changes: 32 additions & 2 deletions dongtai_agent_python/tests/setting/test_setting.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,54 @@
import multiprocessing
import os
import threading
import time
import unittest

from dongtai_agent_python.setting.setting import Setting


class TestSetting(unittest.TestCase):
def test_multithreading(self):
def test(name):
def test_mt(name):
st1 = Setting()
st1.shm = None
st1.set_shm("test-setting-001")
st1.set_container({'name': name, 'version': '0.1'})
st1.incr_request_seq()

thread_num = 5
for i in range(thread_num):
t = threading.Thread(target=test, args=['test' + str(i)])
t = threading.Thread(target=test_mt, args=['test' + str(i)])
t.start()

st = Setting()
st.shm = None
st.set_shm("test-setting-001")
time.sleep(1)
self.assertEqual(thread_num, st.request_seq)

def test_multiprocessing(self):
if os.name == "nt":
return

def test_mp(name):
st1 = Setting()
st1.shm = None
st1.set_shm("test-setting-002")
st1.set_container({'name': name, 'version': '0.1'})
st1.incr_request_seq()

process_num = 5
for i in range(process_num):
p = multiprocessing.Process(target=test_mp, args=('test' + str(i),))
p.start()

st = Setting()
st.shm = None
st.set_shm("test-setting-002")
time.sleep(1)
self.assertEqual(process_num, st.request_seq)


if __name__ == '__main__':
unittest.main()
14 changes: 14 additions & 0 deletions dongtai_agent_python/utils/lock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from functools import wraps


def lock(_lock):
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
_lock.acquire()
try:
return func(*args, **kwargs)
finally:
_lock.release()
return wrapper
return decorator
1 change: 1 addition & 0 deletions dongtai_agent_python/utils/shm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .shm import SharedMemoryDict
148 changes: 148 additions & 0 deletions dongtai_agent_python/utils/shm/shm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
import logging
import pickle
import sys
from contextlib import contextmanager
from multiprocessing import Lock

from dongtai_agent_python.utils.lock import lock

if sys.version_info[:3] < (3, 8):
from shared_memory.shared_memory import SharedMemory
else:
from multiprocessing.shared_memory import SharedMemory

NULL_BYTE = b"\x00"

logger = logging.getLogger(__name__)
_lock = Lock()
DEFAULT_OBJ = object()


class SharedMemoryDict:
def __init__(self, name, size=1024):
self.name = name
self.mem_block = self.get_or_create(size)
self.init_memory()

@lock(_lock)
def get_or_create(self, size):
try:
return SharedMemory(name=self.name)
except FileNotFoundError:
return SharedMemory(name=self.name, create=True, size=size)

def init_memory(self):
memory_is_empty = (bytes(self.mem_block.buf).split(NULL_BYTE, 1)[0] == b'')
if memory_is_empty:
self.save_memory({})

def close(self) -> None:
if not hasattr(self, 'mem_block'):
return
self.mem_block.close()

def unlink(self) -> None:
if not hasattr(self, 'mem_block'):
return
self.mem_block.unlink()

@lock
def clear(self) -> None:
self.save_memory({})

def popitem(self):
with self.modify_db() as db:
return db.popitem()

def save_memory(self, db) -> None:
data = pickle.dumps(db)
try:
self.mem_block.buf[:len(data)] = data
except ValueError as exc:
logging.error("failed save to memory", exc_info=exc)

def read_memory(self):
return pickle.loads(self.mem_block.buf.tobytes())

@contextmanager
@lock(_lock)
def modify_db(self):
db = self.read_memory()
yield db
self.save_memory(db)

def __getitem__(self, key: str):
return self.read_memory()[key]

def __setitem__(self, key: str, value) -> None:
with self.modify_db() as db:
db[key] = value

def __len__(self) -> int:
return len(self.read_memory())

def __delitem__(self, key: str) -> None:
with self.modify_db() as db:
del db[key]

def __iter__(self):
return iter(self.read_memory())

def __reversed__(self):
return reversed(self.read_memory())

def __del__(self) -> None:
self.close()

def __contains__(self, key: str) -> bool:
return key in self.read_memory()

def __eq__(self, other) -> bool:
return self.read_memory() == other

def __ne__(self, other) -> bool:
return self.read_memory() != other

if sys.version_info > (3, 8):
def __or__(self, other):
return self.read_memory() | other

def __ror__(self, other):
return other | self.read_memory()

def __ior__(self, other):
with self.modify_db() as db:
db |= other
return db

def __str__(self):
return str(self.read_memory())

def __repr__(self):
return repr(self.read_memory())

def get(self, key: str, default=None):
return self.read_memory().get(key, default)

def keys(self):
return self.read_memory().keys()

def values(self):
return self.read_memory().values()

def items(self):
return self.read_memory().items()

def pop(self, key: str, default=DEFAULT_OBJ):
with self.modify_db() as db:
if default is DEFAULT_OBJ:
return db.pop(key)
return db.pop(key, default)

def update(self, other=(), **kwargs):
with self.modify_db() as db:
db.update(other, **kwargs)

def setdefault(self, key: str, default=None):
with self.modify_db() as db:
return db.setdefault(key, default)
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,5 @@ install_requires =
psutil >= 5.8.0
requests >= 2.25.1
pip >= 19.2.3
shared-memory38 >= 0.1.2; python_version < '3.8'
; regexploit >= 1.0.0