The Perl Toolchain Summit needs more sponsors. If your company depends on Perl, please support this very important event.

NAME

AI::MXNet::Module::Bucketing

SYNOPSIS

    my $buckets = [10, 20, 30, 40, 50, 60];
    my $start_label   = 1;
    my $invalid_label = 0;

    my ($train_sentences, $vocabulary) = tokenize_text(
        './data/sherlockholmes.train.txt', start_label => $start_label,
        invalid_label => $invalid_label
    );
    my ($validation_sentences) = tokenize_text(
        './data/sherlockholmes.test.txt', vocab => $vocabulary,
        start_label => $start_label, invalid_label => $invalid_label
    );
    my $data_train  = mx->rnn->BucketSentenceIter(
        $train_sentences, $batch_size, buckets => $buckets,
        invalid_label => $invalid_label
    );
    my $data_val    = mx->rnn->BucketSentenceIter(
        $validation_sentences, $batch_size, buckets => $buckets,
        invalid_label => $invalid_label
    );

    my $stack = mx->rnn->SequentialRNNCell();
    for my $i (0..$num_layers-1)
    {
        $stack->add(mx->rnn->LSTMCell(num_hidden => $num_hidden, prefix => "lstm_l${i}_"));
    }

    my $sym_gen = sub {
        my $seq_len = shift;
        my $data  = mx->sym->Variable('data');
        my $label = mx->sym->Variable('softmax_label');
        my $embed = mx->sym->Embedding(
            data => $data, input_dim => scalar(keys %$vocabulary),
            output_dim => $num_embed, name => 'embed'
        );
        $stack->reset;
        my ($outputs, $states) = $stack->unroll($seq_len, inputs => $embed, merge_outputs => 1);
        my $pred = mx->sym->Reshape($outputs, shape => [-1, $num_hidden]);
        $pred    = mx->sym->FullyConnected(data => $pred, num_hidden => scalar(keys %$vocabulary), name => 'pred');
        $label   = mx->sym->Reshape($label, shape => [-1]);
        $pred    = mx->sym->SoftmaxOutput(data => $pred, label => $label, name => 'softmax');
        return ($pred, ['data'], ['softmax_label']);
    };

    my $contexts;
    if(defined $gpus)
    {
        $contexts = [map { mx->gpu($_) } split(/,/, $gpus)];
    }
    else
    {
        $contexts = mx->cpu(0);
    }

    my $model = mx->mod->BucketingModule(
        sym_gen             => $sym_gen,
        default_bucket_key  => $data_train->default_bucket_key,
        context             => $contexts
    );

    $model->fit(
        $data_train,
        eval_data           => $data_val,
        eval_metric         => mx->metric->Perplexity($invalid_label),
        kvstore             => $kv_store,
        optimizer           => $optimizer,
        optimizer_params    => {
                                    learning_rate => $lr,
                                    momentum      => $mom,
                                    wd            => $wd,
                            },
        initializer         => mx->init->Xavier(factor_type => "in", magnitude => 2.34),
        num_epoch           => $num_epoch,
        batch_end_callback  => mx->callback->Speedometer($batch_size, $disp_batches),
        ($chkp_epoch ? (epoch_end_callback  => mx->rnn->do_rnn_checkpoint($stack, $chkp_prefix, $chkp_epoch)) : ())
    );

DESCRIPTION

    Implements the AI::MXNet::Module::Base API, and allows multiple
    symbols to be used depending on the `bucket_key` provided by each different
    mini-batch of data

new

    Parameters
    ----------
    $sym_gen : subref or any perl object that overloads &{} op
        A sub when called with a bucket key, returns a list with triple
        of ($symbol, $data_names, $label_names).
    $default_bucket_key : str or anything else
        The key for the default bucket.
    $logger : Logger
    $context : AI::MXNet::Context or array ref of AI::MXNet::Context objects
        Default is cpu(0)
    $work_load_list : array ref of Num
        Default is undef, indicating uniform workload.
    $fixed_param_names: arrayref of str
        Default is undef, indicating no network parameters are fixed.
    $state_names : arrayref of str
        states are similar to data and label, but not provided by data iterator.
        Instead they are initialized to 0 and can be set by set_states()

bind

    Binding for a AI::MXNet::Module::Bucketing means setting up the buckets and bind the
    executor for the default bucket key. Executors corresponding to other keys are
    binded afterwards with switch_bucket.

    Parameters
    ----------
    :$data_shapes : ArrayRef[AI::MXNet::DataDesc|NameShape]
        This should correspond to the symbol for the default bucket.
    :$label_shapes= : Maybe[ArrayRef[AI::MXNet::DataDesc|NameShape]]
        This should correspond to the symbol for the default bucket.
    :$for_training : Bool
        Default is 1.
    :$inputs_need_grad : Bool
        Default is 0.
    :$force_rebind : Bool
        Default is 0.
    :$shared_module : AI::MXNet::Module::Bucketing
        Default is undef. This value is currently not used.
    :$grad_req : str, array ref of str, hash ref of str to str
        Requirement for gradient accumulation. Can be 'write', 'add', or 'null'
        (defaults to 'write').
        Can be specified globally (str) or for each argument (array ref, hash ref).
    :$bucket_key : str
        bucket key for binding. by default is to use the ->default_bucket_key

switch_bucket

    Switch to a different bucket. This will change $self->_curr_module.

    Parameters
    ----------
    :$bucket_key : str (or any perl object that overloads "" op)
        The key of the target bucket.
    :$data_shapes :  Maybe[ArrayRef[AI::MXNet::DataDesc|NameShape]]
        Typically $data_batch->provide_data.
    :$label_shapes : Maybe[ArrayRef[AI::MXNet::DataDesc|NameShape]]
        Typically $data_batch->provide_label.

save_checkpoint

    Save current progress to a checkpoint.
    Use mx->callback->module_checkpoint as epoch_end_callback to save during training.

    Parameters
    ----------
    prefix : str
        The file prefix to checkpoint to
    epoch : int
        The current epoch number
    save_optimizer_states : bool
        Whether to save optimizer states for later training