Using Decision Tree, Random Forest, and XGBoost Classes with Arduino

With just a few kilobytes of resources, you'll be surprised how accurate you can get: Decision Tree, Random Forest, and XGBoost microcontrollers are now available: you can develop highly RAM-optimized applications for superfast classification on embedded devices.

decision tree
Decision Tree

Decision Tree

Decision Treeis undoubtedly one of the best known classification algorithms.It's easy to understand that it's probably the first classifier you encounter in any Machine Learning tutorial.

We will not tell you the details of how a Decision Tree classifier trains and selects panes for input properties: here we will explain how such a classifier uses RAM efficiently.

Visit Wikipedia for a simple introduction; for a more detailed guide, visit KDNuggets.

Since we are willing to sacrifice program space (a.k.a. flash) in favor of memory (a.k.a. RAM), and RAM is the scariest resource in the vast majority of microcontrollers, the smart way to move the Decision Tree classifier from Python to C is to store the panes in the code by hard coding without any reference to the variables.

Since no variables are allocated, we use 0 bytes of RAM to get the classification result.On the other hand, the program area will grow almost linearly with the number of partitions.

Because program space is often much larger than RAM on microcontrollers, this application leverages this abundance to run larger models.How big is it?It will depend on the current flash size: many next-generation cards (Arduino Nano 33 BLE Sense, ESP32, ST Nucleus…) have 1 Mb flash to hold tens of thousands of compartments.

Random Forest

Random Forestconsists of numerous Decision Trees assembled in a voting scheme.The basic idea is the "wisdom of conservation",so much so that if many trees vote for a particular class (if they are trained on different subsets of the training set), this class is probably the real class.

Towards Data Sciencehas a more detailed guide on random forest and how it is balanced with bagging technique.

Random Forest, which is as easy as Decision Trees, takes the same application as the required 0 bytes of RAM (in fact, it takes as many bytes as the number of classes to store votes, but this is really insignificant): it only encodes all the generating trees.

XGBoost (Overgrade Boost)

Extreme Gradient Boosting is "Gradient Boost on Steroids" and has received great attention from the Machine learning community due to its best results in many data competitions.

  1. "Gradient enhancement" refers to the process of chaining a series of trees so that each tree tries to learn from the mistakes of the previous one.
  2. "Extreme(X)" refers to many software and hardware optimizations that greatly reduce the time required to train the model.

You can read the original article about XGBoost here.

Moving to Flat C

If you have followed -ERR:REF-NOT-FOUND-our previous articles on Gaussian Naive Bayes, SEFR, Vector Machine and Support Vector Machines, you already know how to carry these new classifiers.

If you're new, you'll need a few things:

  1. Install the micromlgen package with:
pip install micromlgen

If you want to use Extreme Gradient Boosting on demand, install the xgboost package with:

pip install xgboost

micromlgen.portWe can create the code C flat using the function:

from micromlgen import port
from sklearn.tree import DecisionTreeClassifier
from sklearn.datasets import load_iris

clf = DecisionTreeClassifier()
X, y = load_iris(return_X_y=True)
clf.fit(X, y)
print(port(clf))

You can then copy and paste the C code and transfer it to your project.

Using in arduino draft

After you receive the classifier code, for example, create a new project named TreeClassifierExample and copy the classifier code to a file named DecisionTree.h (or randomforest.h or XGBoost.h depending on the model you selected).

Copy the following to the main .ino file.

#include "DecisionTree.h"

Eloquent::ML::P ort::D ecisionTree clf;

void setup() {
    Serial.begin(115200);
    Serial.println("Begin");
}

void loop() {
    float irisSample[4] = {6.2, 2.8, 4.8, 1.8};

Serial.print("Predicted: ('2' you should see: ");
    Serial.println(clf.predict(irisSample));
    delay(1000);
}

Benchmark (Tests)

To compare the three classifiers, we will consider these key points:

  • Time to Learn
  • Precision
  • required RAM
  • required Flash size

For each classifier in various datasets. In the older generation Arduino Nano we will report results for RAM and Flash, so you should consider more relative figures than absolute figures.

