Descente de dégradé à Java

1. Introduction

Dans ce didacticiel, nous en apprendrons davantage sur l'algorithme Gradient Descent. Nous allons implémenter l'algorithme en Java et l'illustrer étape par étape.

2. Qu'est-ce que la descente de gradient?

Gradient Descent est un algorithme d'optimisation utilisé pour trouver un minimum local d'une fonction donnée. Il est largement utilisé dans les algorithmes d'apprentissage automatique de haut niveau pour minimiser les fonctions de perte.

Gradient est un autre mot pour la pente, et la descente signifie descendre. Comme son nom l'indique, Gradient Descent descend la pente d'une fonction jusqu'à ce qu'elle atteigne la fin.

3. Propriétés de la descente de gradient

Gradient Descent trouve un minimum local, qui peut être différent du minimum global. Le point local de départ est donné en paramètre de l'algorithme.

C'est un algorithme itératif , et à chaque étape, il essaie de descendre la pente et de se rapprocher du minimum local.

En pratique, l'algorithme fait marche arrière . Nous illustrerons et implémenterons le retour en arrière de Gradient Descent dans ce didacticiel.

4. Illustration étape par étape

Gradient Descent a besoin d'une fonction et d'un point de départ comme entrée. Définissons et traçons une fonction:

Nous pouvons commencer à tout moment souhaité. Commençons par x = 1:

Dans la première étape, Gradient Descent descend la pente avec une taille de pas prédéfinie:

Ensuite, il va plus loin avec la même taille de pas. Cependant, cette fois, il se termine à un y supérieur à la dernière étape:

Cela indique que l'algorithme a dépassé le minimum local, il recule donc avec une taille de pas réduite:

Par la suite, chaque fois que le y actuel est supérieur au y précédent , la taille du pas est réduite et annulée. L'itération se poursuit jusqu'à ce que la précision souhaitée soit atteinte.

Comme on peut le voir, Gradient Descent a trouvé ici un minimum local, mais ce n'est pas le minimum global. Si nous commençons à x = -1 au lieu de x = 1, le minimum global sera trouvé.

5. Implémentation en Java

Il existe plusieurs façons d'implémenter la descente de gradient. Ici, nous ne calculons pas la dérivée de la fonction pour trouver la direction de la pente, donc notre implémentation fonctionne également pour les fonctions non différentiables.

Définissons precision et stepCoefficient et donnons-leur des valeurs initiales:

double precision = 0.000001; double stepCoefficient = 0.1;

Dans la première étape, nous n'avons pas de y précédent pour comparaison. Nous pouvons augmenter ou diminuer la valeur de x pour voir si y diminue ou augmente. Un stepCoefficient positif signifie que nous augmentons la valeur de x .

Maintenant, effectuons la première étape:

double previousX = initialX; double previousY = f.apply(previousX); currentX += stepCoefficient * previousY;

Dans le code ci-dessus, f est une fonction et initialX est un double , les deux étant fournis en entrée.

Un autre point clé à considérer est que la descente de gradient n'est pas garantie de converger. Pour éviter de rester coincé dans la boucle, limitons le nombre d'itérations:

int iter = 100;

Plus tard, nous décrémentons iter par un à chaque itération. Par conséquent, nous sortirons de la boucle à un maximum de 100 itérations.

Maintenant que nous avons un previousX , nous pouvons configurer notre boucle:

while (previousStep > precision && iter > 0) { iter--; double currentY = f.apply(currentX); if (currentY > previousY) { stepCoefficient = -stepCoefficient/2; } previousX = currentX; currentX += stepCoefficient * previousY; previousY = currentY; previousStep = StrictMath.abs(currentX - previousX); }

À chaque itération, nous calculons le nouveau y et le comparons avec le y précédent . Si currentY est supérieur à previousY , nous changeons de direction et diminuons la taille du pas.

La boucle continue jusqu'à ce que la taille de notre pas soit inférieure à la précision souhaitée . Enfin, nous pouvons renvoyer currentX comme le minimum local:

return currentX;

6. Conclusion

Dans cet article, nous avons parcouru l'algorithme de descente de dégradé avec une illustration étape par étape.

Nous avons également implémenté Gradient Descent en Java. Le code est disponible à l'adresse over sur GitHub.