The Perl Toolchain Summit needs more sponsors. If your company depends on Perl, please support this very important event.
use strict;

use Getopt::Long;
use Carp qw/croak/;
use Digest::MD5 qw/md5/;

use Net::SSH::Perl;
use Net::SSH::Perl::Packet;
use Net::SSH::Perl::Cipher;
use Net::SSH::Perl::Auth;
use Net::SSH::Perl::Util qw( :ssh1mp :rsa :hosts _load_private_key );
use Net::SSH::Perl::Constants qw( :msg :hosts PROTOCOL_MAJOR PROTOCOL_MINOR );
use IO::Socket;
use Math::GMP;

use vars qw( $VERSION );
$VERSION = "0.01";

use vars qw( $DEBUG );

GetOptions("port|p=i", \my $port, "debug|d", \$DEBUG,
           "bits|b=i", \my $key_bits,
           "gen-hostkey|g", \my $generate_host_key);

$port ||= 60000;
$DEBUG  = 0 unless defined $DEBUG;
$key_bits ||= 768;

use vars qw( $USER_AUTHORIZED_KEYS $DUMMY_PASSWD $KNOWN_HOSTS
             $PID_FILE );
BEGIN { unshift @INC, 't/' }
require 'test-common.pl';

my $PID = $$;
$SIG{TERM} = sub {
    if ($$ == $PID && -e $PID_FILE) {
        unlink $PID_FILE or die "Can't unlink $PID: $!";
    }
    die "Terminated.\n";
};

my $ssh = Net::SSH::Perl->new("dummy",
    debug => $DEBUG,
    user_known_hosts => $KNOWN_HOSTS);

if (-e $PID_FILE) {
    $ssh->debug("Removing old pid file $PID_FILE");
    unlink $PID_FILE or die "Can't unlink old pid file $PID_FILE: $!";
}

my $SERVER = IO::Socket::INET->new(
    Listen    => 5,
    LocalAddr => 'localhost',
    LocalPort => $port,
    Proto     => 'tcp',
    Reuse     => 1);
die "$0: Couldn't create server on port $port: $@"
    unless $SERVER;
$ssh->debug("Server listening on port $port.");

$ssh->debug("Generating $key_bits bit RSA key.");
my($public, $private) = _rsa_generate_key($key_bits);
$ssh->debug("Key generation complete.");

my %keys;
if ($generate_host_key) {
    $ssh->debug("Generating 1024 bit host key.");
    my($hpub, $hprv) = _rsa_generate_key(1024);
    $ssh->debug("Host key generation complete.");

    $keys{host} = $hprv;

    if (-e $KNOWN_HOSTS) {
        unlink $KNOWN_HOSTS or die "Can't unlink $KNOWN_HOSTS: $!";
    }
    _add_host_to_hostfile("localhost", $KNOWN_HOSTS, $hpub);
}
else {
    $keys{host} = _load_private_key("/etc/ssh_host_key");
}

$keys{server} = $public;
$keys{private} = $private;

$ssh->debug("Writing pid file with pid $PID.");
{
    local *FH;
    open FH, ">$PID_FILE" or die "Can't open $PID_FILE: $!";
    print FH $PID;
    close FH or die "Can't close $PID_FILE: $!";
}

$ssh->debug("Ready to serve connections.");

my $waitedpid = 0;
sub REAPER {
    $waitedpid = wait;
    $SIG{CHLD} = \&REAPER;
    $ssh->debug("Reaped $waitedpid");
}
$SIG{CHLD} = \&REAPER;

while (my $client = $SERVER->accept) {
    my $pid;
    if (!defined($pid = fork)) {
        $ssh->debug("Cannot fork: $!");
    }
    elsif ($pid) {
        $ssh->debug("Forked $pid to handle client");
    }
    else {
        _handle_child($SERVER, $ssh, \%keys, $client);
        exit;
    }
}

