L'algorithme de clustering K-Means en Java

1. Vue d'ensemble

Le clustering est un terme générique désignant une classe d'algorithmes non supervisés permettant de découvrir des groupes de choses, de personnes ou d'idées étroitement liés les uns aux autres .

Dans cette définition en une seule ligne apparemment simple, nous avons vu quelques mots à la mode. Qu'est-ce que le clustering exactement? Qu'est-ce qu'un algorithme non supervisé?

Dans ce tutoriel, nous allons tout d'abord éclairer ces concepts. Ensuite, nous verrons comment ils peuvent se manifester en Java.

2. Algorithmes non supervisés

Avant d'utiliser la plupart des algorithmes d'apprentissage, nous devrions en quelque sorte leur fournir des exemples de données et permettre à l'algorithme d'apprendre à partir de ces données. Dans la terminologie Machine Learning, nous appelons cet exemple de données d'entraînement de jeu de données. En outre, l'ensemble du processus est connu sous le nom de processus de formation.

Quoi qu'il en soit, nous pouvons classer les algorithmes d'apprentissage en fonction de la quantité de supervision dont ils ont besoin pendant le processus de formation. Les deux principaux types d'algorithmes d'apprentissage de cette catégorie sont:

  • Apprentissage supervisé : dans les algorithmes supervisés, les données d'apprentissage doivent inclure la solution réelle pour chaque point. Par exemple, si nous sommes sur le point de former notre algorithme de filtrage anti-spam, nous transmettons à l'algorithme les exemples d'e-mails et leur libellé, c'est-à-dire spam ou non-spam. Mathématiquement parlant, nous allons déduire le f (x) d'un ensemble d'apprentissage comprenant à la fois xs et ys.
  • Apprentissage non supervisé : lorsqu'il n'y a pas d'étiquettes dans les données d'apprentissage, l'algorithme est un algorithme non supervisé. Par exemple, nous avons beaucoup de données sur les musiciens et nous allons découvrir des groupes de musiciens similaires dans les données.

3. Clustering

Le clustering est un algorithme non supervisé pour découvrir des groupes de choses, d'idées ou de personnes similaires. Contrairement aux algorithmes supervisés, nous n'entraînons pas les algorithmes de clustering avec des exemples d'étiquettes connues. Au lieu de cela, le clustering tente de trouver des structures dans un ensemble d'apprentissage où aucun point des données n'est l'étiquette.

3.1. Clustering K-Means

K-Means est un algorithme de clustering avec une propriété fondamentale: le nombre de clusters est défini à l'avance . En plus des K-Means, il existe d'autres types d'algorithmes de clustering tels que le clustering hiérarchique, la propagation d'affinité ou le clustering spectral.

3.2. Comment fonctionne K-Means

Supposons que notre objectif soit de trouver quelques groupes similaires dans un ensemble de données comme:

K-Means commence par k centres de gravité placés aléatoirement. Les centroïdes, comme leur nom l'indique, sont les points centraux des grappes . Par exemple, nous ajoutons ici quatre centres de gravité aléatoires:

Ensuite, nous affectons chaque point de données existant à son centre de gravité le plus proche:

Après l'affectation, nous déplaçons les centres de gravité vers l'emplacement moyen des points qui lui sont affectés. N'oubliez pas que les centres de gravité sont censés être les points centraux des clusters:

L'itération actuelle se termine chaque fois que nous avons terminé de déplacer les centres de gravité. Nous répétons ces itérations jusqu'à ce que l'affectation entre plusieurs itérations consécutives cesse de changer:

Lorsque l'algorithme se termine, ces quatre clusters sont trouvés comme prévu. Maintenant que nous savons comment fonctionne K-Means, implémentons-le en Java.

3.3. Représentation des caractéristiques

Lors de la modélisation de différents ensembles de données d'entraînement, nous avons besoin d'une structure de données pour représenter les attributs du modèle et leurs valeurs correspondantes. Par exemple, un musicien peut avoir un attribut de genre avec une valeur comme Rock . Nous utilisons généralement le terme fonctionnalité pour désigner la combinaison d'un attribut et de sa valeur.

Pour préparer un ensemble de données pour un algorithme d'apprentissage particulier, nous utilisons généralement un ensemble commun d'attributs numériques qui peuvent être utilisés pour comparer différents éléments. Par exemple, si nous permettons à nos utilisateurs de marquer chaque artiste avec un genre, alors à la fin de la journée, nous pouvons compter combien de fois chaque artiste est tagué avec un genre spécifique:

Le vecteur caractéristique pour un artiste comme Linkin Park est [rock -> 7890, nu-metal -> 700, alternatif -> 520, pop -> 3]. Donc, si nous pouvons trouver un moyen de représenter les attributs sous forme de valeurs numériques, alors nous pouvons simplement comparer deux éléments différents, par exemple des artistes, en comparant leurs entrées vectorielles correspondantes.

Étant donné que les vecteurs numériques sont des structures de données si polyvalentes, nous allons représenter les fonctionnalités qui les utilisent . Voici comment nous implémentons les vecteurs de fonctionnalités en Java:

public class Record { private final String description; private final Map features; // constructor, getter, toString, equals and hashcode }

3.4. Recherche d'articles similaires

Dans chaque itération de K-Means, nous avons besoin d'un moyen de trouver le centre de gravité le plus proche de chaque élément de l'ensemble de données. L'un des moyens les plus simples de calculer la distance entre deux vecteurs d'entités consiste à utiliser la distance euclidienne. La distance euclidienne entre deux vecteurs comme [p1, q1] et [p2, q2] est égale à:

Implémentons cette fonction en Java. Tout d'abord, l'abstraction:

public interface Distance { double calculate(Map f1, Map f2); }

En plus de la distance euclidienne, il existe d'autres approches pour calculer la distance ou la similitude entre différents éléments, comme le coefficient de corrélation de Pearson . Cette abstraction permet de basculer facilement entre différentes mesures de distance.

Voyons l'implémentation de la distance euclidienne:

public class EuclideanDistance implements Distance { @Override public double calculate(Map f1, Map f2) { double sum = 0; for (String key : f1.keySet()) { Double v1 = f1.get(key); Double v2 = f2.get(key); if (v1 != null && v2 != null) { sum += Math.pow(v1 - v2, 2); } } return Math.sqrt(sum); } }

First, we calculate the sum of squared differences between corresponding entries. Then, by applying the sqrt function, we compute the actual Euclidean distance.

3.5. Centroid Representation

Centroids are in the same space as normal features, so we can represent them similar to features:

public class Centroid { private final Map coordinates; // constructors, getter, toString, equals and hashcode }

Now that we have a few necessary abstractions in place, it's time to write our K-Means implementation. Here's a quick look at our method signature:

public class KMeans { private static final Random random = new Random(); public static Map
    
