BAEL-3881

This commit is contained in:
pazis
2020-03-23 02:05:05 +00:00
parent 1dcfc639f2
commit 7a5e9c3216
4 changed files with 70 additions and 0 deletions

View File

@@ -0,0 +1,33 @@
package com.baeldung.algorithms.gradientdescent;
import java.util.function.Function;
public class GradientDescent {
private final double precision = 0.000001;
public double findLocalMinimum(Function<Double, Double> f, double initialX) {
double stepCoefficient = 0.1;
double previousStep = 1.0;
double currentX = initialX;
double previousX = initialX;
double previousY = f.apply(previousX);
int iter = 100;
currentX += stepCoefficient * previousY;
while (previousStep > precision && iter > 0) {
iter--;
double currentY = f.apply(currentX);
if (currentY > previousY) {
stepCoefficient = -stepCoefficient / 2;
}
previousX = currentX;
currentX += stepCoefficient * previousY;
previousY = currentY;
previousStep = StrictMath.abs(currentX - previousX);
}
return currentX;
}
}

View File

@@ -0,0 +1,20 @@
package com.baeldung.algorithms.gradientdescent;
import static org.junit.Assert.assertTrue;
import java.util.function.Function;
import org.junit.Test;
public class GradientDescentUnitTest {
@Test
public void givenFunction_whenStartingPointIsOne_thenLocalMinimumIsFound() {
Function<Double, Double> df = x ->
StrictMath.abs(StrictMath.pow(x, 3)) - (3 * StrictMath.pow(x, 2)) + x;
GradientDescent gd = new GradientDescent();
double res = gd.findLocalMinimum(df, 1);
assertTrue(res > 1.78);
assertTrue(res < 1.84);
}
}