sub _handle_child {
    my($server, $ssh, $keys, $client) = @_;

    $client->autoflush(1);

    printf $client "SSH-%d.%d-%s\n",
        PROTOCOL_MAJOR, PROTOCOL_MINOR, $VERSION;
    my $remote_id = <$client>;
    my($remote_major, $remote_minor, $remote_version) = $remote_id =~
        /^SSH-(\d+)\.(\d+)-([^\n]+)\n$/;
    print $client "Protocol mismatch.\n"
        unless $remote_major && $remote_minor && $remote_version;
    $ssh->debug("Client protocol version $remote_major.$remote_minor, remote software version $remote_version");

    $ssh->{session}{sock} = $client;

    _do_connection($server, $ssh, $keys);
}

sub _do_connection {
    my($server, $ssh, $keys) = @_;
    my($packet);

    my $check_bytes = join '', map chr rand 255, 1..8;
    $packet = $ssh->packet_start(SSH_SMSG_PUBLIC_KEY);
    $packet->put_chars($check_bytes);

    for my $k (qw( server host )) {
        my $key = $keys->{$k};
        $packet->put_int32($key->{bits});
        $packet->put_mp_int($key->{e});
        $packet->put_mp_int($key->{n});
    }

    $packet->put_int32(0);
    $packet->put_int32( Net::SSH::Perl::Cipher::mask() );
    $packet->put_int32( Net::SSH::Perl::Auth::mask() );
    $packet->send;

    $ssh->debug(sprintf "Sent %d bit public key and %d bit host key.",
        $keys->{server}{bits}, $keys->{host}{bits});

    $packet = Net::SSH::Perl::Packet->read_expect($ssh, SSH_CMSG_SESSION_KEY);
    my $cipher = $packet->get_int8;

    unless (Net::SSH::Perl::Cipher::supported($cipher)) {
        croak "Client selected unsupported cipher $cipher";
    }

    for my $cb (split //, $check_bytes) {
        if ($cb ne $packet->get_char) {
            croak "IP Spoofing check bytes do not match.";
        }
    }

    my $cipher_name = Net::SSH::Perl::Cipher::name($cipher);
    $ssh->debug("Encryption type: $cipher_name");

    my $session_key = $packet->get_mp_int;
    my $protocol_flags = $packet->get_int32;

    if ($keys->{private}{n} > $keys->{host}{n}) {
        $session_key = _rsa_private_decrypt($session_key, $keys->{private});
        $session_key = _rsa_private_decrypt($session_key, $keys->{host});
    }
    else {
        $session_key = _rsa_private_decrypt($session_key, $keys->{host});
        $session_key = _rsa_private_decrypt($session_key, $keys->{private});
    }

    my $session_id = _compute_session_id($check_bytes, $keys->{host}, $keys->{private});
    $ssh->{session}{id} = $session_id;

    $session_key = _mp_linearize(32, $session_key);
    for my $i (0..15) {
        substr($session_key, $i, 1) ^= substr($session_id, $i, 1);
    }

    $ssh->set_cipher($cipher_name, $session_key);
    $ssh->debug("Received session key; encryption turned on.");

    $packet = $ssh->packet_start(SSH_SMSG_SUCCESS);
    $packet->send;

    $packet = Net::SSH::Perl::Packet->read_expect($ssh, SSH_CMSG_USER);
    $ssh->{user} = $packet->get_str;

    $packet = $ssh->packet_start(SSH_SMSG_FAILURE);
    $packet->send;

    my $authenticated;
    my %auth = (
        SSH_CMSG_AUTH_RHOSTS() => \&auth_rhosts,
        SSH_CMSG_AUTH_RHOSTS_RSA() => \&auth_rhosts_rsa,
        SSH_CMSG_AUTH_RSA() => \&auth_rsa,
        SSH_CMSG_AUTH_PASSWORD() => \&auth_password,
    );

    while (!$authenticated) {
        $packet = Net::SSH::Perl::Packet->read($ssh);

        if (my $code = $auth{ $packet->type }) {
            $authenticated = $code->($ssh, $packet);
        }

        unless ($authenticated) {
            $packet = $ssh->packet_start(SSH_SMSG_FAILURE);
            $packet->send;
        }
    }

    $packet = $ssh->packet_start(SSH_SMSG_SUCCESS);
    $packet->send;

    $packet = Net::SSH::Perl::Packet->read_expect($ssh, SSH_CMSG_EXEC_CMD);
    my $cmd = $packet->get_str;
    my $out = `$cmd`;
    my $exit = $?;
    $ssh->debug("Command exited with status $exit.");

    $packet = $ssh->packet_start(SSH_SMSG_STDOUT_DATA);
    $packet->put_str($out);
    $packet->send;

    $packet = $ssh->packet_start(SSH_SMSG_EXITSTATUS);
    $packet->put_int32($exit);
    $packet->send;

    $packet = Net::SSH::Perl::Packet->read_expect($ssh, SSH_CMSG_EXIT_CONFIRMATION);
    $ssh->debug("Received exit confirmation.");
}

sub auth_rhosts {
    my($ssh, $packet) = @_;
    my $privileged_port = $ssh->sock->peerport < 1024;

    unless ($privileged_port) {
        $ssh->debug("Rhosts authentication not available for connections from unprivileged port.");
        return 0;
    }

    return 1;
}

sub auth_rhosts_rsa {
    my($ssh, $packet) = @_;
    my $privileged_port = $ssh->sock->peerport < 1024;

    $ssh->debug("Trying rhosts with RSA host authentication for $ssh->{user}");

    unless ($privileged_port) {
        $ssh->debug("RhostsRsa authentication not available for connections from unprivileged port.");
        return 0;
    }

    $packet->get_str;  ## User; we already have it.
    my $hk = {};
    $hk->{bits} = $packet->get_int32;
    $hk->{e} = $packet->get_mp_int;
    $hk->{n} = $packet->get_mp_int;

    my $rhost = "localhost";
    my $status = _check_host_in_hostfile($rhost,
      $ssh->config->get('user_known_hosts'), $hk);
    unless ($status == HOST_OK) {
        $ssh->debug("Rhosts with RSA host authentication denied: unknown or invalid host key.");
        return 0;
    }

    unless (_auth_rsa_challenge_dialog($ssh, $hk->{bits}, $hk->{e}, $hk->{n})) {
        $ssh->debug("Client failed to respond correctly to host authentication.");
        return 0;
    }

    $ssh->debug("Rhosts with RSA host authentication accepted.");
    return 1;
}

sub auth_rsa {
    my($ssh, $packet) = @_;

    my $client_n = $packet->get_mp_int;

    my $key_file = $USER_AUTHORIZED_KEYS;
    local *FH;
    unless (open FH, $key_file) {
        $ssh->debug("Could not open $key_file for reading.");
        return 0;
    }
    while (<FH>) {
        chomp;
        my($bits, $e, $n, $comment) = split /\s+/;
        $n = Math::GMP->new($n);
        $e = Math::GMP->new($e);
        next unless $n == $client_n;

        my $valid = _auth_rsa_challenge_dialog($ssh, $bits, $e, $n);
        if ($valid) {
            $ssh->debug("RSA authentication for $ssh->{user} accepted.");
            return 1;
        }
    }
    close FH;

    $ssh->debug("RSA authentication for $ssh->{user} failed.");
    return 0;
}

sub _auth_rsa_challenge_dialog {
    my($ssh, $bits, $e, $n) = @_;
    my($packet);

    my $challenge = _rsa_random_integer(256);
    $challenge = Math::GMP::mod_two($challenge, $n);

    my $pk = { bits => $bits, e => $e, n => $n };
    my $encrypted = _rsa_public_encrypt($challenge, $pk);

    $packet = $ssh->packet_start(SSH_SMSG_AUTH_RSA_CHALLENGE);
    $packet->put_mp_int($encrypted);
    $packet->send;

    my $buf = _mp_linearize(32, $challenge);
    my $md5 = md5($buf, $ssh->session_id);

    $packet = Net::SSH::Perl::Packet->read_expect($ssh, SSH_CMSG_AUTH_RSA_RESPONSE);
    my $response;
    $response .= $packet->get_char for (1..16);

    $response eq $md5;
}

sub auth_password {
    my($ssh, $packet) = @_;

    unless (my $ciph = $ssh->receive_cipher) {
        $ssh->debug("Password authentication not available for unencrypted session.");
        return 0;
    }

    my $pass = $packet->get_str;
    $pass eq $DUMMY_PASSWD;
}

sub _rsa_generate_key {
    my $bits = shift;

    my($prv, $pub) = ({}, {});

    my($aux);
    {
        my $pbits = int($bits / 2);
        my $qbits = $bits - $pbits;

        warn "Generating p...\n" if $DEBUG;
        $prv->{p} = _rsa_random_prime($pbits);

        warn "Generating q...\n" if $DEBUG;
        $prv->{q} = _rsa_random_prime($qbits);

        unless ($prv->{p} < $prv->{q}) {
            ($prv->{p}, $prv->{q}) = ($prv->{q}, $prv->{p});
        }

        $aux = Math::GMP::gcd_two($prv->{p}, $prv->{q});
        warn("The primes are not relatively prime!\n"), redo
            unless "$aux" eq '1';
    }

    my $p_minus_1 = $prv->{p} - Math::GMP->new(1);
    my $q_minus_1 = $prv->{q} - Math::GMP->new(1);

    my $phi = $p_minus_1 * $q_minus_1;
    my $G = Math::GMP::gcd_two($p_minus_1, $q_minus_1);
    warn "Warning: G=$G is large (many spare key sets); key may be bad!\n"
        if $G > 100;
    my $F = $phi / $G;

    my $e = Math::GMP->new(1);
    $e = Math::GMP::mul_2exp_gmp($e, 5);
    $e -= Math::GMP->new(1);
    do {
        $e += Math::GMP->new(2);
        $aux = Math::GMP::gcd_two($e, $phi);
    } while "$aux" ne '1';
    $prv->{e} = $e;

    $prv->{d} = _mpz_mod_inverse($prv->{e}, $F);
    $prv->{u} = _mpz_mod_inverse($prv->{p}, $prv->{q});

    $prv->{n} = $prv->{p} * $prv->{q};
    $prv->{bits} = $bits;

    $pub->{bits} = $bits;
    $pub->{n} = $prv->{n};
    $pub->{e} = $prv->{e};

    ($pub, $prv);
}

sub _mpz_mod_inverse {
    my($a, $n) = @_;
    my $g0 = Math::GMP->new($n);
    my $g1 = Math::GMP->new($a);
    my $div = Math::GMP->new(0);
    my $mod = Math::GMP->new(0);
    my $v0 = Math::GMP->new(0);
    my $v1 = Math::GMP->new(1);
    my $aux;
    while ("$g1" ne '0') {
        Math::GMP::divmod_gmp($div, $mod, $g0, $g1);
        $aux = $v0 - ($div * $v1);
        $v0 = Math::GMP->new($v1);
        $v1 = Math::GMP->new($aux);
        $g0 = Math::GMP->new($g1);
        $g1 = Math::GMP->new($mod);
    }
    if ($v0 < 0) {
        return $v0 + $n;
    }
    else {
        return $v0;
    }
}

sub _rsa_random_prime {
    my $bits = shift;

    my $ret = _rsa_random_integer($bits);
    while (!Math::GMP::probab_prime_p($ret, 10)) {
        Math::GMP::add_ui_gmp($ret, 1);
    }
    $ret;
}

sub _rsa_random_integer {
    my $bits = shift;
    my $bytes = int(($bits + 7) / 8);
    my $str = join '', map sprintf("%02x", int rand 255), 1..$bytes;
    my $ret = Math::GMP->new("0x$str");
    Math::GMP::mod_2exp_gmp($ret, $bits);
}