diff --git a/core-java/src/main/java/com/baeldung/algorithms/slope_one/SlopeOne.java b/core-java/src/main/java/com/baeldung/algorithms/slope_one/SlopeOne.java index 800a86884f..f11538356a 100644 --- a/core-java/src/main/java/com/baeldung/algorithms/slope_one/SlopeOne.java +++ b/core-java/src/main/java/com/baeldung/algorithms/slope_one/SlopeOne.java @@ -11,8 +11,8 @@ import java.util.Map.Entry; */ public class SlopeOne { - private static Map> differencesMatrix = new HashMap<>(); - private static Map> frequenciesMatrix = new HashMap<>(); + private static Map> diff = new HashMap<>(); + private static Map> freq = new HashMap<>(); private static Map> inputData; private static Map> outputData = new HashMap<>(); @@ -28,33 +28,36 @@ public class SlopeOne { * Based on the available data, calculate the relationships between the * items and number of occurences * - * @param data existing user data and their items' ratings + * @param data + * existing user data and their items' ratings */ private static void buildDifferencesMatrix(Map> data) { for (HashMap user : data.values()) { - for (Entry entry : user.entrySet()) { - if (!differencesMatrix.containsKey(entry.getKey())) { - differencesMatrix.put(entry.getKey(), new HashMap()); - frequenciesMatrix.put(entry.getKey(), new HashMap()); + for (Entry e : user.entrySet()) { + if (!diff.containsKey(e.getKey())) { + diff.put(e.getKey(), new HashMap()); + freq.put(e.getKey(), new HashMap()); } - for (Entry entry2 : user.entrySet()) { + for (Entry e2 : user.entrySet()) { int oldCount = 0; - if (frequenciesMatrix.get(entry.getKey()).containsKey(entry2.getKey())) - oldCount = frequenciesMatrix.get(entry.getKey()).get(entry2.getKey()).intValue(); + if (freq.get(e.getKey()).containsKey(e2.getKey())) { + oldCount = freq.get(e.getKey()).get(e2.getKey()).intValue(); + } double oldDiff = 0.0; - if (differencesMatrix.get(entry.getKey()).containsKey(entry2.getKey())) - oldDiff = differencesMatrix.get(entry.getKey()).get(entry2.getKey()).doubleValue(); - double observedDiff = entry.getValue() - entry2.getValue(); - frequenciesMatrix.get(entry.getKey()).put(entry2.getKey(), oldCount + 1); - differencesMatrix.get(entry.getKey()).put(entry2.getKey(), oldDiff + observedDiff); + if (diff.get(e.getKey()).containsKey(e2.getKey())) { + oldDiff = diff.get(e.getKey()).get(e2.getKey()).doubleValue(); + } + double observedDiff = e.getValue() - e2.getValue(); + freq.get(e.getKey()).put(e2.getKey(), oldCount + 1); + diff.get(e.getKey()).put(e2.getKey(), oldDiff + observedDiff); } } } - for (Item j : differencesMatrix.keySet()) { - for (Item i : differencesMatrix.get(j).keySet()) { - double oldvalue = differencesMatrix.get(j).get(i).doubleValue(); - int count = frequenciesMatrix.get(j).get(i).intValue(); - differencesMatrix.get(j).put(i, oldvalue / count); + for (Item j : diff.keySet()) { + for (Item i : diff.get(j).keySet()) { + double oldValue = diff.get(j).get(i).doubleValue(); + int count = freq.get(j).get(i).intValue(); + diff.get(j).put(i, oldValue / count); } } printData(data); @@ -64,41 +67,42 @@ public class SlopeOne { * Based on existing data predict all missing ratings. If prediction is not * possible, the value will be equal to -1 * - * @param data existing user data and their items' ratings + * @param data + * existing user data and their items' ratings */ private static void predict(Map> data) { - HashMap predictions = new HashMap(); - HashMap frequencies = new HashMap(); - for (Item j : differencesMatrix.keySet()) { - frequencies.put(j, 0); - predictions.put(j, 0.0); + HashMap uPred = new HashMap(); + HashMap uFreq = new HashMap(); + for (Item j : diff.keySet()) { + uFreq.put(j, 0); + uPred.put(j, 0.0); } - for (Entry> entry : data.entrySet()) { - for (Item j : entry.getValue().keySet()) { - for (Item k : differencesMatrix.keySet()) { + for (Entry> e : data.entrySet()) { + for (Item j : e.getValue().keySet()) { + for (Item k : diff.keySet()) { try { - double newValue = (differencesMatrix.get(k).get(j).doubleValue() - + entry.getValue().get(j).doubleValue()) * frequenciesMatrix.get(k).get(j).intValue(); - predictions.put(k, predictions.get(k) + newValue); - frequencies.put(k, frequencies.get(k) + frequenciesMatrix.get(k).get(j).intValue()); - } catch (NullPointerException e) { + double predictedValue = diff.get(k).get(j).doubleValue() + e.getValue().get(j).doubleValue(); + double finalValue = predictedValue * freq.get(k).get(j).intValue(); + uPred.put(k, uPred.get(k) + finalValue); + uFreq.put(k, uFreq.get(k) + freq.get(k).get(j).intValue()); + } catch (NullPointerException e1) { } } } - HashMap cleanPredictions = new HashMap(); - for (Item j : predictions.keySet()) { - if (frequencies.get(j) > 0) { - cleanPredictions.put(j, predictions.get(j).doubleValue() / frequencies.get(j).intValue()); + HashMap clean = new HashMap(); + for (Item j : uPred.keySet()) { + if (uFreq.get(j) > 0) { + clean.put(j, uPred.get(j).doubleValue() / uFreq.get(j).intValue()); } } for (Item j : InputData.items) { - if (entry.getValue().containsKey(j)) { - cleanPredictions.put(j, entry.getValue().get(j)); + if (e.getValue().containsKey(j)) { + clean.put(j, e.getValue().get(j)); } else { - cleanPredictions.put(j, -1.0); + clean.put(j, -1.0); } } - outputData.put(entry.getKey(), cleanPredictions); + outputData.put(e.getKey(), clean); } printData(outputData); }