#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import socket
import argparse
import ipaddress
import threading
import time
from queue import Queue
from concurrent.futures import ThreadPoolExecutor

VERSION = "0.8"

BLUE = "\033[94m"
GREEN = "\033[92m"
RED = "\033[91m"
ORANGE = "\033[33m"
ENDC = "\033[0m"

progress_lock = threading.Lock()
progress_counter = 0
total_hosts = 0

def display_banner():
    banner = rf"""
{BLUE}
                                      _________ _________ ___ ___ .__
_______   ____   ___________   ____  /   _____//   _____//   |   \|__| ____   ____
\_  __ \_/ __ \ / ___\_  __ \_/ __ \ \_____  \ \_____  \/    ~    \  |/  _ \ /    \
 |  | \/\  ___// /_/  >  | \/\  ___/ /        \/        \    Y    /  (  <_> )   |  \
 |__|    \___  >___  /|__|    \___  >_______  /_______  /\___|_  /|__|\____/|___|  /
             \/_____/             \/        \/        \/       \/                \/
    CVE-2024-6387 Vulnerability Checker
    v{VERSION} / Alex Hagenah / @xaitax / ah@primepage.de
{ENDC}
"""
    print(banner)

def resolve_hostname(hostname):
    try:
        addr_info = socket.getaddrinfo(hostname, None)
        addresses = [addr[4][0] for addr in addr_info]
        return addresses
    except socket.gaierror:
        print(f"❌ [-] Could not resolve hostname: {hostname}")
        return []

def create_socket(ip, port, timeout):
    try:
        family = socket.AF_INET6 if ':' in ip else socket.AF_INET
        sock = socket.socket(family, socket.SOCK_STREAM)
        sock.settimeout(timeout)
        sock.connect((ip, port))
        return sock
    except Exception:
        return None

def get_ssh_banner(sock, use_help_request):
    try:
        banner = sock.recv(1024).decode(errors='ignore').strip()
        if banner or not use_help_request:
            return banner
        help_string = "HELP\n"
        sock.sendall(help_string.encode())
        banner = sock.recv(1024).decode(errors='ignore').strip()
        return banner
    except Exception as e:
        return str(e)
    finally:
        sock.close()

