1. Einleitung

Machine Learning (ML) ist in den letzten Jahren Bestandteil von zahlreichen neuen Apps. Deshalb stellt sich die Frage wie KI in (Java)-Anwendungen integriert werden können. Kann ML in Java funktionieren? Ist Python überlegen? Fakt ist das Training von sogenannten Modellen wird häufig mit Python durchgeführt, aber auch Java bietet Möglichkeiten, ML-Modelle zu erstellen und in Anwendungen zu integrieren.

1.1. KI und Java – geht das?

Obwohl Python in der Machine Learning-Community dominiert, gibt es zahlreiche Libraries, welche nicht nur gut instandgehalten (maintained) werden, sondern auch vergleichbare Funktionen wie Python-Libraries zur verfügung stellen.

Da viele Backends in Java bereits vorhanden sind, würde eine Integration so erleichtert werden.

Was ist möglich mit Java?

  • Training von ML-Modellen direkt in Java.

  • Integration von bestehenden Modellen (z.B.: aus Python) in Java-Anwendungen.

1.2. Warum (kein) Java?

Hier eine Gegenüberstellung der Vor- und Nachteile, wenn man sich für oder gegen Java entscheidet.

✅ Pro

❌ Kontra

Große Anzahl stabiler¹ Bibliotheken (z. B. Deeplearning4j, Weka, Tribuo)

Weniger Ressourcen & Community-Beiträge im Vergleich zu Python

Einfach in bestehende Java-Backends integrierbar

Weniger Beispiele & Tutorials verfügbar

Gute Performance und Skalierbarkeit

Entwicklung ist teils aufwendiger als in Python

ChatGPT ist keine Hilfe

¹ Libraries, welche aktuell gehalten werden und nicht nur von einer einzelnen Person entwickelt werden.

2. Der Iris-Datensatz

Der Iris-Datensatz ist ein klassisches Beispiel für das Training von ML-Modellen. Er hat Daten über mehrere Blumenmerkmale, welche es ermöglicht die Blumenart zu bestimmen.

2.1. Species

Hier sehen wir eine Übersicht der verschiedenen Iris-Arten, die im Datensatz enthalten sind.

Iris Species

2.2. Sepal und Petal

Ein wichtiges Merkmal der Iris-Blume ist die Größe des Sepals und Petals. Diese Attribute sind entscheidend für die Klassifizierung der Blume.

Sepal und Petal

2.3. CSV

Der Iris-Datensatz ist in unserem Projekt als CSV gespeichert. Hier ein paar Beispieldatensätze:

CSV Sample

3. Decision Tree

Ein Decision Tree ist ein weit verbreitetes Modell für Klassifizierungsprobleme. In den folgenden Abschnitten erklären wir die Funktionsweise von "Entscheidungsbäumen".

3.1. Baumstruktur

Ein Decision Tree ist eine if-Bedingung welche weitere enthält und so eine Aussage trifft.

Decision Tree Struktur

3.2. Splits

Ein Split teilt den Datensatz in zwei Teile basierend auf einer Bedingung. Die Qualität des Splits ist entscheidend für die Genauigkeit des Modells.

Decision Tree Split

3.3. Gute und schlechte Splits

Nicht alle Splits sind gleich gut. Ein schlechter Split führt zu einer niedrigen Modellgenauigkeit. Hier sehen wir den Unterschied zwischen guten und schlechten Splits.

Guter vs. schlechter Split

3.4. Umsetzung mit der Smile-Library für Java

Die Smile-Library ist eine leistungsstarke Java-Bibliothek für maschinelles Lernen, die auch Decision Trees unterstützt.

@ApplicationScoped
public class DecisionTreeRepository {
    private DecisionTree model;
    private Map<String, Integer> speciesMap = new HashMap<>();

    public void setSpeciesMap(Map<String, Integer> map) {
        this.speciesMap = map;
    }

    public String getSpeciesLabel(int code) {
        return speciesMap.entrySet()
                .stream()
                .filter(entry -> entry.getValue() == code)
                .map(Map.Entry::getKey)
                .findFirst()
                .orElse("Unknown");
    }

    public void train(List<Iris> irisList) {
        Formula formula = Formula.lhs("species");
        DataFrame data = DataFrame.of(Iris.class, irisList);
        DecisionTree.Options options = new DecisionTree.Options();
        model = DecisionTree.fit(formula, data, options);
    }

    public int[] predict(DataFrame input) {
        if (model == null) {
            throw new IllegalStateException("Model is not trained yet.");
        }
        return model.predict(input);
    }
}

4. Random Forest

Random Forest ist eine Weiterentwicklung des Decision Trees und verbessert die Leistung, indem mehrere Bäume kombiniert werden.

4.1. Theorie

Random Forest nutzt eine Vielzahl von Entscheidungsbäumen und aggregiert deren Ergebnisse. Dies verbessert die Vorhersagegenauigkeit.

Random Forest Illustration

4.2. Umsetzung

Um ein Random Forest zu implementieren, tauschen wir einfach den verwendeten Baumtyp aus.

private DecisionTree model; // von
private RandomForest model; // durch

5. Neural Network

Neurale Netze sind eine der fortschrittlichsten Methoden im maschinellen Lernen. Sie eignen sich hervorragend für komplexe Klassifizierungsaufgaben.

5.1. Aufbau

Neurale Netze bestehen aus vielen Schichten, die zusammenarbeiten, um Muster in den Daten zu erkennen.

Neuronales Netzwerk Illustration

5.2. Klassifizierung

Hier ein Beispiel für die Klassifizierung von Iris-Blumen mit einem Neuronalen Netzwerk.

Klassifizierung mit NN

5.3. Umsetzung mit DeepLearning4J

DeepLearning4J ist eine populäre Java-Bibliothek für Neuronale Netzwerke.

private MultiLayerConfiguration getModelConfiguration() {
    return new NeuralNetConfiguration.Builder()
            .seed(123)
            .l2(0.0005)
            .weightInit(WeightInit.XAVIER)
            .activation(Activation.RELU)
            .updater(new Nesterovs(0.02, 0.9))
            .list()
            .layer(new DenseLayer.Builder().nIn(4).nOut(10).build())
            .layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
                    .nIn(10).nOut(3)  // 3 output classes: Setosa, Versicolor, Virginica
                    .activation(Activation.SOFTMAX)
                    .build())
            .setInputType(InputType.feedForward(4))  // 4 input features
            .build();
}

6. Import von Modellen

Ein häufiger Use Case ist das Importieren von Modellen, die in anderen Programmiersprachen wie Python erstellt wurden.

6.1. Idee

In diesem Fall erstellen wir ein Modell in Python und importieren es in eine Java-Anwendung. Dies ermöglicht es, das Modell ohne Neutraining zu verwenden.

6.2. Was ist PMML?

PMML (Predictive Model Markup Language) ist ein standardisiertes Format zum Austausch von ML-Modellen zwischen verschiedenen Programmiersprachen.

  • Eine PMML-Datei beschreibt das trainierte Modell und kann von vielen Programmiersprachen gelesen werden, einschließlich Java, C++, Go.

  • PMML ermöglicht die Trennung zwischen dem Trainieren des Modells (z. B. in Python) und der Verwendung des Modells (z. B. in Java).

6.3. Wie funktioniert’s?

  1. Modell in Python trainieren (z. B. Decision Tree).

  2. Modell in eine PMML-Datei umwandeln (z. B. mit sklearn2pmml).

  3. PMML-Datei in Java-Projekt einfügen.

  4. Java-Programm liest die PMML-Datei und nutzt das Modell zur Vorhersage.

6.4. Vorteile

  • Kein Modell-Neuaufbau in Java nötig.

  • Funktioniert mit vielen anderen Programmiersprachen.

  • Klare Trennung zwischen Modelltraining und Modellnutzung.

7. Fazit

Machine Learning mit Java ist eine mächtige, wenn auch anspruchsvollere Methode im Vergleich zu Python. Es bietet hohe Performance und Skalierbarkeit, ist jedoch nicht so einfach zu implementieren.

7.1. Performance

Performance

7.2. Wann nimmt man was?

  • Java: Wenn Performance und Trainingszeit wichtig sind, ist Java eine gute Wahl.

  • Python: Wenn Einfachheit und schnelle Implementierung bevorzugt werden, ist Python ideal.

8. Slides

9. Quellen