-
Notifications
You must be signed in to change notification settings - Fork 0
/
tcp_forward.py
96 lines (84 loc) · 3.23 KB
/
tcp_forward.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import socket
import sys
import threading
import os
import signal
import ssl
if len(sys.argv) < 6:
print('Usage:\n\tpython tcp_forward.py <certificate file> <listen host> <listen port> <remote host> <remote port> [debug]')
print('Example:\n\tpython cert.pem tcp_forward.py localhost 8080 www.google.com 80')
sys.exit(0)
try:
certFile = sys.argv[1]
listenHost = sys.argv[2]
listenPort = int(sys.argv[3])
targetHost = sys.argv[4]
targetPort = int(sys.argv[5])
debugOutput = len(sys.argv) > 6
except Exception as e:
print("Invalid parameters: %s" % (e))
exit()
listenSocket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
contextClient = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)
def main():
try:
# prepare SSL client socket context
contextClient.load_cert_chain(certfile=certFile)
except Exception as e:
print("Invalid certificate: %s" % (e))
os.kill(os.getpid(), signal.SIGKILL)
try:
# start listening for clients
listenSocket.bind((listenHost, listenPort))
listenSocket.listen(5)
except Exception as e:
print("Cannot bind/listen: %s" % (e))
os.kill(os.getpid(), signal.SIGKILL)
print("*** listening on %s:%i" % ( listenHost, listenPort ))
thrd = threading.Thread(target=server, args=())
thrd.start()
thrd.join()
def server(*settings):
try:
while True:
# wait for new client connection
clientSocketHandler, clientAddress = listenSocket.accept()
print("*** new connection from %s:%i to %s:%i" % ( clientAddress, listenPort, targetHost, targetPort ))
# convert to SSL socket
clientSocket = contextClient.wrap_socket(clientSocketHandler, server_side=True)
# create SSL socket to target
serverSocketHandler = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
contextServer = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)
serverSocket = contextServer.wrap_socket(serverSocketHandler, server_hostname=targetHost)
serverSocket.connect((targetHost, targetPort))
threading.Thread(target=forward, args=(clientSocket, serverSocket, "client -> server")).start()
threading.Thread(target=forward, args=(serverSocket, clientSocket, "server -> client")).start()
finally:
threading.Thread(target=server, args=()).start()
def forward(source, destination, description):
data = ' '
while data:
try:
data = source.recv(1024)
if debugOutput:
print("*** %s [%d bytes]: %s" % ( description, len(data) if data else -1, data ))
except Exception as e:
print(e)
break
finally:
if data:
try:
destination.sendall(data)
except Exception as e:
print("Failed sending data: %s" % (e))
else:
try:
source.shutdown(socket.SHUT_RD)
except Exception as e:
print(e)
try:
destination.shutdown(socket.SHUT_WR)
except Exception as e:
print(e)
if __name__ == '__main__':
main()