def check_vulnerability(ip, port, timeout, grace_time_check, use_help_request, dns_resolve, result_queue):
    global progress_counter

    sshsock = create_socket(ip, port, timeout)
    if not sshsock:
        result_queue.put((ip, port, 'closed', "Port closed", ip))
        with progress_lock:
            progress_counter += 1
        return

    banner = get_ssh_banner(sshsock, use_help_request)
    if "SSH-2.0" not in banner:
        result_queue.put(
            (ip, port, 'failed', f"Failed to retrieve SSH banner: {banner}", ip))
        with progress_lock:
            progress_counter += 1
        return

    if "SSH-2.0-OpenSSH" not in banner:
        result_queue.put((ip, port, 'unknown', f"(banner: {banner})", ip))
        with progress_lock:
            progress_counter += 1
        return

    hostname = resolve_ip(ip) if dns_resolve else None

    vulnerable_versions = [
        'SSH-2.0-OpenSSH_1',
        'SSH-2.0-OpenSSH_2',
        'SSH-2.0-OpenSSH_3',
        'SSH-2.0-OpenSSH_4.0',
        'SSH-2.0-OpenSSH_4.1',
        'SSH-2.0-OpenSSH_4.2',
        'SSH-2.0-OpenSSH_4.3',
        'SSH-2.0-OpenSSH_4.4',
        'SSH-2.0-OpenSSH_8.5',
        'SSH-2.0-OpenSSH_8.6',
        'SSH-2.0-OpenSSH_8.7',
        'SSH-2.0-OpenSSH_8.8',
        'SSH-2.0-OpenSSH_8.9',
        'SSH-2.0-OpenSSH_9.0',
        'SSH-2.0-OpenSSH_9.1',
        'SSH-2.0-OpenSSH_9.2',
        'SSH-2.0-OpenSSH_9.3',
        'SSH-2.0-OpenSSH_9.4',
        'SSH-2.0-OpenSSH_9.5',
        'SSH-2.0-OpenSSH_9.6',
        'SSH-2.0-OpenSSH_9.7'
    ]

    patched_versions = [
        'SSH-2.0-OpenSSH_8.9p1 Ubuntu-3ubuntu0.10',
        'SSH-2.0-OpenSSH_8.9p1 Ubuntu-3ubuntu0.13',
        'SSH-2.0-OpenSSH_9.3p1 Ubuntu-3ubuntu3.6',
        'SSH-2.0-OpenSSH_9.6p1 Ubuntu-3ubuntu13.3',
        'SSH-2.0-OpenSSH_9.6p1 Ubuntu-3ubuntu13.4',
        'SSH-2.0-OpenSSH_9.6p1 Ubuntu-3ubuntu13.5',
        'SSH-2.0-OpenSSH_9.6p1 Ubuntu-3ubuntu13.11',
        'SSH-2.0-OpenSSH_9.3p1 Ubuntu-1ubuntu3.6',
        'SSH-2.0-OpenSSH_9.2p1 Debian-2+deb12u3',
        'SSH-2.0-OpenSSH_8.4p1 Debian-5+deb11u3',
        'SSH-2.0-OpenSSH_9.7p1 Debian-7',
        'SSH-2.0-OpenSSH_9.6 FreeBSD-20240701',
        'SSH-2.0-OpenSSH_9.7 FreeBSD-20240701'
    ]

    if any(version in banner for version in vulnerable_versions) and banner not in patched_versions:
        if grace_time_check:
            sshsock = create_socket(ip, port, timeout)
            starttime = time.time()
            banner_throw_away = sshsock.recv(1024).decode(errors='ignore').strip()
            sshsock.settimeout(grace_time_check - (time.time() - starttime) + 4)
            socket_timed_out = 0
            try:
                msg = sshsock.recv(1024).decode(errors='ignore').strip()
            except socket.timeout:
                socket_timed_out = 1
            time_elapsed = time.time() - starttime
            if sshsock:
                if socket_timed_out == 0:
                    result_queue.put((ip, port, 'vulnerable', f"(running {banner}) vulnerable and LoginGraceTime remediation not done (Session was closed by server at {time_elapsed:.1f} seconds)", hostname))
                else:
                    result_queue.put((ip, port, 'likely_not_vulnerable', f"(running {banner} False negative possible depending on LoginGraceTime)", hostname))
                sshsock.close()
            else:
                result_queue.put((ip, port, 'vulnerable', f"(running {banner})", hostname))
        else:
            result_queue.put((ip, port, 'vulnerable', f"(running {banner})", hostname))
    else:
        result_queue.put((ip, port, 'not_vulnerable', f"(running {banner})", hostname))

    with progress_lock:
        progress_counter += 1

def process_ip_list(ip_list_file):
    ips = []

    try:
        with open(ip_list_file, 'r') as file:
            for target in file:
                if '/' in target:
                    try:
                        network = ipaddress.ip_network(target.strip(), strict=False)
                        ips.extend([str(ip) for ip in network.hosts()])
                    except ValueError as e:
                        print(f"❌ [-] Invalid CIDR notation {target} {e}")
                else:
                    ips.append(target)
    except IOError:
        print(f"❌ [-] Could not read file: {ip_list_file}")
    return [ip.strip() for ip in ips]

def resolve_ip(ip):
    try:
        hostname = socket.gethostbyaddr(ip)
        return hostname[0]
    except (socket.herror, socket.gaierror):
        return None

