#! /usr/bin/env perl
# vim: ts=4 sw=4 expandtab nosmarttab
#
# vw generic optimal parameter golden-section search
#   http://en.wikipedia.org/wiki/Golden_section_search
#
# Call without parameters for a full usage message
#
# Credits:
#   * Paul Mineiro ('vowpalwabbit/demo') wrote the original script
#   * Ariel Faigon
#       - Generalize, rewrite, refactor
#       - Add usage message
#       - Add tolerance as optional parameter
#       - Add documentation
#       - Better error messages when underlying command fails for any reason
#       - Add '-t test-set' to optimize on test error rather than
#         train-set error
#       - Add integral-value option support
#       - More reliable/informative progress indication
#       - Add -e <external-eval-plugin-script> support
#       Bug fixes:
#           - Golden-section search bug-fix
#           - Loss value capture bug-fix (can be in scientific notation)
#           - Handle special cases where certain options don't make sense
#   * Alex Hudek:
#       - Support log-space search which seems to work better
#         with --l1 (very small values) and/or hinge-loss
#   * Alex Trufanov:
#       - A bunch of very useful bug reports:
#           https://github.com/JohnLangford/vowpal_wabbit/issues/406
#
use warnings;
use strict;
use Getopt::Std;
use Scalar::Util qw(looks_like_number);
use File::Temp;
use vars qw($opt_v $opt_b $opt_t $opt_c $opt_L $opt_e);

my $MaxDouble = 1.79769e+308;
my $ModelFile;
my $ScratchModel = 0;
my $IntegerOpt = 0;

sub v(@) {
    return unless ($opt_v);
    if (@_ == 1) {
        print STDERR @_;
    } else {
        printf STDERR @_;
    }
}

#
# ExecCmd(cmd, timeout)
#  cmd     - a command or reference to an array of command + arguments
#  timeout - number of seconds to wait (0 = forever)
#
# returns:
#   In case of success:
#       cmd results (STDERR and STDOUT merged into an array ref)
#
#   In case of any failure:
#       (timeout, death from signal, non-zero exit status,
#        failure of fork or exec, ...)
#       print a user-helpful error message and abort
#
sub ExecCmd($;$) {
    my $cmd = shift || return(0, []);
    my $timeout = shift || 0;

    if (ref($cmd) eq 'ARRAY') {
        # Always convert the command to one simple string
        # Otherwise meta-chars in @$cmd don't trigger 'sh -c'
        $cmd = "@$cmd";
    }

    # opening a pipe creates a forked process
    my $pid = open(my $pipe, '-|');
    die "Can't fork: $!\n" unless defined $pid;

    if ($pid) {     # parent

        my @result = ();

        if ($timeout) {
            my $failed = 1;
            eval {
                # set a signal to die if the timeout is reached
                local $SIG{ALRM} = sub { die "alarm\n" };
                alarm $timeout;
                @result = <$pipe>;
                alarm 0;
                $failed = 0;
            };
            die "command timeout. Result so far:\n@result\n" if $failed;

        } else {
            while (<$pipe>) {
                push @result, $_;
                # User feedback on training progress: ......
                # only on real SGD (avg loss) progress
                # (keep 'quiet' on startup/informative lines)
                print STDERR '.' if (/^\d/);
            }
        }
        close($pipe);

        my $exitcode = $? >> 8;
        my $signal = $? & 127;
        my $core = ($? & 128) ? ' (core dumped)' : '';

        if ($signal) {
            warn "\n\t$cmd died from signal $signal$core\n";
            die "Try to run:\n\t$cmd\nmanually to figure out why.\n";
        } elsif ($exitcode) {
            warn "\n\t$cmd failed: (exitcode=$exitcode)\n";
            warn "Output from $cmd is:\n@result\n";
            warn "You may try to run:\n\t$cmd\nmanually to figure out why.\n"
                unless (@result);
            exit 1;
        }
        # return exit status, command output
        return \@result;
    }

    # Child...
    no warnings;

    # redirect STDERR to STDOUT
    open(STDERR, '>&STDOUT');

    # Run the command in child process - if successful, it should never return
    exec($cmd);

    # this code will not execute unless exec fails!
    die "$0: can't exec '$cmd': $!";
}

#
# real_param($)
#   real-space value of param, use this to make sure whetever we show
#   to the user has the correct value when -L is present or not.
#
sub real_param($) {
    my $param = shift;
    ($opt_L) ? exp($param) : $param;
}

#
# loss(param)
#   Input:
#       @ARGV - A vw command where the parameter we want to optimize
#               for appears as a '%' placeholder
#       param - Value for parameter for which we check the loss
#
#   Returns the vw average loss for the vw command @ARGV with '%'
#       set to param value
#
my %Loss;   # x -> f(x) cache
my $BestParam;
my $BestLoss = $MaxDouble;
my $NumberPat = qr{
    (?:                 # mantissa
        [-+]?               # optional sign
        (?:
            \d*\.\d+        # numeric part
            |\d+(?:\.\d*)?
            |\d+
        )
    )
    (?:[Ee][-+]?\d+)?   # optional exponent
}ox;

#
# get_loss($param, @command)
#   Run command (and optional external eval plug-in)
#   and extract the loss from the output.
#   Keep track of best (argmin($param)) result so far.
#
sub get_loss($@) {
    my ($param, @command) = @_;

    my $rv = ExecCmd \@command;

    my $loss;
    my $best_msg = '';

    if ($opt_e) {
        $rv = ExecCmd $opt_e;
    }

    # Read from the end, so if we run a test after train,
    # we get the last (test) loss, rather than the train loss.
    foreach my $line (reverse @$rv) {
        if (
                ($opt_e && $line =~ /($NumberPat)/o)
                    or
                ($line =~ /^average loss\s*=\s*(\S+)/)
            ) {

            # Found a loss
            $loss = $1;
            # should bail out of loop on 1st loss line found
            last;
        }
    }

    if ($loss <= $BestLoss) {
        $BestParam = $param;
        $BestLoss = $loss;
        $best_msg = ' (best)';
    }
 
    die "\n$0: failed to parse average loss from vw output: ",
        join('', @$rv),
        "\n\nTry to run:\n\t@command\nmanually to figure out why.\n"
            unless defined ($loss);

    unlink($ModelFile) if (($opt_t || $opt_c) && $ScratchModel && -e $ModelFile);

    printf STDERR " %.6g%s\n", $loss, $best_msg;
    $loss; 
}

sub loss($) {
    my ($param) = @_;

    if ($opt_L) {
        # -- map back from log-space to real value of the parameter
       $param = exp($param);  
    }
    # if already tested/cached return immediately
    return  $Loss{$param} if (exists $Loss{$param});

    # otherwise try to run an evaluate:
    printf STDERR "trying %s ", $param;

    my @command = @ARGV;
    foreach my $c (@command) {
        $c =~ s/%/$param/g;
    }

    # Run command (and optional external eval plug-in)
    # and extract the loss from the output.
    my $the_loss = get_loss($param, @command);

    $Loss{$param} = $the_loss;      # cache it
    $the_loss;
}

#
# argmin3($$$$$$$)
#   Given 3 value pairs: (arg, loss)  as a six-tuple
#   returns the pair for which the loss is the smallest
#
sub argmin3 ($$$$$$) {
    my ($a, $fa, $b, $fb, $c, $fc) = @_;

    if ($fa < $fb) {
        return $fa < $fc ? ($a, $fa) : ($c, $fc);
    }
    # else:
    return $fb < $fc ? ($b, $fb) : ($c, $fc);
}

sub usage(@) {
    print STDERR @_, "\n" if @_;

    die "Usage: $0 [options] <lower_bound> <upper_bound> [tolerance] vw-command...

    Options:
        -t <TS>     Use <TS> as test-set file + evaluate goodness on it
        -c <TC>     Use <TC> as test-cache file + evaluate goodness on it
                    (This implies '-t' except the test-cache <TC> will be used
                     instead of <TS>)
        -L          Use log-space for mid-point bisection
        -e <prog>   Use <prog> (external utility) to evaluate how good
                    the result is instead of vw's own 'average loss' line
                    The protocol is: <prog> simply prints a number (last
                    line numeric output is used) and the search attempts
                    to find argmin() for this output.  You may also pass
                    arguments to the plugin using: -e '<prog> <args...>'

    lower_bound     lower bound of the parameter search range
    upper_bound     upper bound of the parameter search range
    tolerance       termination condition for optimal parameter search
                    expressed as an improvement fraction of the range

    vw-command ...  a vw command to repeat in a loop until optimum
                    is found, where '%' is used as a placeholder
                    for the optimized parameter

    NOTE:
            -t ... for doing a line-search on a test-set is an
            argument to '$0', not to 'vw'.  vw-command must
            represent the training-phase only.
"
}

#
# opt_value($)
#   Make sure integer-args remain integers, and that we have not
#   gotten into a loop returning the same number (no progress)
#
sub opt_value($) {
    my ($value) = @_;
    my $new_value = $value;
    if ($IntegerOpt) {
        # round it
        $new_value = int($value + 0.5);
        # v("opt_value($value): $new_value\n");
    }
    $new_value;
}

my $MaxIter = 50;

#
# BrentsSearch($low, $high, $tau)
#   As transcribed from wikipedia.
#   Doesn't give a better result than GoldenSection (yet)
#   Need to figure out why.
#
sub BrentsSearch($$$) {
    my ($low, $high, $tau) = @_;
    my ($a, $b, $c, $d) = ($low, $high, 0.0, $MaxDouble);

    my $fa = loss($a);
    my $fb = loss($b);

    my ($fc, $s, $fs) = (0.0, 0.0, 0.0);

    if ($fa * $fb >= 0) {
        if ($fa < $fb) {
            return opt_value($a);
        } else {
            return opt_value($b);
        }
    }

    # if |f(a)| < |f(b)| then swap (a,b)
    if (abs($fa) < abs($fb)) {
        my $tmp = $a; $a = $b; $b = $tmp;
        $tmp = $fa; $fa = $fb; $fb = $tmp;
    }

    $c = $a;
    $fc = $fa;
    my $mflag = 1;
    my $i = 0;

    while (($fb != 0) and (abs($a-$b) > $tau)) {
        if (($fa != $fc) && ($fb != $fc)) {
            # Inverse quadratic interpolation
            $s =    $a * $fb * $fc / ($fa - $fb) / ($fa - $fc)
                                +
                    $b * $fa * $fc / ($fb - $fa) / ($fb - $fc)
                                +
                    $c * $fa * $fb / ($fc - $fa) / ($fc - $fb);
        } else {
            # Secant Rule
            $s =    $b - $fb * ($b - $a) / ($fb - $fa);
        }

        my $tmp2 = (3.0 * $a + $b) / 4.0;
        if (
               ( ! ((($s > $tmp2) && ($s < $b)) ||
                    (($s < $tmp2) && ($s > $b)))
               )
                        or
               ($mflag && (abs($s - $b) >= (abs($b - $c) / 2.0)))
                        or
               (! $mflag && (abs($s - $b) >= (abs($c - $d) / 2.0)))
            )
        {
            $s = ($a + $b) / 2.0;
            $mflag = 1;
        } else {
            if (($mflag && (abs($b - $c) < $tau)) ||
                (! $mflag && (abs($c - $d) < $tau))) {

                $s = ($a + $b) / 2.0;
                $mflag = 1;
            } else {
                $mflag = 0;
            }
        }
        $fs = loss($s);
        $d = $c;
        $c = $b;
        $fc = $fb;

        if ($fa * $fs < 0.0) {
            $b = $s; $fb = $fs;
        } else {
            $a = $s; $fa = $fs;
        }

        # if |f(a)| < |f(b)| then swap (a,b)
        if (abs($fa) < abs($fb)) {
            my $tmp = $a; $a = $b; $b = $tmp;
            $tmp = $fa; $fa = $fb; $fb = $tmp;
        }
        $i++;
        if ($i > $MaxIter) {
            die "Brent's method error too many iterations: $i: f(b): $fb\n";
        }
    }
    opt_value($b);
}

my $Phi = (1.0 + sqrt (5.0)) / 2.0;
my $ResPhi = 2.0 - $Phi;

sub goldenSectionSearch($$$$);

sub goldenSectionSearch($$$$) {
    my ($low, $mid, $high, $tau) = @_;

    my $x;

    $mid = opt_value($mid);
    my $upper_range = $high - $mid;
    my $lower_range = $mid - $low;
    if ($upper_range > $lower_range) {
        $x = $mid + $ResPhi * $upper_range;
    } else {
        $x = $mid - $ResPhi * $lower_range;
    }
    $x = opt_value($x);

    if (abs($high - $low) < $tau * (abs($mid) + abs($x))) {
        return opt_value(($high + $low) / 2.0);
    }

    # assert(loss($x) != loss($mid));
    if (loss($x) == loss($mid)) {
        printf STDERR "loss(%g) == loss(%g): %g\n",
                            real_param($x), real_param($mid), loss($x);
        return opt_value(($x + $mid) / 2.0);
    }

    if (loss($x) < loss($mid)) {
        if ($upper_range > $lower_range) {
            return opt_value(goldenSectionSearch($mid, $x, $high, $tau));
        } else {
            return opt_value(goldenSectionSearch($low, $x, $mid, $tau));
        }
    }

    # else:
    if ($upper_range > $lower_range) {
        return opt_value(goldenSectionSearch($low, $mid, $x, $tau));
    }
    # else
    return opt_value(goldenSectionSearch($x, $mid, $high, $tau));
}


#
# best_hyperparam(lower_bound, upper_bound, tolerance)
#
#   performs golden-ratio section search between lower_bound
#   and upper_bound for the best parameter (lowest loss)
#
#   Termination condition is difference is less than tolerance
#   of the [lower .. upper] range.
#
sub best_hyperparam($$$) {
    my ($lb, $ub, $tol) = @_;

    # log-space search trick: map the inital end-points of the range
    # to log-space and let the bisection algo work in the new range as
    # if it is linear. Only 'loss()' knows about this trick to reverse it,
    # goldenSectionSearch is oblivious to it.
    if ($opt_L) {
        $lb = log($lb);
        $ub = log($ub);
    }

    my ($best, $fbest);

    if ($opt_b) {
        $best = BrentsSearch($lb, $ub, $tol);
        $fbest = loss($best);
    } else {
        my $mid = $lb + $ResPhi * ($ub - $lb);
        $best = goldenSectionSearch($lb, $mid, $ub, $tol);
        $fbest = loss($best);
    }
    if ($opt_L) {
        # remap to real range
        $best = exp($best);
    }

    # Still happens that what we finally return is not the best...
    if ($BestLoss < $fbest) {
        $best = $BestParam;
        $fbest = $BestLoss;
    }
    ($best, $fbest);
}

sub fullpath($) {
    my $exe = shift;
    # I'm told on Windows, the PATH separator is ';' instead of ':'
    my ($path_sep, $ext) = (':', '');
    if ($^O =~ /MSWin/i) {
        $path_sep = ';';
        $ext = '.exe';
    }
    foreach my $dir (split($path_sep, $ENV{PATH}), '.') {
        my $fp = "$dir/$exe";
        return $fp if (-x $fp);
        # On windows, the executable may be called 'vw.exe'
        # so try that as well
        if ($ext) {
            $fp = "$dir/$exe$ext";
            return $fp if (-x $fp);
        }
    }
    '';
}

#
# is_integer_option($option)
#   Return true for all options which expect an integral argument
#
sub is_integer_option($) {
    my $opt = shift;
    my $expects_integer =
        ($opt =~ qr{^-*
            bs?
            |affix
            |autolink
            |batch_sz
            |bootstrap|B
            |(?:csoaa|wap)(?:_ldf)?
            |cb(?:ify)?
            |num_children
            |holdout_(?:period|after)
            |initial_pass_length
            |lda
            |log_multi
            |lrq
            |nn
            |oaa
            |ect
            |passes
            |rank
            |ring_size
            |examples
            |searn
            |search
            |total
            |top
            |ngram
            |skips
        $}x);

    v("is_integer_option(%s): %d\n", $opt, $expects_integer);

    $expects_integer;
}

sub process_args {
    $0 =~ s{.*/}{};

    getopts('vbLc:t:e:');

    # Brent is not ready for prime-time yet so don't advertise it.
    warn "$0: using Brent's method search\n" if ($opt_b);
    if ($opt_t) {       # evaluate on test-set
        usage("-t $opt_t: $!") unless (-e $opt_t && -r $opt_t);
    }
    if ($opt_c) {       # evaluate on test-cache
        usage("-c $opt_c: $!") unless (-e $opt_c && -r $opt_c);
    }
    if ($opt_e) {       # external plugin 'loss like' evaluator
        # Allow users to pass args to the eval script,
        # assume first non-space sequence is the command
        my ($first_word) = ($opt_e =~ /^\s*(\S+)/);
        my $is_ex = (-e $first_word && -x $first_word) || qx{which $first_word};
        usage("-e $opt_e: command not found") unless $is_ex;
        warn "Using external evaluation plug-in: '$opt_e'\n";
    }

    # v("after getopts: \@ARGV=(@ARGV)\n");

    usage("Too few arguments...")
        unless (@ARGV > 2);

    usage("1st argument: $ARGV[0]: expecting a number")
        unless (looks_like_number($ARGV[0]));

    usage("2nd argument: $ARGV[1]: expecting a number")
        unless (looks_like_number($ARGV[1]));

    my $lower_bound = shift @ARGV;
    my $upper_bound = shift @ARGV;

    if ($lower_bound > $upper_bound) {
        warn "$0: lower bound $lower_bound > $upper_bound upper bound: swapping them for you\n";
        my $t = $lower_bound;
        $lower_bound = $upper_bound;
        $upper_bound = $t;
    }
    my $tolerance = 1e-4;
    if (looks_like_number($ARGV[0])) {
        $tolerance = shift @ARGV;
        usage("3rd argument (tolerance): $tolerance: not in the (0, 1) range")
            unless (0 < $tolerance and $tolerance < 1.0);
    }
    if (! -x $ARGV[0]) {
        # Not directly executable, look along PATH:
        my $fp = fullpath($ARGV[0]);
        if ($fp && -x $fp) {
            $ARGV[0] = $fp;
        }
    }
    # retry after looking along PATH, warn if still no go
    if (! -x $ARGV[0]) {
        warn "Couldn't find '$ARGV[0]' executable: " .
             "(is \$PATH correct?)" . "Will try anyway.\n";
    }

    # --- Now @ARGV contains the full vw command

    # --quiet prevents progress from being detected so don't allow it
    my @saved_argv = @ARGV; @ARGV = ();
    foreach my $arg (@saved_argv) {
        next if ($arg eq '--quiet');
        push(@ARGV, $arg);
    }

    my $vw_command_line = "@ARGV";
    usage("command '@ARGV': must include a '%' wildcard to optimize on")
        unless ($vw_command_line =~ /%/);

    # Some vw option arguments must be integers, if we want to
    # optimize those, we must bisect while rounding to an integer
    my ($option_before_pct) = ($vw_command_line =~ /(\S+)\s+\S*%/);
    $IntegerOpt = is_integer_option($option_before_pct);

    if ($opt_t || $opt_c) {
        # Evaluate on test:
        #   1) Make sure we store the model after training
        #   2) Call vw again, loading model and evaluating on test-set
        if ($vw_command_line =~ /-f\s+(\S+)/) {
            $ModelFile = $1;
        } else {
            # my $tmpdir = mkdtemp('vw-hypersearch-XXXXXX');
            # Model file doesn't exist in user provided command line
            # so we need to add its generation
            $ModelFile = mktemp("vw-hypersearch.model-XXXXXX");
            $ScratchModel = 1;
            v("\$ModelFile: %s\n", $ModelFile);
            push(@ARGV, '-f', $ModelFile);
        }
        # Add the test command after the train command
        push(@ARGV, '&&', $ARGV[0], '-t', '-i', $ModelFile);

        # Preserve the user-specified loss function, if any
        if ($vw_command_line =~ /(--loss_function)\s+(\S+)/) {
            push(@ARGV, $1, $2);
        }
        # Add the test file
        if ($opt_t) {
            push(@ARGV, $opt_t);
        }
        # Add test-cache if -c is in effect
        if ($opt_c) {
            push(@ARGV, '--cache_file', $opt_c);
        }
        v("New -t command line: %s\n", "@ARGV");
    }

    if ((defined $opt_L) and $opt_L) {
        if ($IntegerOpt) {
            warn "$0: -L: incompatible with integer option $option_before_pct" .
                    " - ignoring -L\n";
            $opt_L = 0;
        } else {
            warn "$0: -L: using log-space search\n";
        }
    } else {
        $opt_L = 0;
        if (! $IntegerOpt
                and
            ($vw_command_line =~ /--l1\s+%/  ||
             $vw_command_line =~ /--loss_function\s+hinge/  ||
             $lower_bound =~ /e/i  ||
             $upper_bound =~ /e/i)
        ) {
            warn "$0: you may get better results with -L (log-space search)\n" .
                "\t\twhen any of --l1/hinge-loss/small-param-values are used\n";
        }
    }

    return ($lower_bound, $upper_bound, $tolerance);
}

# -- main
my ($lower_bound, $upper_bound, $tolerance) = process_args();
my ($best, $fbest) = best_hyperparam($lower_bound, $upper_bound, $tolerance);

printf "%g\t%g\n", $best, $fbest;

