Refactor: Config dataclass, more unit tests, ruff/mypy setup

- Replace module-level env var globals with frozen Config dataclass
- Add pyproject.toml with project metadata, ruff, mypy, pytest config
- Expand unit tests from 11 to 30 (sanitize edge cases, SSH failure
  modes, proxy forwarding, cleanup, auth loop edge cases)
- Fix all ruff and mypy findings
- Integration tests no longer need importlib.reload hack
- Dockerfile installs from pyproject.toml for consistency

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Joshua Hirsig 2026-03-14 18:56:35 +01:00
parent 37caa7842d
commit c9151df7e0
5 changed files with 381 additions and 87 deletions

View file

@ -1,10 +1,9 @@
FROM python:3.12-slim
RUN pip install --no-cache-dir asyncssh
COPY proxy.py /app/proxy.py
WORKDIR /app
COPY pyproject.toml .
RUN pip install --no-cache-dir .
COPY proxy.py .
EXPOSE 10011
CMD ["python", "-u", "proxy.py"]

View file

@ -1,7 +1,11 @@
from __future__ import annotations
import asyncio
import os
import contextlib
import logging
import os
import sys
from dataclasses import dataclass
import asyncssh
@ -12,10 +16,21 @@ logging.basicConfig(
)
log = logging.getLogger("ts3query-proxy")
TS6_HOST = os.environ.get("TS6_HOST", "teamspeak6")
TS6_SSH_PORT = int(os.environ.get("TS6_SSH_PORT", "10022"))
LISTEN_HOST = "0.0.0.0"
LISTEN_PORT = int(os.environ.get("LISTEN_PORT", "10011"))
@dataclass(frozen=True)
class Config:
ts6_host: str = "teamspeak6"
ts6_ssh_port: int = 10022
listen_host: str = "0.0.0.0"
listen_port: int = 10011
@classmethod
def from_env(cls) -> Config:
return cls(
ts6_host=os.environ.get("TS6_HOST", "teamspeak6"),
ts6_ssh_port=int(os.environ.get("TS6_SSH_PORT", "10022")),
listen_port=int(os.environ.get("LISTEN_PORT", "10011")),
)
TS3_BANNER = (
b"TS3\n\r"
@ -26,12 +41,18 @@ TS3_BANNER = (
class ClientHandler:
def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
def __init__(
self,
reader: asyncio.StreamReader,
writer: asyncio.StreamWriter,
config: Config | None = None,
):
self.reader = reader
self.writer = writer
self.config = config or Config()
self.ssh_conn: asyncssh.SSHClientConnection | None = None
self.ssh_process: asyncssh.SSHClientProcess | None = None
self.addr = writer.get_extra_info("peername")
self.addr: tuple[str, int] | None = writer.get_extra_info("peername")
async def handle(self) -> None:
log.info("Client connected from %s", self.addr)
@ -39,7 +60,7 @@ class ClientHandler:
self.writer.write(TS3_BANNER)
await self.writer.drain()
await self._auth_loop()
except asyncio.TimeoutError:
except TimeoutError:
log.warning("Client %s timed out", self.addr)
except (ConnectionResetError, BrokenPipeError):
log.info("Client %s disconnected abruptly", self.addr)
@ -90,11 +111,16 @@ class ClientHandler:
async def _connect_ssh(self, username: str, password: str) -> bool:
try:
log.info("SSH connecting to %s:%d as %s", TS6_HOST, TS6_SSH_PORT, username)
log.info(
"SSH connecting to %s:%d as %s",
self.config.ts6_host,
self.config.ts6_ssh_port,
username,
)
self.ssh_conn = await asyncio.wait_for(
asyncssh.connect(
TS6_HOST,
TS6_SSH_PORT,
self.config.ts6_host,
self.config.ts6_ssh_port,
username=username,
password=password,
known_hosts=None,
@ -115,7 +141,7 @@ class ClientHandler:
banner += chunk
if "help" in banner.lower():
break
except asyncio.TimeoutError:
except TimeoutError:
break
log.info("SSH session established for %s", self.addr)
@ -124,11 +150,18 @@ class ClientHandler:
log.warning("SSH auth failed for %s (bad credentials)", self.addr)
return False
except Exception as e:
log.error("SSH connection to %s:%d failed: %s", TS6_HOST, TS6_SSH_PORT, e)
log.error(
"SSH connection to %s:%d failed: %s",
self.config.ts6_host,
self.config.ts6_ssh_port,
e,
)
return False
async def _proxy(self) -> None:
assert self.ssh_process is not None
log.info("Proxying traffic for %s", self.addr)
ssh_process = self.ssh_process
async def client_to_ssh() -> None:
try:
@ -136,7 +169,7 @@ class ClientHandler:
line = await self.reader.readline()
if not line:
break
self.ssh_process.stdin.write(
ssh_process.stdin.write(
line.decode("utf-8", errors="replace")
)
except (ConnectionResetError, BrokenPipeError):
@ -145,7 +178,7 @@ class ClientHandler:
async def ssh_to_client() -> None:
try:
while True:
data = await self.ssh_process.stdout.read(4096)
data = await ssh_process.stdout.read(4096)
if not data:
break
raw = data.encode() if isinstance(data, str) else data
@ -161,12 +194,9 @@ class ClientHandler:
done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
for t in pending:
t.cancel()
# Await cancelled tasks to suppress warnings
for t in pending:
try:
with contextlib.suppress(asyncio.CancelledError):
await t
except asyncio.CancelledError:
pass
async def _cleanup(self) -> None:
try:
@ -191,16 +221,27 @@ def _sanitize(cmd: str) -> str:
async def handle_client(
reader: asyncio.StreamReader, writer: asyncio.StreamWriter
reader: asyncio.StreamReader,
writer: asyncio.StreamWriter,
config: Config | None = None,
) -> None:
handler = ClientHandler(reader, writer)
handler = ClientHandler(reader, writer, config)
await handler.handle()
async def main() -> None:
server = await asyncio.start_server(handle_client, LISTEN_HOST, LISTEN_PORT)
log.info("TS3 Query Proxy listening on %s:%d", LISTEN_HOST, LISTEN_PORT)
log.info("Forwarding to %s:%d (SSH Query)", TS6_HOST, TS6_SSH_PORT)
config = Config.from_env()
async def _on_connect(
reader: asyncio.StreamReader, writer: asyncio.StreamWriter
) -> None:
await handle_client(reader, writer, config)
server = await asyncio.start_server(
_on_connect, config.listen_host, config.listen_port
)
log.info("TS3 Query Proxy listening on %s:%d", config.listen_host, config.listen_port)
log.info("Forwarding to %s:%d (SSH Query)", config.ts6_host, config.ts6_ssh_port)
async with server:
await server.serve_forever()

36
pyproject.toml Normal file
View file

@ -0,0 +1,36 @@
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[project]
name = "ts3-query-proxy"
version = "1.0.0"
description = "Translation layer bridging TS3 raw ServerQuery (TCP 10011) to TS6 SSH Query (TCP 10022)"
readme = "README.md"
license = "MIT"
requires-python = ">=3.11"
authors = [{ name = "Joshua Hirsig" }]
dependencies = ["asyncssh>=2.17"]
[project.optional-dependencies]
dev = ["pytest>=8", "pytest-asyncio>=0.24", "ruff>=0.9", "mypy>=1.14"]
[tool.hatch.build.targets.wheel]
packages = ["."]
only-include = ["proxy.py"]
[tool.pytest.ini_options]
testpaths = ["tests"]
[tool.ruff]
target-version = "py311"
line-length = 100
[tool.ruff.lint]
select = ["E", "W", "F", "I", "UP", "B", "SIM"]
[tool.mypy]
python_version = "3.11"
warn_return_any = true
warn_unused_configs = true
disallow_untyped_defs = true

0
tests/__init__.py Normal file
View file

View file

@ -10,11 +10,11 @@ import os
import unittest
from unittest.mock import AsyncMock, MagicMock, patch
import sys
from proxy import TS3_BANNER, ClientHandler, Config, _sanitize
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
from proxy import ClientHandler, TS3_BANNER, _sanitize
# ---------------------------------------------------------------------------
# Unit tests — _sanitize
# ---------------------------------------------------------------------------
class TestSanitize(unittest.TestCase):
@ -27,6 +27,59 @@ class TestSanitize(unittest.TestCase):
def test_short_login(self):
self.assertEqual(_sanitize("login serveradmin"), "login serveradmin")
def test_empty_string(self):
self.assertEqual(_sanitize(""), "")
def test_login_keyword_only(self):
self.assertEqual(_sanitize("login"), "login")
def test_case_insensitive(self):
self.assertEqual(_sanitize("LOGIN serveradmin secret"), "login serveradmin ***")
def test_password_with_spaces(self):
self.assertEqual(
_sanitize("login serveradmin pass word with spaces"),
"login serveradmin ***",
)
# ---------------------------------------------------------------------------
# Unit tests — Config
# ---------------------------------------------------------------------------
class TestConfig(unittest.TestCase):
def test_defaults(self):
cfg = Config()
self.assertEqual(cfg.ts6_host, "teamspeak6")
self.assertEqual(cfg.ts6_ssh_port, 10022)
self.assertEqual(cfg.listen_host, "0.0.0.0")
self.assertEqual(cfg.listen_port, 10011)
def test_from_env(self):
env = {"TS6_HOST": "myhost", "TS6_SSH_PORT": "9999", "LISTEN_PORT": "5555"}
with patch.dict(os.environ, env, clear=False):
cfg = Config.from_env()
self.assertEqual(cfg.ts6_host, "myhost")
self.assertEqual(cfg.ts6_ssh_port, 9999)
self.assertEqual(cfg.listen_port, 5555)
def test_from_env_defaults(self):
with patch.dict(os.environ, {}, clear=True):
cfg = Config.from_env()
self.assertEqual(cfg.ts6_host, "teamspeak6")
self.assertEqual(cfg.ts6_ssh_port, 10022)
def test_frozen(self):
cfg = Config()
with self.assertRaises(AttributeError):
cfg.ts6_host = "other"
# ---------------------------------------------------------------------------
# Unit tests — Banner
# ---------------------------------------------------------------------------
class TestBanner(unittest.TestCase):
def test_banner_starts_with_ts3(self):
@ -35,6 +88,14 @@ class TestBanner(unittest.TestCase):
def test_banner_contains_help(self):
self.assertIn(b"help", TS3_BANNER)
def test_banner_ends_with_crlf(self):
self.assertTrue(TS3_BANNER.endswith(b"\n\r"))
# ---------------------------------------------------------------------------
# Unit tests — ClientHandler
# ---------------------------------------------------------------------------
class TestClientHandlerUnit(unittest.IsolatedAsyncioTestCase):
def _make_handler(self, input_lines: list[bytes]):
@ -49,28 +110,70 @@ class TestClientHandlerUnit(unittest.IsolatedAsyncioTestCase):
handler = ClientHandler(reader, writer)
return handler, writer
def _get_writes(self, writer):
return [call[0][0] for call in writer.write.call_args_list]
# --- Banner ---
async def test_sends_banner_on_connect(self):
handler, writer = self._make_handler([b"quit\n"])
await handler.handle()
first_write = writer.write.call_args_list[0][0][0]
self.assertEqual(first_write, TS3_BANNER)
# --- Pre-auth commands ---
async def test_quit_before_login(self):
handler, writer = self._make_handler([b"quit\n"])
await handler.handle()
writes = [call[0][0] for call in writer.write.call_args_list]
self.assertIn(b"error id=0 msg=ok\n\r", writes)
self.assertIn(b"error id=0 msg=ok\n\r", self._get_writes(writer))
async def test_command_before_login_returns_not_logged_in(self):
handler, writer = self._make_handler([b"serverinfo\n", b"quit\n"])
await handler.handle()
writes = [call[0][0] for call in writer.write.call_args_list]
self.assertIn(b"error id=1794 msg=not\\slogged\\sin\n\r", writes)
self.assertIn(b"error id=1794 msg=not\\slogged\\sin\n\r", self._get_writes(writer))
async def test_multiple_commands_before_login(self):
handler, writer = self._make_handler(
[b"serverinfo\n", b"clientlist\n", b"quit\n"]
)
await handler.handle()
writes = self._get_writes(writer)
not_logged_in = b"error id=1794 msg=not\\slogged\\sin\n\r"
count = writes.count(not_logged_in)
self.assertEqual(count, 2, f"Expected 2 not-logged-in errors, got {count}")
async def test_empty_line_ignored(self):
handler, writer = self._make_handler([b"\n", b"quit\n"])
await handler.handle()
writes = self._get_writes(writer)
self.assertIn(b"error id=0 msg=ok\n\r", writes)
self.assertNotIn(b"error id=1794", b"".join(writes))
# --- Disconnect ---
async def test_empty_disconnect(self):
handler, writer = self._make_handler([b""])
await handler.handle()
# Should not raise, just disconnect
async def test_timeout_in_auth_loop(self):
reader = AsyncMock(spec=asyncio.StreamReader)
reader.readline = AsyncMock(side_effect=asyncio.TimeoutError)
writer = MagicMock(spec=asyncio.StreamWriter)
writer.write = MagicMock()
writer.drain = AsyncMock()
writer.close = MagicMock()
writer.wait_closed = AsyncMock()
writer.get_extra_info = MagicMock(return_value=("127.0.0.1", 12345))
handler = ClientHandler(reader, writer)
await handler.handle() # should not raise
# --- Login ---
async def test_login_incomplete_command(self):
handler, writer = self._make_handler([b"login serveradmin\n", b"quit\n"])
await handler.handle()
self.assertIn(b"error id=256 msg=command\\snot\\sfound\n\r", self._get_writes(writer))
@patch("proxy.asyncssh")
async def test_login_bad_credentials(self, mock_asyncssh):
@ -82,19 +185,144 @@ class TestClientHandlerUnit(unittest.IsolatedAsyncioTestCase):
)
handler, writer = self._make_handler([b"login serveradmin wrongpw\n"])
await handler.handle()
writes = [call[0][0] for call in writer.write.call_args_list]
self.assertIn(
b"error id=520 msg=invalid\\sloginname\\sor\\spassword\n\r", writes
b"error id=520 msg=invalid\\sloginname\\sor\\spassword\n\r",
self._get_writes(writer),
)
async def test_login_incomplete_command(self):
handler, writer = self._make_handler([b"login serveradmin\n", b"quit\n"])
@patch("proxy.asyncssh")
async def test_login_ssh_timeout(self, mock_asyncssh):
import asyncssh as real_asyncssh
mock_asyncssh.PermissionDenied = real_asyncssh.PermissionDenied
mock_asyncssh.connect = AsyncMock(side_effect=asyncio.TimeoutError)
handler, writer = self._make_handler([b"login serveradmin pass\n"])
await handler.handle()
writes = [call[0][0] for call in writer.write.call_args_list]
self.assertIn(b"error id=256 msg=command\\snot\\sfound\n\r", writes)
self.assertIn(
b"error id=520 msg=invalid\\sloginname\\sor\\spassword\n\r",
self._get_writes(writer),
)
@patch("proxy.asyncssh")
async def test_login_ssh_connection_refused(self, mock_asyncssh):
import asyncssh as real_asyncssh
mock_asyncssh.PermissionDenied = real_asyncssh.PermissionDenied
mock_asyncssh.connect = AsyncMock(side_effect=OSError("Connection refused"))
handler, writer = self._make_handler([b"login serveradmin pass\n"])
await handler.handle()
self.assertIn(
b"error id=520 msg=invalid\\sloginname\\sor\\spassword\n\r",
self._get_writes(writer),
)
@patch("proxy.asyncssh")
async def test_login_case_insensitive(self, mock_asyncssh):
"""LOGIN (uppercase) should be handled the same as login."""
import asyncssh as real_asyncssh
mock_asyncssh.PermissionDenied = real_asyncssh.PermissionDenied
mock_asyncssh.connect = AsyncMock(
side_effect=real_asyncssh.PermissionDenied("bad")
)
handler, writer = self._make_handler([b"LOGIN serveradmin wrongpw\n"])
await handler.handle()
self.assertIn(
b"error id=520 msg=invalid\\sloginname\\sor\\spassword\n\r",
self._get_writes(writer),
)
# --- Proxy (bidirectional forwarding) ---
@patch("proxy.asyncssh")
async def test_proxy_forwards_command_to_ssh(self, mock_asyncssh):
"""After login, client commands should be forwarded to SSH stdin."""
import asyncssh as real_asyncssh
mock_asyncssh.PermissionDenied = real_asyncssh.PermissionDenied
# Mock SSH process
mock_process = MagicMock()
mock_process.stdin = MagicMock()
mock_process.stdin.write = MagicMock()
mock_process.stdout = AsyncMock()
mock_process.stdout.read = AsyncMock(
side_effect=["TS3 banner help\n\r", ""]
)
mock_process.close = MagicMock()
mock_conn = AsyncMock()
mock_conn.create_process = AsyncMock(return_value=mock_process)
mock_conn.close = MagicMock()
mock_asyncssh.connect = AsyncMock(return_value=mock_conn)
handler, writer = self._make_handler(
[b"login serveradmin pass\n", b"serverinfo\n", b""]
)
await handler.handle()
# Verify login success response
self.assertIn(b"error id=0 msg=ok\n\r", self._get_writes(writer))
# Verify command was forwarded to SSH stdin
stdin_writes = [call[0][0] for call in mock_process.stdin.write.call_args_list]
self.assertIn("serverinfo\n", stdin_writes)
@patch("proxy.asyncssh")
async def test_proxy_forwards_ssh_response_to_client(self, mock_asyncssh):
"""SSH stdout data should be forwarded to the TCP client."""
import asyncssh as real_asyncssh
mock_asyncssh.PermissionDenied = real_asyncssh.PermissionDenied
mock_process = MagicMock()
mock_process.stdin = MagicMock()
mock_process.stdin.write = MagicMock()
mock_process.stdout = AsyncMock()
mock_process.stdout.read = AsyncMock(
side_effect=[
"TS3 banner help\n\r",
"virtualserver_name=Test error id=0 msg=ok\n\r",
"",
]
)
mock_process.close = MagicMock()
mock_conn = AsyncMock()
mock_conn.create_process = AsyncMock(return_value=mock_process)
mock_conn.close = MagicMock()
mock_asyncssh.connect = AsyncMock(return_value=mock_conn)
handler, writer = self._make_handler(
[b"login serveradmin pass\n", b""]
)
await handler.handle()
writes = self._get_writes(writer)
joined = b"".join(writes)
self.assertIn(b"virtualserver_name=Test", joined)
# --- Cleanup ---
async def test_cleanup_with_no_ssh(self):
"""Cleanup should not raise when SSH was never established."""
handler, writer = self._make_handler([b""])
self.assertIsNone(handler.ssh_conn)
self.assertIsNone(handler.ssh_process)
await handler._cleanup()
async def test_cleanup_when_writer_close_raises(self):
"""Cleanup should swallow exceptions from writer.close."""
handler, writer = self._make_handler([b""])
writer.wait_closed = AsyncMock(side_effect=OSError("already closed"))
await handler._cleanup() # should not raise
# --- Integration tests (only run when env vars are set) ---
# ---------------------------------------------------------------------------
# Integration tests (only run when env vars are set)
# ---------------------------------------------------------------------------
TS6_TEST_HOST = os.environ.get("TS6_TEST_HOST")
TS6_TEST_SSH_PORT = int(os.environ.get("TS6_TEST_SSH_PORT", "10022"))
@ -111,18 +339,15 @@ class TestIntegration(unittest.IsolatedAsyncioTestCase):
"""Integration tests that start a real proxy and connect through it."""
async def asyncSetUp(self):
os.environ["TS6_HOST"] = TS6_TEST_HOST
os.environ["TS6_SSH_PORT"] = str(TS6_TEST_SSH_PORT)
config = Config(ts6_host=TS6_TEST_HOST, ts6_ssh_port=TS6_TEST_SSH_PORT)
# Re-import to pick up env vars
import importlib
import proxy as proxy_mod
async def on_connect(
reader: asyncio.StreamReader, writer: asyncio.StreamWriter
) -> None:
handler = ClientHandler(reader, writer, config)
await handler.handle()
importlib.reload(proxy_mod)
self.server = await asyncio.start_server(
proxy_mod.handle_client, "127.0.0.1", PROXY_TEST_PORT
)
self.server = await asyncio.start_server(on_connect, "127.0.0.1", PROXY_TEST_PORT)
async def asyncTearDown(self):
self.server.close()
@ -130,11 +355,33 @@ class TestIntegration(unittest.IsolatedAsyncioTestCase):
async def _connect(self):
reader, writer = await asyncio.open_connection("127.0.0.1", PROXY_TEST_PORT)
# Read banner
banner = await asyncio.wait_for(reader.read(4096), timeout=5)
self.assertIn(b"TS3", banner)
return reader, writer
async def _login_and_use(self):
"""Helper: connect, login, select virtual server."""
reader, writer = await self._connect()
writer.write(f"login {TS6_TEST_USER} {TS6_TEST_PASS}\n".encode())
await writer.drain()
await asyncio.wait_for(reader.readline(), timeout=10)
writer.write(b"use sid=1\n")
await writer.drain()
await asyncio.wait_for(reader.readline(), timeout=5)
return reader, writer
async def _send_cmd(self, reader, writer, cmd):
"""Send a command and return the full response including error line."""
writer.write(f"{cmd}\n".encode())
await writer.drain()
data = b""
for _ in range(20):
chunk = await asyncio.wait_for(reader.readline(), timeout=5)
data += chunk
if b"error id=" in chunk:
break
return data
async def test_login_and_version(self):
reader, writer = await self._connect()
writer.write(f"login {TS6_TEST_USER} {TS6_TEST_PASS}\n".encode())
@ -144,7 +391,6 @@ class TestIntegration(unittest.IsolatedAsyncioTestCase):
writer.write(b"version\n")
await writer.drain()
# version response + error line
data = b""
for _ in range(10):
chunk = await asyncio.wait_for(reader.readline(), timeout=5)
@ -186,44 +432,18 @@ class TestIntegration(unittest.IsolatedAsyncioTestCase):
await writer.drain()
writer.close()
async def _login_and_use(self):
"""Helper: connect, login, select virtual server."""
reader, writer = await self._connect()
writer.write(f"login {TS6_TEST_USER} {TS6_TEST_PASS}\n".encode())
await writer.drain()
await asyncio.wait_for(reader.readline(), timeout=10)
writer.write(b"use sid=1\n")
await writer.drain()
await asyncio.wait_for(reader.readline(), timeout=5)
return reader, writer
async def _send_cmd(self, reader, writer, cmd):
"""Send a command and return the full response including error line."""
writer.write(f"{cmd}\n".encode())
await writer.drain()
data = b""
for _ in range(20):
chunk = await asyncio.wait_for(reader.readline(), timeout=5)
data += chunk
if b"error id=" in chunk:
break
return data
async def test_channellist_and_channelinfo(self):
reader, writer = await self._login_and_use()
# Get channel list and extract first channel ID
data = await self._send_cmd(reader, writer, "channellist")
self.assertIn(b"error id=0", data)
self.assertIn(b"cid=", data)
# Extract first cid
for part in data.decode(errors="replace").split("|")[0].split():
if part.startswith("cid="):
cid = part.split("=")[1]
break
# channelinfo with valid cid
data = await self._send_cmd(reader, writer, f"channelinfo cid={cid}")
self.assertIn(b"error id=0", data)
self.assertIn(b"channel_name=", data)
@ -235,24 +455,22 @@ class TestIntegration(unittest.IsolatedAsyncioTestCase):
async def test_ban_add_list_delete(self):
reader, writer = await self._login_and_use()
# Add a temporary ban
data = await self._send_cmd(reader, writer, "banadd ip=254.253.252.251 banreason=proxytest time=10")
data = await self._send_cmd(
reader, writer, "banadd ip=254.253.252.251 banreason=proxytest time=10"
)
self.assertIn(b"error id=0", data)
self.assertIn(b"banid=", data)
# Extract ban ID
banid = None
for part in data.decode(errors="replace").split():
if part.startswith("banid="):
banid = part.split("=")[1].strip()
break
# List bans
data = await self._send_cmd(reader, writer, "banlist")
self.assertIn(b"error id=0", data)
self.assertIn(b"ip=254.253.252.251", data)
# Delete the ban
data = await self._send_cmd(reader, writer, f"bandel banid={banid}")
self.assertIn(b"error id=0", data)