BAEL-3881
This commit is contained in:
@@ -0,0 +1,33 @@
|
||||
package com.baeldung.algorithms.gradientdescent;
|
||||
|
||||
import java.util.function.Function;
|
||||
|
||||
public class GradientDescent {
|
||||
|
||||
private final double precision = 0.000001;
|
||||
|
||||
public double findLocalMinimum(Function<Double, Double> f, double initialX) {
|
||||
double stepCoefficient = 0.1;
|
||||
double previousStep = 1.0;
|
||||
double currentX = initialX;
|
||||
double previousX = initialX;
|
||||
double previousY = f.apply(previousX);
|
||||
int iter = 100;
|
||||
|
||||
currentX += stepCoefficient * previousY;
|
||||
|
||||
while (previousStep > precision && iter > 0) {
|
||||
iter--;
|
||||
double currentY = f.apply(currentX);
|
||||
if (currentY > previousY) {
|
||||
stepCoefficient = -stepCoefficient / 2;
|
||||
}
|
||||
previousX = currentX;
|
||||
currentX += stepCoefficient * previousY;
|
||||
previousY = currentY;
|
||||
previousStep = StrictMath.abs(currentX - previousX);
|
||||
}
|
||||
return currentX;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,20 @@
|
||||
package com.baeldung.algorithms.gradientdescent;
|
||||
|
||||
import static org.junit.Assert.assertTrue;
|
||||
|
||||
import java.util.function.Function;
|
||||
|
||||
import org.junit.Test;
|
||||
|
||||
public class GradientDescentUnitTest {
|
||||
|
||||
@Test
|
||||
public void givenFunction_whenStartingPointIsOne_thenLocalMinimumIsFound() {
|
||||
Function<Double, Double> df = x ->
|
||||
StrictMath.abs(StrictMath.pow(x, 3)) - (3 * StrictMath.pow(x, 2)) + x;
|
||||
GradientDescent gd = new GradientDescent();
|
||||
double res = gd.findLocalMinimum(df, 1);
|
||||
assertTrue(res > 1.78);
|
||||
assertTrue(res < 1.84);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user