Add SSH and TCP keepalive to prevent silent disconnects
- SSH keepalive every 30s with max 3 retries (configurable via env vars) - TCP keepalive on client sockets to detect dead connections - Increase pre-auth timeout from 30s to 300s to match TS6 query timeout - Add keepalive settings to Config dataclass
This commit is contained in:
parent
c9151df7e0
commit
58f817af98
1 changed files with 23 additions and 2 deletions
25
proxy.py
25
proxy.py
|
|
@ -4,6 +4,7 @@ import asyncio
|
|||
import contextlib
|
||||
import logging
|
||||
import os
|
||||
import socket
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
|
@ -23,6 +24,8 @@ class Config:
|
|||
ts6_ssh_port: int = 10022
|
||||
listen_host: str = "0.0.0.0"
|
||||
listen_port: int = 10011
|
||||
ssh_keepalive_interval: int = 30
|
||||
ssh_keepalive_count_max: int = 3
|
||||
|
||||
@classmethod
|
||||
def from_env(cls) -> Config:
|
||||
|
|
@ -30,6 +33,8 @@ class Config:
|
|||
ts6_host=os.environ.get("TS6_HOST", "teamspeak6"),
|
||||
ts6_ssh_port=int(os.environ.get("TS6_SSH_PORT", "10022")),
|
||||
listen_port=int(os.environ.get("LISTEN_PORT", "10011")),
|
||||
ssh_keepalive_interval=int(os.environ.get("SSH_KEEPALIVE_INTERVAL", "30")),
|
||||
ssh_keepalive_count_max=int(os.environ.get("SSH_KEEPALIVE_COUNT_MAX", "3")),
|
||||
)
|
||||
|
||||
TS3_BANNER = (
|
||||
|
|
@ -57,6 +62,11 @@ class ClientHandler:
|
|||
async def handle(self) -> None:
|
||||
log.info("Client connected from %s", self.addr)
|
||||
try:
|
||||
# Enable TCP keepalive on the client socket
|
||||
sock = self.writer.get_extra_info("socket")
|
||||
if sock:
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
|
||||
|
||||
self.writer.write(TS3_BANNER)
|
||||
await self.writer.drain()
|
||||
await self._auth_loop()
|
||||
|
|
@ -71,7 +81,7 @@ class ClientHandler:
|
|||
|
||||
async def _auth_loop(self) -> None:
|
||||
while True:
|
||||
line = await asyncio.wait_for(self.reader.readline(), timeout=30)
|
||||
line = await asyncio.wait_for(self.reader.readline(), timeout=300)
|
||||
if not line:
|
||||
return
|
||||
cmd = line.decode("utf-8", errors="replace").strip()
|
||||
|
|
@ -124,6 +134,8 @@ class ClientHandler:
|
|||
username=username,
|
||||
password=password,
|
||||
known_hosts=None,
|
||||
keepalive_interval=self.config.ssh_keepalive_interval,
|
||||
keepalive_count_max=self.config.ssh_keepalive_count_max,
|
||||
),
|
||||
timeout=10,
|
||||
)
|
||||
|
|
@ -144,7 +156,11 @@ class ClientHandler:
|
|||
except TimeoutError:
|
||||
break
|
||||
|
||||
log.info("SSH session established for %s", self.addr)
|
||||
log.info(
|
||||
"SSH session established for %s (keepalive=%ds)",
|
||||
self.addr,
|
||||
self.config.ssh_keepalive_interval,
|
||||
)
|
||||
return True
|
||||
except asyncssh.PermissionDenied:
|
||||
log.warning("SSH auth failed for %s (bad credentials)", self.addr)
|
||||
|
|
@ -242,6 +258,11 @@ async def main() -> None:
|
|||
)
|
||||
log.info("TS3 Query Proxy listening on %s:%d", config.listen_host, config.listen_port)
|
||||
log.info("Forwarding to %s:%d (SSH Query)", config.ts6_host, config.ts6_ssh_port)
|
||||
log.info(
|
||||
"SSH keepalive: interval=%ds, max_count=%d",
|
||||
config.ssh_keepalive_interval,
|
||||
config.ssh_keepalive_count_max,
|
||||
)
|
||||
async with server:
|
||||
await server.serve_forever()
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue