Decision Tree, Random Forest ve XGBoost Sınıflarını Arduino ile Kullanmak

Yalnızca birkaç kilobaytlık kaynakla ne kadar doğruluk elde edebileceğinize şaşıracaksınız: Karar Ağacı(Decision Tree), Random Forest ve XGBoost (Aşırı Gradyan Artırma) mikrodenetleyicilerinizde artık kullanılabilir: gömülü cihazlarda süper hızlı sınıflandırma için son derece RAM optimize edilmiş uygulamalar geliştirebilirsiniz.

Decision Tree, Random Forest ve XGBoost Sınıflarını Arduino ile Kullanmak decision tree,random forest,xgboost
Karar Ağacı

Karar Ağacı

Karar Ağacı, şüphesiz en iyi bilinen sınıflandırma algoritmalarından biridir. Muhtemelen herhangi bir Makine Öğrenimi eğiticisinde karşılaştığınız ilk sınıflandırıcı olduğunu anlamak çok kolaydır.

Bir Karar Ağacı sınıflandırıcısının girdi özellikleri için bölmeleri nasıl eğittiğini ve seçtiğinin ayrıntılarını anlatmayacağız: burada böyle bir sınıflandırıcının RAM’i nasıl verimli bir şekilde kullandığını açıklayacağız.

Basit bir giriş için Wikipedia’yı ziyaret edin; daha ayrıntılı bir kılavuz için KDNuggets adresini ziyaret edin.

Program alanını (diğer adıyla flash) bellek (diğer adıyla RAM) lehine feda etmeye istekli olduğumuzdan ve RAM mikro denetleyicilerin büyük çoğunluğunda en kıt kaynak olduğundan, Karar Ağacı sınıflandırıcısını Python’dan C’ye taşımanın akıllı yolu değişkenlere herhangi bir referans vermeden koddaki bölmeleri sabit kodlama yaparak saklamaktır.

Hiçbir değişken tahsis edilmediğinden sınıflandırma sonucunu almak için 0 bayt RAM kullanıyoruz. Öte yandan, program alanı bölme sayısı ile neredeyse doğrusal olarak büyüyecektir.

Program alanı genellikle mikro denetleyicilerde RAM’den çok daha büyük olduğundan, bu uygulama daha büyük modelleri çalıştırabilmek için bu bolluktan yararlanır. Ne kadar büyük? Mevcut flaş boyutuna bağlı olacaktır: birçok yeni nesil kartta (Arduino Nano 33 BLE Sense, ESP32, ST Nucleus…) on binlerce bölmeyi tutacak 1 Mb flaş bulunur.

Random Forest

Random Forest, bir oylama düzeninde bir araya getirilmiş çok sayıda Karar Ağacından ibarettir. Temel fikir, “korumanın bilgeliği”dir, öyle ki, belirli bir sınıf için birçok ağaç oy verirse (eğitim setinin farklı alt kümeleri üzerinde eğitilmişse), bu sınıf muhtemelen gerçek sınıftır.

Towards Data Science, Random Forest ve torbalama tekniği ile nasıl dengelendiği hakkında daha detaylı bir rehbere sahiptir.

Karar Ağaçları kadar kolay olan Random Forest, gereken 0 bayt RAM ile aynı uygulamayı alır (aslında oyları depolamak için sınıf sayısı kadar bayt gerekir, ancak bu gerçekten önemsizdir): sadece tüm oluşturan ağaçları kodlar.

XGBoost (Aşırı Gradyan Artırma)

Extreme Gradient Boosting, “Steroidler üzerinde Gradyan Artırma” dır ve birçok veri yarışmasında en iyi sonuçları nedeniyle Makine öğrenimi topluluğundan büyük ilgi görmüştür.

  1. “Degrade artırma”, her ağacın bir öncekinin hatalarından öğrenmeye çalışması için bir dizi ağacı zincirleme sürecini ifade eder.
  2. “Aşırı(X)”, modeli eğitmek için gereken süreyi büyük ölçüde azaltan birçok yazılım ve donanım optimizasyonunu ifade eder.

XGBoost hakkındaki orijinal makaleyi buradan okuyabilirsiniz.

Düz C’ye Taşıma

Gaussian Naive BayesSEFRVektör Makinesi ve Destek Vektör Makineleri hakkındaki önceki yazılarımızı takip ettiyseniz, bu yeni sınıflandırıcıları nasıl taşıyacağınızı zaten biliyorsunuzdur.

Yeniyseniz, birkaç şeye ihtiyacınız olacak:

  1. micromlgen paketini şununla kurun:
pip install micromlgen

İsteğe bağlı olarak Extreme Gradient Boosting’i kullanmak istiyorsanız xgboost paketini şununla kurun:

pip install xgboost

micromlgen.portfonksiyonunu kullanarak düz C kodunu oluşturabiliriz:

from micromlgen import port
from sklearn.tree import DecisionTreeClassifier
from sklearn.datasets import load_iris

clf = DecisionTreeClassifier()
X, y = load_iris(return_X_y=True)
clf.fit(X, y)
print(port(clf))

Daha sonra C kodunu kopyalayıp yapıştırabilir ve projenize aktarabilirsiniz.

Arduino Taslağında Kullanmak

Sınıflandırıcı kodunu aldıktan sonra, örneğin TreeClassifierExample adında yeni bir proje oluşturun ve sınıflandırıcı kodunu DecisionTree.h adlı bir dosyaya kopyalayın (veya RandomForest.h veya XGBoost.h seçtiğiniz modele bağlı olarak).

Aşağıdakileri ana .ino dosyasına kopyalayın.

#include "DecisionTree.h"

Eloquent::ML::Port::DecisionTree clf;

void setup() {
    Serial.begin(115200);
    Serial.println("Begin");
}

void loop() {
    float irisSample[4] = {6.2, 2.8, 4.8, 1.8};

    Serial.print("Tahmin edilen: ('2' görmelisiniz: ");
    Serial.println(clf.predict(irisSample));
    delay(1000);
}

Benchmark (Sınamalar)

Üç sınıflandırıcıyı karşılaştırmak için, bu kilit noktaları göz önüne alacağız:

  • Öğrenme Vakti
  • Kesinlik
  • gerekli RAM
  • gerekli Flaş boyutu

Çeşitli veri kümelerindeki her sınıflandırıcı için. Eski nesil Arduino Nano’da RAM ve Flash için sonuçları rapor edeceğiz, bu yüzden mutlak rakamlardan daha göreceli rakamları düşünmelisiniz.

Veri kümesiSınıflandırıcıEğitim
süresi
KesinlikRAM
(bayt)
Flaş
(bayt)
Gaz Sensörü Veri KümesiKarar Ağacı1,60.781 ± 0.122905722
13910 örnek x 128 özellikRandom Forest30.865 ± 0.0832906438
6 sınıfXGBoost18,80.878 ± 0.0742906506
Hareket Segmentasyonu Veri KümesiKarar Ağacı0,10,943 ± 0,0052905638
10000 örnek x 19 özellikRandom Forest0,70,970 ± 0,0043066466
5 sınıfXGBoost18,90,699 ± 0,0033066536
Sürücü Tanılama Veri KümesiKarar Ağacı0,60,946 ± 0,0053065850
10000 örnek x 48 özellikRandom Forest2,60,983 ± 0,0033066526
11 sınıfXGBoost68,90,977 ± 0,0053066698

Tüm veri kümeleri UCI Machine Learning veri kümeleri arşivinden alınmıştır.

Tam bir kıyaslama için daha fazla veri toplayabiliriz, ancak bu arada hem Rastgele Ormanın hem de XGBoost’un eşit olduğunu görebilirsiniz: eğer değilse, XGBoost’un eğitilmesi 5 ila 25 kat daha uzun sürüyor.

Sorun Giderme

micromlgen kullanırken TemplateNotFound hatası alabilirsiniz, böyle bir durumda kütüphaneyi kaldırıp tekrar kurarak sorunu gidebilirsiniz:

pip uninstall micromlgen

Ardından Github’a gidin, paketi zip olarak indirin ve micromlgenklasörünü projenize çıkarın.

Program Kodları

int predict(float *x) {
  if (x[3] <= 0.800000011920929) {
      return 0;
  }
  else {
      if (x[3] <= 1.75) {
          if (x[2] <= 4.950000047683716) {
              if (x[0] <= 5.049999952316284) {
                  return 1;
              }
              else {
                  return 1;
              }
          }
          else {
              return 2;
          }
      }
      else {
          if (x[2] <= 4.950000047683716) {
              return 2;
          }
          else {
              return 2;
          }
      }
  }
}
int predict(float *x) {
  uint16_t votes[3] = { 0 };

  // tree #1
  if (x[0] <= 5.450000047683716) {
      if (x[1] <= 2.950000047683716) {
          votes[1] += 1;
      }
      else {
          votes[0] += 1;
      }
  }
  else {
      if (x[0] <= 6.049999952316284) {
          if (x[3] <= 1.699999988079071) {
              if (x[2] <= 3.549999952316284) {
                  votes[0] += 1;
              }
              else {
                  votes[1] += 1;
              }
          }
          else {
              votes[2] += 1;
          }
      }
      else {
          if (x[3] <= 1.699999988079071) {
              if (x[3] <= 1.449999988079071) {
                  if (x[0] <= 6.1499998569488525) {
                      votes[1] += 1;
                  }
                  else {
                      votes[1] += 1;
                  }
              }
              else {
                  votes[1] += 1;
              }
          }
          else {
              votes[2] += 1;
          }
      }
  }

  // tree #2
  if (x[0] <= 5.549999952316284) {
      if (x[2] <= 2.449999988079071) {
          votes[0] += 1;
      }
      else {
          if (x[2] <= 3.950000047683716) {
              votes[1] += 1;
          }
          else {
              votes[1] += 1;
          }
      }
  }
  else {
      if (x[3] <= 1.699999988079071) {
          if (x[1] <= 2.649999976158142) {
              if (x[3] <= 1.25) {
                  votes[1] += 1;
              }
              else {
                  votes[1] += 1;
              }
          }
          else {
              if (x[2] <= 4.1499998569488525) {
                  votes[1] += 1;
              }
              else {
                  if (x[0] <= 6.75) {
                      votes[1] += 1;
                  }
                  else {
                      votes[1] += 1;
                  }
              }
          }
      }
      else {
          if (x[0] <= 6.0) {
              votes[2] += 1;
          }
          else {
              votes[2] += 1;
          }
      }
  }

  // tree #3
  if (x[3] <= 1.75) {
      if (x[2] <= 2.449999988079071) {
          votes[0] += 1;
      }
      else {
          if (x[2] <= 4.8500001430511475) {
              if (x[0] <= 5.299999952316284) {
                  votes[1] += 1;
              }
              else {
                  votes[1] += 1;
              }
          }
          else {
              votes[1] += 1;
          }
      }
  }
  else {
      if (x[0] <= 5.950000047683716) {
          votes[2] += 1;
      }
      else {
          votes[2] += 1;
      }
  }

  // return argmax of votes
  uint8_t classIdx = 0;
  float maxVotes = votes[0];

  for (uint8_t i = 1; i < 3; i++) {
      if (votes[i] > maxVotes) {
          classIdx = i;
          maxVotes = votes[i];
      }
  }

  return classIdx;
}