      fit(List records, int k, Distance distance, int maxIterations) { // omitted } }
    

Let's break down this method signature:

  • The dataset is a set of feature vectors. Since each feature vector is a Record, then the dataset type is List
  • The k parameter determines the number of clusters, which we should provide in advance
  • distance encapsulates the way we're going to calculate the difference between two features
  • K-Means terminates when the assignment stops changing for a few consecutive iterations. In addition to this termination condition, we can place an upper bound for the number of iterations, too. The maxIterations argument determines that upper bound
  • When K-Means terminates, each centroid should have a few assigned features, hence we're using a Map as the return type. Basically, each map entry corresponds to a cluster

3.6. Centroid Generation

The first step is to generate k randomly placed centroids.

Although each centroid can contain totally random coordinates, it's a good practice to generate random coordinates between the minimum and maximum possible values for each attribute. Generating random centroids without considering the range of possible values would cause the algorithm to converge more slowly.

First, we should compute the minimum and maximum value for each attribute, and then, generate the random values between each pair of them:

private static List randomCentroids(List records, int k) { List centroids = new ArrayList(); Map maxs = new HashMap(); Map mins = new HashMap(); for (Record record : records) { record.getFeatures().forEach((key, value) -> ); } Set attributes = records.stream() .flatMap(e -> e.getFeatures().keySet().stream()) .collect(toSet()); for (int i = 0; i < k; i++) { Map coordinates = new HashMap(); for (String attribute : attributes) { double max = maxs.get(attribute); double min = mins.get(attribute); coordinates.put(attribute, random.nextDouble() * (max - min) + min); } centroids.add(new Centroid(coordinates)); } return centroids; }

Now, we can assign each record to one of these random centroids.

3.7. Assignment

First off, given a Record, we should find the centroid nearest to it:

private static Centroid nearestCentroid(Record record, List centroids, Distance distance) { double minimumDistance = Double.MAX_VALUE; Centroid nearest = null; for (Centroid centroid : centroids) { double currentDistance = distance.calculate(record.getFeatures(), centroid.getCoordinates()); if (currentDistance < minimumDistance) { minimumDistance = currentDistance; nearest = centroid; } } return nearest; }

Each record belongs to its nearest centroid cluster:

private static void assignToCluster(Map
    
      clusters, Record record, Centroid centroid) { clusters.compute(centroid, (key, list) -> { if (list == null) { list = new ArrayList(); } list.add(record); return list; }); }
    

3.8. Centroid Relocation

If, after one iteration, a centroid does not contain any assignments, then we won't relocate it. Otherwise, we should relocate the centroid coordinate for each attribute to the average location of all assigned records:

private static Centroid average(Centroid centroid, List records) { if (records == null || records.isEmpty()) { return centroid; } Map average = centroid.getCoordinates(); records.stream().flatMap(e -> e.getFeatures().keySet().stream()) .forEach(k -> average.put(k, 0.0)); for (Record record : records) { record.getFeatures().forEach( (k, v) -> average.compute(k, (k1, currentValue) -> v + currentValue) ); } average.forEach((k, v) -> average.put(k, v / records.size())); return new Centroid(average); }

Since we can relocate a single centroid, now it's possible to implement the relocateCentroids method:

private static List relocateCentroids(Map
    
      clusters) { return clusters.entrySet().stream().map(e -> average(e.getKey(), e.getValue())).collect(toList()); }
    

This simple one-liner iterates through all centroids, relocates them, and returns the new centroids.

3.9. Putting It All Together

In each iteration, after assigning all records to their nearest centroid, first, we should compare the current assignments with the last iteration.

If the assignments were identical, then the algorithm terminates. Otherwise, before jumping to the next iteration, we should relocate the centroids:

public static Map
    
      fit(List records, int k, Distance distance, int maxIterations) { List centroids = randomCentroids(records, k); Map
     
       clusters = new HashMap(); Map
      
        lastState = new HashMap(); // iterate for a pre-defined number of times for (int i = 0; i < maxIterations; i++) { boolean isLastIteration = i == maxIterations - 1; // in each iteration we should find the nearest centroid for each record for (Record record : records) { Centroid centroid = nearestCentroid(record, centroids, distance); assignToCluster(clusters, record, centroid); } // if the assignments do not change, then the algorithm terminates boolean shouldTerminate = isLastIteration || clusters.equals(lastState); lastState = clusters; if (shouldTerminate) { break; } // at the end of each iteration we should relocate the centroids centroids = relocateCentroids(clusters); clusters = new HashMap(); } return lastState; }
      
     
    

4. Example: Discovering Similar Artists on Last.fm

Last.fm builds a detailed profile of each user's musical taste by recording details of what the user listens to. In this section, we're going to find clusters of similar artists. To build a dataset appropriate for this task, we'll use three APIs from Last.fm:

  1. API to get a collection of top artists on Last.fm.
  2. Another API to find popular tags. Each user can tag an artist with something, e.g. rock. So, Last.fm maintains a database of those tags and their frequencies.
  3. And an API to get the top tags for an artist, ordered by popularity. Since there are many such tags, we'll only keep those tags that are among the top global tags.

4.1. Last.fm's API

To use these APIs, we should get an API Key from Last.fm and send it in every HTTP request. We're going to use the following Retrofit service for calling those APIs:

public interface LastFmService { @GET("/2.0/?method=chart.gettopartists&format=json&limit=50") Call topArtists(@Query("page") int page); @GET("/2.0/?method=artist.gettoptags&format=json&limit=20&autocorrect=1") Call topTagsFor(@Query("artist") String artist); @GET("/2.0/?method=chart.gettoptags&format=json&limit=100") Call topTags(); // A few DTOs and one interceptor }

So, let's find the most popular artists on Last.fm:

// setting up the Retrofit service private static List getTop100Artists() throws IOException { List artists = new ArrayList(); // Fetching the first two pages, each containing 50 records. for (int i = 1; i <= 2; i++) { artists.addAll(lastFm.topArtists(i).execute().body().all()); } return artists; }

Similarly, we can fetch the top tags:

private static Set getTop100Tags() throws IOException { return lastFm.topTags().execute().body().all(); }

Finally, we can build a dataset of artists along with their tag frequencies:

private static List datasetWithTaggedArtists(List artists, Set topTags) throws IOException { List records = new ArrayList(); for (String artist : artists) { Map tags = lastFm.topTagsFor(artist).execute().body().all(); // Only keep popular tags. tags.entrySet().removeIf(e -> !topTags.contains(e.getKey())); records.add(new Record(artist, tags)); } return records; }

4.2. Forming Artist Clusters

Now, we can feed the prepared dataset to our K-Means implementation:

List artists = getTop100Artists(); Set topTags = getTop100Tags(); List records = datasetWithTaggedArtists(artists, topTags); Map
    
      clusters = KMeans.fit(records, 7, new EuclideanDistance(), 1000); // Printing the cluster configuration clusters.forEach((key, value) -> { System.out.println("-------------------------- CLUSTER ----------------------------"); // Sorting the coordinates to see the most significant tags first. System.out.println(sortedCentroid(key)); String members = String.join(", ", value.stream().map(Record::getDescription).collect(toSet())); System.out.print(members); System.out.println(); System.out.println(); });
    

If we run this code, then it would visualize the clusters as text output:

------------------------------ CLUSTER ----------------------------------- Centroid {classic rock=65.58333333333333, rock=64.41666666666667, british=20.333333333333332, ... } David Bowie, Led Zeppelin, Pink Floyd, System of a Down, Queen, blink-182, The Rolling Stones, Metallica, Fleetwood Mac, The Beatles, Elton John, The Clash ------------------------------ CLUSTER ----------------------------------- Centroid {Hip-Hop=97.21428571428571, rap=64.85714285714286, hip hop=29.285714285714285, ... } Kanye West, Post Malone, Childish Gambino, Lil Nas X, A$AP Rocky, Lizzo, xxxtentacion, Travi$ Scott, Tyler, the Creator, Eminem, Frank Ocean, Kendrick Lamar, Nicki Minaj, Drake ------------------------------ CLUSTER ----------------------------------- Centroid {indie rock=54.0, rock=52.0, Psychedelic Rock=51.0, psychedelic=47.0, ... } Tame Impala, The Black Keys ------------------------------ CLUSTER ----------------------------------- Centroid {pop=81.96428571428571, female vocalists=41.285714285714285, indie=22.785714285714285, ... } Ed Sheeran, Taylor Swift, Rihanna, Miley Cyrus, Billie Eilish, Lorde, Ellie Goulding, Bruno Mars, Katy Perry, Khalid, Ariana Grande, Bon Iver, Dua Lipa, Beyoncé, Sia, P!nk, Sam Smith, Shawn Mendes, Mark Ronson, Michael Jackson, Halsey, Lana Del Rey, Carly Rae Jepsen, Britney Spears, Madonna, Adele, Lady Gaga, Jonas Brothers ------------------------------ CLUSTER ----------------------------------- Centroid {indie=95.23076923076923, alternative=70.61538461538461, indie rock=64.46153846153847, ... } Twenty One Pilots, The Smiths, Florence + the Machine, Two Door Cinema Club, The 1975, Imagine Dragons, The Killers, Vampire Weekend, Foster the People, The Strokes, Cage the Elephant, Arcade Fire, Arctic Monkeys ------------------------------ CLUSTER ----------------------------------- Centroid {electronic=91.6923076923077, House=39.46153846153846, dance=38.0, ... } Charli XCX, The Weeknd, Daft Punk, Calvin Harris, MGMT, Martin Garrix, Depeche Mode, The Chainsmokers, Avicii, Kygo, Marshmello, David Guetta, Major Lazer ------------------------------ CLUSTER ----------------------------------- Centroid {rock=87.38888888888889, alternative=72.11111111111111, alternative rock=49.16666666, ... } Weezer, The White Stripes, Nirvana, Foo Fighters, Maroon 5, Oasis, Panic! at the Disco, Gorillaz, Green Day, The Cure, Fall Out Boy, OneRepublic, Paramore, Coldplay, Radiohead, Linkin Park, Red Hot Chili Peppers, Muse

Since centroid coordinations are sorted by the average tag frequency, we can easily spot the dominant genre in each cluster. For example, the last cluster is a cluster of a good old rock-bands, or the second one is filled with rap stars.

Although this clustering makes sense, for the most part, it's not perfect since the data is merely collected from user behavior.

5. Visualization

A few moments ago, our algorithm visualized the cluster of artists in a terminal-friendly way. If we convert our cluster configuration to JSON and feed it to D3.js, then with a few lines of JavaScript, we'll have a nice human-friendly Radial Tidy-Tree:

We have to convert our Map to a JSON with a similar schema like this d3.js example.

6. Number of Clusters

One of the fundamental properties of K-Means is the fact that we should define the number of clusters in advance. So far, we used a static value for k, but determining this value can be a challenging problem. There are two common ways to calculate the number of clusters:

  1. Domain Knowledge
  2. Mathematical Heuristics

If we're lucky enough that we know so much about the domain, then we might be able to simply guess the right number. Otherwise, we can apply a few heuristics like Elbow Method or Silhouette Method to get a sense on the number of clusters.

Before going any further, we should know that these heuristics, although useful, are just heuristics and may not provide clear-cut answers.

6.1. Elbow Method

To use the elbow method, we should first calculate the difference between each cluster centroid and all its members. As we group more unrelated members in a cluster, the distance between the centroid and its members goes up, hence the cluster quality decreases.

One way to perform this distance calculation is to use the Sum of Squared Errors. Sum of squared errors or SSE is equal to the sum of squared differences between a centroid and all its members:

public static double sse(Map
    
      clustered, Distance distance) { double sum = 0; for (Map.Entry
     
       entry : clustered.entrySet()) { Centroid centroid = entry.getKey(); for (Record record : entry.getValue()) { double d = distance.calculate(centroid.getCoordinates(), record.getFeatures()); sum += Math.pow(d, 2); } } return sum; }
     
    

Then, we can run the K-Means algorithm for different values of kand calculate the SSE for each of them:

List records = // the dataset; Distance distance = new EuclideanDistance(); List sumOfSquaredErrors = new ArrayList(); for (int k = 2; k <= 16; k++) { Map
    
      clusters = KMeans.fit(records, k, distance, 1000); double sse = Errors.sse(clusters, distance); sumOfSquaredErrors.add(sse); }
    

At the end of the day, it's possible to find an appropriate k by plotting the number of clusters against the SSE:

Usually, as the number of clusters increases, the distance between cluster members decreases. However, we can't choose any arbitrary large values for k, since having multiple clusters with just one member defeats the whole purpose of clustering.

L'idée derrière la méthode du coude est de trouver une valeur appropriée pour k de telle sorte que l'ESS diminue considérablement autour de cette valeur. Par exemple, k = 9 peut être un bon candidat ici.

7. Conclusion

Dans ce didacticiel, nous avons d'abord abordé quelques concepts importants du Machine Learning. Ensuite, nous nous sommes familiarisés avec la mécanique de l'algorithme de clustering K-Means. Enfin, nous avons écrit une implémentation simple pour K-Means, testé notre algorithme avec un ensemble de données réel de Last.fm et visualisé le résultat du clustering d'une manière graphique agréable.

Comme d'habitude, l'exemple de code est disponible sur notre projet GitHub, alors assurez-vous de le vérifier!