def main():
    global total_hosts
    display_banner()

    parser = argparse.ArgumentParser(
        description="Check if servers are running a vulnerable version of OpenSSH (CVE-2024-6387).")
    parser.add_argument(
        "targets", nargs='*', help="IP addresses, domain names, file paths containing IP addresses, or CIDR network ranges.")
    parser.add_argument("-p", "--ports", type=str, default="22",
                        help="Comma-separated list of port numbers to check (default: 22).")
    parser.add_argument("-t", "--timeout", type=float, default=1.0,
                        help="Connection timeout in seconds (default: 1 second).")
    parser.add_argument("-l", "--list", help="File containing a list of IP addresses to check.")
    parser.add_argument("-g", "--grace-time-check", nargs='?', const=120, type=int,
                        help="Time in seconds to wait after identifying the version to check for LoginGraceTime mitigation (default: 120 seconds).")
    parser.add_argument("-d", "--dns-resolve", action='store_true',
                    help="Resolve and display hostnames for IP addresses.")
    parser.add_argument("-u", "--use-help-request", action='store_true',
                        help="Enable sending a HELP request if the initial SSH banner retrieval fails.")

    args = parser.parse_args()
    targets = args.targets
    ports = [int(p) for p in args.ports.split(',')]
    timeout = args.timeout
    grace_time_check = args.grace_time_check
    use_help_request = args.use_help_request

    ips = set()

    if args.list:
        ips.update(process_ip_list(args.list))

    for target in targets:
        try:
            with open(target, 'r') as file:
                ips.update(file.readlines())
        except IOError:
            if '/' in target:
                try:
                    network = ipaddress.ip_network(target.strip(), strict=False)
                    ips.update([str(ip) for ip in network.hosts()])
                except ValueError:
                    print(f"❌ [-] Invalid CIDR notation: {target}")
            else:
                resolved_ips = resolve_hostname(target)
                if resolved_ips:
                    ips.update(resolved_ips)
                else:
                    ips.add(target.strip())

    result_queue = Queue()

    ips = list(ips)
    total_hosts = len(ips)

    max_workers = 100

    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        futures = [executor.submit(check_vulnerability, ip.strip(), port, timeout, grace_time_check, use_help_request, args.dns_resolve, result_queue)
                for ip in ips for port in ports]

        while any(future.running() for future in futures):
            with progress_lock:
                print(f"\rProgress: {progress_counter}/{total_hosts * len(ports)} checks performed", end="")
            time.sleep(1)

    for future in futures:
        future.result()

    print(f"\rProgress: {progress_counter}/{total_hosts * len(ports)} checks performed")

    closed_ports = 0
    unknown = []
    not_vulnerable = []
    likely_not_vulnerable = []
    vulnerable = []

    while not result_queue.empty():
        ip, port, status, message, hostname = result_queue.get()
        display_ip = f"{ip} ({hostname})" if hostname else ip
        if status == 'closed':
            closed_ports += 1
        elif status == 'unknown':
            unknown.append((display_ip, port, message))
        elif status == 'vulnerable':
            vulnerable.append((display_ip, port, message))
        elif status == 'not_vulnerable':
            not_vulnerable.append((display_ip, port, message))
        elif status == 'likely_not_vulnerable':
            likely_not_vulnerable.append((display_ip, port, message))
        else:
            print(f"⚠️ [!] Server at {display_ip} is {message}")

    print(f"\n🛡️ Servers not vulnerable: {len(not_vulnerable)}")
    for ip, port, msg in not_vulnerable:
        print(f"   [+] Server at {GREEN}{ip}{ENDC} {msg}")
    if grace_time_check:
        print(f"\n🛡️ Servers likely not vulnerable (possible LoginGraceTime remediation): {len(likely_not_vulnerable)}")
        for ip, port, msg in likely_not_vulnerable:
            print(f"   [+] Server at {GREEN}{ip}{ENDC} {msg}")
    print(f"\n🚨 Servers likely vulnerable: {len(vulnerable)}")
    for ip, port, msg in vulnerable:
        print(f"   [+] Server at {RED}{ip}{ENDC} {msg}")
    print(f"\n⚠️ Servers with unknown SSH version: {len(unknown)}")
    for ip, port, msg in unknown:
        print(f"   [+] Server at {ORANGE}{ip}{ENDC} {msg}")
    print(f"\n🔒 Servers with port(s) closed: {closed_ports}")
    print(f"\n📊 Total scanned hosts: {total_hosts}")
    print(f"📊 Total port checks performed: {progress_counter}")

if __name__ == "__main__":
    main()
