Refactor: Config dataclass, more unit tests, ruff/mypy setup
- Replace module-level env var globals with frozen Config dataclass - Add pyproject.toml with project metadata, ruff, mypy, pytest config - Expand unit tests from 11 to 30 (sanitize edge cases, SSH failure modes, proxy forwarding, cleanup, auth loop edge cases) - Fix all ruff and mypy findings - Integration tests no longer need importlib.reload hack - Dockerfile installs from pyproject.toml for consistency Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
37caa7842d
commit
c9151df7e0
5 changed files with 381 additions and 87 deletions
|
|
@ -1,10 +1,9 @@
|
|||
FROM python:3.12-slim
|
||||
|
||||
RUN pip install --no-cache-dir asyncssh
|
||||
|
||||
COPY proxy.py /app/proxy.py
|
||||
WORKDIR /app
|
||||
COPY pyproject.toml .
|
||||
RUN pip install --no-cache-dir .
|
||||
COPY proxy.py .
|
||||
|
||||
EXPOSE 10011
|
||||
|
||||
CMD ["python", "-u", "proxy.py"]
|
||||
|
|
|
|||
89
proxy.py
89
proxy.py
|
|
@ -1,7 +1,11 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import contextlib
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
|
||||
import asyncssh
|
||||
|
||||
|
|
@ -12,10 +16,21 @@ logging.basicConfig(
|
|||
)
|
||||
log = logging.getLogger("ts3query-proxy")
|
||||
|
||||
TS6_HOST = os.environ.get("TS6_HOST", "teamspeak6")
|
||||
TS6_SSH_PORT = int(os.environ.get("TS6_SSH_PORT", "10022"))
|
||||
LISTEN_HOST = "0.0.0.0"
|
||||
LISTEN_PORT = int(os.environ.get("LISTEN_PORT", "10011"))
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Config:
|
||||
ts6_host: str = "teamspeak6"
|
||||
ts6_ssh_port: int = 10022
|
||||
listen_host: str = "0.0.0.0"
|
||||
listen_port: int = 10011
|
||||
|
||||
@classmethod
|
||||
def from_env(cls) -> Config:
|
||||
return cls(
|
||||
ts6_host=os.environ.get("TS6_HOST", "teamspeak6"),
|
||||
ts6_ssh_port=int(os.environ.get("TS6_SSH_PORT", "10022")),
|
||||
listen_port=int(os.environ.get("LISTEN_PORT", "10011")),
|
||||
)
|
||||
|
||||
TS3_BANNER = (
|
||||
b"TS3\n\r"
|
||||
|
|
@ -26,12 +41,18 @@ TS3_BANNER = (
|
|||
|
||||
|
||||
class ClientHandler:
|
||||
def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
|
||||
def __init__(
|
||||
self,
|
||||
reader: asyncio.StreamReader,
|
||||
writer: asyncio.StreamWriter,
|
||||
config: Config | None = None,
|
||||
):
|
||||
self.reader = reader
|
||||
self.writer = writer
|
||||
self.config = config or Config()
|
||||
self.ssh_conn: asyncssh.SSHClientConnection | None = None
|
||||
self.ssh_process: asyncssh.SSHClientProcess | None = None
|
||||
self.addr = writer.get_extra_info("peername")
|
||||
self.addr: tuple[str, int] | None = writer.get_extra_info("peername")
|
||||
|
||||
async def handle(self) -> None:
|
||||
log.info("Client connected from %s", self.addr)
|
||||
|
|
@ -39,7 +60,7 @@ class ClientHandler:
|
|||
self.writer.write(TS3_BANNER)
|
||||
await self.writer.drain()
|
||||
await self._auth_loop()
|
||||
except asyncio.TimeoutError:
|
||||
except TimeoutError:
|
||||
log.warning("Client %s timed out", self.addr)
|
||||
except (ConnectionResetError, BrokenPipeError):
|
||||
log.info("Client %s disconnected abruptly", self.addr)
|
||||
|
|
@ -90,11 +111,16 @@ class ClientHandler:
|
|||
|
||||
async def _connect_ssh(self, username: str, password: str) -> bool:
|
||||
try:
|
||||
log.info("SSH connecting to %s:%d as %s", TS6_HOST, TS6_SSH_PORT, username)
|
||||
log.info(
|
||||
"SSH connecting to %s:%d as %s",
|
||||
self.config.ts6_host,
|
||||
self.config.ts6_ssh_port,
|
||||
username,
|
||||
)
|
||||
self.ssh_conn = await asyncio.wait_for(
|
||||
asyncssh.connect(
|
||||
TS6_HOST,
|
||||
TS6_SSH_PORT,
|
||||
self.config.ts6_host,
|
||||
self.config.ts6_ssh_port,
|
||||
username=username,
|
||||
password=password,
|
||||
known_hosts=None,
|
||||
|
|
@ -115,7 +141,7 @@ class ClientHandler:
|
|||
banner += chunk
|
||||
if "help" in banner.lower():
|
||||
break
|
||||
except asyncio.TimeoutError:
|
||||
except TimeoutError:
|
||||
break
|
||||
|
||||
log.info("SSH session established for %s", self.addr)
|
||||
|
|
@ -124,11 +150,18 @@ class ClientHandler:
|
|||
log.warning("SSH auth failed for %s (bad credentials)", self.addr)
|
||||
return False
|
||||
except Exception as e:
|
||||
log.error("SSH connection to %s:%d failed: %s", TS6_HOST, TS6_SSH_PORT, e)
|
||||
log.error(
|
||||
"SSH connection to %s:%d failed: %s",
|
||||
self.config.ts6_host,
|
||||
self.config.ts6_ssh_port,
|
||||
e,
|
||||
)
|
||||
return False
|
||||
|
||||
async def _proxy(self) -> None:
|
||||
assert self.ssh_process is not None
|
||||
log.info("Proxying traffic for %s", self.addr)
|
||||
ssh_process = self.ssh_process
|
||||
|
||||
async def client_to_ssh() -> None:
|
||||
try:
|
||||
|
|
@ -136,7 +169,7 @@ class ClientHandler:
|
|||
line = await self.reader.readline()
|
||||
if not line:
|
||||
break
|
||||
self.ssh_process.stdin.write(
|
||||
ssh_process.stdin.write(
|
||||
line.decode("utf-8", errors="replace")
|
||||
)
|
||||
except (ConnectionResetError, BrokenPipeError):
|
||||
|
|
@ -145,7 +178,7 @@ class ClientHandler:
|
|||
async def ssh_to_client() -> None:
|
||||
try:
|
||||
while True:
|
||||
data = await self.ssh_process.stdout.read(4096)
|
||||
data = await ssh_process.stdout.read(4096)
|
||||
if not data:
|
||||
break
|
||||
raw = data.encode() if isinstance(data, str) else data
|
||||
|
|
@ -161,12 +194,9 @@ class ClientHandler:
|
|||
done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
|
||||
for t in pending:
|
||||
t.cancel()
|
||||
# Await cancelled tasks to suppress warnings
|
||||
for t in pending:
|
||||
try:
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await t
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
async def _cleanup(self) -> None:
|
||||
try:
|
||||
|
|
@ -191,16 +221,27 @@ def _sanitize(cmd: str) -> str:
|
|||
|
||||
|
||||
async def handle_client(
|
||||
reader: asyncio.StreamReader, writer: asyncio.StreamWriter
|
||||
reader: asyncio.StreamReader,
|
||||
writer: asyncio.StreamWriter,
|
||||
config: Config | None = None,
|
||||
) -> None:
|
||||
handler = ClientHandler(reader, writer)
|
||||
handler = ClientHandler(reader, writer, config)
|
||||
await handler.handle()
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
server = await asyncio.start_server(handle_client, LISTEN_HOST, LISTEN_PORT)
|
||||
log.info("TS3 Query Proxy listening on %s:%d", LISTEN_HOST, LISTEN_PORT)
|
||||
log.info("Forwarding to %s:%d (SSH Query)", TS6_HOST, TS6_SSH_PORT)
|
||||
config = Config.from_env()
|
||||
|
||||
async def _on_connect(
|
||||
reader: asyncio.StreamReader, writer: asyncio.StreamWriter
|
||||
) -> None:
|
||||
await handle_client(reader, writer, config)
|
||||
|
||||
server = await asyncio.start_server(
|
||||
_on_connect, config.listen_host, config.listen_port
|
||||
)
|
||||
log.info("TS3 Query Proxy listening on %s:%d", config.listen_host, config.listen_port)
|
||||
log.info("Forwarding to %s:%d (SSH Query)", config.ts6_host, config.ts6_ssh_port)
|
||||
async with server:
|
||||
await server.serve_forever()
|
||||
|
||||
|
|
|
|||
36
pyproject.toml
Normal file
36
pyproject.toml
Normal file
|
|
@ -0,0 +1,36 @@
|
|||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[project]
|
||||
name = "ts3-query-proxy"
|
||||
version = "1.0.0"
|
||||
description = "Translation layer bridging TS3 raw ServerQuery (TCP 10011) to TS6 SSH Query (TCP 10022)"
|
||||
readme = "README.md"
|
||||
license = "MIT"
|
||||
requires-python = ">=3.11"
|
||||
authors = [{ name = "Joshua Hirsig" }]
|
||||
dependencies = ["asyncssh>=2.17"]
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = ["pytest>=8", "pytest-asyncio>=0.24", "ruff>=0.9", "mypy>=1.14"]
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = ["."]
|
||||
only-include = ["proxy.py"]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = ["tests"]
|
||||
|
||||
[tool.ruff]
|
||||
target-version = "py311"
|
||||
line-length = 100
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = ["E", "W", "F", "I", "UP", "B", "SIM"]
|
||||
|
||||
[tool.mypy]
|
||||
python_version = "3.11"
|
||||
warn_return_any = true
|
||||
warn_unused_configs = true
|
||||
disallow_untyped_defs = true
|
||||
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
|
|
@ -10,11 +10,11 @@ import os
|
|||
import unittest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import sys
|
||||
from proxy import TS3_BANNER, ClientHandler, Config, _sanitize
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
|
||||
|
||||
from proxy import ClientHandler, TS3_BANNER, _sanitize
|
||||
# ---------------------------------------------------------------------------
|
||||
# Unit tests — _sanitize
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSanitize(unittest.TestCase):
|
||||
|
|
@ -27,6 +27,59 @@ class TestSanitize(unittest.TestCase):
|
|||
def test_short_login(self):
|
||||
self.assertEqual(_sanitize("login serveradmin"), "login serveradmin")
|
||||
|
||||
def test_empty_string(self):
|
||||
self.assertEqual(_sanitize(""), "")
|
||||
|
||||
def test_login_keyword_only(self):
|
||||
self.assertEqual(_sanitize("login"), "login")
|
||||
|
||||
def test_case_insensitive(self):
|
||||
self.assertEqual(_sanitize("LOGIN serveradmin secret"), "login serveradmin ***")
|
||||
|
||||
def test_password_with_spaces(self):
|
||||
self.assertEqual(
|
||||
_sanitize("login serveradmin pass word with spaces"),
|
||||
"login serveradmin ***",
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Unit tests — Config
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestConfig(unittest.TestCase):
|
||||
def test_defaults(self):
|
||||
cfg = Config()
|
||||
self.assertEqual(cfg.ts6_host, "teamspeak6")
|
||||
self.assertEqual(cfg.ts6_ssh_port, 10022)
|
||||
self.assertEqual(cfg.listen_host, "0.0.0.0")
|
||||
self.assertEqual(cfg.listen_port, 10011)
|
||||
|
||||
def test_from_env(self):
|
||||
env = {"TS6_HOST": "myhost", "TS6_SSH_PORT": "9999", "LISTEN_PORT": "5555"}
|
||||
with patch.dict(os.environ, env, clear=False):
|
||||
cfg = Config.from_env()
|
||||
self.assertEqual(cfg.ts6_host, "myhost")
|
||||
self.assertEqual(cfg.ts6_ssh_port, 9999)
|
||||
self.assertEqual(cfg.listen_port, 5555)
|
||||
|
||||
def test_from_env_defaults(self):
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
cfg = Config.from_env()
|
||||
self.assertEqual(cfg.ts6_host, "teamspeak6")
|
||||
self.assertEqual(cfg.ts6_ssh_port, 10022)
|
||||
|
||||
def test_frozen(self):
|
||||
cfg = Config()
|
||||
with self.assertRaises(AttributeError):
|
||||
cfg.ts6_host = "other"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Unit tests — Banner
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBanner(unittest.TestCase):
|
||||
def test_banner_starts_with_ts3(self):
|
||||
|
|
@ -35,6 +88,14 @@ class TestBanner(unittest.TestCase):
|
|||
def test_banner_contains_help(self):
|
||||
self.assertIn(b"help", TS3_BANNER)
|
||||
|
||||
def test_banner_ends_with_crlf(self):
|
||||
self.assertTrue(TS3_BANNER.endswith(b"\n\r"))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Unit tests — ClientHandler
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestClientHandlerUnit(unittest.IsolatedAsyncioTestCase):
|
||||
def _make_handler(self, input_lines: list[bytes]):
|
||||
|
|
@ -49,28 +110,70 @@ class TestClientHandlerUnit(unittest.IsolatedAsyncioTestCase):
|
|||
handler = ClientHandler(reader, writer)
|
||||
return handler, writer
|
||||
|
||||
def _get_writes(self, writer):
|
||||
return [call[0][0] for call in writer.write.call_args_list]
|
||||
|
||||
# --- Banner ---
|
||||
|
||||
async def test_sends_banner_on_connect(self):
|
||||
handler, writer = self._make_handler([b"quit\n"])
|
||||
await handler.handle()
|
||||
first_write = writer.write.call_args_list[0][0][0]
|
||||
self.assertEqual(first_write, TS3_BANNER)
|
||||
|
||||
# --- Pre-auth commands ---
|
||||
|
||||
async def test_quit_before_login(self):
|
||||
handler, writer = self._make_handler([b"quit\n"])
|
||||
await handler.handle()
|
||||
writes = [call[0][0] for call in writer.write.call_args_list]
|
||||
self.assertIn(b"error id=0 msg=ok\n\r", writes)
|
||||
self.assertIn(b"error id=0 msg=ok\n\r", self._get_writes(writer))
|
||||
|
||||
async def test_command_before_login_returns_not_logged_in(self):
|
||||
handler, writer = self._make_handler([b"serverinfo\n", b"quit\n"])
|
||||
await handler.handle()
|
||||
writes = [call[0][0] for call in writer.write.call_args_list]
|
||||
self.assertIn(b"error id=1794 msg=not\\slogged\\sin\n\r", writes)
|
||||
self.assertIn(b"error id=1794 msg=not\\slogged\\sin\n\r", self._get_writes(writer))
|
||||
|
||||
async def test_multiple_commands_before_login(self):
|
||||
handler, writer = self._make_handler(
|
||||
[b"serverinfo\n", b"clientlist\n", b"quit\n"]
|
||||
)
|
||||
await handler.handle()
|
||||
writes = self._get_writes(writer)
|
||||
not_logged_in = b"error id=1794 msg=not\\slogged\\sin\n\r"
|
||||
count = writes.count(not_logged_in)
|
||||
self.assertEqual(count, 2, f"Expected 2 not-logged-in errors, got {count}")
|
||||
|
||||
async def test_empty_line_ignored(self):
|
||||
handler, writer = self._make_handler([b"\n", b"quit\n"])
|
||||
await handler.handle()
|
||||
writes = self._get_writes(writer)
|
||||
self.assertIn(b"error id=0 msg=ok\n\r", writes)
|
||||
self.assertNotIn(b"error id=1794", b"".join(writes))
|
||||
|
||||
# --- Disconnect ---
|
||||
|
||||
async def test_empty_disconnect(self):
|
||||
handler, writer = self._make_handler([b""])
|
||||
await handler.handle()
|
||||
# Should not raise, just disconnect
|
||||
|
||||
async def test_timeout_in_auth_loop(self):
|
||||
reader = AsyncMock(spec=asyncio.StreamReader)
|
||||
reader.readline = AsyncMock(side_effect=asyncio.TimeoutError)
|
||||
writer = MagicMock(spec=asyncio.StreamWriter)
|
||||
writer.write = MagicMock()
|
||||
writer.drain = AsyncMock()
|
||||
writer.close = MagicMock()
|
||||
writer.wait_closed = AsyncMock()
|
||||
writer.get_extra_info = MagicMock(return_value=("127.0.0.1", 12345))
|
||||
handler = ClientHandler(reader, writer)
|
||||
await handler.handle() # should not raise
|
||||
|
||||
# --- Login ---
|
||||
|
||||
async def test_login_incomplete_command(self):
|
||||
handler, writer = self._make_handler([b"login serveradmin\n", b"quit\n"])
|
||||
await handler.handle()
|
||||
self.assertIn(b"error id=256 msg=command\\snot\\sfound\n\r", self._get_writes(writer))
|
||||
|
||||
@patch("proxy.asyncssh")
|
||||
async def test_login_bad_credentials(self, mock_asyncssh):
|
||||
|
|
@ -82,19 +185,144 @@ class TestClientHandlerUnit(unittest.IsolatedAsyncioTestCase):
|
|||
)
|
||||
handler, writer = self._make_handler([b"login serveradmin wrongpw\n"])
|
||||
await handler.handle()
|
||||
writes = [call[0][0] for call in writer.write.call_args_list]
|
||||
self.assertIn(
|
||||
b"error id=520 msg=invalid\\sloginname\\sor\\spassword\n\r", writes
|
||||
b"error id=520 msg=invalid\\sloginname\\sor\\spassword\n\r",
|
||||
self._get_writes(writer),
|
||||
)
|
||||
|
||||
async def test_login_incomplete_command(self):
|
||||
handler, writer = self._make_handler([b"login serveradmin\n", b"quit\n"])
|
||||
@patch("proxy.asyncssh")
|
||||
async def test_login_ssh_timeout(self, mock_asyncssh):
|
||||
import asyncssh as real_asyncssh
|
||||
|
||||
mock_asyncssh.PermissionDenied = real_asyncssh.PermissionDenied
|
||||
mock_asyncssh.connect = AsyncMock(side_effect=asyncio.TimeoutError)
|
||||
handler, writer = self._make_handler([b"login serveradmin pass\n"])
|
||||
await handler.handle()
|
||||
writes = [call[0][0] for call in writer.write.call_args_list]
|
||||
self.assertIn(b"error id=256 msg=command\\snot\\sfound\n\r", writes)
|
||||
self.assertIn(
|
||||
b"error id=520 msg=invalid\\sloginname\\sor\\spassword\n\r",
|
||||
self._get_writes(writer),
|
||||
)
|
||||
|
||||
@patch("proxy.asyncssh")
|
||||
async def test_login_ssh_connection_refused(self, mock_asyncssh):
|
||||
import asyncssh as real_asyncssh
|
||||
|
||||
mock_asyncssh.PermissionDenied = real_asyncssh.PermissionDenied
|
||||
mock_asyncssh.connect = AsyncMock(side_effect=OSError("Connection refused"))
|
||||
handler, writer = self._make_handler([b"login serveradmin pass\n"])
|
||||
await handler.handle()
|
||||
self.assertIn(
|
||||
b"error id=520 msg=invalid\\sloginname\\sor\\spassword\n\r",
|
||||
self._get_writes(writer),
|
||||
)
|
||||
|
||||
@patch("proxy.asyncssh")
|
||||
async def test_login_case_insensitive(self, mock_asyncssh):
|
||||
"""LOGIN (uppercase) should be handled the same as login."""
|
||||
import asyncssh as real_asyncssh
|
||||
|
||||
mock_asyncssh.PermissionDenied = real_asyncssh.PermissionDenied
|
||||
mock_asyncssh.connect = AsyncMock(
|
||||
side_effect=real_asyncssh.PermissionDenied("bad")
|
||||
)
|
||||
handler, writer = self._make_handler([b"LOGIN serveradmin wrongpw\n"])
|
||||
await handler.handle()
|
||||
self.assertIn(
|
||||
b"error id=520 msg=invalid\\sloginname\\sor\\spassword\n\r",
|
||||
self._get_writes(writer),
|
||||
)
|
||||
|
||||
# --- Proxy (bidirectional forwarding) ---
|
||||
|
||||
@patch("proxy.asyncssh")
|
||||
async def test_proxy_forwards_command_to_ssh(self, mock_asyncssh):
|
||||
"""After login, client commands should be forwarded to SSH stdin."""
|
||||
import asyncssh as real_asyncssh
|
||||
|
||||
mock_asyncssh.PermissionDenied = real_asyncssh.PermissionDenied
|
||||
|
||||
# Mock SSH process
|
||||
mock_process = MagicMock()
|
||||
mock_process.stdin = MagicMock()
|
||||
mock_process.stdin.write = MagicMock()
|
||||
mock_process.stdout = AsyncMock()
|
||||
mock_process.stdout.read = AsyncMock(
|
||||
side_effect=["TS3 banner help\n\r", ""]
|
||||
)
|
||||
mock_process.close = MagicMock()
|
||||
|
||||
mock_conn = AsyncMock()
|
||||
mock_conn.create_process = AsyncMock(return_value=mock_process)
|
||||
mock_conn.close = MagicMock()
|
||||
|
||||
mock_asyncssh.connect = AsyncMock(return_value=mock_conn)
|
||||
|
||||
handler, writer = self._make_handler(
|
||||
[b"login serveradmin pass\n", b"serverinfo\n", b""]
|
||||
)
|
||||
await handler.handle()
|
||||
|
||||
# Verify login success response
|
||||
self.assertIn(b"error id=0 msg=ok\n\r", self._get_writes(writer))
|
||||
|
||||
# Verify command was forwarded to SSH stdin
|
||||
stdin_writes = [call[0][0] for call in mock_process.stdin.write.call_args_list]
|
||||
self.assertIn("serverinfo\n", stdin_writes)
|
||||
|
||||
@patch("proxy.asyncssh")
|
||||
async def test_proxy_forwards_ssh_response_to_client(self, mock_asyncssh):
|
||||
"""SSH stdout data should be forwarded to the TCP client."""
|
||||
import asyncssh as real_asyncssh
|
||||
|
||||
mock_asyncssh.PermissionDenied = real_asyncssh.PermissionDenied
|
||||
|
||||
mock_process = MagicMock()
|
||||
mock_process.stdin = MagicMock()
|
||||
mock_process.stdin.write = MagicMock()
|
||||
mock_process.stdout = AsyncMock()
|
||||
mock_process.stdout.read = AsyncMock(
|
||||
side_effect=[
|
||||
"TS3 banner help\n\r",
|
||||
"virtualserver_name=Test error id=0 msg=ok\n\r",
|
||||
"",
|
||||
]
|
||||
)
|
||||
mock_process.close = MagicMock()
|
||||
|
||||
mock_conn = AsyncMock()
|
||||
mock_conn.create_process = AsyncMock(return_value=mock_process)
|
||||
mock_conn.close = MagicMock()
|
||||
|
||||
mock_asyncssh.connect = AsyncMock(return_value=mock_conn)
|
||||
|
||||
handler, writer = self._make_handler(
|
||||
[b"login serveradmin pass\n", b""]
|
||||
)
|
||||
await handler.handle()
|
||||
|
||||
writes = self._get_writes(writer)
|
||||
joined = b"".join(writes)
|
||||
self.assertIn(b"virtualserver_name=Test", joined)
|
||||
|
||||
# --- Cleanup ---
|
||||
|
||||
async def test_cleanup_with_no_ssh(self):
|
||||
"""Cleanup should not raise when SSH was never established."""
|
||||
handler, writer = self._make_handler([b""])
|
||||
self.assertIsNone(handler.ssh_conn)
|
||||
self.assertIsNone(handler.ssh_process)
|
||||
await handler._cleanup()
|
||||
|
||||
async def test_cleanup_when_writer_close_raises(self):
|
||||
"""Cleanup should swallow exceptions from writer.close."""
|
||||
handler, writer = self._make_handler([b""])
|
||||
writer.wait_closed = AsyncMock(side_effect=OSError("already closed"))
|
||||
await handler._cleanup() # should not raise
|
||||
|
||||
|
||||
# --- Integration tests (only run when env vars are set) ---
|
||||
# ---------------------------------------------------------------------------
|
||||
# Integration tests (only run when env vars are set)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
TS6_TEST_HOST = os.environ.get("TS6_TEST_HOST")
|
||||
TS6_TEST_SSH_PORT = int(os.environ.get("TS6_TEST_SSH_PORT", "10022"))
|
||||
|
|
@ -111,18 +339,15 @@ class TestIntegration(unittest.IsolatedAsyncioTestCase):
|
|||
"""Integration tests that start a real proxy and connect through it."""
|
||||
|
||||
async def asyncSetUp(self):
|
||||
os.environ["TS6_HOST"] = TS6_TEST_HOST
|
||||
os.environ["TS6_SSH_PORT"] = str(TS6_TEST_SSH_PORT)
|
||||
config = Config(ts6_host=TS6_TEST_HOST, ts6_ssh_port=TS6_TEST_SSH_PORT)
|
||||
|
||||
# Re-import to pick up env vars
|
||||
import importlib
|
||||
import proxy as proxy_mod
|
||||
async def on_connect(
|
||||
reader: asyncio.StreamReader, writer: asyncio.StreamWriter
|
||||
) -> None:
|
||||
handler = ClientHandler(reader, writer, config)
|
||||
await handler.handle()
|
||||
|
||||
importlib.reload(proxy_mod)
|
||||
|
||||
self.server = await asyncio.start_server(
|
||||
proxy_mod.handle_client, "127.0.0.1", PROXY_TEST_PORT
|
||||
)
|
||||
self.server = await asyncio.start_server(on_connect, "127.0.0.1", PROXY_TEST_PORT)
|
||||
|
||||
async def asyncTearDown(self):
|
||||
self.server.close()
|
||||
|
|
@ -130,11 +355,33 @@ class TestIntegration(unittest.IsolatedAsyncioTestCase):
|
|||
|
||||
async def _connect(self):
|
||||
reader, writer = await asyncio.open_connection("127.0.0.1", PROXY_TEST_PORT)
|
||||
# Read banner
|
||||
banner = await asyncio.wait_for(reader.read(4096), timeout=5)
|
||||
self.assertIn(b"TS3", banner)
|
||||
return reader, writer
|
||||
|
||||
async def _login_and_use(self):
|
||||
"""Helper: connect, login, select virtual server."""
|
||||
reader, writer = await self._connect()
|
||||
writer.write(f"login {TS6_TEST_USER} {TS6_TEST_PASS}\n".encode())
|
||||
await writer.drain()
|
||||
await asyncio.wait_for(reader.readline(), timeout=10)
|
||||
writer.write(b"use sid=1\n")
|
||||
await writer.drain()
|
||||
await asyncio.wait_for(reader.readline(), timeout=5)
|
||||
return reader, writer
|
||||
|
||||
async def _send_cmd(self, reader, writer, cmd):
|
||||
"""Send a command and return the full response including error line."""
|
||||
writer.write(f"{cmd}\n".encode())
|
||||
await writer.drain()
|
||||
data = b""
|
||||
for _ in range(20):
|
||||
chunk = await asyncio.wait_for(reader.readline(), timeout=5)
|
||||
data += chunk
|
||||
if b"error id=" in chunk:
|
||||
break
|
||||
return data
|
||||
|
||||
async def test_login_and_version(self):
|
||||
reader, writer = await self._connect()
|
||||
writer.write(f"login {TS6_TEST_USER} {TS6_TEST_PASS}\n".encode())
|
||||
|
|
@ -144,7 +391,6 @@ class TestIntegration(unittest.IsolatedAsyncioTestCase):
|
|||
|
||||
writer.write(b"version\n")
|
||||
await writer.drain()
|
||||
# version response + error line
|
||||
data = b""
|
||||
for _ in range(10):
|
||||
chunk = await asyncio.wait_for(reader.readline(), timeout=5)
|
||||
|
|
@ -186,44 +432,18 @@ class TestIntegration(unittest.IsolatedAsyncioTestCase):
|
|||
await writer.drain()
|
||||
writer.close()
|
||||
|
||||
async def _login_and_use(self):
|
||||
"""Helper: connect, login, select virtual server."""
|
||||
reader, writer = await self._connect()
|
||||
writer.write(f"login {TS6_TEST_USER} {TS6_TEST_PASS}\n".encode())
|
||||
await writer.drain()
|
||||
await asyncio.wait_for(reader.readline(), timeout=10)
|
||||
writer.write(b"use sid=1\n")
|
||||
await writer.drain()
|
||||
await asyncio.wait_for(reader.readline(), timeout=5)
|
||||
return reader, writer
|
||||
|
||||
async def _send_cmd(self, reader, writer, cmd):
|
||||
"""Send a command and return the full response including error line."""
|
||||
writer.write(f"{cmd}\n".encode())
|
||||
await writer.drain()
|
||||
data = b""
|
||||
for _ in range(20):
|
||||
chunk = await asyncio.wait_for(reader.readline(), timeout=5)
|
||||
data += chunk
|
||||
if b"error id=" in chunk:
|
||||
break
|
||||
return data
|
||||
|
||||
async def test_channellist_and_channelinfo(self):
|
||||
reader, writer = await self._login_and_use()
|
||||
|
||||
# Get channel list and extract first channel ID
|
||||
data = await self._send_cmd(reader, writer, "channellist")
|
||||
self.assertIn(b"error id=0", data)
|
||||
self.assertIn(b"cid=", data)
|
||||
|
||||
# Extract first cid
|
||||
for part in data.decode(errors="replace").split("|")[0].split():
|
||||
if part.startswith("cid="):
|
||||
cid = part.split("=")[1]
|
||||
break
|
||||
|
||||
# channelinfo with valid cid
|
||||
data = await self._send_cmd(reader, writer, f"channelinfo cid={cid}")
|
||||
self.assertIn(b"error id=0", data)
|
||||
self.assertIn(b"channel_name=", data)
|
||||
|
|
@ -235,24 +455,22 @@ class TestIntegration(unittest.IsolatedAsyncioTestCase):
|
|||
async def test_ban_add_list_delete(self):
|
||||
reader, writer = await self._login_and_use()
|
||||
|
||||
# Add a temporary ban
|
||||
data = await self._send_cmd(reader, writer, "banadd ip=254.253.252.251 banreason=proxytest time=10")
|
||||
data = await self._send_cmd(
|
||||
reader, writer, "banadd ip=254.253.252.251 banreason=proxytest time=10"
|
||||
)
|
||||
self.assertIn(b"error id=0", data)
|
||||
self.assertIn(b"banid=", data)
|
||||
|
||||
# Extract ban ID
|
||||
banid = None
|
||||
for part in data.decode(errors="replace").split():
|
||||
if part.startswith("banid="):
|
||||
banid = part.split("=")[1].strip()
|
||||
break
|
||||
|
||||
# List bans
|
||||
data = await self._send_cmd(reader, writer, "banlist")
|
||||
self.assertIn(b"error id=0", data)
|
||||
self.assertIn(b"ip=254.253.252.251", data)
|
||||
|
||||
# Delete the ban
|
||||
data = await self._send_cmd(reader, writer, f"bandel banid={banid}")
|
||||
self.assertIn(b"error id=0", data)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue