The Perl Toolchain Summit needs more sponsors. If your company depends on Perl, please support this very important event.
package Redis::NaiveBayes;
# ABSTRACT: A generic Redis-backed NaiveBayes implementation
$Redis::NaiveBayes::VERSION = '0.0.4';

use strict;
use warnings;
use List::Util qw(sum reduce);

use Redis;

use constant {
    LABELS => 'labels',
};

# Lua scripts
my $LUA_FLUSH_FMT = q{
    local namespace  = '%s'
    local labels_key = namespace .. '%s'
    for _, member in ipairs(redis.call('smembers', labels_key)) do
        redis.call('del', namespace .. member)
        redis.call('del', namespace .. 'tally_for:' .. member)
    end
    redis.call('del', labels_key);
};

my $LUA_TRAIN_FMT = q{
    -- ARGV:
    --   1: raw label name being trained
    --   2: number of tokens being updated
    --   3-X: token being updated
    --   X+1-N: value to increment corresponding token

    local namespace  = '%s'
    local labels_key = namespace .. '%s'
    local label      = namespace .. ARGV[1]
    local tally_key  = namespace .. 'tally_for:' .. ARGV[1]
    local num_tokens = ARGV[2]
    local tot_added  = 0

    redis.call('sadd', labels_key, ARGV[1])

    for index, token in ipairs(ARGV) do
        if index > num_tokens + 2 then
            break
        end
        if index > 2 then
            redis.call('hincrby', label, token, ARGV[index + num_tokens])
            tot_added = tot_added + ARGV[index + num_tokens]
        end
    end

    local old_tally = redis.call('get', tally_key);
    if (not old_tally) then
        old_tally = 0
    end

    redis.call('set', tally_key, old_tally + tot_added)
};

my $LUA_UNTRAIN_FMT = q{
    -- ARGV:
    --   1: raw label name being untrained
    --   2: number of tokens being updated
    --   3-X: token being updated
    --   X+1-N: value to increment corresponding token

    local namespace  = '%s'
    local labels_key = namespace .. '%s'
    local label      = namespace .. ARGV[1]
    local tally_key  = namespace .. 'tally_for:' .. ARGV[1]
    local num_tokens = ARGV[2]

    for index, token in ipairs(ARGV) do
        if index > num_tokens + 2 then
            break
        end
        if index > 2 then
            local current = redis.call('hget', label, token);

            if (current and current - ARGV[index + num_tokens] > 0) then
                redis.call('hincrby', label, token, -1 * ARGV[index + num_tokens])
            else
                redis.call('hdel', label, token)
            end
        end
    end

    local tally = 0
    for _, value in ipairs(redis.call('hvals', label)) do
        tally = tally + value
    end

    if tally <= 0 then
        redis.call('del', label)
        redis.call('srem', labels_key, ARGV[1])
        redis.call('del', tally_key)
    else
        redis.call('set', tally_key, tally)
    end
};

my $_LUA_CALCULATE_SCORES = q{
    -- ARGV
    --   1: correction
    --   2: number of tokens
    --   3-X: tokens
    --   X+1-N: values for each token
    -- FIXME: I'm ignoring the scores per token on purpose for now

    local namespace  = '%s'
    local labels_key = namespace .. '%s'
    local correction = ARGV[1]
    local num_tokens = ARGV[2]

    local scores = {}

    for index, raw_label in ipairs(redis.call('smembers', labels_key)) do
        local label = namespace .. raw_label

        local tally = tonumber(redis.call('get', namespace .. 'tally_for:' .. raw_label))

        if (tally and tally > 0) then
            scores[raw_label] = 0.0

            for idx, token in ipairs(ARGV) do
                if idx > num_tokens + 2 then
                    break
                end

                if idx > 2 then
                    local score = redis.call('hget', label, token)

                    if (not score or score == 0) then
                        score = correction
                    end

                    scores[raw_label] = scores[raw_label] + math.log(score / tally)
                end
            end
        end
    end
};

my $LUA_SCORES_FMT = qq{
    $_LUA_CALCULATE_SCORES

    local return_crap = {}
    local index = 1
    for key, value in pairs(scores) do
        return_crap[index] = key
        return_crap[index+1] = value
        index = index + 2
    end

    return return_crap;
};

my $LUA_CLASSIFY_FMT = qq{
    $_LUA_CALCULATE_SCORES

    local best_label = nil
    local best_score = nil
    for label, score in pairs(scores) do
        if (best_score == nil or best_score < score) then
            best_label = label
            best_score = score
        end
    end

    return best_label
};


sub new {
    my ($class, %args) = @_;
    my $self = bless {}, $class;

    $self->{redis}      = $args{redis}      || Redis->new(%args);
    $self->{correction} = $args{correction} || 0.001;
    $self->{namespace}  = $args{namespace}  or die "Missing namespace";
    $self->{tokenizer}  = $args{tokenizer}  or die "Missing tokenizer";

    $self->_load_scripts;

    return $self;
}

sub _redis_script_load {
    my ($self, $script_fmt, @args) = @_;

    my ($sha1) = $self->{redis}->script_load(sprintf($script_fmt, $self->{namespace}, LABELS, @args));

    return $sha1;
}

sub _load_scripts {
    my ($self) = @_;

    $self->{scripts} = {};

    $self->{scripts}->{flush} = $self->_redis_script_load($LUA_FLUSH_FMT);
    $self->{scripts}->{train} = $self->_redis_script_load($LUA_TRAIN_FMT);
    $self->{scripts}->{untrain} = $self->_redis_script_load($LUA_UNTRAIN_FMT);
    $self->{scripts}->{scores} = $self->_redis_script_load($LUA_SCORES_FMT);
    $self->{scripts}->{classify} = $self->_redis_script_load($LUA_CLASSIFY_FMT);
}

