summaryrefslogtreecommitdiff
path: root/src/support/learning.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/support/learning.h')
-rw-r--r--src/support/learning.h113
1 files changed, 113 insertions, 0 deletions
diff --git a/src/support/learning.h b/src/support/learning.h
new file mode 100644
index 000000000..2c251d87b
--- /dev/null
+++ b/src/support/learning.h
@@ -0,0 +1,113 @@
+/*
+ * Copyright 2016 WebAssembly Community Group participants
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef wasm_learning_h
+#define wasm_learning_h
+
+#include <algorithm>
+#include <random>
+
+namespace wasm {
+
+//
+// Machine learning using a genetic algorithm.
+//
+// The Genome class is the element on which to learn. It must
+// implement the following:
+//
+// * Fitness getFitness(); - calculate how good this item is.
+//
+// The Generator must implement the following:
+//
+// * Genome* makeRandom(); - make a random element
+// * Genome* makeMixture(Genome* one, Genome* two); - make a new element by mixing two
+//
+// Fitness is the type of the fitness values, e.g. uint32_t. More is better.
+//
+// Typical usage of this class is to run call runGeneration(), check the best
+// quality using getBest()->getFitness(), and do that repeatedly until the
+// fitness is good enough. Then acquireBest() to get ownership of the best,
+// and the learner can be discarded (with all the rest of the population
+// cleaned up).
+
+template<typename Genome, typename Fitness, typename Generator>
+class GeneticLearner {
+ Generator& generator;
+
+ typedef std::unique_ptr<Genome> unique_ptr;
+ std::vector<unique_ptr> population;
+
+ void sort() {
+ std::sort(population.begin(), population.end(), [this](const unique_ptr& left, const unique_ptr& right) {
+ return left->getFitness() > right->getFitness();
+ });
+ }
+
+ std::mt19937 noise;
+
+ size_t randomIndex() {
+ // simple random index that favorizes low indexes TODO tweak
+ return std::min(noise() % population.size(), noise() % population.size());
+ }
+
+public:
+ GeneticLearner(Generator& generator, size_t size) : generator(generator), noise(1337) {
+ population.resize(size);
+ for (size_t i = 0; i < size; i++) {
+ population[i] = unique_ptr(generator.makeRandom());
+ }
+ sort();
+ }
+
+ Genome* getBest() {
+ return population[0].get();
+ }
+
+ unique_ptr acquireBest() {
+ return population[0];
+ }
+
+ void runGeneration() {
+ size_t size = population.size();
+
+ // we have a mix of promoted from the last generation, mixed from the last generation, and random
+ const size_t promoted = (25 * size) / 100;
+ const size_t mixed = (50 * size) / 100;
+
+ // promoted just stay in place
+ // mixtures are computed, then added back in (as we still need them as we work)
+ std::vector<unique_ptr> mixtures;
+ mixtures.resize(mixed);
+ for (size_t i = 0; i < mixed; i++) {
+ mixtures[i] = unique_ptr(generator.makeMixture(population[randomIndex()].get(), population[randomIndex()].get()));
+ }
+ for (size_t i = 0; i < mixed; i++) {
+ population[promoted + i].swap(mixtures[i]);
+ }
+ // TODO: de-duplicate at this point
+ // randoms fill in the test
+ for (size_t i = promoted + mixed; i < size; i++) {
+ population[i] = unique_ptr(generator.makeRandom());
+ }
+
+ sort();
+ }
+};
+
+} // namespace wasm
+
+#endif // wasm_learning_h
+