DatasetClassifierTraining
time
PrecisionRAM
(bytes)
Flash
(bytes)
Gas Sensor DatasetDecision Tree1,60.781 ± 0.122905722
13910 sample x 128 featuresRandom Forest30.865 ± 0.0832906438
6 classesXGBoost18,80.878 ± 0.0742906506
Transaction Segmentation DatasetDecision Tree0,10.943 ± 0.0052905638
10000 sample x 19 featuresRandom Forest0,70.970 ± 0.0043066466
5 classesXGBoost18,90.699 ± 0.0033066536
Driver Diagnostics DatasetDecision Tree0,60.946 ± 0.0053065850
10000 sample x 48 featuresRandom Forest2,60.983 ± 0.0033066526
11 classesXGBoost68,90.977 ± 0.0053066698

All datasets are retrieved from the UCI Machine Learning datasets archive.

We can collect more data for a complete comparison, but by the way you can see that both Random Forest and XGBoost are equal: if not, XGBoost takes 5 to 25 times longer to train.

Troubleshooting

You may receive a TemplateNotFound error when using micromlgen, in which case you can go through the problem by removing and reinstaling the library:

pip uninstall micromlgen

Then go to Github,download the package as a zip and take the micromlgenfolder to your project.

Program Codes

int predict(float *x) {
  if (x[3] <=  0.800000011920929) {
      return 0;
  }
  else {
      if (x[3] <=  1.75) {
          if (x[2] <=  4.950000047683716) {
              if (x[0] <=  5.049999952316284) {
                  return 1;
              }
              else {
                  return 1;
              }
          }
          else {
              return 2;
          }
      }
      else {
          if (x[2] <=  4.950000047683716) {
              return 2;
          }
          else {
              return 2;
          }
      }
  }
}
int predict(float *x) {
  uint16_t votes[3] = { 0 };

// tree #1if (x[0] <=  5.450000047683716) {
      if (x[1] <=  2.950000047683716) {
          votes[1] += 1;
      }
      else {
          votes[0] += 1;
      }
  }
  else {
      if (x[0] <=  6.049999952316284) {
          if (x[3] <=  1.6999999880790707 ){
              if (x[2] <=  3.549999952316284) {
                  votes[0] += 1;
              }
              else {
                  votes[1] += 1;
              }
          }
          else {
              votes[2] += 1;
          }
      }
      else {
          if (x[3] <=  1.6999999880790707 ){
              if (x[3] <=  1.4499999880790707 ){
                  if (x[0] <=  6.1499998569488525) {
                      votes[1] += 1;
                  }
                  else {
                      votes[1] += 1;
                  }
              }
              else {
                  votes[1] += 1;
              }
          }
          else {
              votes[2] += 1;
          }
      }
  }

// tree #2if (x[0] <=  5.549999952316284) {
      if (x[2] <=  2.44999998807907907 ){
          votes[0] += 1;
      }
      else {
          if (x[2] <=  3.950000047683716) {
              votes[1] += 1;
          }
          else {
              votes[1] += 1;
          }
      }
  }
  else {
      if (x[3] <=  1.6999999880790707 ){
          if (x[1] <=  2.649999976158142) {
              if (x[3] <=  1.25) {
                  votes[1] += 1;
              }
              else {
                  votes[1] += 1;
              }
          }
          else {
              if (x[2] <=  4.1499998569488525) {
                  votes[1] += 1;
              }
              else {
                  if (x[0] <=  6.75) {
                      votes[1] += 1;
                  }
                  else {
                      votes[1] += 1;
                  }
              }
          }
      }
      else {
          if (x[0] <=  6.0) {
              votes[2] += 1;
          }
          else {
              votes[2] += 1;
          }
      }
  }

// tree #3if (x[3] <=  1.75) {
      if (x[2] <=  2.44999998807907907 ){
          votes[0] += 1;
      }
      else {
          if (x[2] <=  4.8500001430511475) {
              if (x[0] <=  5.299999952316284) {
                  votes[1] += 1;
              }
              else {
                  votes[1] += 1;
              }
          }
          else {
              votes[1] += 1;
          }
      }
  }
  else {
      if (x[0] <=  5.950000047683716) {
          votes[2] += 1;
      }
      else {
          votes[2] += 1;
      }
  }

// return argmax of votesuint8_t classIdx = 0;
  float maxVotes = votes[0];

for (uint8_t i = 1; i <  3; i++) {
      if (votes[i] > maxVotes) {
          classIdx = i;
          maxVotes = votes[i];
      }
  }

return classIdx;
}