sub _exec {
    my ($self, $command, $key, @rest) = @_;

    return $self->{redis}->$command($self->{namespace} . $key, @rest);
}

sub _run_script {
    my ($self, $script, $numkeys, @rest) = @_;

    $numkeys ||= 0;
    my $sha1 = $self->{scripts}->{$script} or die "Script wasn't loaded: '$script'";

    $self->{redis}->evalsha($sha1, $numkeys, @rest);
}


sub flush {
    my ($self) = @_;

    $self->_run_script('flush');
}

sub _mrproper {
    my ($self) = @_;

    my @keys = $self->{redis}->keys($self->{namespace} . '*');
    $self->{redis}->del(@keys) if @keys;
}

sub _train {
    my ($self, $label, $item, $script) = @_;

    my $occurrences = $self->{tokenizer}->($item);
    die "tokenizer() didn't return a HASHREF" unless ref $occurrences eq 'HASH';

    my @argv = ($label, (scalar keys %$occurrences), keys %$occurrences, values %$occurrences);

    $self->_run_script($script, 0, @argv);

    return $occurrences;
}


sub train {
    my ($self, $label, $item) = @_;

    return $self->_train($label, $item, 'train');
}


sub untrain {
    my ($self, $label, $item) = @_;

    return $self->_train($label, $item, 'untrain');
}


sub classify {
    my ($self, $item) = @_;

    my $occurrences = $self->{tokenizer}->($item);
    die "tokenizer() didn't return a HASHREF" unless ref $occurrences eq 'HASH';

    my @argv = ($self->{correction}, scalar keys %$occurrences, keys %$occurrences, values %$occurrences);

    my $best_label = $self->_run_script('classify', 0, @argv);

    return $best_label;
}


sub scores {
    my ($self, $item) = @_;

    my $occurrences = $self->{tokenizer}->($item);
    die "tokenizer() didn't return a HASHREF" unless ref $occurrences eq 'HASH';

    my @argv = ($self->{correction}, scalar keys %$occurrences, keys %$occurrences, values %$occurrences);

    my %scores = $self->_run_script('scores', 0, @argv);

    return \%scores;
}

sub _labels {
    my ($self) = @_;

    return $self->_exec('smembers', LABELS);
}

sub _priors {
    my ($self, $label) = @_;

    my %data = $self->_exec('hgetall', $label);
    return { %data };
}


1;

__END__

=pod

=encoding UTF-8

=head1 NAME

Redis::NaiveBayes - A generic Redis-backed NaiveBayes implementation

=head1 VERSION

version 0.0.4

=head1 SYNOPSIS

    my $tokenizer = sub {
        my $input = shift;

        my %occurs;
        $occurs{$_}++ for split(/\s/, lc $input);

        return \%occurs;
    };

    my $bayes = Redis::NaiveBayes->new(
        namespace => 'playground:',
        tokenizer => \&tokenizer,
    );

=head1 DESCRIPTION

This distribution provides a very simple NaiveBayes classifier
backed by a Redis instance. It uses the evalsha functionality
available since Redis 2.6.0 to try to speed things up while
avoiding some obvious race conditions during the untrain() phase.

The goal of Redis::NaiveBayes is to keep dependencies at
minimum while being as generic as possible to allow any sort
of usage. By design, it doesn't provide any sort of tokenization
nor filtering out of the box.

=head1 METHODS

=head2 new

    my $bayes = Redis::NaiveBayes->new(
        namespace  => 'playground:',
        tokenizer  => \&tokenizer,
        correction => 0.1,
        redis      => $redis_instance,
    );

Instantiates a L<Redis::NaiveBayes> instance using the provided
C<correction>, C<namespace> and C<tokenizers>.

If provided, it also uses a L<Redis> instance (C<redis> parameter)
instead of instantiating one by itself.

A tokenizer is any subroutine that returns a HASHREF of occurrences
in the item provided for train()ining or classify()ing.

=head2 flush

    $bayes->flush;

Cleanup all the possible keys this classifier instance could've
touched. If you want to clean everything under the provided namespace,
call _mrproper() instead, but beware that it will delete all the
keys that match C<namespace*>.

=head2 train

    $bayes->train("ham", "this is a good message");
    $bayes->train("spam", "price from Nigeria needs your help");

Trains as a label ("ham") the given item. The item can be any arbitrary
structure as long as the provided C<tokenizer> understands it.

=head2 untrain

    $bayes->untrain("ham", "I don't thing this message is good anymore")

The opposite of train().

=head2 classify

    my $label = $bayes->classify("Nigeria needs help");
    >>> "spam"

Gets the most probable category the provided item in is.

=head2 scores

    my $scores = $bayes->scores("any sort of message");

Returns a HASHREF with the scores for each of the labels known by the model

=head1 NOTES

This module is heavilly inspired by the Python implementation
available at https://github.com/jart/redisbayes - the main
difference, besides the obvious language choice, is that
Redis::NaiveBayes focuses on being generic and minimizing
the number of roundtrips to Redis.

=head1 TODO

=over

=item Add support for additive smoothing

=back

=head1 SEE ALSO

L<Redis>, L<Redis::Bayes>, L<Algorithm::NaiveBayes>

=head1 AUTHORS

=over 4

=item *

Caio Romão <cpan@caioromao.com>

=item *

Stanislaw Pusep <stas@sysd.org>

=back

=head1 COPYRIGHT AND LICENSE

This software is Copyright (c) 2013 by Caio Romão.

This is free software, licensed under:

  The MIT (X11) License

=cut