Source code for proxy.core.ssh.listener

# -*- coding: utf-8 -*-
"""
    proxy.py
    ~~~~~~~~
    ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on
    Network monitoring, controls & Application development, testing, debugging.

    :copyright: (c) 2013-present by Abhinav Singh and contributors.
    :license: BSD, see LICENSE for more details.
"""
import sys
import socket
import logging
import argparse
from typing import TYPE_CHECKING, Any, Set, Optional, cast


try:
    if TYPE_CHECKING:  # pragma: no cover
        from ...common.types import HostPort
except ImportError:     # pragma: no cover
    pass

from .base import BaseSshTunnelHandler, BaseSshTunnelListener
from ...common.flag import flags


logger = logging.getLogger(__name__)


flags.add_argument(
    '--tunnel-hostname',
    type=str,
    default=None,
    help='Default: None. Remote hostname or IP address to which SSH tunnel will be established.',
)

flags.add_argument(
    '--tunnel-port',
    type=int,
    default=22,
    help='Default: 22. SSH port of the remote host.',
)

flags.add_argument(
    '--tunnel-username',
    type=str,
    default=None,
    help='Default: None. Username to use for establishing SSH tunnel.',
)

flags.add_argument(
    '--tunnel-ssh-key',
    type=str,
    default=None,
    help='Default: None. Private key path in pem format',
)

flags.add_argument(
    '--tunnel-ssh-key-passphrase',
    type=str,
    default=None,
    help='Default: None. Private key passphrase',
)

flags.add_argument(
    '--tunnel-remote-port',
    type=int,
    default=8899,
    help='Default: 8899. Remote port which will be forwarded locally for proxy.',
)


[docs]class SshTunnelListener(BaseSshTunnelListener): """Connects over SSH and forwards a remote port to local host. Incoming connections are delegated to provided callback.""" def __init__( self, flags: argparse.Namespace, handler: BaseSshTunnelHandler, *args: Any, **kwargs: Any, ) -> None: paramiko_logger = logging.getLogger('paramiko') paramiko_logger.setLevel(logging.WARNING) # pylint: disable=import-outside-toplevel from paramiko import SSHClient from paramiko.transport import Transport self.flags = flags self.handler = handler self.ssh: Optional[SSHClient] = None self.transport: Optional[Transport] = None self.forwarded: Set['HostPort'] = set()
[docs] def start_port_forward(self, remote_addr: 'HostPort') -> None: assert self.transport is not None self.transport.request_port_forward( *remote_addr, handler=self.handler.on_connection, ) self.forwarded.add(remote_addr) logger.debug('%s:%d forwarding successful...' % remote_addr)
[docs] def stop_port_forward(self, remote_addr: 'HostPort') -> None: assert self.transport is not None self.transport.cancel_port_forward(*remote_addr) self.forwarded.remove(remote_addr)
[docs] def setup(self) -> None: # pylint: disable=import-outside-toplevel from paramiko import SSHClient, AutoAddPolicy self.ssh = SSHClient() self.ssh.load_system_host_keys() self.ssh.set_missing_host_key_policy(AutoAddPolicy()) self.ssh.connect( hostname=self.flags.tunnel_hostname, port=self.flags.tunnel_port, username=self.flags.tunnel_username, key_filename=self.flags.tunnel_ssh_key, passphrase=self.flags.tunnel_ssh_key_passphrase, compress=True, timeout=10, auth_timeout=7, ) logger.debug( 'SSH connection established to %s:%d...' % ( self.flags.tunnel_hostname, self.flags.tunnel_port, ), ) self.transport = self.ssh.get_transport() assert self.transport sock = cast(socket.socket, self.transport.sock) # type: ignore[redundant-cast] # Enable TCP keep-alive sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) # Keep-alive interval (in seconds) if sys.platform != 'darwin': sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 30) # Keep-alive probe interval (in seconds) sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 5) # Number of keep-alive probes before timeout sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 5) self.start_port_forward(('', self.flags.tunnel_remote_port))
[docs] def shutdown(self) -> None: for remote_addr in list(self.forwarded): self.stop_port_forward(remote_addr) self.forwarded.clear() if self.transport is not None: self.transport.close() if self.ssh is not None: self.ssh.close() self.handler.shutdown()
[docs] def is_alive(self) -> bool: return self.transport.is_alive() if self.transport else False
[docs] def is_active(self) -> bool: return self.transport.is_active() if self.transport else False