Subject: | set_callback missing |
I noticed that this Perl module does not wrap the set_callback function provided by the library.
I've implemented it in the following branch in this GitHub repository:
https://github.com/Flimm/p5-AI-FANN/commits/set_callback
I've also attached a patch, if you don't like GitHub.
Subject: | set_callback.patch |
From 1d1d68cd9f806431a61d389f22b8fc4d388b3579 Mon Sep 17 00:00:00 2001
From: David D Lowe <flimm@cpan.org>
Date: Fri, 27 Feb 2015 15:48:08 +0000
Subject: [PATCH] Add new set_callback method
---
FANN.xs | 69 ++++++++++++++++++++++++++++++++++++++++++++
lib/AI/FANN.pm | 37 ++++++++++++++++++++++++
t/callback.t | 90 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
3 files changed, 196 insertions(+)
create mode 100644 t/callback.t
diff --git FANN.xs FANN.xs
index fa495209d..72350aff3 100644
--- FANN.xs
+++ FANN.xs
@@ -73,6 +73,49 @@ _fta2sv(pTHX_ fann_type *fta, unsigned int len) {
return newRV_noinc((SV*)av);
}
+int FANN_API
+call_perl_callback(struct fann *ann, struct fann_train_data *train,
+ unsigned int max_epochs, unsigned int epochs_between_reports,
+ float desired_error, unsigned int epochs)
+{
+ dTHX;
+ HV *callback_hash = get_hv("AI::FANN::_callback_for_ann", 0);
+ if (! callback_hash) {
+ Perl_croak(aTHX_ "Could not get callback hash");
+ }
+ char buffer[256];
+ if (snprintf(buffer, sizeof(buffer), "%" IVdf, PTR2IV(ann)) >= sizeof(buffer)) {
+ Perl_croak(aTHX_ "Could not store key in buffer");
+ }
+ SV **callback = hv_fetch(callback_hash, buffer, strlen(buffer), NULL);
+ if (! callback) {
+ Perl_croak(aTHX_ "Could not get callback");
+ }
+ dSP;
+ ENTER;
+ SAVETMPS;
+ PUSHMARK(SP);
+ XPUSHs(&PL_sv_undef);
+ XPUSHs(&PL_sv_undef);
+ XPUSHs(sv_2mortal(newSViv(max_epochs)));
+ XPUSHs(sv_2mortal(newSViv(epochs_between_reports)));
+ XPUSHs(sv_2mortal(newSVnv(desired_error)));
+ XPUSHs(sv_2mortal(newSViv(epochs)));
+ PUTBACK;
+ int count = call_sv(*callback, G_SCALAR);
+ SPAGAIN;
+ SV *return_sv = POPs;
+ int return_value = 0;
+ if (SvOK(return_sv)) {
+ return_value = SvIV(return_sv);
+ }
+ PUTBACK;
+ FREETMPS;
+ LEAVE;
+ return return_value;
+}
+
+
static AV*
_srv2av(pTHX_ SV* sv, unsigned int len, char * const name) {
if (SvROK(sv)) {
@@ -211,6 +254,20 @@ fann_new_standard(klass, ...)
CLEANUP:
_check_error(aTHX_ (struct fann_error *)RETVAL);
+void
+fann__enable_perl_callback(self)
+ struct fann * self;
+ CODE:
+ fann_set_callback(self, call_perl_callback);
+
+int
+fann__get_struct_addr(self)
+ struct fann * self;
+ CODE:
+ RETVAL = PTR2IV(self);
+ OUTPUT:
+ RETVAL
+
struct fann *
fann_new_sparse(klass, connection_rate, ...)
SV *klass;
@@ -267,6 +324,18 @@ void
fann_DESTROY(self)
struct fann * self;
CODE:
+ HV *callback_hash = get_hv("AI::FANN::_callback_for_ann", 0);
+ if (! callback_hash) {
+ Perl_croak(aTHX_ "Could not get callback hash");
+ }
+ char buffer[256];
+ if (snprintf(buffer, sizeof(buffer), "%" IVdf, PTR2IV(self)) >= sizeof(buffer)) {
+ Perl_croak(aTHX_ "Could not store key in buffer");
+ }
+ SV **callback = hv_fetch(callback_hash, buffer, strlen(buffer), NULL);
+ if (hv_exists(callback_hash, buffer, strlen(buffer))) {
+ hv_delete(callback_hash, buffer, strlen(buffer), G_DISCARD);
+ }
fann_destroy(self);
sv_unmagic(SvRV(ST(0)), '~');
diff --git lib/AI/FANN.pm lib/AI/FANN.pm
index 387c5bf8a..db7b9b01d 100644
--- lib/AI/FANN.pm
+++ lib/AI/FANN.pm
@@ -10,6 +10,12 @@ require XSLoader;
XSLoader::load('AI::FANN', $VERSION);
use Exporter qw(import);
+use Scalar::Util qw(refaddr);
+
+# %_callback_for_ann is like an inside-out object for callback data that can't
+# fit in the struct. It's public but marked private to allow tests to access
+# it.
+our %_callback_for_ann;
{
my @constants = _constants();
@@ -36,6 +42,18 @@ sub num_neurons {
}
}
+sub set_callback {
+ @_ == 2 or croak "Usage: set_callback(self, callback)";
+ my ($self, $callback) = @_;
+
+ croak "self is not a reference" if ! ref($self);
+ croak "callback is not a code reference" if ! defined($callback) || ref($callback) ne "CODE";
+
+ $_callback_for_ann{$self->_get_struct_addr} = $callback;
+ $self->_enable_perl_callback;
+ return;
+}
+
1;
__END__
@@ -532,6 +550,25 @@ return the number of neurons on layer C<$layer_index>.
return a list with the number of neurons on every layer
+=item $ann->set_callback($callback)
+
+Sets the callback for use during training. If this is not set, the default
+callback function simply prints out some status information. $callback may not
+be undefined. Here's an example of a callback:
+
+ $callback = sub {
+ my ($unused1, $unused2, $max_epochs, $epochs_between_reports,
+ $desired_error, $epochs) = @_;
+ printf("Epochs: %d\n", $epochs);
+ return 0;
+ }
+
+The callback is called in scalar context. It should return an integer or undef,
+if the callback function returns -1, the training will terminate.
+
+Note that the first two arguments are unused, they are currently undef, but may
+be changed in future versions to the AI::FANN and AI::FANN::TrainData objects.
+
=back
=head2 AI::FANN::TrainData
diff --git t/callback.t t/callback.t
new file mode 100644
index 000000000..be44f0029
--- /dev/null
+++ t/callback.t
@@ -0,0 +1,90 @@
+use strict;
+use warnings;
+
+use Test::More;
+
+use AI::FANN qw(:all);
+
+my @data = ([-1, -1], [-1],
+ [-1, 1], [1],
+ [1, -1], [1],
+ [1, 1], [-1]);
+
+is(scalar(keys %AI::FANN::_callback_for_ann), 0, "Zero keys in %callback to start with");
+
+{
+ my $ann = AI::FANN->new_standard(2, 3, 1);
+
+ $ann->hidden_activation_function(FANN_SIGMOID_SYMMETRIC);
+ $ann->output_activation_function(FANN_SIGMOID_SYMMETRIC);
+
+ my $xor_train = AI::FANN::TrainData->new(@data);
+
+ cmp_ok($ann->_get_struct_addr(), '>=', 0, "_get_struct_addr returns positive number");
+ is($ann->_get_struct_addr(), $ann->_get_struct_addr(), "Consecutive calls of _get_struct_addr consistent");
+
+ my $num_called = 0;
+ my $last_epoch = undef;
+
+ my $rc_callback = sub {
+ $num_called++;
+
+ my ($c_ann, $train_data, $max_epochs, $epoch_between_reports, $desired_error, $epochs) = @_;
+ is(scalar(@_), 6, "Callback got 6 arguments");
+
+ $desired_error = sprintf("%.3f", $desired_error);
+ is($c_ann , undef , "Callback received ann argument as expected");
+ is($train_data , undef , "Callback received train_data argument as expected");
+ is($max_epochs , 500000 , "Callback received max_epochs argument as expected");
+ is($epoch_between_reports , 1000 , "Callback received epoch_between_reports as expected");
+ is($desired_error , 0.001 , "Callback received desired_error as expected");
+ if (defined $last_epoch) {
+ cmp_ok($epochs, '>', $last_epoch, "Callback received epochs greater than last recorded");
+ }
+ else {
+ cmp_ok($epochs, '>', 0, "Callback received epochs argument greater than 0");
+ }
+ $last_epoch //= $epochs;
+ return;
+ };
+
+ $ann->set_callback($rc_callback);
+
+ is($num_called, 0, "Callback still hasn't been called");
+
+ is_deeply([values %AI::FANN::_callback_for_ann], [$rc_callback], "Callback registered");
+
+ $ann->train_on_data($xor_train, 500000, 1000, 0.001);
+
+ cmp_ok($num_called, '>=', 1, "Callback called at least once");
+}
+
+{
+ my $ann = AI::FANN->new_standard(2, 3, 1);
+
+ $ann->hidden_activation_function(FANN_SIGMOID_SYMMETRIC);
+ $ann->output_activation_function(FANN_SIGMOID_SYMMETRIC);
+
+ my $xor_train = AI::FANN::TrainData->new(@data);
+
+ my $num_called = 0;
+ my $last_epoch = undef;
+
+ my $rc_callback = sub {
+ $num_called++;
+ return -1;
+ };
+
+ $ann->set_callback($rc_callback);
+
+ is($num_called, 0, "Callback still hasn't been called");
+
+ $ann->train_on_data($xor_train, 500000, 1000, 0.001);
+
+ is($num_called, 1, "Callback called exactly once");
+}
+
+is(scalar(keys %AI::FANN::_callback_for_ann), 0, "Zero keys in %callback at the end");
+
+
+done_testing;
--
1.9.1