diff --git a/io.openems.edge.application/EdgeApp.bndrun b/io.openems.edge.application/EdgeApp.bndrun index 4dfeab59531..8dddf82bfa1 100644 --- a/io.openems.edge.application/EdgeApp.bndrun +++ b/io.openems.edge.application/EdgeApp.bndrun @@ -167,6 +167,7 @@ bnd.identity;id='io.openems.edge.meter.weidmueller',\ bnd.identity;id='io.openems.edge.meter.ziehl',\ bnd.identity;id='io.openems.edge.onewire.thermometer',\ + bnd.identity;id='io.openems.edge.predictor.lstmmodel',\ bnd.identity;id='io.openems.edge.predictor.persistencemodel',\ bnd.identity;id='io.openems.edge.predictor.similardaymodel',\ bnd.identity;id='io.openems.edge.pvinverter.cluster',\ @@ -345,6 +346,7 @@ io.openems.edge.meter.ziehl;version=snapshot,\ io.openems.edge.onewire.thermometer;version=snapshot,\ io.openems.edge.predictor.api;version=snapshot,\ + io.openems.edge.predictor.lstmmodel;version=snapshot,\ io.openems.edge.predictor.persistencemodel;version=snapshot,\ io.openems.edge.predictor.similardaymodel;version=snapshot,\ io.openems.edge.pvinverter.api;version=snapshot,\ diff --git a/io.openems.edge.predictor.lstmmodel/.classpath b/io.openems.edge.predictor.lstmmodel/.classpath new file mode 100644 index 00000000000..bbfbdbe40e7 --- /dev/null +++ b/io.openems.edge.predictor.lstmmodel/.classpath @@ -0,0 +1,12 @@ + + + + + + + + + + + + diff --git a/io.openems.edge.predictor.lstmmodel/.gitignore b/io.openems.edge.predictor.lstmmodel/.gitignore new file mode 100644 index 00000000000..c2b941a96de --- /dev/null +++ b/io.openems.edge.predictor.lstmmodel/.gitignore @@ -0,0 +1,2 @@ +/bin_test/ +/generated/ diff --git a/io.openems.edge.predictor.lstmmodel/.project b/io.openems.edge.predictor.lstmmodel/.project new file mode 100644 index 00000000000..8fe907a680b --- /dev/null +++ b/io.openems.edge.predictor.lstmmodel/.project @@ -0,0 +1,23 @@ + + + io.openems.edge.predictor.lstmmodel + + + + + + org.eclipse.jdt.core.javabuilder + + + + + bndtools.core.bndbuilder + + + + + + org.eclipse.jdt.core.javanature + bndtools.core.bndnature + + diff --git a/io.openems.edge.predictor.lstmmodel/.settings/org.eclipse.core.resources.prefs b/io.openems.edge.predictor.lstmmodel/.settings/org.eclipse.core.resources.prefs new file mode 100644 index 00000000000..99f26c0203a --- /dev/null +++ b/io.openems.edge.predictor.lstmmodel/.settings/org.eclipse.core.resources.prefs @@ -0,0 +1,2 @@ +eclipse.preferences.version=1 +encoding/=UTF-8 diff --git a/io.openems.edge.predictor.lstmmodel/bnd.bnd b/io.openems.edge.predictor.lstmmodel/bnd.bnd new file mode 100644 index 00000000000..055ee9ad704 --- /dev/null +++ b/io.openems.edge.predictor.lstmmodel/bnd.bnd @@ -0,0 +1,16 @@ +Bundle-Name: OpenEMS Edge Predictor Lstm-Model +Bundle-Vendor: OpenEMS Association e.V. +Bundle-License: https://opensource.org/licenses/EPL-2.0 +Bundle-Version: 1.0.0.${tstamp} + +-buildpath: \ + ${buildpath},\ + io.openems.common,\ + io.openems.edge.common,\ + io.openems.edge.controller.api,\ + io.openems.edge.predictor.api,\ + io.openems.edge.timedata.api,\ + org.apache.commons.math3,\ + +-testpath: \ + ${testpath} \ No newline at end of file diff --git a/io.openems.edge.predictor.lstmmodel/readme.adoc b/io.openems.edge.predictor.lstmmodel/readme.adoc new file mode 100644 index 00000000000..0791f8580b3 --- /dev/null +++ b/io.openems.edge.predictor.lstmmodel/readme.adoc @@ -0,0 +1,28 @@ += Long short term model predictor + +The Long Short-Term Memory (LSTM) model is a type of recurrent neural network (RNN) that is particularly well-suited for time series prediction tasks, including consumption and production power predictions, due to its ability to capture dependencies and patterns over time. https://en.wikipedia.org/wiki/Long_short-term_memory[More details of LSTM] + +This application is used for predicting power (consumption and production) values. +Here, For power prediction, LSTM models can analyze historical power data to learn patterns and trends that occur over time, such as: + +* Daily and Seasonal Variations: example, Consumption power often follows cyclic patterns (e.g., higher usage during the day, lower at night). Production power often higher during the day and none during the nights. +* External Factors: LSTM can incorporate external factors like weather, day of the week, or holidays to improve prediction accuracy. + +== Training LSTM for Power Predictions: + +* Input Data (Channels address "_sum/ConsumptionActivePower"): Time series data of past consumption levels. +* Pre-processing: Data needs to be scaled and sometimes transformed to remove seasonality or noise. +* Training: The LSTM is trained on historical data using techniques like backpropagation through time (BPTT), where it learns to minimize the error between predicted and actual consumption. +* Prediction: Once trained, the model can predict future power consumption for various time steps ahead (e.g., hours, days, or even weeks). + +In practice, LSTMs are favored for their ability to learn complex time-related patterns, making them effective in forecasting energy demand patterns that can inform Energy management system (EMS), energy distribution, and cost optimization strategies. + +== Note for activating the predictor + +To run this predictor, please create a folder named "models" in the OpenEMS data directory (openems/data/). + +Initially, a generic model will be used for predictions, which may not yield optimal results. However, a training process is scheduled to occur every 45 days, during which the models in this directory will be updated. The 45-day interval consists of 30 days for training and 15 days for validation. + +As a result of this process, a new model will be trained and will automatically replace the previous one. + +https://github.com/OpenEMS/openems/tree/develop/io.openems.edge.predictor.lstmmodel[Source Code icon:github[]] \ No newline at end of file diff --git a/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/Config.java b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/Config.java new file mode 100644 index 00000000000..400d302a3fc --- /dev/null +++ b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/Config.java @@ -0,0 +1,31 @@ +package io.openems.edge.predictor.lstmmodel; + +import org.osgi.service.metatype.annotations.AttributeDefinition; +import org.osgi.service.metatype.annotations.ObjectClassDefinition; + +import io.openems.edge.predictor.api.prediction.LogVerbosity; + +@ObjectClassDefinition(// + name = "Predictor Lstm-Model", // + description = "Implements Long Short-Term Memory (LSTM) model, which is a type of recurrent neural network (RNN) designed to capture long-range dependencies in sequential data, such as time series. " + + "This makes LSTMs particularly effective for time series prediction, " + + "as they can learn patterns and trends over time, handling long-term dependencies while filtering out irrelevant information.") +@interface Config { + + @AttributeDefinition(name = "Component-ID", description = "Unique ID of this Component") + String id() default "predictor0"; + + @AttributeDefinition(name = "Alias", description = "Human-readable name of this Component; defaults to Component-ID") + String alias() default ""; + + @AttributeDefinition(name = "Is enabled?", description = "Is this Component enabled?") + boolean enabled() default true; + + @AttributeDefinition(name = "Channel-Address", description = "Channel-Address this Predictor is used for, e.g. '_sum/UnmanagedConsumptionActivePower'") + String channelAddress(); + + @AttributeDefinition(name = "Log-Verbosity", description = "The log verbosity.") + LogVerbosity logVerbosity() default LogVerbosity.NONE; + + String webconsole_configurationFactory_nameHint() default "Predictor Lstm-Model [{id}]"; +} \ No newline at end of file diff --git a/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/LstmModel.java b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/LstmModel.java new file mode 100644 index 00000000000..534cd1661bc --- /dev/null +++ b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/LstmModel.java @@ -0,0 +1,129 @@ +package io.openems.edge.predictor.lstmmodel; + +import io.openems.common.types.OpenemsType; +import io.openems.edge.common.channel.BooleanReadChannel; +import io.openems.edge.common.channel.Doc; +import io.openems.edge.common.channel.DoubleReadChannel; +import io.openems.edge.common.channel.StringReadChannel; +import io.openems.edge.common.channel.value.Value; +import io.openems.edge.common.component.OpenemsComponent; + +public interface LstmModel extends OpenemsComponent { + + public enum ChannelId implements io.openems.edge.common.channel.ChannelId { + LAST_TRAINED_TIME(Doc.of(OpenemsType.STRING) // + .text("Last trained time in Unixstimestamp")), // + MODEL_ERROR(Doc.of(OpenemsType.DOUBLE) // + .text("Error in the Model")), // + CANNOT_TRAIN_CONDITON(Doc.of(OpenemsType.BOOLEAN) // + .text("When the data set is empty, entirely null, or contains 50% null values.")); + + private final Doc doc; + + private ChannelId(Doc doc) { + this.doc = doc; + } + + @Override + public Doc doc() { + return this.doc; + } + } + + /** + * Gets the Channel for {@link ChannelId#CANNOT_TRAIN_CONDITON}. + * + * @return the Channel + */ + public default BooleanReadChannel getCannotTrainConditionChannel() { + return this.channel(ChannelId.CANNOT_TRAIN_CONDITON); + } + + /** + * Gets the Cannot train condition in boolean. See + * {@link ChannelId#CANNOT_TRAIN_CONDITON}. + * + * @return the Channel {@link Value} + */ + public default Value getCannotTrainCondition() { + return this.getCannotTrainConditionChannel().value(); + } + + /** + * Internal method to set the 'nextValue' on + * {@link ChannelId#CANNOT_TRAIN_CONDITON} Channel. + * + * @param value the next value + */ + public default void _setCannotTrainCondition(boolean value) { + this.getCannotTrainConditionChannel().setNextValue(value); + } + + /** + * Internal method to set the 'nextValue' on + * {@link ChannelId#CANNOT_TRAIN_CONDITON} Channel. + * + * @param value the next value + */ + public default void _setCannotTrainCondition(Boolean value) { + this.getCannotTrainConditionChannel().setNextValue(value); + } + + /** + * Gets the Channel for {@link ChannelId#LAST_TRAINED_TIME}. + * + * @return the Channel + */ + public default StringReadChannel getLastTrainedTimeChannel() { + return this.channel(ChannelId.LAST_TRAINED_TIME); + } + + /** + * Gets the Last time trained time in Unix time stamp. See + * {@link ChannelId#LAST_TRAINED_TIME}. + * + * @return the Channel {@link Value} + */ + public default Value getLastTrainedTime() { + return this.getLastTrainedTimeChannel().value(); + } + + /** + * Internal method to set the 'nextValue' on {@link ChannelId#LAST_TRAINED_TIME} + * Channel. + * + * @param value the next value + */ + public default void _setLastTrainedTime(String value) { + this.getLastTrainedTimeChannel().setNextValue(value); + } + + /** + * Gets the Channel for {@link ChannelId#MODEL_ERROR}. + * + * @return the Channel + */ + public default DoubleReadChannel getModelErrorChannel() { + return this.channel(ChannelId.MODEL_ERROR); + } + + /** + * Gets the Model error. See {@link ChannelId#MODEL_ERROR}. + * + * @return the Channel {@link Value} + */ + public default Value getModelError() { + return this.getModelErrorChannel().value(); + } + + /** + * Internal method to set the 'nextValue' on {@link ChannelId#LAST_TRAINED_TIME} + * Channel. + * + * @param value the next value + */ + public default void _setModelError(Double value) { + this.getModelErrorChannel().setNextValue(value); + } + +} diff --git a/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/LstmModelImpl.java b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/LstmModelImpl.java new file mode 100644 index 00000000000..d32764bf3a9 --- /dev/null +++ b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/LstmModelImpl.java @@ -0,0 +1,301 @@ +package io.openems.edge.predictor.lstmmodel; + +import static io.openems.common.utils.ThreadPoolUtils.shutdownAndAwaitTermination; +import static io.openems.edge.predictor.lstmmodel.utilities.DataUtility.combine; +import static io.openems.edge.predictor.lstmmodel.utilities.DataUtility.concatenateList; +import static io.openems.edge.predictor.lstmmodel.utilities.DataUtility.getData; +import static io.openems.edge.predictor.lstmmodel.utilities.DataUtility.getDate; +import static io.openems.edge.predictor.lstmmodel.utilities.DataUtility.getMinute; + +import java.time.ZonedDateTime; +import java.time.temporal.ChronoUnit; +import java.util.ArrayList; +import java.util.SortedMap; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; +import java.util.stream.IntStream; + +import org.osgi.service.component.ComponentContext; +import org.osgi.service.component.annotations.Activate; +import org.osgi.service.component.annotations.Component; +import org.osgi.service.component.annotations.ConfigurationPolicy; +import org.osgi.service.component.annotations.Deactivate; +import org.osgi.service.component.annotations.Reference; +//import org.slf4j.Logger; +//import org.slf4j.LoggerFactory; +import org.osgi.service.metatype.annotations.Designate; + +import com.google.common.collect.Sets; +import com.google.gson.JsonElement; + +import io.openems.common.exceptions.OpenemsError.OpenemsNamedException; +import io.openems.common.session.Role; +import io.openems.common.timedata.Resolution; +import io.openems.common.types.ChannelAddress; +import io.openems.edge.common.component.ClockProvider; +import io.openems.edge.common.component.ComponentManager; +import io.openems.edge.common.component.OpenemsComponent; +import io.openems.edge.common.jsonapi.ComponentJsonApi; +import io.openems.edge.common.jsonapi.EdgeGuards; +import io.openems.edge.common.jsonapi.JsonApiBuilder; +import io.openems.edge.common.sum.Sum; +import io.openems.edge.controller.api.Controller; +import io.openems.edge.predictor.api.manager.PredictorManager; +import io.openems.edge.predictor.api.prediction.AbstractPredictor; +import io.openems.edge.predictor.api.prediction.Prediction; +import io.openems.edge.predictor.api.prediction.Predictor; +import io.openems.edge.predictor.lstmmodel.common.HyperParameters; +import io.openems.edge.predictor.lstmmodel.common.ReadAndSaveModels; +import io.openems.edge.predictor.lstmmodel.jsonrpc.GetPredictionRequest; +import io.openems.edge.predictor.lstmmodel.jsonrpc.PredictionRequestHandler; +import io.openems.edge.predictor.lstmmodel.preprocessing.DataModification; +import io.openems.edge.predictor.lstmmodel.train.LstmTrain; +import io.openems.edge.timedata.api.Timedata; + +@Designate(ocd = Config.class, factory = true) +@Component(// + name = "Predictor.LstmModel", // + immediate = true, // + configurationPolicy = ConfigurationPolicy.REQUIRE // +) +public class LstmModelImpl extends AbstractPredictor + implements Predictor, OpenemsComponent, ComponentJsonApi, LstmModel { + + // private final Logger log = LoggerFactory.getLogger(LstmModelImpl.class); + + /** 45 days. */ + private static final long DAYS_45 = 45; + + /** 45 days in minutes. */ + private static final long PERIOD = DAYS_45 * 24 * 60; + + @Reference + private Sum sum; + + @Reference + private Timedata timedata; + + @Reference + private ComponentManager componentManager; + + @Reference + private PredictorManager predictorManager; + + @Override + protected ClockProvider getClockProvider() { + return this.componentManager; + } + + public LstmModelImpl() throws OpenemsNamedException { + super(// + OpenemsComponent.ChannelId.values(), // + Controller.ChannelId.values(), // + LstmModel.ChannelId.values()// + ); + } + + private ScheduledExecutorService scheduler = Executors.newScheduledThreadPool(1); + private ChannelAddress channelForPrediction; + + @Activate + private void activate(ComponentContext context, Config config) throws OpenemsNamedException { + super.activate(context, config.id(), config.alias(), config.enabled(), // + new String[] { config.channelAddress() }, config.logVerbosity()); + + var channelAddress = ChannelAddress.fromString(config.channelAddress()); + this.channelForPrediction = channelAddress; + + /* + * Avoid training for the new FEMs due to lack of data. Set a fixed 45-day + * period: 30 days for training and 15 days for validation. + */ + this.scheduler.scheduleAtFixedRate(// + new LstmTrain(this.timedata, channelAddress, this, DAYS_45), // + 0, // + PERIOD, // + TimeUnit.MINUTES// + ); + } + + @Override + @Deactivate + protected void deactivate() { + shutdownAndAwaitTermination(this.scheduler, 0); + super.deactivate(); + } + + @Override + protected Prediction createNewPrediction(ChannelAddress channelAddress) { + + var hyperParameters = ReadAndSaveModels.read(channelAddress.getChannelId()); + var nowDate = ZonedDateTime.now(); + + var seasonalityFuture = CompletableFuture + .supplyAsync(() -> this.predictSeasonality(channelAddress, nowDate, hyperParameters)); + + var trendFuture = CompletableFuture + .supplyAsync(() -> this.predictTrend(channelAddress, nowDate, hyperParameters)); + + var dayPlus1SeasonalityFuture = CompletableFuture + .supplyAsync(() -> this.predictSeasonality(channelAddress, nowDate.plusDays(1), hyperParameters)); + + var combinePrerequisites = CompletableFuture.allOf(seasonalityFuture, trendFuture); + + try { + combinePrerequisites.get(); + + // Current day prediction + var currentDayPredicted = combine(trendFuture.get(), seasonalityFuture.get()); + + // Next Day prediction + var plus1DaySeasonalityPrediction = dayPlus1SeasonalityFuture.get(); + + // Concat current and Nextday + var actualPredicted = concatenateList(currentDayPredicted, plus1DaySeasonalityPrediction); + + var baseTimeOfPrediction = nowDate.withMinute(getMinute(nowDate, hyperParameters)).withSecond(0) + .withNano(0); + + return Prediction.from(// + Prediction.getValueRange(this.sum, channelAddress), // + baseTimeOfPrediction, // + averageInChunks(actualPredicted)); + } catch (Exception e) { + throw new RuntimeException("Error in getting prediction execution", e); + } + } + + /** + * Averages the elements of an integer array in chunks of a specified size. + * + *

+ * This method takes an input array of integers and divides it into chunks of a + * fixed size. For each chunk, it calculates the average of the integers and + * stores the result in a new array. The size of the result array is determined + * by the total number of elements in the input array divided by the chunk size. + *

+ * + * @param inputList an arrayList of Doubles to be processed. The array length + * must be a multiple of the chunk size for correct processing. + * @return an array of integers containing the averages of each chunk. + * + */ + private static Integer[] averageInChunks(ArrayList inputList) { + final int chunkSize = 3; + int resultSize = inputList.size() / chunkSize; + Integer[] result = new Integer[resultSize]; + + for (int i = 0; i < inputList.size(); i += chunkSize) { + double sum = IntStream.range(i, Math.min(i + chunkSize, inputList.size())) + .mapToDouble(j -> inputList.get(j))// + .sum(); + result[i / chunkSize] = (int) (sum / chunkSize); + } + return result; + } + + /** + * Queries historic data for a specified time range and channel address with + * given {@link ChannelAddress}. + * + * @param from the start of the time range + * @param until the end of the time range + * @param channelAddress the {@link ChannelAddress} for the query + * @param hyperParameters the {@link HyperParameters} that include the interval + * for data resolution + * @return a SortedMap where the key is a ZonedDateTime representing the + * timestamp of the data point, and the value is another SortedMap where + * the key is the ChannelAddress and the value is the data point as a + * JsonElement. and null if error + */ + private SortedMap> queryHistoricData(ZonedDateTime from, + ZonedDateTime until, ChannelAddress channelAddress, HyperParameters hyperParameters) { + try { + return this.timedata.queryHistoricData(null, from, until, Sets.newHashSet(channelAddress), + new Resolution(hyperParameters.getInterval(), ChronoUnit.MINUTES)); + } catch (OpenemsNamedException e) { + e.printStackTrace(); + } + return null; + } + + /** + * Predicts trend values for a specified channel at the current date using LSTM + * models. + * + * @param channelAddress The {@link ChannelAddress} for which trend values are + * predicted. + * @param nowDate The current date and time for which trend values are + * predicted. + * @param hyperParameters The {@link HyperParameters} for the prediction model. + * @return A list of predicted trend values for the specified channel at the + * current date. + * @throws SomeException If there's any specific exception that might be thrown + * during the process. + */ + public ArrayList predictTrend(ChannelAddress channelAddress, ZonedDateTime nowDate, + HyperParameters hyperParameters) { + + var till = nowDate.withMinute(getMinute(nowDate, hyperParameters)).withSecond(0).withNano(0); + var from = till.minusMinutes(hyperParameters.getInterval() * hyperParameters.getWindowSizeTrend()); + + var trendQueryResult = this.queryHistoricData(// + from, // + till, // + channelAddress, // + hyperParameters); + + return LstmPredictor.predictTrend(// + getData(trendQueryResult), // + getDate(trendQueryResult), // + till, // + hyperParameters); + } + + /** + * Predicts Seasonality values for a specified channel at the current date using + * LSTM models. + * + * @param channelAddress The address of the channel for which seasonality + * values are predicted. + * @param nowDate The current date and time for which seasonality values + * are predicted. + * @param hyperParameters The {@link ChannelAddress} for the prediction model. + * @return A list of predicted seasonality values for the specified channel at + * the current date. + * @throws SomeException If there's any specific exception that might be thrown + * during the process. + */ + public ArrayList predictSeasonality(ChannelAddress channelAddress, ZonedDateTime nowDate, + HyperParameters hyperParameters) { + + var till = nowDate.withMinute(getMinute(nowDate, hyperParameters)).withSecond(0).withNano(0); + var temp = till.minusDays(hyperParameters.getWindowSizeSeasonality() - 1); + + var from = temp// + .withMinute(getMinute(nowDate, hyperParameters))// + .withSecond(0)// + .withNano(0); + + var targetFrom = till.plusMinutes(hyperParameters.getInterval()); + var queryResult = this.queryHistoricData(from, till, channelAddress, hyperParameters); + + return LstmPredictor.getArranged( + LstmPredictor.getIndex(targetFrom.getHour(), targetFrom.getMinute(), hyperParameters), // + LstmPredictor.predictSeasonality(DataModification.removeNegatives(getData(queryResult)), + getDate(queryResult), // + hyperParameters)); + } + + @Override + public void buildJsonApiRoutes(JsonApiBuilder builder) { + builder.handleRequest(GetPredictionRequest.METHOD, endpoint -> { + endpoint.setGuards(EdgeGuards.roleIsAtleast(Role.OWNER)); + }, call -> { + return PredictionRequestHandler.handlerGetPredictionRequest(call.getRequest().id, this.predictorManager, + this.channelForPrediction); + }); + } +} \ No newline at end of file diff --git a/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/LstmPredictor.java b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/LstmPredictor.java new file mode 100644 index 00000000000..0081e2efddc --- /dev/null +++ b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/LstmPredictor.java @@ -0,0 +1,416 @@ +package io.openems.edge.predictor.lstmmodel; + +import static io.openems.edge.predictor.lstmmodel.utilities.UtilityConversion.to1DArray; +import static io.openems.edge.predictor.lstmmodel.utilities.UtilityConversion.to1DArrayList; +import static io.openems.edge.predictor.lstmmodel.utilities.UtilityConversion.to2DArrayList; +import static io.openems.edge.predictor.lstmmodel.utilities.UtilityConversion.to2DList; + +import java.time.OffsetDateTime; +import java.time.ZonedDateTime; +import java.util.ArrayList; +import java.util.List; + +import io.openems.edge.predictor.lstmmodel.common.DataStatistics; +import io.openems.edge.predictor.lstmmodel.common.HyperParameters; +import io.openems.edge.predictor.lstmmodel.preprocessingpipeline.PreprocessingPipeImpl; +import io.openems.edge.predictor.lstmmodel.utilities.MathUtils; + +public class LstmPredictor { + + /** + * Predicts seasonality based on the provided data and models. + * + * @param data The input data to predict seasonality for. + * @param date The corresponding date and time information for the + * data points. + * @param hyperParameters The hyperparameters for the prediction model. + * @return A list of predicted values for the seasonality. + * @throws SomeException If there's any specific exception that might be thrown + * during the process. + */ + public static ArrayList predictSeasonality(ArrayList data, ArrayList date, + HyperParameters hyperParameters) { + + var preprocessing = new PreprocessingPipeImpl(hyperParameters); + preprocessing.setData(to1DArray(data)).setDates(date); + var resized = to2DList((double[][][]) preprocessing.interpolate()// + .scale()// + .filterOutliers() // + .groupByHoursAndMinutes()// + .execute()); + preprocessing.setData(resized); + var normalized = (double[][]) preprocessing// + .normalize()// + .execute(); + var allModel = hyperParameters.getBestModelSeasonality(); + var predicted = predictPre(to2DArrayList(normalized), allModel, hyperParameters); + preprocessing.setData(to1DArray(predicted))// + .setMean(DataStatistics.getMean(resized)) + .setStandardDeviation(DataStatistics.getStandardDeviation(resized)); + var seasonalityPrediction = (double[]) preprocessing.reverseNormalize()// + .reverseScale()// + .execute(); + return to1DArrayList(seasonalityPrediction); + } + + /** + * Predicts trend values for a given time period using LSTM models. + * + * @param data The historical data for trend prediction. + * @param date The corresponding date and time information for the + * historical data points. + * @param until The target time until which trend values will be + * predicted. + * @param hyperParameters The hyperparameters for the prediction model. + * @return A list of predicted trend values. + * @throws SomeException If there's any specific exception that might be thrown + * during the process. + */ + public static ArrayList predictTrend(ArrayList data, ArrayList date, + ZonedDateTime until, HyperParameters hyperParameters) { + + var preprocessing = new PreprocessingPipeImpl(hyperParameters); + preprocessing.setData(to1DArray(data)).setDates(date); + + var scaled = (double[]) preprocessing// + .interpolate()// + .scale()// + .execute(); + // normalize + var trendPrediction = new double[hyperParameters.getTrendPoint()]; + var mean = DataStatistics.getMean(scaled); + var standerDev = DataStatistics.getStandardDeviation(scaled); + preprocessing.setData(scaled); + var normData = to1DArrayList((double[]) preprocessing// + .normalize()// + .execute()); + + var predictionFor = until.plusMinutes(hyperParameters.getInterval()); + var val = hyperParameters.getBestModelTrend(); + for (int i = 0; i < hyperParameters.getTrendPoint(); i++) { + var temp = predictionFor.plusMinutes(i * hyperParameters.getInterval()); + + var modlelindex = (int) decodeDateToColumnIndex(temp, hyperParameters); + double predTemp = LstmPredictor.predict(// + normData, // + val.get(modlelindex).get(0), val.get(modlelindex).get(1), // + val.get(modlelindex).get(2), val.get(modlelindex).get(3), // + val.get(modlelindex).get(4), val.get(modlelindex).get(5), // + val.get(modlelindex).get(7), val.get(modlelindex).get(6), // + hyperParameters); + normData.add(predTemp); + normData.remove(0); + trendPrediction[i] = (predTemp); + } + + preprocessing.setData(trendPrediction).setMean(mean).setStandardDeviation(standerDev); + + return to1DArrayList((double[]) preprocessing// + .reverseNormalize()// + .reverseScale()// + .execute()); + } + + /** + * Decodes a ZonedDateTime to its corresponding column index based on prediction + * interval and window size. + * + * @param predictionFor The ZonedDateTime for which the column index is to be + * decoded. + * @param hyperParameters The hyperparameters for the prediction model. + * @return The decoded column index for the given ZonedDateTime. If the index is + * negative, it is adjusted to the corresponding positive index for a + * 24-hour period. + */ + public static double decodeDateToColumnIndex(ZonedDateTime predictionFor, HyperParameters hyperParameters) { + var hour = predictionFor.getHour(); + var minute = predictionFor.getMinute(); + var index = (Integer) hour * (60 / hyperParameters.getInterval()) + minute / hyperParameters.getInterval(); + var modifiedIndex = index - hyperParameters.getWindowSizeTrend(); + if (modifiedIndex >= 0) { + return modifiedIndex; + } else { + return modifiedIndex + 60 / hyperParameters.getInterval() * 24; + } + } + + /** + * Re-arranges an ArrayList of Double values by splitting it at the specified + * index and moving the second part to the front. + * + * @param splitIndex The index at which the ArrayList will be split. + * @param singleArray An ArrayList of Double values to be re-arranged. + * @return A new ArrayList containing the Double values after re-arrangement. + */ + public static ArrayList getArranged(int splitIndex, ArrayList singleArray) { + var arranged = new ArrayList(); + var firstGroup = new ArrayList(); + var secondGroup = new ArrayList(); + + for (var i = 0; i < singleArray.size(); i++) { + if (i < splitIndex) { + firstGroup.add(singleArray.get(i)); + } else { + secondGroup.add(singleArray.get(i)); + } + } + + arranged.addAll(secondGroup); + arranged.addAll(firstGroup); + + return arranged; + } + + /** + * Calculates the index of a specific hour and minute combination within a + * 24-hour period, divided into 15-minute intervals. + * + * @param hour The hour component (0-23) to be used for the + * calculation. + * @param minute The minute component (0, 5, 10, ..., 55) to be used + * for the + * @param hyperParameters is the object of class HyperParameters, calculation. + * @return The index representing the specified hour and minute combination. + */ + public static Integer getIndex(Integer hour, Integer minute, HyperParameters hyperParameters) { + var k = 0; + for (var i = 0; i < 24; i++) { + for (var j = 0; j < (int) 60 / hyperParameters.getInterval(); j++) { + var h = i; + var m = j * hyperParameters.getInterval(); + if (hour == h && minute == m) { + return k; + } else { + k = k + 1; + } + } + } + return k; + } + + /** + * Predict output values based on input data and a list of model parameters for + * multiple instances. This method takes a list of input data instances and a + * list of model parameters and predicts output values for each instance using + * the model. + * + * @param inputData An ArrayList of ArrayLists of Doubles, where each + * inner ArrayList represents input data for one + * instance. + * @param val An ArrayList of ArrayLists of ArrayLists of Doubles + * representing the model parameters for each instance. + * Each innermost ArrayList should contain model + * parameters in the following order: 0: Input weight + * vector (wi) 1: Output weight vector (wo) 2: Recurrent + * weight vector (wz) 3: Recurrent input activations (rI) + * 4: Recurrent output activations (rO) 5: Recurrent + * update activations (rZ) 6: Current cell state (ct) 7: + * Current output (yt) + * @param hyperParameters instance of class HyperParamters data + * @return An ArrayList of Double values representing the predicted output for + * each input data instance. + */ + public static ArrayList predictPre(ArrayList> inputData, + ArrayList>> val, HyperParameters hyperParameters) { + + var result = new ArrayList(); + for (var i = 0; i < inputData.size(); i++) { + + var wi = val.get(i).get(0); + var wo = val.get(i).get(1); + var wz = val.get(i).get(2); + var rI = val.get(i).get(3); + var rO = val.get(i).get(4); + var rZ = val.get(i).get(5); + var ct = val.get(i).get(7); + var yt = val.get(i).get(6); + + result.add(predict(inputData.get(i), wi, wo, wz, rI, rO, rZ, ct, yt, hyperParameters)); + } + return result; + } + + /** + * Predict the output values based on input data and model parameters. This + * method takes input data and a set of model parameters and predicts output + * values for each data point using the model. + * + * @param data A 2D array representing the input data where each row + * is a data point. + * @param val An ArrayList containing model parameters, including + * weight vectors and activation values. The ArrayList + * should contain the following sublists in this order: + * 0: Input weight vector (wi) 1: Output weight vector + * (wo) 2: Recurrent weight vector (wz) 3: Recurrent + * input activations (rI) 4: Recurrent output activations + * (rO) 5: Recurrent update activations (rZ) 6: Current + * output (yt) 7: Current cell state (ct) + * + * @param hyperParameters instance of class HyperParamters data + * + * @return An ArrayList of Double values representing the predicted output for + * each input data point. + * + */ + public static ArrayList predictPre(double[][] data, List> val, + HyperParameters hyperParameters) { + + var result = new ArrayList(); + + var wi = val.get(0); + var wo = val.get(1); + var wz = val.get(2); + var rI = val.get(3); + var rO = val.get(4); + var rZ = val.get(5); + var yt = val.get(6); + var ct = val.get(7); + + for (var i = 0; i < data.length; i++) { + result.add(predict(data[i], wi, wo, wz, rI, rO, rZ, yt, ct, hyperParameters)); + } + return result; + } + + /** + * Predict an output value based on input data and model parameters. This method + * predicts a single output value based on input data and a set of model + * parameters for a LSTM model. + * + * @param inputData An ArrayList of Doubles representing the input data + * for prediction. + * @param wi An ArrayList of Doubles representing the input weight + * vector (wi) for the RNN model. + * @param wo An ArrayList of Doubles representing the output weight + * vector (wo) for the RNN model. + * @param wz An ArrayList of Doubles representing the recurrent + * weight vector (wz) for the RNN model. + * @param rI An ArrayList of Doubles representing the recurrent + * input activations (rI) for the RNN model. + * @param rO An ArrayList of Doubles representing the recurrent + * output activations (rO) for the RNN model. + * @param rZ An ArrayList of Doubles representing the recurrent + * update activations (rZ) for the RNN model. + * @param cta An ArrayList of Doubles representing the current cell + * state (ct) for the RNN model. + * @param yta An ArrayList of Doubles representing the current + * output (yt) for the RNN model. + * @param hyperParameters instance of class HyperParamters data + * @return A double representing the predicted output value based on the input + * data and model parameters. + */ + public static double predict(ArrayList inputData, ArrayList wi, ArrayList wo, + ArrayList wz, ArrayList rI, ArrayList rO, ArrayList rZ, + ArrayList cta, ArrayList yta, HyperParameters hyperParameters) { + var ct = hyperParameters.getCtInit(); + var yt = hyperParameters.getYtInit(); + var standData = inputData;// DataModification.standardize(inputData, hyperParameters); + + for (var i = 0; i < standData.size(); i++) { + var ctMinusOne = ct; + var yTMinusOne = yt; + var xt = standData.get(i); + var it = MathUtils.sigmoid(wi.get(i) * xt + rI.get(i) * yTMinusOne); + var ot = MathUtils.sigmoid(wo.get(i) * xt + rO.get(i) * yTMinusOne); + var zt = MathUtils.tanh(wz.get(i) * xt + rZ.get(i) * yTMinusOne); + ct = ctMinusOne + it * zt; + yt = ot * MathUtils.tanh(ct); + } + return yt; + } + + /** + * Predict an output value based on input data and model parameters. This method + * predicts a single output value based on input data and a set of model + * parameters for a LSTM model. + * + * @param inputData An ArrayList of Doubles representing the input data + * for prediction. + * @param wi An ArrayList of Doubles representing the input weight + * vector (wi) for the RNN model. + * @param wo An ArrayList of Doubles representing the output weight + * vector (wo) for the RNN model. + * @param wz An ArrayList of Doubles representing the recurrent + * weight vector (wz) for the RNN model. + * @param rI An ArrayList of Doubles representing the recurrent + * input activations (rI) for the RNN model. + * @param rO An ArrayList of Doubles representing the recurrent + * output activations (rO) for the RNN model. + * @param rZ An ArrayList of Doubles representing the recurrent + * update activations (rZ) for the RNN model. + * @param cta An ArrayList of Doubles representing the current cell + * state (ct) for the RNN model. + * @param yta An ArrayList of Doubles representing the current + * output (yt) for the RNN model. + * @param hyperParameters instance of class HyperParamters data + * @return A double representing the predicted output value based on the input + * data and model parameters. + */ + public static double predict(double[] inputData, ArrayList wi, ArrayList wo, ArrayList wz, + ArrayList rI, ArrayList rO, ArrayList rZ, ArrayList cta, + ArrayList yta, HyperParameters hyperParameters) { + var ct = hyperParameters.getCtInit(); + var yt = hyperParameters.getYtInit(); + var standData = inputData;// DataModification.standardize(inputData, hyperParameters); + + for (var i = 0; i < standData.length; i++) { + var ctMinusOne = ct; + var yTMinusOne = yt; + var xt = standData.length; + var it = MathUtils.sigmoid(wi.get(i) * xt + rI.get(i) * yTMinusOne); + var ot = MathUtils.sigmoid(wo.get(i) * xt + rO.get(i) * yTMinusOne); + var zt = MathUtils.tanh(wz.get(i) * xt + rZ.get(i) * yTMinusOne); + ct = ctMinusOne + it * zt; + yt = ot * MathUtils.tanh(ct); + } + return yt; + } + + /** + * Predict a focused output value based on input data and model parameters. This + * method predicts a single focused output value based on input data and a set + * of model parameters for a LSTM model with a focus on specific activations. + * + * @param inputData An ArrayList of Doubles representing the input data + * for prediction. + * @param wi An ArrayList of Doubles representing the input weight + * vector (wi) for the RNN model. + * @param wo An ArrayList of Doubles representing the output weight + * vector (wo) for the RNN model. + * @param wz An ArrayList of Doubles representing the recurrent + * weight vector (wz) for the RNN model. + * @param rI An ArrayList of Doubles representing the recurrent + * input activations (rI) for the RNN model. + * @param rO An ArrayList of Doubles representing the recurrent + * output activations (rO) for the RNN model. + * @param rZ An ArrayList of Doubles representing the recurrent + * update activations (rZ) for the RNN model. + * @param cta An ArrayList of Doubles representing the current cell + * state (ct) for the RNN model. + * @param yta An ArrayList of Doubles representing the current + * output (yt) for the RNN model. + * @param hyperParameters instance of class HyperParamters data + * @return A double representing the predicted focused output value based on the + * input data and model parameters. + */ + public static double predictFocoused(ArrayList inputData, ArrayList wi, ArrayList wo, + ArrayList wz, ArrayList rI, ArrayList rO, ArrayList rZ, + ArrayList cta, ArrayList yta, HyperParameters hyperParameters) { + var ct = hyperParameters.getCtInit(); + var yt = hyperParameters.getYtInit(); + + var standData = inputData; + + for (var i = 0; i < standData.size(); i++) { + var ctMinusOne = ct; + var ytMinusOne = yt; + var xt = standData.get(i); + var it = MathUtils.sigmoid(rI.get(i) * ytMinusOne); + var ot = MathUtils.sigmoid(rO.get(i) * ytMinusOne); + var zt = MathUtils.tanh(wz.get(i) * xt); + ct = ctMinusOne + it * zt; + yt = ot * MathUtils.tanh(ct); + } + return yt; + } +} diff --git a/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/common/DataStatistics.java b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/common/DataStatistics.java new file mode 100644 index 00000000000..a7a3f014f16 --- /dev/null +++ b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/common/DataStatistics.java @@ -0,0 +1,145 @@ +package io.openems.edge.predictor.lstmmodel.common; + +import java.util.Arrays; +import java.util.Collection; +import java.util.stream.IntStream; + +public class DataStatistics { + + /** + * Get the mean of the array. + * + * @param data the data + * @return mean value + */ + public static double getMean(Collection data) { + return data.stream().mapToDouble(Number::doubleValue).average().orElse(0.0); + } + + /** + * Calculates the mean (average) of each row in a 2D array of doubles and + * returns an ArrayList containing the means of each row. + * + * @param data a 2D array of doubles containing the data from which to calculate + * means + * @return an ArrayList of Double containing the means of each row + */ + public static double[] getMean(double[][] data) { + return Arrays.stream(data).mapToDouble(row -> Arrays.stream(row).average().orElse(0.0)).toArray(); + } + + /** + * Computes the mean (average) of an array of double values. * + *

+ * This method calculates the mean by summing all the elements in the input + * array and dividing by the number of elements. If the array is empty, it + * throws a NoSuchElementException. + *

+ * + * @param data the array of double values for which the mean is to be computed + * @return the mean of the input array + * @throws java.util.NoSuchElementException if the array is empty + */ + public static double getMean(double[] data) { + return Arrays.stream(data).parallel().average().getAsDouble(); + } + + /** + * Calculates the standard deviation of a list of double values. This method + * computes the standard deviation of the provided list of double values. + * Standard deviation measures the amount of variation or dispersion in the + * data. It is calculated as the square root of the variance, which is the + * average of the squared differences between each data point and the mean. When + * stander deviation is 0, the method returns a value close to zero to avoid + * divisible by 0 error + * + * @param data An ArrayList of double values for which to calculate the standard + * deviation. + * @return The standard deviation of the provided data as a double value. + * @throws IllegalArgumentException if the input list is empty. + */ + public static double getStandardDeviation(Collection data) { + double mean = getMean(data); + double sumSquaredDeviations = data.stream().mapToDouble(x -> Math.pow(x.doubleValue() - mean, 2)).sum(); + double variance = sumSquaredDeviations / data.size(); + double stdDeviation = Math.sqrt(variance); + return (stdDeviation == 0) ? 0.000000000000001 : stdDeviation; + } + + /** + * * calculates the deviation of the data from the expected error. THis method + * computes the average deviation from the expected error. + * + * @param data the data of type numbers + * @param expectedError the expected error + * @return stdDeviation the standard deviation + */ + public static double getStandardDeviation(Collection data, double expectedError) { + double mean = expectedError; + double sumSquaredDeviations = data.stream()// + .mapToDouble(x -> Math.pow(x.doubleValue() - mean, 2))// + .sum(); + double variance = sumSquaredDeviations / data.size(); + double stdDeviation = Math.sqrt(variance); + return (stdDeviation == 0) ? 0.000000000000001 : stdDeviation; + } + + /** + * Computes the standard deviation of an array of double values. + * + *

+ * This method calculates the mean of the input array, then computes the + * variance by finding the average of the squared differences from the mean. + * Finally, it returns the square root of the variance as the standard + * deviation. If the standard deviation is zero, a very small positive number + * (1e-15) is returned to avoid returning zero. + *

+ * + * @param data the array of double values for which the standard deviation is to + * be computed + * @return the standard deviation of the input array + */ + + public static double getStandardDeviation(double[] data) { + double mean = Arrays.stream(data).average().getAsDouble(); + double sumSquaredDeviations = Arrays.stream(data).map(x -> Math.pow(x - mean, 2)).sum(); + double variance = sumSquaredDeviations / data.length; + double stdDeviation = Math.sqrt(variance); + return (stdDeviation == 0) ? 0.000000000000001 : stdDeviation; + } + + /** + * Calculates the standard deviation of each row in a 2D array of doubles and + * returns an ArrayList containing the standard deviations of each row. + * + * @param data a 2D array of doubles containing the data from which to calculate + * standard deviations + * @return an ArrayList of Double containing the standard deviations of each row + */ + public static double[] getStandardDeviation(double[][] data) { + return Arrays.stream(data)// + .mapToDouble(row -> getStandardDeviation(row))// + .toArray(); + } + + /** + * Computes the root mean square (RMS) error between two arrays of double + * values. + * + * @param original the original array of double values + * @param computed the computed array of double values + * @return the RMS error between the original and computed arrays + * @throws IllegalArgumentException if the arrays have different lengths + */ + public static double computeRms(double[] original, double[] computed) { + if (original.length != computed.length) { + throw new IllegalArgumentException("Arrays must have the same length"); + } + + var sumOfSquaredDifferences = IntStream.range(0, original.length) + .mapToDouble(i -> Math.pow(original[i] - computed[i], 2))// + .average(); + + return Math.sqrt(sumOfSquaredDifferences.getAsDouble()); + } +} diff --git a/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/common/DynamicItterationValue.java b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/common/DynamicItterationValue.java new file mode 100644 index 00000000000..7876753043b --- /dev/null +++ b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/common/DynamicItterationValue.java @@ -0,0 +1,25 @@ +package io.openems.edge.predictor.lstmmodel.common; + +import java.util.ArrayList; +import java.util.Collections; + +public class DynamicItterationValue { + + public static int setIteration(ArrayList errors, int errorIndex, HyperParameters hyperParameters) { + + if (errors.isEmpty()) { + return 10; + } + + var minError = Collections.min(errors); + var maxError = Collections.max(errors); + var minIteration = 1; + var maxIteration = 10 * hyperParameters.getEpochTrack() + 1; + + var errorValue = errors.get(errorIndex); + var normalizedError = (errorValue - minError) / (maxError - minError); + var iterationValue = minIteration + (normalizedError * (maxIteration - minIteration)); + + return (int) Math.round(iterationValue); + } +} diff --git a/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/common/HyperParameters.java b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/common/HyperParameters.java new file mode 100644 index 00000000000..1ae7a819e3d --- /dev/null +++ b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/common/HyperParameters.java @@ -0,0 +1,932 @@ +package io.openems.edge.predictor.lstmmodel.common; + +import java.io.Serializable; +import java.time.OffsetDateTime; +import java.util.ArrayList; +import java.util.Collections; + +public class HyperParameters implements Serializable { + + private OffsetDateTime lastTrainedDate; + + public OffsetDateTime getLastTrainedDate() { + return this.lastTrainedDate; + } + + /** + * Serializable class version number for ensuring compatibility during + * serialization. + */ + private static final long serialVersionUID = 1L; + + /** + * Maximum iteration factor. + * + *

+ * This value is used by DynamicItterationValue class to set the + * gdItterationValue DynamicItterationValue class changes the classes such that + * the gdItteration value is in between 1 and maxItterFactor*current Epoch value + * +1 When epoch increase, learning rate decreases and the gdItteration value + * increases. Set the value always to 10. + *

+ */ + private final int maxItterFactor = 10; + + /** + * Upper limit for the learning rate. + * + *

+ * This value is used by the ADGRAD optimizer as the initial learning rate. The + * optimizer dynamically adjusts the learning rate over epochs, starting with + * the value of learningRateUpperLimit. The adjustment is typically aimed at + * improving convergence by starting with a higher learning rate and gradually + * decreasing it. + * + *

+ * This variable can be set to any value between 0 and 1. It is important to + * ensure that the value of learningRateUpperLimit is always greater than + * learnignRateLowerLimit to allow proper functioning of the dynamic learning + * rate setup. Default value: 0.01 + */ + private double learningRateUpperLimit = 0.01; + + /** + * Lower limit for the learning rate. + * + *

+ * This value is used by the ADGRAD optimizer as the minimum learning rate. As + * the training progresses, the optimizer adjusts the learning rate and it + * converges to the value of learnignRateLowerLimit by the final epoch. This + * helps in fine-tuning the model parameters and achieving better accuracy by + * the end of the training. + *

+ * This variable can be set to any value between 0 and 1. It is crucial that the + * value of learnignRateLowerLimit is always less than learningRateUpperLimit to + * enable the proper decreasing trend of the learning rate throughout the + * training process. Default value: 0.0001 + */ + private double learnignRateLowerLimit = 0.0001; + + /** + * Proportion of data to be used for training. + * + *

+ * This variable determines the fraction of the entire dataset that will be + * allocated for training purposes. The remaining portion of the dataset will be + * used for validation. The value of this variable should be within the range of + * 0 to 1, where: + *

    + *
  • 0 means 0% of the dataset is used for training (i.e., no training + * data).
  • + *
  • 1 means 100% of the dataset is used for training (i.e., no validation + * data).
  • + *
+ *

+ * The program utilizes this variable to split the input dataset vector into two + * separate vectors. One vector contains the training data, and the other vector + * contains the validation data. The split is essential for assessing the + * performance of the model on unseen data, helping to prevent overfitting and + * to ensure the model's generalizability. + */ + private double dataSplitTrain = 0.7; + + /** + * Proportion of data to be used for validation. + */ + private double dataSplitValidate = 1 - this.dataSplitTrain; + + private double wiInit = 0.2; + private double woInit = 0.2; + private double wzInit = 0.2; + private double riInit = 0.2; + private double roInit = 0.2; + private double rzInit = 0.2; + private double ytInit = 0.2; + private double ctInit = 0.2; + + /** + * Interval for logging or updating parameters. + */ + private int interval = 5; + + /** + * Size of each batch for training. + * + *

+ * To manage the computational load on the CPU during training, the training + * data is divided into smaller subsets called batches. + *

+ * + *

+ * For our LSTM (Long Short-Term Memory) model, a general rule of thumb is that + * datasets consisting of 30 days of data with 5-minute intervals should not be + * divided into batches greater than 2. This helps to balance the computational + * load and the memory usage during training. + *

+ * + *

+ * Considerations for setting the batch size: + *

+ *
    + *
  • If the training data size is large, more batches should be created to + * avoid excessive memory usage, which could lead to heap memory errors.
  • + *
  • If the training data size is small, fewer batches should be created to + * ensure each batch contains a sufficient number of samples for meaningful + * updates. Creating too many batches with too few samples can lead to index out + * of range errors during training.
  • + *
+ */ + private int batchSize = 10; + + /** + * Counter for tracking batches. + * + *

+ * This counter keeps track of the number of batches that have passed through + * the training process. + *

+ * - It updates after each batch completes its training. - In case the training + * is interrupted, this counter allows the process to resume from the last + * completed batch, ensuring continuity and efficiency in the training process. + * + *

+ * This mechanism is crucial for maintaining the state of the training process, + * especially in scenarios where interruptions may occur. + *

+ */ + private int batchTrack = 0; + + /** + * Number of epochs for training. + * + *

+ * An epoch refers to one complete pass through the entire training dataset. + * During each epoch, the model processes all the training data in batches, + * updating the model parameters iteratively. After each epoch, the learning + * rate can be adjusted, and the training process continues on the same dataset. + *

+ * + *

+ * The number of epochs is a crucial hyperparameter in training neural networks. + * More epochs generally mean that the model has more opportunities to learn + * from the data, potentially improving its performance. However, more epochs + * also mean longer training times and a higher risk of overfitting, where the + * model learns the training data too well and performs poorly on new, unseen + * data. + *

+ * + *

+ * It is recommended to keep the number of epochs in the range of 30 to 50 for a + * balanced approach between training time and model performance. Adjusting the + * number of epochs can be necessary based on the specific characteristics of + * the dataset and the complexity of the model. + *

+ */ + private int epoch = 10; + + /** + * Counter for tracking epochs. The counter updates after every time all batches + * undergoes training. This value is searilized along with the weights. in case + * training stops, this record is used to resme the training from the last stop + * point. + */ + private int epochTrack = 0; + + /** + * Number of predictions using trend weights. + * + *

+ * This parameter determines the number of predictions made based on the trend + * weights derived from the most recent trend window data. The trend window is a + * specific period used to analyze the trend patterns of the data. + * + *

+ * + *

+ * + * By default, one prediction is made using the last trend window data if this + * value is set to 1. This means that the system will use the data from the last + * trend window to make a single prediction. + * + *

+ * It is advisable to set this value to 12 if the interval between data points + * is 5 minutes Similarly, set this value to 8 if the interval between data + * points is 15 minutes . The interval represents the time or sequence gap + * between consecutive data points being analyzed. + * + *

+ * Setting a higher value than recommended can lead to inaccuracies in the + * prediction. This is because too many trend points may cause the model + * misinterpret the trend patterns, resulting in errors. + *

+ */ + private int trendPoints = 12; + + /** + * Window size for analyzing seasonality. + * + *

+ * This parameter defines the window size used for analyzing seasonal patterns + * in the data. A window size of 7 means that the model will use data from the + * last 7 days to train at one instance. Additionally, it will utilize the data + * from the last 7 days to predict data points for the next 24 hours. + *

+ * + *

+ * The window size can be adjusted up to a maximum of 14. While increasing the + * window size can potentially provide more accurate seasonal insights, it also + * increases the computational load. + *

+ * + *

+ * Key points: - Set to 7 to use the last 7 days of data for training and for + * predicting the next 24 hours. - The value can be adjusted up to 14. - Be + * aware that higher values may be computationally intensive. + *

+ */ + private int windowSizeSeasonality = 7; + /** + * Window size for analyzing trend. + * + *

+ * This parameter specifies the window size used for analyzing trend patterns in + * the data. A window size of 5 means that the model will consider data from the + * last 5 time intervals to analyze the trend. This helps in identifying the + * direction and strength of the trend over recent time periods. Keep the value + * in between 5 to 7 + *

+ */ + private int windowSizeTrend = 5; + + /** + * Number of iterations for gradient descent. + * + *

+ * This parameter defines the number of iterations to be performed during the + * gradient descent optimization process. Gradient descent is used to minimize + * the cost function by iteratively updating the model parameters. + *

+ * + *

+ * The number of iterations can be set between 1 and 100. A higher number of + * iterations can potentially lead to models with improved accuracy as the + * optimization process has more opportunities to converge to a minimum. + * However, increasing the number of iterations also increases the computation + * time required for training the model. + *

+ * + *

+ * Key points: - Set to 10 to perform 10 iterations of gradient descent. - Can + * be adjusted between 1 and 100 based on the trade-off between accuracy and + * computation time. - Higher values may improve model accuracy but will also + * increase computation time. + *

+ */ + private int gdIterration = 10; + + /** + * Counter for general tracking purposes. + * + *

+ * This counter is used to determine whether the training process is being + * executed for the first time. + *

+ * + *

+ * - If the count is 0, the algorithm will use the initial weights and start a + * new training process. - If the count value is greater than 0, the algorithm + * will continue training the existing models. + *

+ * + *

+ * This mechanism ensures that the model can distinguish between initializing + * new training sessions and performing subsequent training iterations. + *

+ * + *

+ * Note: Just like in programming, remember that if you start counting from 0, + * you're a true computer scientist! + *

+ */ + private int count = 0; + + /** + * Threshold error value. + * + *

+ * This value represents the threshold error, typically measured in the same + * units as the training data. It can also be considered as the allowed error + * margin. The Root Mean Square (RMS) error computed during the model evaluation + * reflects the average deviation from this threshold value. + *

+ * + *

+ * Key points: - Measured in the same units as the training data. - Represents + * the acceptable error margin. - RMS error indicates the average deviation from + * this threshold. + *

+ */ + + private double targetError = 0; + + /** + * Minimum value for scaling data. + * + *

+ * This value defines the minimum threshold for scaling the data. It should + * always be less than the `scalingMax` value. The unit of this value is the + * same as that of the training data. this valve can be negative and positive it + * id + *

+ * + *

+ * Once set, it is important not to change this value, as it could affect the + * consistency of the scaling process. + *

+ * + *

+ * Note: value once set should not be changed, as changing it is as risky as + * debugging a program on a Friday afternoon! + *

+ */ + private double scalingMin = 0; + + /** + * Maximum value for scaling data. + * + *

+ * This value defines the maximum threshold for scaling the data. It should + * always be greater than the `scalingMin` value. The unit of this value is the + * same as that of the training data. This value can be positive or negative, + * depending on the data range. + *

+ * + *

+ * Once set, it is important not to change this value, as it could affect the + * consistency of the scaling process. + *

+ * + *

+ * Note: Setting this value high is like aiming for the stars with your data! + * Just remember, changing it later could be as risky as giving a programmer, a + * cup of coffee after midnight! + *

+ * + */ + private double scalingMax = 20000; + + /** + * Model data structure for trend analysis. + * + *

+ * This is the brain of the model, responsible for storing updated weights and + * biases during the training process for trend analysis. + *

+ * + *

+ * The structure comprises nested arrays to store weights and biases: + *

+ * + *
+	 * [ [wi1,wi2, wi3, ..., wik], 
+	 *   [wo1, wo2, wo3, ..., wok], 
+	 *   [wz1, wz2, wz3, ..., wzk],
+	 *   [Ri1, Ri2, Ri3, ..., Rik], 
+	 *   [Ro1, Ro2, Ro3, ..., Rok], 
+	 *   [Rz1, Rz2, Rz3, ..., Rzk], 
+	 *   [Yt1, Yt2, Yt3, ..., Ytk], 
+	 *   [Ct1, Ct2, Ct3, ..., Ctk] ]
+	 * 
+	 * 
+ * + *

+ * Where Wi, Wo, Wz, Ri, Ro, Rz, Yt, and Ct are the weights and biases of the + * LSTM cells, and 1, 2, 3, ..., k represent the window size. + *

+ * + *

+ * The first two nested arrays ensure that the second nested array is available + * for every time depending on the interval. first element of second nested + * array is used for the prediction of the trend point for 00:05 (if the + * interval is 5) + *

+ * + *

+ * Fun Fact: This data structure holds the keys to predicting trends better than + * a psychic octopus predicting World Cup winners! + *

+ */ + private ArrayList>>> modelTrend = new ArrayList>>>(); + + /** + * Model data structure for seasonality analysis. + * + *

+ * This data structure serves as the backbone of the model, specifically + * designed to store updated weights and biases during the training process for + * seasonality analysis. + *

+ * + *

+ * The structure consists of nested ArrayLists to accommodate the weights and + * biases. + *

+ * + *
+	 * [ [wi1,wi2, wi3, ..., wik], 
+	 *   [wo1, wo2, wo3, ..., wok], 
+	 *   [wz1, wz2, wz3, ..., wzk],
+	 *   [Ri1, Ri2, Ri3, ..., Rik], 
+	 *   [Ro1, Ro2, Ro3, ..., Rok], 
+	 *   [Rz1, Rz2, Rz3, ..., Rzk], 
+	 *   [Yt1, Yt2, Yt3, ..., Ytk], 
+	 *   [Ct1, Ct2, Ct3, ..., Ctk] ]
+	 * 
+	 * 
+ * + *

+ * Where Wi, Wo, Wz are the weights of the LSTM cells, and 1, 2, 3, ..., k + * represent the window size. + *

+ * + *

+ * The first two nested arrays ensure that the second nested array is available + * for every time depending on the interval. first element of second nested + * array is used for the prediction of the trend point for 00:00 (if the + * interval is 5) + *

+ * + *

+ * Fun Fact: With this data structure, our model can predict seasonal pattern + * more accurately than a fortune-teller! + *

+ */ + private ArrayList>>> modelSeasonality = new ArrayList>>>(); + + /** + * List of all model errors related to trend analysis. + * + *

+ * This vector holds the Root Mean Square (RMS) errors of different models + * recorded during multiple training steps in modelTrend. + * + *

+ * + *

+ * Fun Fact: These errors are like the turn signals on a BMW - sometimes they're + * there, sometimes they're not, but they always keep us guessing and learning + * along the way! + *

+ */ + private ArrayList allModelErrorTrend = new ArrayList(); + + /** + * List of all model errors related to seasonality analysis. + * + *

+ * This vector contains the Root Mean Square (RMS) errors of different models + * recorded during multiple training steps in modelSeasonality. + *

+ * + *

+ * Fun Fact: These errors are like the various recipes for currywurst - some may + * be a bit spicier than others, but they all add flavor to our models, just + * like currywurst adds flavor to German cuisine! + *

+ */ + private ArrayList allModelErrorSeasonality = new ArrayList(); + + /** + * Mean value for normalization or scaling purposes. + * + *

+ * This value is crucial for ensuring proper normalization or scaling of the + * data. It acts as the central point around which the data is normalized or + * scaled. + * + *

+ * + *

+ * It's important to set this value to 0, just like it's important to feed your + * girlfriend when she's hungry, because, trust me, she can be mean when hungry! + *

+ */ + private double mean = 0; + + /** + * Standard deviation for normalization or scaling purposes. + * + *

+ * This value plays a crucial role in determining the spread or dispersion of + * the data during normalization or scaling. + *

+ */ + private double standerDeviation = 1; + + /** + * Root Mean Square Error (RMSE) for trend analysis. + * + *

+ * This list contains RMSE values for trend analysis. Unlike + * 'allModelErrorTrend', this list is limited in size to accommodate 60 divided + * by the interval multiplied by 24, and each value represents the RMSE of the + * model predicting for a specific time interval. + * + *

+ * The error at index 0 corresponds to the model predicting for 00:05, with + * subsequent indices representing subsequent time intervals. + */ + private ArrayList rmsErrorTrend = new ArrayList(); + + /** + * Root Mean Square Error (RMSE) for seasonality analysis. + * + *

+ * This list contains RMSE values for seasonality analysis. Each value + * represents the RMSE of the model's predictions related to seasonality. + *

+ */ + private ArrayList rmsErrorSeasonality = new ArrayList(); + + /** + * Counter for outer loop iterations, possibly for nested loops. Note: only used + * in unit test case + */ + private int outerLoopCount = 0; + + /** + * Name of the model. + */ + private String modelName = ""; + + public HyperParameters() { + } + + public void setLearningRateUpperLimit(double rate) { + this.learningRateUpperLimit = rate; + } + + public double getLearningRateUpperLimit() { + return this.learningRateUpperLimit; + } + + public void setLearningRateLowerLimit(double val) { + this.learnignRateLowerLimit = val; + } + + public double getLearningRateLowerLimit() { + return this.learnignRateLowerLimit; + } + + public void setWiInit(double val) { + this.wiInit = val; + } + + public double getWiInit() { + return this.wiInit; + } + + public void setWoInit(double val) { + this.woInit = val; + } + + public double getWoInit() { + return this.woInit; + } + + public void setWzInit(double val) { + this.wzInit = val; + } + + public double getWzInit() { + return this.wzInit; + } + + public void setriInit(double rate) { + this.riInit = rate; + } + + public double getRiInit() { + return this.riInit; + } + + public void setRoInit(double val) { + this.roInit = val; + } + + public double getRoInit() { + return this.roInit; + } + + public void setRzInit(double val) { + this.rzInit = val; + } + + public double getRzInit() { + return this.rzInit; + } + + public void setYtInit(double val) { + this.ytInit = val; + } + + public double getYtInit() { + return this.ytInit; + } + + public void setCtInit(double val) { + this.ctInit = val; + } + + public double getCtInit() { + return this.ctInit; + } + + public int getWindowSizeSeasonality() { + return this.windowSizeSeasonality; + } + + public int getGdIterration() { + return this.gdIterration; + } + + public void setGdIterration(int val) { + this.gdIterration = val; + } + + public int getWindowSizeTrend() { + return this.windowSizeTrend; + } + + public double getScalingMin() { + return this.scalingMin; + } + + public double getScalingMax() { + return this.scalingMax; + } + + public void setCount(int val) { + this.count = val; + } + + public int getCount() { + return this.count; + } + + public void setDatasplitTrain(double val) { + this.dataSplitTrain = val; + } + + public double getDataSplitTrain() { + return this.dataSplitTrain; + } + + public void setDatasplitValidate(double val) { + this.dataSplitValidate = val; + } + + public double getDataSplitValidate() { + return this.dataSplitValidate; + } + + public int getTrendPoint() { + return this.trendPoints; + } + + public int getEpoch() { + + return this.epoch; + } + + public int getInterval() { + return this.interval; + } + + public void setRmsErrorTrend(double val) { + this.rmsErrorTrend.add(val); + } + + public void setRmsErrorSeasonality(double val) { + this.rmsErrorSeasonality.add(val); + } + + public ArrayList getRmsErrorSeasonality() { + return this.rmsErrorSeasonality; + } + + public ArrayList getRmsErrorTrend() { + return this.rmsErrorTrend; + } + + public void setEpochTrack(int val) { + this.epochTrack = val; + } + + public int getEpochTrack() { + return this.epochTrack; + } + + public int getMinimumErrorModelSeasonality() { + return this.rmsErrorSeasonality.indexOf(Collections.min(this.rmsErrorSeasonality)); + } + + public int getMinimumErrorModelTrend() { + return this.rmsErrorTrend.indexOf(Collections.min(this.rmsErrorTrend)); + } + + public int getOuterLoopCount() { + return this.outerLoopCount; + } + + public void setOuterLoopCount(int val) { + this.outerLoopCount = val; + } + + public int getBatchSize() { + return this.batchSize; + } + + public int getBatchTrack() { + return this.batchTrack; + } + + public void setBatchTrack(int val) { + this.batchTrack = val; + } + + public void setModelName(String val) { + this.modelName = val; + } + + public String getModelName() { + return this.modelName; + } + + public double getMean() { + return this.mean; + + } + + public double getStanderDeviation() { + return this.standerDeviation; + } + + public double getTargetError() { + return this.targetError; + } + + public void setTargetError(double val) { + this.targetError = val; + } + + public int getMaxItter() { + return this.maxItterFactor; + } + + /** + * Updates the model trend with new values. + * + * @param val ArrayList of ArrayLists of ArrayLists of Double containing the new + * values to add to the model trend + */ + public void updatModelTrend(ArrayList>> val) { + this.modelTrend.add(val); + } + + /** + * Retrieves the most recently recorded model trend from the list of model + * trends. + * + * @return The most recently recorded model trend, represented as an ArrayList + * of ArrayLists of ArrayLists of Double. + */ + public ArrayList>> getlastModelTrend() { + return this.modelTrend.get(this.modelTrend.size() - 1); + } + + public ArrayList>> getBestModelTrend() { + return this.modelTrend.get(this.getMinimumErrorModelTrend()); + } + + public ArrayList>> getBestModelSeasonality() { + return this.modelSeasonality.get(this.getMinimumErrorModelSeasonality()); + } + + public ArrayList>>> getAllModelsTrend() { + return this.modelTrend; + } + + public ArrayList>>> getAllModelSeasonality() { + return this.modelSeasonality; + } + + public void setAllModelErrorTrend(ArrayList val) { + this.allModelErrorTrend = val; + } + + public void setAllModelErrorSeason(ArrayList val) { + this.allModelErrorSeasonality = val; + } + + public ArrayList getAllModelErrorTrend() { + return this.allModelErrorTrend; + } + + public ArrayList getAllModelErrorSeason() { + return this.allModelErrorSeasonality; + } + + /** + * Retrieves the last model trend from the list of model trends. + * + * @return ArrayList of ArrayLists of ArrayLists of Double representing the last + * model trend + */ + public ArrayList>> getlastModelSeasonality() { + return this.modelSeasonality.get(this.modelSeasonality.size() - 1); + } + + /** + * reset the error in the model. + */ + public void resetModelErrorValue() { + this.rmsErrorSeasonality = new ArrayList(); + this.rmsErrorTrend = new ArrayList(); + } + + /** + * Updates the model seasonality with new values. + * + * @param val The new model seasonality values to add, represented as an + * ArrayList of ArrayLists of ArrayLists of Double. + */ + public void updateModelSeasonality(ArrayList>> val) { + this.modelSeasonality.add(val); + } + + /** + * Prints the current values of hyperparameters and related attributes to the + * console. + */ + public void printHyperParameters() { + StringBuilder builder = new StringBuilder(); + + builder.append("learningRateUpperLimit = ").append(this.learningRateUpperLimit).append("\n"); + builder.append("learnignRateLowerLimit = ").append(this.learnignRateLowerLimit).append("\n"); + builder.append("wiInit = ").append(this.wiInit).append("\n"); + builder.append("woInit = ").append(this.woInit).append("\n"); + builder.append("wzInit = ").append(this.wzInit).append("\n"); + builder.append("riInit = ").append(this.riInit).append("\n"); + builder.append("roInit = ").append(this.roInit).append("\n"); + builder.append("rzInit = ").append(this.rzInit).append("\n"); + builder.append("ytInit = ").append(this.ytInit).append("\n"); + builder.append("ctInit = ").append(this.ctInit).append("\n"); + builder.append("Epoch = ").append(this.epoch).append("\n"); + builder.append("windowSizeSeasonality = ").append(this.windowSizeSeasonality).append("\n"); + builder.append("windowSizeTrend = ").append(this.windowSizeTrend).append("\n"); + builder.append("scalingMin = ").append(this.scalingMin).append("\n"); + builder.append("scalingMax = ").append(this.scalingMax).append("\n"); + builder.append("RMS error trend = ").append(this.getRmsErrorTrend()).append("\n"); + builder.append("RMS error Seasonlality =").append(this.getRmsErrorSeasonality()).append("\n"); + builder.append("Count value = ").append(this.count).append("\n"); + builder.append("Outer loop Count = ").append(this.outerLoopCount).append("\n"); + builder.append("Epoch track = ").append(this.epochTrack).append("\n"); + + System.out.println(builder.toString()); + } + + /** + * Updates the models and their corresponding error indices based on the minimum + * error values obtained from model trends and model seasonality. This method + * first retrieves the indices of models with minimum errors for both trends and + * seasonality. Then it retrieves the corresponding models and clears the + * existing model trends, model seasonality, RMS errors for trend, and RMS + * errors for seasonality. After that, it adds the retrieved models to the + * respective model lists and updates the RMS errors with the minimum error + * values. + */ + public void update() { + int minErrorIndTrend = this.getMinimumErrorModelTrend(); + int minErrorIndSeasonlity = this.getMinimumErrorModelSeasonality(); + + // uipdating models + var modelTrendTemp = this.modelTrend.get(minErrorIndTrend); + final var modelTempSeasonality = this.modelSeasonality.get(minErrorIndSeasonlity); + this.modelTrend.clear(); + this.modelSeasonality.clear(); + this.modelTrend.add(modelTrendTemp); + this.modelSeasonality.add(modelTempSeasonality); + // updating index + double minErrorTrend = this.rmsErrorTrend.get(minErrorIndTrend); + final double minErrorSeasonality = this.rmsErrorSeasonality.get(minErrorIndSeasonlity); + this.rmsErrorTrend.clear(); + this.rmsErrorSeasonality.clear(); + this.rmsErrorTrend.add(minErrorTrend); + this.rmsErrorSeasonality.add(minErrorSeasonality); + this.count = 1; + this.lastTrainedDate = OffsetDateTime.now(); + + } + +} diff --git a/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/common/OffsetDateTimeAdapter.java b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/common/OffsetDateTimeAdapter.java new file mode 100644 index 00000000000..d9d0cf13689 --- /dev/null +++ b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/common/OffsetDateTimeAdapter.java @@ -0,0 +1,29 @@ +package io.openems.edge.predictor.lstmmodel.common; + +import java.lang.reflect.Type; +import java.time.OffsetDateTime; +import java.time.format.DateTimeFormatter; + +import com.google.gson.JsonDeserializationContext; +import com.google.gson.JsonDeserializer; +import com.google.gson.JsonElement; +import com.google.gson.JsonParseException; +import com.google.gson.JsonPrimitive; +import com.google.gson.JsonSerializationContext; +import com.google.gson.JsonSerializer; + +public class OffsetDateTimeAdapter implements JsonSerializer, JsonDeserializer { + + private static final DateTimeFormatter FORMATTER = DateTimeFormatter.ISO_OFFSET_DATE_TIME; + + @Override + public JsonElement serialize(OffsetDateTime src, Type typeOfSrc, JsonSerializationContext context) { + return new JsonPrimitive(src.format(FORMATTER)); + } + + @Override + public OffsetDateTime deserialize(JsonElement json, Type typeOfT, JsonDeserializationContext context) + throws JsonParseException { + return OffsetDateTime.parse(json.getAsString(), FORMATTER); + } +} diff --git a/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/common/ReadAndSaveModels.java b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/common/ReadAndSaveModels.java new file mode 100644 index 00000000000..738e85ca769 --- /dev/null +++ b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/common/ReadAndSaveModels.java @@ -0,0 +1,158 @@ +package io.openems.edge.predictor.lstmmodel.common; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.File; +import java.io.FileNotFoundException; +import java.io.FileReader; +import java.io.FileWriter; +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.io.Reader; +import java.nio.file.Paths; +import java.time.OffsetDateTime; +import java.util.ArrayList; +import java.util.Base64; +import java.util.zip.DeflaterOutputStream; +import java.util.zip.InflaterInputStream; + +import com.google.gson.Gson; +import com.google.gson.GsonBuilder; + +import io.openems.common.OpenemsConstants; +import io.openems.edge.predictor.lstmmodel.validator.ValidationSeasonalityModel; +import io.openems.edge.predictor.lstmmodel.validator.ValidationTrendModel; + +public class ReadAndSaveModels { + + private static final String MODEL_DIRECTORY = Paths.get(OpenemsConstants.getOpenemsDataDir())// + .toFile()// + .getAbsolutePath(); + + private static final String MODEL_FOLDER = File.separator + "models" + File.separator; + + /** + * Saves the {@link HyperParameters} object to a file in JSON format. This + * method serializes the provided {@link HyperParameters} object into JSON + * format and saves it to a file with the specified name in the "models" + * directory. The serialization process utilizes a custom Gson instance + * configured to handle the serialization of OffsetDateTime objects. The file is + * saved in the directory specified by the OpenEMS data directory. + * + * @param hyperParameters The {@link HyperParameters} object to be saved. + */ + public static void save(HyperParameters hyperParameters) { + String modelName = hyperParameters.getModelName(); + String filePath = Paths.get(MODEL_DIRECTORY, MODEL_FOLDER, modelName)// + .toString(); + + Gson gson = new GsonBuilder()// + .registerTypeAdapter(OffsetDateTime.class, new OffsetDateTimeAdapter())// + .create(); + + try { + var compressedData = compress(hyperParameters); + var compressedDataString = Base64.getEncoder().encodeToString(compressedData); + var json = gson.toJson(compressedDataString); + + try (FileWriter writer = new FileWriter(filePath)) { + writer.write(json); + } + } catch (IOException e) { + e.printStackTrace(); + } + } + + /** + * Reads and de-serializes a {@link HyperParameters} object from a JSON file. + * This method reads a HyperParameters object from the specified JSON file, + * de-serializing it into a {@link HyperParameters} instance. The + * de-serialization process utilizes a custom Gson instance configured to handle + * the de-serialization of {@link OffsetDateTime} objects. The file is expected + * to be located in the "models" directory within the OpenEMS data directory. + * + * @param fileName The name of the JSON file to read the HyperParameters from. + * @return The {@link HyperParameters} object read from the file. + * @throws FileNotFoundException If the specified file is not found. + * @throws IOException If an I/O error occurs while reading the file. + */ + public static HyperParameters read(String fileName) { + + String filePath = Paths.get(MODEL_DIRECTORY, MODEL_FOLDER, fileName)// + .toString(); + + try (Reader reader = new FileReader(filePath)) { + Gson gson = new GsonBuilder()// + .registerTypeAdapter(OffsetDateTime.class, new OffsetDateTimeAdapter())// + .create(); + var json = gson.fromJson(reader, String.class); + var deserializedData = Base64.getDecoder().decode(json); + return decompress(deserializedData); + } catch (IOException e) { + var hyperParameters = new HyperParameters(); + hyperParameters.setModelName(fileName); + return hyperParameters; + } + } + + /** + * Compress the data. + * + * @param hyp the Hyper parameter object + * @return compressend byte array + */ + public static byte[] compress(HyperParameters hyp) { + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); + DeflaterOutputStream dos = new DeflaterOutputStream(baos); + ObjectOutputStream oos = new ObjectOutputStream(dos)) { + + oos.writeObject(hyp); + dos.finish(); + return baos.toByteArray(); + + } catch (IOException e) { + e.printStackTrace(); + return null; + } + } + + /** + * DeCompress the data. + * + * @param value the value array to decompress + * @return Hyper parameter + */ + public static HyperParameters decompress(byte[] value) { + HyperParameters hyperParameters = null; + try (ByteArrayInputStream bais = new ByteArrayInputStream(value); + InflaterInputStream iis = new InflaterInputStream(bais); + ObjectInputStream ois = new ObjectInputStream(iis)) { + hyperParameters = (HyperParameters) ois.readObject(); + } catch (IOException | ClassNotFoundException e) { + e.printStackTrace(); + } + return hyperParameters; + } + + /** + * Adapt it. + * + * @param hyperParameters the Hyperparameter + * @param data the data + * @param dates the dates + */ + public static void adapt(HyperParameters hyperParameters, ArrayList data, ArrayList dates) { + if (hyperParameters.getCount() == 0) { + return; + } + + var valSeas = new ValidationSeasonalityModel(); + var valTrend = new ValidationTrendModel(); + + hyperParameters.resetModelErrorValue(); + + valSeas.validateSeasonality(data, dates, hyperParameters.getAllModelSeasonality(), hyperParameters); + valTrend.validateTrend(data, dates, hyperParameters.getAllModelsTrend(), hyperParameters); + } +} diff --git a/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/common/ReadCsv.java b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/common/ReadCsv.java new file mode 100644 index 00000000000..43097788fd6 --- /dev/null +++ b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/common/ReadCsv.java @@ -0,0 +1,75 @@ +package io.openems.edge.predictor.lstmmodel.common; + +import java.io.BufferedReader; +import java.io.File; +import java.io.FileReader; +import java.io.IOException; +import java.nio.file.Paths; +import java.time.OffsetDateTime; +import java.util.ArrayList; + +import io.openems.common.OpenemsConstants; + +public class ReadCsv { + + private static final String MODEL_DIRECTORY = Paths.get(OpenemsConstants.getOpenemsDataDir())// + .toFile().getAbsolutePath(); + private static final String MODEL_FOLDER = File.separator + "models" + File.separator; + + private ArrayList data = new ArrayList(); + private ArrayList dates = new ArrayList(); + + public ReadCsv(String path) { + this.getDataFromCsv(path); + } + + /** + * Reads data from a CSV file and populates class fields with the data. This + * method reads data from a CSV file specified by the provided file name. Each + * line in the CSV file is expected to contain timestamped data points, where + * the first column represents timestamps in the ISO-8601 format and subsequent + * columns represent numeric data. The data is parsed, and the timestamps and + * numeric values are stored in class fields for further processing. + * + * @param fileName The name of the CSV file to read data from. + * @throws IOException if there are issues reading the file. + */ + public void getDataFromCsv(String fileName) { + + try { + var path = Paths.get(MODEL_DIRECTORY, MODEL_FOLDER, fileName)// + .toString(); + + var reader = new BufferedReader(new FileReader(path)); + var line = reader.readLine(); + + while (line != null) { + var parts = line.split(","); + var date = OffsetDateTime.parse(parts[0]); + var temp2 = 0.0; + + for (int i = 1; i < parts.length; i++) { + if (parts[i].equals("") || parts[i].equals("nan")) { + temp2 = Double.NaN; + } else { + temp2 = (Double.parseDouble(parts[i])); + } + } + this.dates.add(date); + this.data.add(temp2); + line = reader.readLine(); + } + reader.close(); + } catch (IOException e) { + e.printStackTrace(); + } + } + + public ArrayList getData() { + return this.data; + } + + public ArrayList getDates() { + return this.dates; + } +} diff --git a/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/interpolation/CubicalInterpolation.java b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/interpolation/CubicalInterpolation.java new file mode 100644 index 00000000000..41482c01e2e --- /dev/null +++ b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/interpolation/CubicalInterpolation.java @@ -0,0 +1,119 @@ +package io.openems.edge.predictor.lstmmodel.interpolation; + +import java.util.ArrayList; +import java.util.stream.IntStream; + +import org.apache.commons.math3.analysis.interpolation.SplineInterpolator; +import org.apache.commons.math3.analysis.polynomials.PolynomialSplineFunction; + +public class CubicalInterpolation extends SplineInterpolator { + + private ArrayList data; + + public CubicalInterpolation(ArrayList data) { + this.data = data; + } + + public CubicalInterpolation() { + } + + /** + * Compute Cubical interpolation. + * + * @return interpolated results + */ + public ArrayList compute() { + var interpolation = new ArrayList>(); + var function = this.getFunctionForAllInterval(this.data); + var differences = this.firstOrderDiff(function); + + for (int i = 0; i < differences.length; i++) { + if (differences[i] != 1) { + int requiredPoints = (int) (differences[i] - 1); + interpolation.add(this.calculate(function.getPolynomials()[i].getCoefficients(), requiredPoints)); + } + } + this.generateCombineInstruction(interpolation, differences); + return this.data; + } + + private PolynomialSplineFunction getFunctionForAllInterval(ArrayList data) { + var nonNaNCount = data.stream().filter(d -> !Double.isNaN(d)).count(); + + var dataNew = new double[(int) nonNaNCount]; + var xVal = new double[(int) nonNaNCount]; + + int[] index = { 0 }; + IntStream.range(0, data.size())// + .filter(i -> !Double.isNaN(data.get(i)))// + .forEach(i -> { + dataNew[index[0]] = data.get(i); + xVal[index[0]] = i + 1; + index[0]++; + }); + + return interpolate(xVal, dataNew); + } + + private double[] firstOrderDiff(PolynomialSplineFunction function) { + double[] knots = function.getKnots(); + return IntStream.range(0, knots.length - 1)// + .mapToDouble(i -> knots[i + 1] - knots[i])// + .toArray(); + } + + private ArrayList calculate(double[] weight, int requiredPoints) { + + ArrayList result = new ArrayList<>(); + for (int j = 0; j < requiredPoints; j++) { + double sum = 0; + for (int i = 0; i < weight.length; i++) { + sum += weight[i] * Math.pow(j + 1, i); + } + result.add(sum); + } + return result; + } + + private void generateCombineInstruction(ArrayList> interPolatedValue, double[] firstOrderDiff) { + + int count = 0; + int startingPoint = 0; + int addedData = 0; + + for (int i = 0; i < firstOrderDiff.length; i++) { + + if (firstOrderDiff[i] != 1) { + startingPoint = i + 1 + addedData; + this.combineToData(startingPoint, (int) firstOrderDiff[i] - 1, interPolatedValue.get(count)); + addedData = (int) (firstOrderDiff[i] - 1 + addedData); + count = count + 1; + } + } + } + + private void combineToData(int startingPoint, int totalpointsRequired, ArrayList dataToAdd) { + for (int i = 0; i < totalpointsRequired; i++) { + this.data.set(i + startingPoint, dataToAdd.get(i)); + } + } + + /** + * Can interpolate ?. + * + * @return boolean yes or no. + */ + public boolean canInterpolate() { + var nonNaNCount = this.data.stream().filter(d -> d != null && !Double.isNaN(d)).count(); + return this.data.size() > 4 && nonNaNCount > 2; + } + + public void setData(ArrayList val) { + this.data = val; + } + + public ArrayList getInterPolatedData() { + return this.data; + } + +} diff --git a/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/interpolation/InterpolationManager.java b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/interpolation/InterpolationManager.java new file mode 100644 index 00000000000..5981aca4141 --- /dev/null +++ b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/interpolation/InterpolationManager.java @@ -0,0 +1,149 @@ +package io.openems.edge.predictor.lstmmodel.interpolation; + +import java.time.OffsetDateTime; +import java.util.ArrayList; +import java.util.List; +import java.util.OptionalDouble; +import java.util.stream.Collectors; + +import io.openems.edge.predictor.lstmmodel.common.HyperParameters; +import io.openems.edge.predictor.lstmmodel.utilities.UtilityConversion; + +public class InterpolationManager { + + private ArrayList interpolated = new ArrayList(); + private ArrayList newDates = new ArrayList(); + + public ArrayList getInterpolatedData() { + return this.interpolated; + } + + public ArrayList getNewDates() { + return this.newDates; + } + + public InterpolationManager(double[] data, HyperParameters hyperParameters) { + var dataList = UtilityConversion.to1DArrayList(data); + this.makeInterpolation(dataList); + } + + public InterpolationManager(ArrayList data, HyperParameters hyperParameters) { + this.makeInterpolation(data); + } + + private void makeInterpolation(ArrayList data) { + ArrayList dataDouble = replaceNullWithNaN(data); + double mean = calculateMean(dataDouble); + + // TODO why 96 + int groupSize = 96; + + List> groupedData = group(dataDouble, groupSize); + + CubicalInterpolation inter = new CubicalInterpolation(); + + List> interpolatedGroupedData = groupedData.stream()// + .map(currentGroup -> { + if (this.interpolationDecision(currentGroup)) { + this.handleFirstAndLastDataPoint(currentGroup, mean); + inter.setData(currentGroup); + return inter.canInterpolate() ? inter.compute() : LinearInterpolation.interpolate(currentGroup); + } else { + return currentGroup; + } + }).collect(Collectors.toList()); + + this.interpolated = unGroup(interpolatedGroupedData); + + } + + private void handleFirstAndLastDataPoint(ArrayList currentGroup, double mean) { + int firstIndex = 0; + int lastIndex = currentGroup.size() - 1; + + if (Double.isNaN(currentGroup.get(firstIndex))) { + currentGroup.set(firstIndex, mean); + } + if (Double.isNaN(currentGroup.get(lastIndex))) { + currentGroup.set(lastIndex, mean); + } + } + + /** + * Checks whether interpolation is needed based on the presence of NaN values in + * the provided list. + * + * @param data The list of Double values to be checked. + * @return true if interpolation is needed (contains at least one NaN value), + * false otherwise. + */ + private boolean interpolationDecision(ArrayList data) { + return data.stream().anyMatch(value -> Double.isNaN(value)); + } + + /** + * Replaces null values with Double.NaN in the given ArrayList. + * + * @param data The ArrayList to be processed. + * @return A new ArrayList with null values replaced by Double.NaN. + */ + public static ArrayList replaceNullWithNaN(ArrayList data) { + return data.stream()// + .map(value -> (value == null) ? Double.NaN : value)// + .collect(Collectors.toCollection(ArrayList::new)); + } + + /** + * Calculates the mean (average) of a list of numeric values, excluding NaN + * values. + * + * @param data The list of numeric values from which to calculate the mean. + * @return The mean of the non-NaN numeric values in the input list. + */ + public static double calculateMean(ArrayList data) { + if (data.isEmpty()) { + return Double.NaN; + } + + OptionalDouble meanOptional = data.stream()// + .filter(value -> !Double.isNaN(value))// + .mapToDouble(Double::doubleValue)// + .average(); + + return meanOptional.orElse(Double.NaN); + } + + /** + * Ungroups a list of sublists into a single list. + * + * @param data The list of sublists to be ungrouped. + * @return A single list containing all elements from the sublists. + */ + public static ArrayList unGroup(List> data) { + return data.stream()// + .flatMap(List::stream)// + .collect(Collectors.toCollection(ArrayList::new)); + } + + /** + * Groups a list of data into sublists of a specified size. This method takes a + * list of data and groups it into sublists of a specified size. Each sublist + * will contain up to {@code groupSize} elements, except for the last sublist, + * which may contain fewer elements if the total number of elements is not a + * multiple of {@code groupSize}. + * + * @param data The list of data to be grouped. + * @param groupSize The maximum number of elements in each sublist. + * @return A list of sublists, each containing up to {@code groupSize} elements. + */ + public static ArrayList> group(ArrayList data, int groupSize) { + ArrayList> groupedData = new ArrayList<>(); + + for (int i = 0; i < data.size(); i += groupSize) { + ArrayList sublist = new ArrayList<>(data.subList(i, Math.min(i + groupSize, data.size()))); + groupedData.add(sublist); + } + return groupedData; + } + +} diff --git a/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/interpolation/LinearInterpolation.java b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/interpolation/LinearInterpolation.java new file mode 100644 index 00000000000..bd335074a52 --- /dev/null +++ b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/interpolation/LinearInterpolation.java @@ -0,0 +1,102 @@ +package io.openems.edge.predictor.lstmmodel.interpolation; + +import java.util.ArrayList; + +public class LinearInterpolation { + + /** + * Interpolates NaN values in the provided data set. + * + * @param data The input data set with NaN values. + * @return The data set with NaN values replaced by interpolated values. + */ + public static ArrayList interpolate(ArrayList data) { + + ArrayList> coordinate = determineInterpolatingPoints(data); + for (int i = 0; i < coordinate.size(); i++) { + var xVal1 = coordinate.get(i).get(0); + var xVal2 = coordinate.get(i).get(1); + + var ineterPolationResult = computeInterpolation(xVal1, xVal2, data.get(xVal1), data.get((int) xVal2)); + data = combine(data, ineterPolationResult, xVal1, xVal2); + + } + return data; + } + + /** + * Determines the indices where NaN values are sandwiched between non-NaN values + * in a given data set. + * + * @param data The input data set. + * @return A list of coordinate pairs representing the indices where NaN values + * are sandwiched. + */ + public static ArrayList> determineInterpolatingPoints(ArrayList data) { + + ArrayList> coordinates = new ArrayList<>(); + + var inNaNSequence = false; + var xVal1 = -1; + + for (int i = 0; i < data.size(); i++) { + var currentValue = data.get(i); + + if (Double.isNaN(currentValue)) { + if (!inNaNSequence) { + xVal1 = i - 1; + inNaNSequence = true; + } + } else { + if (inNaNSequence) { + var xVal2 = i; + ArrayList temp = new ArrayList<>(); + temp.add(xVal1); + temp.add(xVal2); + coordinates.add(temp); + inNaNSequence = false; + } + } + } + return coordinates; + } + + /** + * Computes linear interpolation between two values. + * + * @param xValue1 The x-value corresponding to the first data point. + * @param xValue2 The x-value corresponding to the second data point. + * @param yValue1 The y-value corresponding to the first data point. + * @param yValue2 The y-value corresponding to the second data point. + * @return A list of interpolated y-values between xValue1 and xValue2. + */ + public static ArrayList computeInterpolation(int xValue1, int xValue2, double yValue1, double yValue2) { + var interPolatedResults = new ArrayList(); + var xVal1 = (double) xValue1; + var xVal2 = (double) xValue2; + + for (int i = 1; i < (xValue2 - xValue1); i++) { + interPolatedResults + .add((yValue1 * ((xVal2 - (i + xVal1)) / (xVal2 - xVal1)) + yValue2 * ((i) / (xVal2 - xVal1)))); + } + return interPolatedResults; + } + + /** + * Combines the original data set with the interpolation result. + * + * @param orginalData The original data set. + * @param interpolatedResult The result of linear interpolation. + * @param xValue1 The first index used for interpolation. + * @param xValue2 The second index used for interpolation. + * @return The combined data set with interpolated values. + */ + public static ArrayList combine(ArrayList orginalData, ArrayList interpolatedResult, + int xValue1, int xValue2) { + + for (int i = 0; i < (interpolatedResult.size()); i++) { + orginalData.set((i + xValue1 + 1), interpolatedResult.get(i)); + } + return orginalData; + } +} diff --git a/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/jsonrpc/GetPredictionRequest.java b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/jsonrpc/GetPredictionRequest.java new file mode 100644 index 00000000000..f501169bf0a --- /dev/null +++ b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/jsonrpc/GetPredictionRequest.java @@ -0,0 +1,50 @@ +package io.openems.edge.predictor.lstmmodel.jsonrpc; + +import com.google.gson.JsonObject; + +import io.openems.common.exceptions.OpenemsException; +import io.openems.common.jsonrpc.base.JsonrpcRequest; + +/* + * url = http://localhost:8084/jsonrpc + * { + * "method": "componentJsonApi", + * "params": { + * "componentId": "predictor0", + * "payload": { + * "method": "getLstmPrediction", + * "params": { + * "id": "edge0" + * } + * } + * } +*} + */ +public class GetPredictionRequest extends JsonrpcRequest { + + public static final String METHOD = "getLstmPrediction"; + + /** + * get predictions. + * + * @param r the request + * @return new prediction + * @throws on error + */ + public static GetPredictionRequest from(JsonrpcRequest r) throws OpenemsException { + return new GetPredictionRequest(r); + } + + public GetPredictionRequest() { + super(METHOD); + } + + private GetPredictionRequest(JsonrpcRequest request) { + super(request, METHOD); + } + + @Override + public JsonObject getParams() { + return new JsonObject(); + } +} diff --git a/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/jsonrpc/GetPredictionResponse.java b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/jsonrpc/GetPredictionResponse.java new file mode 100644 index 00000000000..2da8233535c --- /dev/null +++ b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/jsonrpc/GetPredictionResponse.java @@ -0,0 +1,55 @@ +package io.openems.edge.predictor.lstmmodel.jsonrpc; + +import java.time.ZonedDateTime; +import java.util.SortedMap; +import java.util.UUID; + +import com.google.gson.JsonArray; +import com.google.gson.JsonNull; +import com.google.gson.JsonObject; +import com.google.gson.JsonPrimitive; + +import io.openems.common.jsonrpc.base.JsonrpcResponseSuccess; +import io.openems.common.utils.JsonUtils; +import io.openems.common.utils.JsonUtils.JsonObjectBuilder; + +public class GetPredictionResponse extends JsonrpcResponseSuccess { + + private JsonArray prediction; + private SortedMap predictionResult; + + public GetPredictionResponse(JsonArray prediction) { + this(UUID.randomUUID(), prediction); + } + + public GetPredictionResponse(UUID id, JsonArray prediction) { + super(id); + this.prediction = prediction != null ? prediction : new JsonArray(); + this.predictionResult = null; + } + + public GetPredictionResponse(UUID id, SortedMap predictionResult) { + super(id); + this.predictionResult = predictionResult; + this.prediction = new JsonArray(); + if (predictionResult != null) { + predictionResult.values().forEach(value -> { + this.prediction.add(value != null ? new JsonPrimitive(value) : JsonNull.INSTANCE); + }); + } + } + + @Override + public JsonObject getResult() { + JsonObjectBuilder result = JsonUtils.buildJsonObject() // + .add("prediction", this.prediction) // + .add("size", new JsonPrimitive(this.prediction.size())); + + if (this.predictionResult != null) { + result.add("TimeValueMap", new JsonPrimitive(this.predictionResult.toString())); + } else { + result.add("timeValueMap", JsonNull.INSTANCE); + } + return result.build(); + } +} diff --git a/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/jsonrpc/PredictionRequestHandler.java b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/jsonrpc/PredictionRequestHandler.java new file mode 100644 index 00000000000..b94f838a7d7 --- /dev/null +++ b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/jsonrpc/PredictionRequestHandler.java @@ -0,0 +1,24 @@ +package io.openems.edge.predictor.lstmmodel.jsonrpc; + +import java.util.UUID; + +import io.openems.common.types.ChannelAddress; +import io.openems.edge.predictor.api.manager.PredictorManager; + +public class PredictionRequestHandler { + + /** + * get predictionsReasponse. + * + * @param requestId the id + * @param predictionManager the manager + * @param channelAddress the {@link ChannelAddress} + * @return the new prediction + */ + public static GetPredictionResponse handlerGetPredictionRequest(UUID requestId, PredictorManager predictionManager, + ChannelAddress channelAddress) { + + var sortedMap = predictionManager.getPrediction(channelAddress).valuePerQuarter; + return new GetPredictionResponse(requestId, sortedMap); + } +} diff --git a/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/performance/PerformanceMatrix.java b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/performance/PerformanceMatrix.java new file mode 100644 index 00000000000..b721a8a14fe --- /dev/null +++ b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/performance/PerformanceMatrix.java @@ -0,0 +1,273 @@ +package io.openems.edge.predictor.lstmmodel.performance; + +import java.util.ArrayList; +import java.util.List; + +import org.apache.commons.math3.distribution.TDistribution; +import org.apache.commons.math3.stat.StatUtils; + +import io.openems.edge.predictor.lstmmodel.common.DataStatistics; + +public class PerformanceMatrix { + private ArrayList target = new ArrayList(); + private ArrayList predicted = new ArrayList(); + private double allowedError = 0.0; + + public PerformanceMatrix(ArrayList tar, ArrayList predict, double allowedErr) { + this.target = tar; + this.predicted = predict; + this.allowedError = allowedErr; + } + + /** + * Calculates the mean absolute error between the target and predicted values. + * Mean absolute error (MAE) is a metric that measures the average absolute + * difference between corresponding elements of two lists. + * + * @param target The list of target values. + * @param predicted The list of predicted values. + * @return The mean absolute error between the target and predicted values. + * @throws IllegalArgumentException If the input lists have different sizes. + */ + public static double meanAbsoluteError(ArrayList target, ArrayList predicted) { + + if (predicted.size() != target.size()) { + throw new IllegalArgumentException("Input lists must have the same size"); + } + + double sumError = 0.0; + for (int i = 0; i < predicted.size(); i++) { + double error = Math.abs(predicted.get(i) - target.get(i)); + sumError += error; + } + + return sumError / predicted.size(); + } + + /** + * Calculates the Root Mean Square (RMS) error between the target and predicted + * values. RMS error is a measure of the average magnitude of the differences + * between corresponding elements of two lists. + * + * @param target The list of target values. + * @param predicted The list of predicted values. + * @return The root mean square error between the target and predicted values. + * @throws IllegalArgumentException If the input lists have different sizes. + */ + public static double rmsError(ArrayList target, ArrayList predicted) { + if (predicted.size() != target.size()) { + throw new IllegalArgumentException("Input lists must have the same size"); + } + + double sumSquaredError = 0.0; + for (int i = 0; i < predicted.size(); i++) { + double error = predicted.get(i) - target.get(i); + sumSquaredError += error * error; + } + + double meanSquaredError = sumSquaredError / predicted.size(); + return Math.sqrt(meanSquaredError); + } + + /** + * Calculate the RmsError of two arrays. + * + * @param target double array of target + * @param predicted double array of predicted + * @return rms Error + */ + public static double rmsError(double[] target, double[] predicted) { + if (predicted.length != target.length) { + throw new IllegalArgumentException("Input lists must have the same size"); + } + + double sumSquaredError = 0.0; + for (int i = 0; i < predicted.length; i++) { + double error = predicted[i] - target[i]; + sumSquaredError += error * error; + } + + double meanSquaredError = sumSquaredError / predicted.length; + return Math.sqrt(meanSquaredError); + } + + /** + * Calculates the Mean Squared Error (MSE) between the target and predicted + * values. MSE is a measure of the average squared differences between + * corresponding elements of two lists. + * + * @param target The list of target values. + * @param predicted The list of predicted values. + * @return The mean squared error between the target and predicted values. + * @throws IllegalArgumentException If the input lists have different sizes. + */ + public static double meanSquaredError(ArrayList target, ArrayList predicted) { + if (predicted.size() != target.size()) { + throw new IllegalArgumentException("Input lists must have the same size"); + } + + double sumSquaredError = 0.0; + for (int i = 0; i < predicted.size(); i++) { + double error = predicted.get(i) - target.get(i); + sumSquaredError += error * error; + } + + return sumSquaredError / predicted.size(); + } + + /** + * Calculates the accuracy between the target and predicted values within a + * specified allowed percentage difference. + * + * @param target The list of target values. + * @param predicted The list of predicted values. + * @param allowedPercentage The maximum allowed percentage difference for + * accuracy. + * @return The accuracy between the target and predicted values. + */ + public static double accuracy(ArrayList target, ArrayList predicted, double allowedPercentage) { + double count = 0; + + for (int i = 0; i < predicted.size(); i++) { + double diff = Math.abs(predicted.get(i) - target.get(i)) // + / Math.max(predicted.get(i), target.get(i)); + if (diff <= allowedPercentage) { + count++; + } + } + return (double) count / predicted.size(); + } + + /** + * Calculate the Accuracy of the predicted compared to target. + * + * @param target double array of target + * @param predicted double array of predicted + * @param allowedPercentage allowed percentage error + * @return accuracy + */ + public static double accuracy(double[] target, double[] predicted, double allowedPercentage) { + double count = 0; + + for (int i = 0; i < predicted.length; i++) { + double diff = Math.abs(predicted[i] - target[i]) // + / Math.max(predicted[i], target[i]); + if (diff <= allowedPercentage) { + count++; + } + } + return (double) count / predicted.length; + } + + /** + * Calculates the Mean Absolute Percentage Error (MAPE) between the target and + * predicted values. MAPE is a measure of the average percentage difference + * between corresponding elements of two lists. + * + * @param target The list of target values. + * @param predicted The list of predicted values. + * @return The mean absolute percentage error between the target and predicted + * values. + * @throws IllegalArgumentException If the input lists have different sizes. + */ + public static double meanAbslutePercentage(ArrayList target, ArrayList predicted) { + if (predicted.size() != target.size()) { + throw new IllegalArgumentException("Input lists must have the same size"); + } + + double sumPercentageError = 0.0; + for (int i = 0; i < predicted.size(); i++) { + double absoluteError = Math.abs(predicted.get(i) - target.get(i)); + double percentageError = absoluteError / target.get(i) * 100.0; + sumPercentageError += percentageError; + } + + return sumPercentageError / predicted.size(); + } + + /** + * Calculates the two-tailed p-value using the t-statistic for the differences + * between predicted and actual values. + * + * @param target The list of target values. + * @param predicted The list of predicted values. + * @return The two-tailed p-value for the differences between predicted and + * actual values. + * @throws IllegalArgumentException If the input lists have different sizes. + */ + public static double pvalue(ArrayList target, ArrayList predicted) { + if (predicted.size() != target.size()) { + throw new IllegalArgumentException("Input lists must have the same size."); + } + + List differences = new ArrayList<>(); + for (int i = 0; i < predicted.size(); i++) { + differences.add(predicted.get(i) - target.get(i)); + } + + double[] differencesArray = differences.stream()// + .mapToDouble(Double::doubleValue).toArray(); + double mean = StatUtils.mean(differencesArray); + double stdDev = Math.sqrt(StatUtils.variance(differencesArray)); + + // Calculate the t-statistic + double tStat = mean / (stdDev / Math.sqrt(predicted.size())); + + // Degrees of freedom + int degreesOfFreedom = predicted.size() - 1; + + // Create a T-distribution with the appropriate degrees of freedom + TDistribution tDistribution = new TDistribution(degreesOfFreedom); + + // Calculate the two-tailed p-value + double pValue = 2 * (1.0 - tDistribution.cumulativeProbability(Math.abs(tStat))); + + return pValue; + } + + /** + * Generates and prints a performance report containing various statistical + * metrics and error measures between the actual and predicted data. The report + * includes average, standard deviation, mean absolute error, RMS error, mean + * squared error, mean absolute percentage error, and accuracy with a specified + * error margin. Note: This method assumes that the necessary statistical + * methods (e.g., meanAbsoluteError, rmsError, meanSquaredError, + * meanAbslutePercentage, accuracy) are implemented in the same class. The + * p-value calculation is not included in the report by default. + */ + public void statusReport() { + System.out.println("\n.................. Performance Report ............................."); + + // Calculate and display statistics for actual data + double averageActual = DataStatistics.getMean(this.target); + double stdDevActual = DataStatistics.getStandardDeviation(this.target); + System.out.println("Average of actual data = " + averageActual); + System.out.println("Standard deviation of actual data = " + stdDevActual); + + // Calculate and display statistics for predicted data + double averagePredicted = DataStatistics.getMean(this.predicted); + double stdDevPredicted = DataStatistics.getStandardDeviation(this.predicted); + System.out.println("Average of prediction data = " + averagePredicted); + System.out.println("Standard deviation of predicted data = " + stdDevPredicted); + + // Display various error metrics + System.out.println("Mean absolute error = " + meanAbsoluteError(this.target, this.predicted) + + " (average absolute difference between predicted and actual values)"); + + System.out.println("RMS error = " + rmsError(this.target, this.predicted) + " (square root of the MSE)"); + + System.out.println("Mean squared error = " + meanSquaredError(this.target, this.predicted) + + " (average of the squared differences between predicted and actual values)"); + + System.out.println("Mean absolute percentage error = " + meanAbslutePercentage(this.target, this.predicted) + + " (measures the average percentage difference between predicted and actual values)"); + + // Display accuracy with the specified error margin + double accuracyPercentage = accuracy(this.target, this.predicted, this.allowedError) * 100; + + System.out.println("Accuracy for " + this.allowedError * 100 + "% error margin = " + accuracyPercentage + "%"); + + System.out.println(""); + } + +} diff --git a/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/preprocessing/DataModification.java b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/preprocessing/DataModification.java new file mode 100644 index 00000000000..93df1bc94b4 --- /dev/null +++ b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/preprocessing/DataModification.java @@ -0,0 +1,730 @@ +package io.openems.edge.predictor.lstmmodel.preprocessing; + +import static io.openems.edge.predictor.lstmmodel.common.DataStatistics.getMean; +import static io.openems.edge.predictor.lstmmodel.common.DataStatistics.getStandardDeviation; + +import java.time.OffsetDateTime; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import io.openems.edge.predictor.lstmmodel.common.HyperParameters; + +public class DataModification { + + private static final double MIN_SCALED = 0.2; + private static final double MAX_SCALED = 0.8; + + /** + * Scales a list of numeric data values to a specified range. This method scales + * a list of numeric data values to a specified range defined by the minimum + * (min) and maximum (max) values. The scaled data will be within the range + * defined by the minimumScaled (minScaled) and maximumScaled (maxScaled) + * values. + * + * @param data The list of numeric data values to be scaled. + * @param min The original minimum value in the data. + * @param max The original maximum value in the data. + * @return A new list containing the scaled data within the specified range. + */ + public static ArrayList scale(ArrayList data, double min, double max) { + return data.stream()// + .map(value -> MIN_SCALED + ((value - min) / (max - min)) * (MAX_SCALED - MIN_SCALED)) + .collect(Collectors.toCollection(ArrayList::new)); + } + + /** + * * Scales a list of numeric data values to a specified range. This method + * scales a list of numeric data values to a specified range defined by the + * minimum (min) and maximum (max) values. The scaled data will be within the + * range defined by the minimumScaled (minScaled) and maximumScaled (maxScaled) + * values. + * + * @param data The array of numeric data values to be scaled. + * @param min The original minimum value in the data. + * @param max The original maximum value in the data. + * @return A new list containing the scaled data within the specified range. + */ + public static double[] scale(double[] data, double min, double max) { + return Arrays.stream(data)// + .map(value -> MIN_SCALED + ((value - min) / (max - min)) * (MAX_SCALED - MIN_SCALED))// + .toArray(); + } + + /** + * Re-scales a single data point from the scaled range to the original range. + * This method re-scales a single data point from the scaled range (defined by + * 'minScaled' and 'maxScaled') back to the original range, which is specified + * by 'minOriginal' and 'maxOriginal'. It performs the reverse scaling operation + * for a single data value. + * + * @param scaledData The data point to be rescaled from the scaled range to the + * original range. + * @param minOriginal The minimum value of the training dataset (original data + * range). + * @param maxOriginal The maximum value of the training dataset (original data + * range). + * @return The rescaled data point in the original range. + */ + public static double scaleBack(double scaledData, double minOriginal, double maxOriginal) { + return calculateScale(scaledData, MIN_SCALED, MAX_SCALED, minOriginal, maxOriginal); + } + + /** + * Scales back a list of double values from a scaled range to the original + * range. This method takes a list of scaled values and scales them back to + * their original range based on the specified minimum and maximum values of the + * original range. + * + * @param data The list of double values to be scaled back. + * @param minOriginal The minimum value of the original range. + * @param maxOriginal The maximum value of the original range. + * @return A new ArrayList containing the scaled back values. + */ + public static ArrayList scaleBack(ArrayList data, double minOriginal, double maxOriginal) { + return data.stream()// + .map(value -> calculateScale(value, MIN_SCALED, MAX_SCALED, minOriginal, maxOriginal))// + .collect(Collectors.toCollection(ArrayList::new)); + } + + /** + * * Scales back a list of double values from a scaled range to the original + * range. This method takes a list of scaled values and scales them back to + * their original range based on the specified minimum and maximum values of the + * original range. + * + * @param data The list of double values to be scaled back. + * @param minOriginal The minimum value of the original range. + * @param maxOriginal The maximum value of the original range. + * @return A new ArrayList containing the scaled back values. + */ + public static double[] scaleBack(double[] data, double minOriginal, double maxOriginal) { + return Arrays.stream(data)// + .map(value -> calculateScale(value, MIN_SCALED, MAX_SCALED, minOriginal, maxOriginal))// + .toArray(); + } + + /** + * Scales a value from a scaled range back to the original range. + * + * @param valScaled The value in the scaled range to be converted back to the + * original range. + * @param minScaled The minimum value of the scaled range. + * @param maxScaled The maximum value of the scaled range. + * @param minOriginal The minimum value of the original range. + * @param maxOriginal The maximum value of the original range. + * @return The value converted back to the original range. + */ + private static double calculateScale(double valScaled, double minScaled, double maxScaled, double minOriginal, + double maxOriginal) { + return ((valScaled - minScaled) * (maxOriginal - minOriginal) / (maxScaled - minScaled)// + ) + minOriginal; + } + + /** + * Normalize a 2D array of data using standardization (z-score normalization). + * This method normalizes a 2D array of data by applying standardization + * (z-score normalization) to each row independently. The result is a new 2D + * array of normalized data. + * + * @param data The 2D array of data to be normalized. + * @param hyperParameters instance of class HyperParameters + * @return A new 2D array containing the standardized (normalized) data. + */ + public static double[][] normalizeData(double[][] data, HyperParameters hyperParameters) { + double[][] standData; + standData = new double[data.length][data[0].length];// Here error + for (int i = 0; i < data.length; i++) { + standData[i] = standardize(data[i], hyperParameters); + } + return standData; + } + + /** + * Normalizes the data based on the given target values, using standardization. + * This method calculates the standardization of each data point in the input + * data array with respect to the corresponding target value. It utilizes the + * mean and standard deviation of the input data array to perform the + * standardization. + * + * @param data The input data array containing the features to be + * normalized. + * @param target The target values to which the data will be + * standardized. + * @param hyperParameters The {@link HyperParameters} required for + * normalization. + * @return A double array containing the normalized data. + */ + + public static double[] normalizeData(double[][] data, double[] target, HyperParameters hyperParameters) { + double[] standData; + standData = new double[target.length]; + for (int i = 0; i < data.length; i++) { + standData[i] = standardize(target[i], getMean(data[i]), getStandardDeviation(data[i]), hyperParameters); + } + return standData; + } + + /** + * Standardizes a 1D array of data using Z-score normalization. This method + * standardizes a 1D array of data by applying Z-score normalization. It + * calculates the mean and standard deviation of the input data and then + * standardizes each data point. + * + * @param inputData The 1D array of data to be standardized. + * @param hyperParameters instance of {@link HyperParameters} + * @return A new 1D array containing the standardized (normalized) data. + */ + public static double[] standardize(double[] inputData, HyperParameters hyperParameters) { + double meanCurrent = getMean(inputData); + + double stdDeviationCurrent = getStandardDeviation(inputData); + double meanTarget = hyperParameters.getMean(); + double standerDeviationTarget = hyperParameters.getStanderDeviation(); + + double[] standardizedData = new double[inputData.length]; + for (int i = 0; i < inputData.length; i++) { + standardizedData[i] = meanTarget + + ((inputData[i] - meanCurrent) * (standerDeviationTarget / stdDeviationCurrent)); + } + return standardizedData; + } + + /** + * Standardizes a given input data point using mean and standard deviation. This + * method standardizes the input data point based on the provided mean and + * standard deviation of the current data and the target mean and standard + * deviation specified in the {@link HyperParameters}. + * + * @param inputData The input data point to be standardized. + * @param mean The mean of the current data. + * @param standerdDev The standard deviation of the current data. + * @param hyperParameters The {@link HyperParameters} containing the target mean + * and standard deviation. + * @return The standardized value of the input data point. + */ + public static double standardize(double inputData, double mean, double standerdDev, + HyperParameters hyperParameters) { + + double meanCurrent = mean; + + double stdDeviationCurrent = standerdDev; + double meanTarget = hyperParameters.getMean(); + double standerDeviationTarget = hyperParameters.getStanderDeviation(); + return meanTarget + ((inputData - meanCurrent) * (standerDeviationTarget / stdDeviationCurrent)); + + } + + /** + * Reverse standardizes a data point that was previously standardized using + * Z-score normalization. This method reverses the standardization process for a + * single data point that was previously standardized using Z-score + * normalization. It requires the mean and standard deviation of the original + * data along with the Z-score value (zvalue) to perform the reverse + * standardization. + * + * @param mean The mean of the original data. + * @param standardDeviation The standard deviation of the original data. + * @param zvalue The Z-score value of the standardized data point. + * @param hyperParameters instance of {@link HyperParameters} + * @return The reverse standardized value in the original data's scale. + */ + public static double reverseStandrize(double zvalue, double mean, double standardDeviation, + HyperParameters hyperParameters) { + + double reverseStand = 0; + double meanTarget = hyperParameters.getMean(); + double standardDeviationTarget = hyperParameters.getStanderDeviation(); + + reverseStand = ((zvalue - meanTarget) * (standardDeviation / standardDeviationTarget) + mean); + return reverseStand; + } + + /** + * Reverse standardizes a list of data points based on given mean, standard + * deviation, and {@link HyperParameters}. This method reverse standardizes each + * data point in the input list based on the provided mean, standard deviation, + * and {@link HyperParameters}. It returns a new Array containing the reverse + * standardized values. + * + * @param data The list of data points to be reverse standardized. + * @param mean The list of means corresponding to the data points. + * @param standDeviation The list of standard deviations corresponding to the + * data points. + * @param hyperParameters The {@link HyperParameters} containing the target mean + * and standard deviation. + * @return A new list containing the reverse standardized values. + */ + public static double[] reverseStandrize(ArrayList data, ArrayList mean, + ArrayList standDeviation, HyperParameters hyperParameters) { + double[] revNorm = new double[data.size()]; + for (int i = 0; i < data.size(); i++) { + revNorm[i] = (reverseStandrize(data.get(i), mean.get(i), standDeviation.get(i), hyperParameters)); + } + return revNorm; + } + + /** + * Reverse standardizes a list of data points based on given mean, standard + * deviation, and {@link HyperParameters}. This method reverse standardizes each + * data point in the input list based on the provided mean, standard deviation, + * and {@link HyperParameters}. It returns a new list containing the reverse + * standardized values. + * + * @param data The Array of data points to be reverse standardized. + * @param mean The Array of means corresponding to the data points. + * @param standDeviation The Array of standard deviations corresponding to the + * data points. + * @param hyperParameters The {@link HyperParameters} containing the target mean + * and standard deviation. + * @return A new Array containing the reverse standardized values. + */ + public static double[] reverseStandrize(double[] data, double[] mean, double[] standDeviation, + HyperParameters hyperParameters) { + double[] revNorm = new double[data.length]; + for (int i = 0; i < data.length; i++) { + revNorm[i] = (reverseStandrize(data[i], mean[i], standDeviation[i], hyperParameters)); + } + return revNorm; + } + + /** + * Reverse standardizes a list of data points based on given mean, standard + * deviation, and {@link HyperParameters}. This method reverse standardizes each + * data point in the input list based on the provided mean, standard deviation, + * and {@link HyperParameters}. It returns a new Array containing the reverse + * standardized values. + * + * @param data The Array of data points to be reverse standardized. + * @param mean The mean corresponding to the data points. + * @param standDeviation The standard deviation corresponding to the data + * points. + * @param hyperParameters The {@link HyperParameters} containing the target mean + * and standard deviation. + * @return A new Array containing the reverse standardized values. + */ + public static double[] reverseStandrize(ArrayList data, double mean, double standDeviation, + HyperParameters hyperParameters) { + double[] revNorm = new double[data.size()]; + for (int i = 0; i < data.size(); i++) { + revNorm[i] = (reverseStandrize(data.get(i), mean, standDeviation, hyperParameters)); + } + return revNorm; + } + + /** + * Reverse standardizes a list of data points based on given mean, standard + * deviation, and {@link HyperParameters}. This method reverse standardizes each + * data point in the input list based on the provided mean, standard deviation, + * and {@link HyperParameters}. It returns a new list containing the reverse + * standardized values. + * + * @param data The list of data points to be reverse standardized. + * @param mean The mean corresponding to the data points. + * @param standDeviation The standard deviation corresponding to the data + * points. + * @param hyperParameters The {@link HyperParameters} containing the target mean + * and standard deviation. + * @return A new list containing the reverse standardized values. + */ + public static double[] reverseStandrize(double[] data, double mean, double standDeviation, + HyperParameters hyperParameters) { + double[] revNorm = new double[data.length]; + for (int i = 0; i < data.length; i++) { + revNorm[i] = (reverseStandrize(data[i], mean, standDeviation, hyperParameters)); + } + return revNorm; + } + + /** + * Modifies the given time-series data for long-term prediction by grouping it + * based on hours and minutes. + * + * @param data The {@link ArrayList} of Double values representing the + * time-series data. + * @param date The {@link ArrayList} of OffsetDateTime objects corresponding to + * the timestamps of the data. + * @return An {@link ArrayList} of {@link ArrayList} of {@link ArrayList}, + * representing the modified data grouped by hours and minutes. + */ + + public static ArrayList>> groupDataByHourAndMinute(ArrayList data, + ArrayList date) { + + ArrayList>> dataGroupedByMinute = new ArrayList<>(); + ArrayList>> dateGroupedByMinute = new ArrayList<>(); + + GroupBy groupByHour = new GroupBy(data, date); + groupByHour.hour(); + + for (int i = 0; i < groupByHour.getGroupedDataByHour().size(); i++) { + GroupBy groupByMinute = new GroupBy(groupByHour.getGroupedDataByHour().get(i), + groupByHour.getGroupedDateByHour().get(i)); + + groupByMinute.minute(); + dataGroupedByMinute.add(groupByMinute.getGroupedDataByMinute()); + dateGroupedByMinute.add(groupByMinute.getGroupedDateByMinute()); + } + return dataGroupedByMinute; + } + + /** + * Modify the data for trend term prediction. + * + * @param data The ArrayList of Double values data. + * @param date The ArrayList of Double values date. + * @param hyperParameters The {@link HyperParameters} + * @return The ArrayList of modified values + */ + public static ArrayList> modifyFortrendPrediction(ArrayList data, + ArrayList date, HyperParameters hyperParameters) { + + ArrayList>> firstModification = groupDataByHourAndMinute(data, date); + + // Flatten the structure of the first modification + ArrayList> secondModification = flatten3dto2d(firstModification); + + // Apply windowing to create the third modification + ArrayList> thirdModification = applyWindowing(secondModification, hyperParameters); + + return thirdModification; + } + + private static ArrayList> flatten3dto2d(// + ArrayList>> data) { + return data.stream()// + .flatMap(twoDList -> twoDList.stream())// + .collect(Collectors.toCollection(ArrayList::new)); + } + + /** + * Decreases the dimensionality of a 4D ArrayList to a 3D ArrayList. This method + * flattens the input 4D ArrayList to a 3D ArrayList by merging the innermost + * ArrayLists into one. It returns the resulting 3D ArrayList. + * + * @param model The 4D ArrayList to decrease in dimensionality. + * @return The resulting 3D ArrayList after decreasing the dimensionality. + */ + public static ArrayList>> flattern4dto3d( + ArrayList>>> model) { + + return model.stream()// + .flatMap(threeDList -> threeDList.stream())// + .collect(Collectors.toCollection(ArrayList::new)); + } + + private static ArrayList> applyWindowing(ArrayList> data, + HyperParameters hyperParameters) { + ArrayList> windowedData = new ArrayList<>(); + int windowSize = hyperParameters.getWindowSizeTrend(); + + for (int i = 0; i < data.size(); i++) { + ArrayList> toCombine = new ArrayList<>(); + + for (int j = 0; j <= windowSize; j++) { + int index = (j + i) % data.size(); + toCombine.add(data.get(index)); + } + windowedData.add(combinedArray(toCombine)); + } + return windowedData; + } + + /** + * Flatten the array by combining. + * + * @param values The ArrayList of Double values. + * @return reGroupedsecond Teh Flattened ArrayList + */ + public static ArrayList combinedArray(ArrayList> values) { + int minSize = values.stream()// + .mapToInt(ArrayList::size)// + .min()// + .orElse(0); + + ArrayList reGroupedsecond = new ArrayList<>(); + + for (int i = 0; i < minSize; i++) { + for (ArrayList innerList : values) { + reGroupedsecond.add(innerList.get(i)); + } + } + return reGroupedsecond; + } + + /** + * Splits a list of Double values into multiple batches and returns the batches. + * The method divides the original list into a specified number of groups, + * ensuring that each group has an approximately equal number of elements. It + * handles any remainder by distributing the extra elements among the first few + * groups. + * + * @param originalList The original list of Double values to be split into + * batches. + * @param numberOfGroups The desired number of groups to split the list into. + * @return An ArrayList of ArrayLists, where each inner ArrayList represents a + * batch of Double values. + */ + public static ArrayList> getDataInBatch(ArrayList originalList, int numberOfGroups) { + ArrayList> splitGroups = new ArrayList<>(); + + int originalSize = originalList.size(); + int groupSize = originalSize / numberOfGroups; + int remainder = originalSize % numberOfGroups; + + int currentIndex = 0; + for (int i = 0; i < numberOfGroups; i++) { + int groupCount = groupSize + (i < remainder ? 1 : 0); + ArrayList group = new ArrayList<>(originalList.subList(currentIndex, currentIndex + groupCount)); + splitGroups.add(group); + currentIndex += groupCount; + } + return splitGroups; + } + + /** + * Splits a list of OffsetDateTime into multiple batches and returns the + * batches. The method divides the original list into a specified number of + * groups, ensuring that each group has an approximately equal number of + * elements. It handles any remainder by distributing the extra elements among + * the first few groups. + * + * @param originalList The original list of OffsetDateTime to be split into + * batches. + * @param numberOfGroups The desired number of groups to split the list into. + * @return An ArrayList of ArrayLists, where each inner ArrayList represents a + * batch of OffsetDateTime objects. + */ + public static ArrayList> getDateInBatch(ArrayList originalList, + int numberOfGroups) { + ArrayList> splitGroups = new ArrayList<>(); + + int originalSize = originalList.size(); + int groupSize = originalSize / numberOfGroups; + int remainder = originalSize % numberOfGroups; + + int currentIndex = 0; + for (int i = 0; i < numberOfGroups; i++) { + int groupCount = groupSize + (i < remainder ? 1 : 0); + ArrayList group = new ArrayList<>( + originalList.subList(currentIndex, currentIndex + groupCount)); + splitGroups.add(group); + currentIndex += groupCount; + } + + return splitGroups; + } + + /** + * Removes negative values from the given ArrayList of Doubles by replacing them + * with 0. + * + * @param data The ArrayList of Doubles containing numeric values. + * @return ArrayList<Double> A new ArrayList<Double> with negative + * values replaced by zero. + */ + public static ArrayList removeNegatives(ArrayList data) { + return data.stream()// + // Replace negative values with 0 + .map(value -> value == null || Double.isNaN(value) ? Double.NaN : Math.max(value, 0)) + .collect(Collectors.toCollection(ArrayList::new)); + + } + + /** + * Replaces all negative values in the input array with 0. NaN values in the + * array remain unchanged. + * + * @param data the input array of doubles + * @return a new array with negative values replaced by 0 + */ + public static double[] removeNegatives(double[] data) { + return Arrays.stream(data)// + .map(value -> Double.isNaN(value) ? Double.NaN : Math.max(value, 0))// + .toArray(); + } + + /** + * Scales each element in the input ArrayList by a specified scaling factor. + * + * @param data The ArrayList of Double values to be scaled. + * @param scalingFactor The factor by which each element in the data ArrayList + * will be multiplied. + * @return A new ArrayList containing the scaled values. + */ + public static ArrayList constantScaling(ArrayList data, double scalingFactor) { + return data.stream().map(val -> val * scalingFactor).collect(Collectors.toCollection(ArrayList::new)); + } + + /** + * Scales each element in the input ArrayList by a specified scaling factor. + * + * @param data The Array of Double values to be scaled. + * @param scalingFactor The factor by which each element in the data Array will + * be multiplied. + * @return A new Array containing the scaled values. + */ + public static double[] constantScaling(double[] data, double scalingFactor) { + return Arrays.stream(data).map(val -> val * scalingFactor).toArray(); + } + + /** + * Reshapes a 3D ArrayList into a 4D ArrayList structure. This method takes a + * three-dimensional ArrayList of data and reshapes it into a four-dimensional + * ArrayList structure. The reshaping is performed by dividing the original data + * into blocks of size 4x24. The resulting four-dimensional ArrayList contains + * these blocks. + * + * + * @param dataList The 3D list to be reshaped. + * @param hyperParameters The hyperparameters containing the interval used to + * reshape the list. + * @return A reshaped 4D list. + */ + public static ArrayList>>> reshape( + ArrayList>> dataList, HyperParameters hyperParameters) { + + // Calculate the dimensions for reshaping + int rowsPerDay = 60 / hyperParameters.getInterval() * 24; + int numDays = dataList.size() / rowsPerDay; + + // Initialize the reshaped 4D list + ArrayList>>> reshapedData = new ArrayList<>(); + + int dataIndex = 0; + for (int day = 0; day < numDays; day++) { + ArrayList>> dailyData = new ArrayList<>(); + for (int row = 0; row < rowsPerDay; row++) { + dailyData.add(dataList.get(dataIndex)); + dataIndex++; + } + reshapedData.add(dailyData); + } + + return reshapedData; + } + + /** + * Updates the model with the specified weights based on the given indices and + * model type. This method extracts the optimum weights from the provided 4D + * ArrayList of models using the given indices and model type. It updates the + * hyperparameters with the extracted weights based on the model type. + * + * @param allModel The 4D ArrayList containing all models. + * @param indices The list of indices specifying the location of optimum + * weights in the models. + * @param fileName The name of the file to save the final model. + * @param modelType The type of the model ("trend.txt" or + * "seasonality.txt"). + * @param hyperParameters The hyperparameters to update with the extracted + * weights. + */ + public static void updateModel(ArrayList>>> allModel, // + List> indices, // + String fileName, // + String modelType, // + HyperParameters hyperParameters) { + + ArrayList>> optimumWeights = new ArrayList>>(); + + for (List idx : indices) { + ArrayList> tempWeights = allModel// + .get(idx.get(0))// + .get(idx.get(1)); + optimumWeights.add(tempWeights); + } + + switch (modelType.toLowerCase()) { + case "trend": + hyperParameters.updatModelTrend(optimumWeights); + break; + case "seasonality": + hyperParameters.updateModelSeasonality(optimumWeights); + break; + default: + throw new IllegalArgumentException("Invalid model type: " + modelType); + } + } + + /** + * Performs element-wise multiplication of two arrays. + * + * @param featureA the first array + * @param featureB the second array + * @return a new array where each element is the product of the corresponding + * elements of featureA and featureB + * @throws IllegalArgumentException if the input arrays are of different lengths + */ + public static double[] elementWiseMultiplication(double[] featureA, double[] featureB) { + if (featureA.length != featureB.length) { + throw new IllegalArgumentException("The input arrays must have the same length."); + } + return IntStream.range(0, featureA.length)// + .mapToDouble(i -> featureA[i] * featureB[i])// + .toArray(); + } + + /** + * Performs element-wise multiplication of two ArrayLists. + * + * @param featureA the first ArrayList + * @param featureB the second ArrayList + * @return a new ArrayList where each element is the result of multiplying the + * corresponding elements of featureA and featureB + * @throws IllegalArgumentException if the input ArrayLists are of different + * lengths + */ + public static ArrayList elementWiseMultiplication(ArrayList featureA, ArrayList featureB) { + if (featureA.size() != featureB.size()) { + throw new IllegalArgumentException("The input ArrayLists must have the same length."); + } + ArrayList result = new ArrayList<>(); + IntStream.range(0, featureA.size()).forEach(i -> result.add(featureA.get(i) * featureB.get(i))); + return result; + } + + /** + * Performs element-wise division of two ArrayLists. If an element in featureB + * is zero, the corresponding element in the result will be zero. + * + * @param featureA the first ArrayList + * @param featureB the second ArrayList + * @return a new ArrayList where each element is the result of dividing the + * corresponding elements of featureA by featureB or zero if the element + * in featureB is zero + * @throws IllegalArgumentException if the input ArrayLists are of different + * lengths + */ + public static ArrayList elementWiseDiv(ArrayList featureA, ArrayList featureB) { + if (featureA.size() != featureB.size()) { + throw new IllegalArgumentException("The input ArrayLists must have the same length."); + } + ArrayList result = new ArrayList<>(); + IntStream.range(0, featureA.size()) + .forEach(i -> result.add((featureB.get(i) == 0) ? featureA.get(i) : featureA.get(i) / featureB.get(i))); + return result; + } + + /** + * Performs element-wise division of two arrays. If an element in featureB is + * zero, the corresponding element in the result will be zero. + * + * @param featureA the first array + * @param featureB the second array + * @return a new array where each element is the result of dividing the + * corresponding elements of featureA by featureB or zero if the element + * in featureB is zero + * @throws IllegalArgumentException if the input arrays are of different lengths + */ + public static double[] elementWiseDiv(double[] featureA, double[] featureB) { + if (featureA.length != featureB.length) { + throw new IllegalArgumentException("The input arrays must have the same length."); + } + return IntStream.range(0, featureA.length)// + .mapToDouble(i -> (featureB[i] == 0) ? featureA[i] : featureA[i] / featureB[i])// + .toArray(); + } + +} diff --git a/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/preprocessing/Differencing.java b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/preprocessing/Differencing.java new file mode 100644 index 00000000000..9e4a61cd611 --- /dev/null +++ b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/preprocessing/Differencing.java @@ -0,0 +1,55 @@ +package io.openems.edge.predictor.lstmmodel.preprocessing; + +import java.util.stream.IntStream; + +public class Differencing { + + /** + * First order Differencing. + * + * @param data data for Differencing + * @return the first order Differencing + */ + public static double[] firstOrderDifferencing(double[] data) { + if (data.length < 2) { + throw new IllegalArgumentException("Data array must contain at least two elements."); + } + + return IntStream.range(0, data.length - 1)// + .mapToDouble(i -> data[i] - data[i + 1])// + .toArray(); + } + + /** + * first Order Accumulating. + * + * @param data data for Differencing + * @param init data for init + * @return the first order Differencing + */ + public static double[] firstOrderAccumulating(double[] data, double init) { + if (data.length == 0) { + throw new IllegalArgumentException("Data array must not be empty."); + } + + double[] accumulating = new double[data.length]; + + accumulating[0] = data[0] + init; + + IntStream.range(1, data.length)// + .forEach(i -> accumulating[i] = accumulating[i - 1] + data[i]); + + return accumulating; + } + + /** + * first Order Accumulating. + * + * @param data data for Differencing + * @param init data for init + * @return the first order Differencing + */ + public static double firstOrderAccumulating(double data, double init) { + return data + init; + } +} diff --git a/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/preprocessing/FilterOutliers.java b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/preprocessing/FilterOutliers.java new file mode 100644 index 00000000000..192c2623df9 --- /dev/null +++ b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/preprocessing/FilterOutliers.java @@ -0,0 +1,97 @@ +package io.openems.edge.predictor.lstmmodel.preprocessing; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.stream.IntStream; + +import org.apache.commons.math3.stat.descriptive.rank.Percentile; + +import io.openems.edge.predictor.lstmmodel.utilities.MathUtils; + +public class FilterOutliers { + + /** + * Filters out outliers from the dataset until no outliers are detected. + * + * @param data the input dataset + * @return the filtered dataset with outliers removed + */ + public static double[] filterOutlier(double[] data) { + if (data == null || data.length == 0) { + throw new IllegalArgumentException("Input data must not be null or empty."); + } + + double[] filteredData = Arrays.copyOf(data, data.length); + int iterationCount = 0; + boolean hasOutliers = true; + + while (hasOutliers && iterationCount <= 100) { + var outlierIndices = detect(filteredData); + + if (outlierIndices.isEmpty()) { + hasOutliers = false; + } else { + filteredData = filter(filteredData, outlierIndices); + } + + iterationCount++; + } + + return filteredData; + } + + /** + * Applies the hyperbolic tangent function to data points at the specified + * indices. + * + * @param data the input dataset + * @param indices the indices of data points to be transformed + * @return the transformed dataset + */ + public static double[] filter(double[] data, ArrayList indices) { + + if (data == null || indices == null) { + throw new IllegalArgumentException("Input data and indices must not be null."); + } + + if (indices.isEmpty()) { + return data; + } + + double[] result = data.clone(); + for (int index : indices) { + if (index >= 0 && index < result.length) { + result[index] = MathUtils.tanh(result[index]); + } else { + throw new IllegalArgumentException("Index out of bounds: " + index); + } + } + return result; + } + + /** + * Detects outliers in the dataset using the interquartile range (IQR) method. + * + * @param data the input dataset + * @return a list of indices of the detected outliers + */ + public static ArrayList detect(double[] data) { + + if (data == null || data.length == 0) { + throw new IllegalArgumentException("Input data must not be null or empty."); + } + + Percentile perc = new Percentile(); + var q1 = perc.evaluate(data, 25);// 25th percentile (Q1) (First percentile) + var q3 = perc.evaluate(data, 75);// 75th percentile (Q3) (Third percentile) + var iqr = q3 - q1; + var upperLimit = q3 + 1.5 * iqr; + var lowerLimit = q1 - 1.5 * iqr; + + // Detect outliers + return IntStream.range(0, data.length)// + .filter(i -> data[i] < lowerLimit || data[i] > upperLimit) + .collect(ArrayList::new, ArrayList::add, ArrayList::addAll); + } + +} diff --git a/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/preprocessing/GroupBy.java b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/preprocessing/GroupBy.java new file mode 100644 index 00000000000..c6cc5de3ae6 --- /dev/null +++ b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/preprocessing/GroupBy.java @@ -0,0 +1,98 @@ +package io.openems.edge.predictor.lstmmodel.preprocessing; + +import java.time.OffsetDateTime; +import java.time.temporal.ChronoField; +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +public class GroupBy { + private final ArrayList data; + private final ArrayList date; + + private final ArrayList> groupedDateByMin = new ArrayList<>(); + private final ArrayList> groupedDataByMin = new ArrayList<>(); + private final ArrayList> groupedDateByHour = new ArrayList<>(); + private final ArrayList> groupedDataByHour = new ArrayList<>(); + + /** + * Group by Temporal filed. + * + * @param chronoField {@link ChronoField} + * @param groupedDateList The list of groupedDateList. + * @param groupedDataList The list of groupedDataList. + */ + public void groupByTemporalField(ChronoField chronoField, List> groupedDateList, + List> groupedDataList) { + + List uniqueList = this.extractUniqueAndSortedValues(chronoField); + + for (Integer uniqueValue : uniqueList) { + List groupedDateTemp = this.groupDatesByUniqueValue(uniqueValue, chronoField); + List groupedDataTemp = this.groupDataByUniqueValue(uniqueValue, chronoField); + + groupedDateList.add(new ArrayList<>(groupedDateTemp)); + groupedDataList.add(new ArrayList<>(groupedDataTemp)); + } + } + + private List extractUniqueAndSortedValues(ChronoField chronoField) { + return this.date.stream()// + .map(date -> date.get(chronoField))// + .distinct()// + .sorted()// + .collect(Collectors.toList()); + } + + private List groupDatesByUniqueValue(Integer uniqueValue, ChronoField chronoField) { + return this.date.stream()// + .filter(date -> uniqueValue.equals(date.get(chronoField)))// + .collect(Collectors.toList()); + } + + private List groupDataByUniqueValue(Integer uniqueValue, ChronoField chronoField) { + return IntStream.range(0, this.data.size())// + .filter(i -> { + double dateValue = this.date.get(i).get(chronoField); + return Double.compare(dateValue, uniqueValue.doubleValue()) == 0; + })// + .mapToObj(i -> this.data.get(i))// + .collect(Collectors.toList()); + } + + /** + * grouping by hour. + */ + public void hour() { + this.groupByTemporalField(ChronoField.HOUR_OF_DAY, this.groupedDateByHour, this.groupedDataByHour); + } + + /** + * grouping by minute. + */ + public void minute() { + this.groupByTemporalField(ChronoField.MINUTE_OF_HOUR, this.groupedDateByMin, this.groupedDataByMin); + } + + public ArrayList> getGroupedDataByHour() { + return this.groupedDataByHour; + } + + public ArrayList> getGroupedDateByHour() { + return this.groupedDateByHour; + } + + public ArrayList> getGroupedDataByMinute() { + return this.groupedDataByMin; + } + + public ArrayList> getGroupedDateByMinute() { + return this.groupedDateByMin; + } + + public GroupBy(ArrayList data, List date) { + this.data = new ArrayList<>(data); + this.date = new ArrayList<>(date); + } +} \ No newline at end of file diff --git a/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/preprocessing/MovingAverage.java b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/preprocessing/MovingAverage.java new file mode 100644 index 00000000000..2b6e4bea70a --- /dev/null +++ b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/preprocessing/MovingAverage.java @@ -0,0 +1,30 @@ +package io.openems.edge.predictor.lstmmodel.preprocessing; + +public class MovingAverage { + + public static final int WINDOW_SIZE = 3; + + /** + * Compute the Moving average for the data array. + * + * @param data the data for calculating the Moving average + * @return the moving average + */ + public static double[] movingAverage(double[] data) { + + double[] paddedInputData = new double[data.length + WINDOW_SIZE - 1]; + System.arraycopy(data, 0, paddedInputData, WINDOW_SIZE / 2, data.length); + + double[] movingAverages = new double[data.length]; + + for (int i = 0; i < data.length; i++) { + double sum = 0; + for (int j = 0; j < WINDOW_SIZE; j++) { + sum += paddedInputData[i + j]; + } + movingAverages[i] = sum / WINDOW_SIZE; + } + + return movingAverages; + } +} diff --git a/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/preprocessing/Shuffle.java b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/preprocessing/Shuffle.java new file mode 100644 index 00000000000..4b19f331df3 --- /dev/null +++ b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/preprocessing/Shuffle.java @@ -0,0 +1,69 @@ +package io.openems.edge.predictor.lstmmodel.preprocessing; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Random; +import java.util.concurrent.CompletableFuture; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +public class Shuffle { + + private double[][] data; + private double[] target; + + public Shuffle(double[][] data, double[] target) { + this.data = this.copy2DArray(data); + this.target = Arrays.copyOf(target, target.length); + this.shuffleIt(); + } + + /** + * Shuffles the data and target arrays to randomize the order of elements. This + * method shuffles the data and target arrays simultaneously, ensuring that the + * corresponding data and target values remain aligned. + */ + public void shuffleIt() { + List indices = IntStream.range(0, this.data.length)// + .boxed()// + .collect(Collectors.toList()); + + Collections.shuffle(indices, new Random(100)); + + CompletableFuture dataFuture = CompletableFuture + .runAsync(() -> this.shuffleData(new ArrayList<>(indices))); + CompletableFuture targetFuture = CompletableFuture + .runAsync(() -> this.shuffleTarget(new ArrayList<>(indices))); + + CompletableFuture combinedFuture = CompletableFuture.allOf(dataFuture, targetFuture); + combinedFuture.join(); + } + + private void shuffleData(List indices) { + this.data = indices.stream()// + .map(i -> Arrays.copyOf(this.data[i], this.data[i].length))// + .toArray(double[][]::new); + } + + private void shuffleTarget(List indices) { + this.target = indices.stream()// + .mapToDouble(i -> this.target[i])// + .toArray(); + } + + public double[] getTarget() { + return this.target; + } + + public double[][] getData() { + return this.data; + } + + private double[][] copy2DArray(double[][] array) { + return Arrays.stream(array)// + .map(row -> Arrays.copyOf(row, row.length))// + .toArray(double[][]::new); + } +} diff --git a/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/preprocessingpipeline/ConstantScalingPipe.java b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/preprocessingpipeline/ConstantScalingPipe.java new file mode 100644 index 00000000000..5e3789f1588 --- /dev/null +++ b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/preprocessingpipeline/ConstantScalingPipe.java @@ -0,0 +1,19 @@ +package io.openems.edge.predictor.lstmmodel.preprocessingpipeline; + +import static io.openems.edge.predictor.lstmmodel.preprocessing.DataModification.constantScaling; + +public class ConstantScalingPipe implements Stage { + + private double scalingFactor; + + public ConstantScalingPipe(double factor) { + this.scalingFactor = factor; + } + + @Override + public Object execute(Object input) { + return (input instanceof double[] in)// + ? constantScaling(in, this.scalingFactor)// + : new IllegalArgumentException("Input must be an instance of double[]"); + } +} \ No newline at end of file diff --git a/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/preprocessingpipeline/DifferencingPipe.java b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/preprocessingpipeline/DifferencingPipe.java new file mode 100644 index 00000000000..f14a0237b38 --- /dev/null +++ b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/preprocessingpipeline/DifferencingPipe.java @@ -0,0 +1,13 @@ +package io.openems.edge.predictor.lstmmodel.preprocessingpipeline; + +import io.openems.edge.predictor.lstmmodel.preprocessing.Differencing; + +public class DifferencingPipe implements Stage { + + @Override + public Object execute(Object input) { + return (input instanceof double[] in)// + ? Differencing.firstOrderDifferencing(in)// + : new IllegalArgumentException("Input must be an instance of double[]"); + } +} diff --git a/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/preprocessingpipeline/FilterOutliersPipe.java b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/preprocessingpipeline/FilterOutliersPipe.java new file mode 100644 index 00000000000..0f38db656cf --- /dev/null +++ b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/preprocessingpipeline/FilterOutliersPipe.java @@ -0,0 +1,13 @@ +package io.openems.edge.predictor.lstmmodel.preprocessingpipeline; + +import io.openems.edge.predictor.lstmmodel.preprocessing.FilterOutliers; + +public class FilterOutliersPipe implements Stage { + + @Override + public Object execute(Object input) { + return (input instanceof double[] in)// + ? FilterOutliers.filterOutlier(in) + : new IllegalArgumentException("Input must be an instance of double[]"); + } +} diff --git a/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/preprocessingpipeline/GroupToStiffWindowPipe.java b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/preprocessingpipeline/GroupToStiffWindowPipe.java new file mode 100644 index 00000000000..11c1c202055 --- /dev/null +++ b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/preprocessingpipeline/GroupToStiffWindowPipe.java @@ -0,0 +1,118 @@ +package io.openems.edge.predictor.lstmmodel.preprocessingpipeline; + +import static io.openems.edge.predictor.lstmmodel.utilities.UtilityConversion.to1DArray; +import static io.openems.edge.predictor.lstmmodel.utilities.UtilityConversion.to1DArrayList; +import static io.openems.edge.predictor.lstmmodel.utilities.UtilityConversion.to2DArray; + +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +public class GroupToStiffWindowPipe implements Stage { + private int window; + + public GroupToStiffWindowPipe(int windowSize) { + this.window = windowSize; + + } + + @Override + public Object execute(Object input) { + + if (input instanceof double[] inputData) { + + var inputDataList = to1DArrayList(inputData); + + var resultArray = new double[2][][]; + var stiffedTargetGroup = new double[1][]; + + stiffedTargetGroup[0] = groupToStiffedTarget(inputDataList, this.window); + var stiffedWindowGroup = groupToStiffedWindow(inputDataList, this.window); + + resultArray[0] = stiffedWindowGroup; + resultArray[1] = stiffedTargetGroup; + + return resultArray; + } else { + throw new IllegalArgumentException("Input must be an instance of double[]"); + } + } + + /** + * Groups the values in the input ArrayList into windows of a specified size and + * converts the grouped data into a 2D array representing the stiffed windowed + * structure. + * + * @param values The input ArrayList of Double values to be grouped into + * windows. + * @param windowSize The size of each window for grouping the values. + * @return A 2D array representing the stiffed windowed structure of the grouped + * values. + */ + public static double[][] groupToStiffedWindow(ArrayList values, int windowSize) { + if (windowSize < 1 || windowSize > values.size()) { + throw new IllegalArgumentException("Invalid window size"); + } + + List indices = IntStream.range(0, values.size() - windowSize + 1) // + .filter(i -> i % (windowSize + 1) == 0) // + .boxed() // + .collect(Collectors.toList()); // + + List> windowedData = indices.stream() // + .map(i -> values.subList(i, i + windowSize)) // + .map(ArrayList::new) // + .collect(Collectors.toList()); // + + return to2DArray(windowedData); + } + + /** + * Groups the values in the input ArrayList into a stiffed target structure, + * extracting every nth element to form windows of a specified size and converts + * the grouped data into a 1D array. + * + * @param val The input ArrayList of Double values from which the stiffed + * target structure is created. + * @param windowSize The size of each window, representing the step size for + * selecting elements. + * @return A 1D array representing the stiffed target structure of the grouped + * values. + */ + public static double[] groupToStiffedTarget(ArrayList val, int windowSize) { + if (windowSize < 1 || windowSize > val.size()) { + throw new IllegalArgumentException("Invalid window size"); + } + + List windowedData = IntStream.range(0, val.size())// + .filter(j -> j % (windowSize + 1) == windowSize)// + .mapToObj(val::get)// + .collect(Collectors.toList()); + + return to1DArray(windowedData); + } + + /** + * Groups the values in the input Array into a stiffed target structure, + * extracting every nth element to form windows of a specified size and converts + * the grouped data into a 1D array. + * + * @param val The input Array of Double values from which the stiffed + * target structure is created. + * @param windowSize The size of each window, representing the step size for + * selecting elements. + * @return A 1D array representing the stiffed target structure of the grouped + * values. + */ + public static double[] groupToStiffedTarget(double[] val, int windowSize) { + if (windowSize < 1 || windowSize > val.length) { + throw new IllegalArgumentException("Invalid window size"); + } + + return IntStream.range(0, val.length)// + .filter(j -> j % (windowSize + 1) == windowSize)// + .mapToDouble(j -> val[j]).toArray(); + } + +} diff --git a/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/preprocessingpipeline/GroupbyPipe.java b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/preprocessingpipeline/GroupbyPipe.java new file mode 100644 index 00000000000..442b0142789 --- /dev/null +++ b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/preprocessingpipeline/GroupbyPipe.java @@ -0,0 +1,29 @@ +package io.openems.edge.predictor.lstmmodel.preprocessingpipeline; + +import static io.openems.edge.predictor.lstmmodel.preprocessing.DataModification.groupDataByHourAndMinute; +import static io.openems.edge.predictor.lstmmodel.utilities.UtilityConversion.to1DArrayList; +import static io.openems.edge.predictor.lstmmodel.utilities.UtilityConversion.to3DArray; + +import java.time.OffsetDateTime; +import java.util.ArrayList; + +import io.openems.edge.predictor.lstmmodel.common.HyperParameters; + +public class GroupbyPipe implements Stage { + private ArrayList dates; + + public GroupbyPipe(HyperParameters hype, ArrayList date) { + this.dates = date; + } + + @Override + public Object execute(Object input) { + if (input instanceof double[] in) { + var inList = to1DArrayList(in); + var groupedByHourAndMinuteList = groupDataByHourAndMinute(inList, this.dates); + return to3DArray(groupedByHourAndMinuteList); + } else { + throw new IllegalArgumentException("Input must be an instance of double[]"); + } + } +} \ No newline at end of file diff --git a/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/preprocessingpipeline/GrouptoWindowpipe.java b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/preprocessingpipeline/GrouptoWindowpipe.java new file mode 100644 index 00000000000..ef86032f2ad --- /dev/null +++ b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/preprocessingpipeline/GrouptoWindowpipe.java @@ -0,0 +1,75 @@ +package io.openems.edge.predictor.lstmmodel.preprocessingpipeline; + +import static io.openems.edge.predictor.lstmmodel.utilities.SlidingWindowSpliterator.windowed; + +import java.util.ArrayList; +import java.util.List; +import java.util.function.Function; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import io.openems.edge.predictor.lstmmodel.utilities.UtilityConversion; + +public class GrouptoWindowpipe implements Stage { + + public static final Function>, double[][]> twoDListToTwoDArray = UtilityConversion::to2DArray; + + private int window; + + public GrouptoWindowpipe(int windowSize) { + this.window = windowSize; + } + + @Override + public Object execute(Object input) { + if (input instanceof double[] inputData) { + try { + double[] windowedTarget = this.getTargetData(inputData); + double[][] windowedData = this.getWindowDataTrain(inputData); + + return new double[][][] { windowedData, new double[][] { windowedTarget } }; + } catch (Exception e) { + throw new RuntimeException("Error processing input data", e); + } + } else { + throw new IllegalArgumentException("Input must be an instance of double[]"); + } + } + + private double[][] getWindowDataTrain(double[] data) { + var lower = 0; + var upper = data.length - 1; + + var subList = IntStream.range(lower, upper)// + .mapToObj(index -> data[index]) // + .collect(Collectors.toCollection(ArrayList::new)); + + var res = windowed(subList, this.window) // + .map(s -> s.collect(Collectors.toList())) // + .collect(Collectors.toList()); + + return twoDListToTwoDArray.apply(res); + } + + /** + * Retrieves the target data from a list of scaled data points. + * + * @param data The list containing scaled data points. + * @return An array containing the target data. + * @throws Exception If the provided list of scaled data is empty. + */ + public double[] getTargetData(double[] data) throws Exception { + + if (data.length == 0) { + throw new Exception("Scaled data is empty"); + } + var lower = 0; + var upper = data.length - 1; + + var subArr = IntStream.range(lower + this.window, upper + 1) // + .mapToDouble(index -> data[index]) // + .toArray(); + + return subArr; + } +} diff --git a/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/preprocessingpipeline/InterpolationPipe.java b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/preprocessingpipeline/InterpolationPipe.java new file mode 100644 index 00000000000..4cd6e3ef11a --- /dev/null +++ b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/preprocessingpipeline/InterpolationPipe.java @@ -0,0 +1,29 @@ +package io.openems.edge.predictor.lstmmodel.preprocessingpipeline; + +import java.time.OffsetDateTime; +import java.util.ArrayList; + +import io.openems.edge.predictor.lstmmodel.common.HyperParameters; +import io.openems.edge.predictor.lstmmodel.interpolation.InterpolationManager; +import static io.openems.edge.predictor.lstmmodel.utilities.UtilityConversion.to1DArrayList; +import static io.openems.edge.predictor.lstmmodel.utilities.UtilityConversion.to1DArray; + +public class InterpolationPipe implements Stage { + private HyperParameters hyperParameters; + + public InterpolationPipe(HyperParameters hype, ArrayList dates) { + this.hyperParameters = hype; + } + + @Override + public Object execute(Object input) { + + if (input instanceof double[] in) { + var inList = to1DArrayList(in); + var inter = new InterpolationManager(inList, this.hyperParameters); + return to1DArray(inter.getInterpolatedData()); + } else { + throw new IllegalArgumentException("Input must be an instance of double[]"); + } + } +} diff --git a/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/preprocessingpipeline/ModifyDataForTrend.java b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/preprocessingpipeline/ModifyDataForTrend.java new file mode 100644 index 00000000000..a3cd956eefd --- /dev/null +++ b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/preprocessingpipeline/ModifyDataForTrend.java @@ -0,0 +1,41 @@ +package io.openems.edge.predictor.lstmmodel.preprocessingpipeline; + +import java.time.OffsetDateTime; +import java.util.ArrayList; + +import io.openems.edge.predictor.lstmmodel.common.HyperParameters; +import static io.openems.edge.predictor.lstmmodel.preprocessing.DataModification.modifyFortrendPrediction; +import static io.openems.edge.predictor.lstmmodel.utilities.UtilityConversion.to1DArrayList; +import static io.openems.edge.predictor.lstmmodel.utilities.UtilityConversion.to2DArray; + +public class ModifyDataForTrend implements Stage { + + private HyperParameters hyperparameters; + + private ArrayList dates; + + public ModifyDataForTrend(ArrayList date, HyperParameters hype) { + this.dates = date; + this.hyperparameters = hype; + } + + @Override + public Object execute(Object input) { + + if (input instanceof double[] inputData) { + try { + var inList = to1DArrayList(inputData); + var modified = modifyFortrendPrediction(inList, this.dates, this.hyperparameters); + return to2DArray(modified); + } catch (Exception e) { + throw new RuntimeException("Error processing input data", e); + } + } else { + throw new IllegalArgumentException("Input must be an instance of double[]"); + } + } + + public void setDates(ArrayList date) { + this.dates = date; + } +} diff --git a/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/preprocessingpipeline/MovingAveragePipe.java b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/preprocessingpipeline/MovingAveragePipe.java new file mode 100644 index 00000000000..e3576a4d810 --- /dev/null +++ b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/preprocessingpipeline/MovingAveragePipe.java @@ -0,0 +1,13 @@ +package io.openems.edge.predictor.lstmmodel.preprocessingpipeline; + +import io.openems.edge.predictor.lstmmodel.preprocessing.MovingAverage; + +public class MovingAveragePipe implements Stage { + + @Override + public Object execute(Object input) { + return (input instanceof double[] in) // + ? MovingAverage.movingAverage(in) // + : new IllegalArgumentException("Input must be an instance of double[]"); + } +} diff --git a/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/preprocessingpipeline/NormalizePipe.java b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/preprocessingpipeline/NormalizePipe.java new file mode 100644 index 00000000000..1ae4c848dd8 --- /dev/null +++ b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/preprocessingpipeline/NormalizePipe.java @@ -0,0 +1,40 @@ +package io.openems.edge.predictor.lstmmodel.preprocessingpipeline; + +import static io.openems.edge.predictor.lstmmodel.preprocessing.DataModification.normalizeData; +import static io.openems.edge.predictor.lstmmodel.preprocessing.DataModification.standardize; + +import io.openems.edge.predictor.lstmmodel.common.HyperParameters; + +public class NormalizePipe implements Stage { + + private HyperParameters hyperParameters; + + public NormalizePipe(HyperParameters hyper) { + this.hyperParameters = hyper; + } + + @Override + public Object execute(Object input) { + try { + if (input instanceof double[][][] inputArray) { + + double[][] trainData = inputArray[0]; + double[] targetData = inputArray[1][0]; + + double[][] normalizedTrainData = normalizeData(trainData, this.hyperParameters); + double[] normalizedTargetData = normalizeData(trainData, targetData, this.hyperParameters); + + return new double[][][] { normalizedTrainData, { normalizedTargetData } }; + + } else if (input instanceof double[][] inputArray) { + return normalizeData(inputArray, this.hyperParameters); + } else if (input instanceof double[] inputArray) { + return standardize(inputArray, this.hyperParameters); + } else { + throw new IllegalArgumentException("Illegal Argument encountered during normalization"); + } + } catch (Exception e) { + throw new RuntimeException("Illegal Argument encountered during normalization"); + } + } +} diff --git a/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/preprocessingpipeline/PiplineInterface.java b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/preprocessingpipeline/PiplineInterface.java new file mode 100644 index 00000000000..5fe1aa5fcfd --- /dev/null +++ b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/preprocessingpipeline/PiplineInterface.java @@ -0,0 +1,23 @@ +package io.openems.edge.predictor.lstmmodel.preprocessingpipeline; + +/** + * Represents a pipeline that processes data through a series of stages. + * + * @param The type of the output produced by the pipeline. + * @param The type of the input consumed by the pipeline. + */ +public interface PiplineInterface { + /** + * Adds a stage to the pipeline. + * + * @param stage The stage to be added to the pipeline. + */ + void add(Stage stage); + + /** + * Executes the pipeline, processing the data through the added stages. + * + * @return The result of executing the pipeline. + */ + Object execute(); +} diff --git a/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/preprocessingpipeline/PreprocessingPipeImpl.java b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/preprocessingpipeline/PreprocessingPipeImpl.java new file mode 100644 index 00000000000..6e4a5ee1458 --- /dev/null +++ b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/preprocessingpipeline/PreprocessingPipeImpl.java @@ -0,0 +1,230 @@ +package io.openems.edge.predictor.lstmmodel.preprocessingpipeline; + +import java.time.OffsetDateTime; +import java.util.ArrayList; +import java.util.List; + +import io.openems.edge.predictor.lstmmodel.common.HyperParameters; + +public class PreprocessingPipeImpl implements PiplineInterface { + + private Object inputData; + private Object outputData; + private Object mean; + private Object standardDeviation; + private HyperParameters hyperParameter; + private double scalingFactor = 0.001; + private ArrayList dates; + private List> stages = new ArrayList<>(); + + public PreprocessingPipeImpl(ArrayList data, ArrayList dates, + HyperParameters hyperParameters) { + this.inputData = data; + this.dates = dates; + this.hyperParameter = hyperParameters; + } + + public PreprocessingPipeImpl(HyperParameters hyperParameters) { + this.hyperParameter = hyperParameters; + } + + private PreprocessingPipeImpl addStage(Stage stage) { + this.stages.add(stage); + return this; + } + + /** + * Add moving average stage. + * + * @return this + */ + public PreprocessingPipeImpl movingAverage() { + return this.addStage(new MovingAveragePipe()); + } + + /** + * Add scale stage. + * + * @return this + */ + public PreprocessingPipeImpl scale() { + return this.addStage(new ScalingPipe(this.hyperParameter)); + } + + /** + * Add constant scale stage. + * + * @return this + */ + public PreprocessingPipeImpl constantscale() { + return this.addStage(new ConstantScalingPipe(this.scalingFactor)); + } + + /** + * Add trainTestSplit stage. + * + * @return this + */ + public PreprocessingPipeImpl trainTestSplit() { + return this.addStage(new TrainandTestSplitPipe(this.hyperParameter)); + } + + /** + * Add filterOutliers stage. + * + * @return this + */ + public PreprocessingPipeImpl filterOutliers() { + return this.addStage(new FilterOutliersPipe()); + } + + /** + * Add normalize stage. + * + * @return this + */ + public PreprocessingPipeImpl normalize() { + return this.addStage(new NormalizePipe(this.hyperParameter)); + } + + /** + * Add groupToWIndowTrend stage. + * + * @return this + */ + public PreprocessingPipeImpl groupToWIndowTrend() { + return this.addStage(new GrouptoWindowpipe(this.hyperParameter.getWindowSizeTrend())); + } + + /** + * Add groupToWIndowSeasonality stage. + * + * @return this + */ + public PreprocessingPipeImpl groupToWIndowSeasonality() { + return this.addStage(new GrouptoWindowpipe(this.hyperParameter.getWindowSizeSeasonality())); + } + + /** + * Add shuffle stage. + * + * @return this + */ + public PreprocessingPipeImpl shuffle() { + return this.addStage(new ShufflePipe()); + } + + /** + * Add groupByHoursAndMinutes stage. + * + * @return this + */ + public PreprocessingPipeImpl groupByHoursAndMinutes() { + return this.addStage(new GroupbyPipe(this.hyperParameter, this.dates)); + } + + /** + * Add Remove Negatives. + * + * @return this + */ + + public PreprocessingPipeImpl removeNegatives() { + return this.addStage(new RemoveNegativesPipe()); + } + + /** + * Add groupToStiffedWindow stage. + * + * @return this + */ + public PreprocessingPipeImpl groupToStiffedWindow() { + return this.addStage(new GroupToStiffWindowPipe(this.hyperParameter.getWindowSizeTrend())); + } + + /** + * Add interpolate stage. + * + * @return this + */ + public PreprocessingPipeImpl interpolate() { + return this.addStage(new InterpolationPipe(this.hyperParameter, this.dates)); + } + + /** + * Add modifyForShortTermPrediction stage. + * + * @return this + */ + public PreprocessingPipeImpl modifyForTrendPrediction() { + return this.addStage(new ModifyDataForTrend(this.dates, this.hyperParameter)); + } + + /** + * Add differencing stage. + * + * @return this + */ + public PreprocessingPipeImpl differencing() { + return this.addStage(new DifferencingPipe()); + } + + /** + * Add reverseScale stage. + * + * @return this + */ + public PreprocessingPipeImpl reverseScale() { + return this.addStage(new ReverseScalingPipe(this.hyperParameter)); + } + + /** + * Add reverseNormalize stage. + * + * @return this + */ + public PreprocessingPipeImpl reverseNormalize() { + return this.addStage(new ReverseNormalizationPipe(this.mean, this.standardDeviation, this.hyperParameter)); + } + + @Override + public Object execute() { + Object preprocessingInput = this.inputData; + for (Stage i : this.stages) { + preprocessingInput = i.execute(preprocessingInput); + } + this.outputData = preprocessingInput; + this.stages = new ArrayList>(); + return this.outputData; + } + + @Override + public void add(Stage stage) { + this.stages.add(stage); + } + + public PreprocessingPipeImpl setData(Object val) { + this.inputData = val; + return this; + } + + public PreprocessingPipeImpl setMean(Object val) { + this.mean = val; + return this; + } + + public PreprocessingPipeImpl setStandardDeviation(Object val) { + this.standardDeviation = val; + return this; + } + + public PreprocessingPipeImpl setDates(ArrayList date) { + this.dates = date; + return this; + } + + public PreprocessingPipeImpl setScalingFactor(double val) { + this.scalingFactor = val; + return this; + } +} diff --git a/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/preprocessingpipeline/RemoveNegativesPipe.java b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/preprocessingpipeline/RemoveNegativesPipe.java new file mode 100644 index 00000000000..38b6bdd6f02 --- /dev/null +++ b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/preprocessingpipeline/RemoveNegativesPipe.java @@ -0,0 +1,13 @@ +package io.openems.edge.predictor.lstmmodel.preprocessingpipeline; + +import static io.openems.edge.predictor.lstmmodel.preprocessing.DataModification.removeNegatives; + +public class RemoveNegativesPipe implements Stage { + + @Override + public Object execute(Object input) { + return (input instanceof double[] in) // + ? removeNegatives(in) // + : new IllegalArgumentException("Input must be an instance of double[]"); + } +} \ No newline at end of file diff --git a/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/preprocessingpipeline/ReverseNormalizationPipe.java b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/preprocessingpipeline/ReverseNormalizationPipe.java new file mode 100644 index 00000000000..3702354f6b2 --- /dev/null +++ b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/preprocessingpipeline/ReverseNormalizationPipe.java @@ -0,0 +1,42 @@ +package io.openems.edge.predictor.lstmmodel.preprocessingpipeline; + +import io.openems.edge.predictor.lstmmodel.common.HyperParameters; +import io.openems.edge.predictor.lstmmodel.preprocessing.DataModification; + +public class ReverseNormalizationPipe implements Stage { + private Object mean; + private Object standerDeviation; + private HyperParameters hyperParameters; + + public ReverseNormalizationPipe(Object average, Object std, HyperParameters hyp) { + this.mean = average; + this.standerDeviation = std; + this.hyperParameters = hyp; + } + + @Override + public Object execute(Object input) { + + try { + if (input instanceof double[] inputArray) { + if (this.mean instanceof double[] meanArray // + && this.standerDeviation instanceof double[] sdArray) { + return DataModification.reverseStandrize(inputArray, meanArray, sdArray, this.hyperParameters); + } else if (this.mean instanceof Double meanValue // + && this.standerDeviation instanceof Double sdValue) { + return DataModification.reverseStandrize(inputArray, meanValue, sdValue, this.hyperParameters); + } else { + throw new IllegalArgumentException("Input must be an instance of double[]"); + } + } else if (input instanceof Double inputArray) { + double mean = (double) this.mean; + double std = (double) this.standerDeviation; + return DataModification.reverseStandrize(inputArray, mean, std, this.hyperParameters); + } else { + throw new IllegalArgumentException("Input must be an instance of double[]"); + } + } catch (Exception e) { + throw new RuntimeException("Error processing input data", e); + } + } +} diff --git a/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/preprocessingpipeline/ReverseScalingPipe.java b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/preprocessingpipeline/ReverseScalingPipe.java new file mode 100644 index 00000000000..1709c48a4ca --- /dev/null +++ b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/preprocessingpipeline/ReverseScalingPipe.java @@ -0,0 +1,19 @@ +package io.openems.edge.predictor.lstmmodel.preprocessingpipeline; + +import io.openems.edge.predictor.lstmmodel.common.HyperParameters; +import static io.openems.edge.predictor.lstmmodel.preprocessing.DataModification.scaleBack; + +public class ReverseScalingPipe implements Stage { + private HyperParameters hype; + + public ReverseScalingPipe(HyperParameters hyperParameters) { + this.hype = hyperParameters; + } + + @Override + public Object execute(Object input) { + return (input instanceof double[] inputArray) // + ? scaleBack(inputArray, this.hype.getScalingMin(), this.hype.getScalingMax()) // + : new IllegalArgumentException("Input must be an instance of double[]"); + } +} diff --git a/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/preprocessingpipeline/ScalingPipe.java b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/preprocessingpipeline/ScalingPipe.java new file mode 100644 index 00000000000..b7b2cc08283 --- /dev/null +++ b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/preprocessingpipeline/ScalingPipe.java @@ -0,0 +1,65 @@ +package io.openems.edge.predictor.lstmmodel.preprocessingpipeline; + +import java.util.Arrays; + +import io.openems.edge.predictor.lstmmodel.common.HyperParameters; + +public class ScalingPipe implements Stage { + + private static final double MIN_SCALED = 0.2; + private static final double MAX_SCALED = 0.8; + + private HyperParameters hyperParameter; + + public ScalingPipe(HyperParameters hyperParameters) { + this.hyperParameter = hyperParameters; + } + + @Override + public Object execute(Object value) { + if (value instanceof double[][] v) { + return this.scaleSecondCase(v); + } else if (value instanceof double[] v) { + return (this.scaleFirstCase(v)); + } else { + throw new IllegalArgumentException("Input must be an instance of double[]"); + } + } + + /** + * Scales the data in the second case of a two-dimensional array using the + * preprocessing pipeline. + * + * @param value The two-dimensional array containing data to be scaled, where + * each row represents a separate case. + * @return A two-dimensional array containing the scaled data, where each row + * corresponds to the scaled data of the respective case. + */ + public double[][] scaleSecondCase(double[][] value) { + if (value == null || value.length != 2 || value[0] == null || value[1] == null) { + throw new IllegalArgumentException("Input must be a non-null 2xN array."); + } + + double[][] result = new double[2][]; + result[0] = this.scaleFirstCase(value[0]); + result[1] = this.scaleFirstCase(value[1]); + + return result; + } + + /** + * Scales the data in the first case of a one-dimensional array using the + * provided scaling range. + * + * @param value The one-dimensional array containing data to be scaled. + * @return An array containing the scaled data. + */ + public double[] scaleFirstCase(double[] value) { + double min = this.hyperParameter.getScalingMin(); + double max = this.hyperParameter.getScalingMax(); + + return Arrays.stream(value)// + .map(v -> MIN_SCALED + ((v - min) / (max - min)) * (MAX_SCALED - MIN_SCALED))// + .toArray(); + } +} diff --git a/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/preprocessingpipeline/ShufflePipe.java b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/preprocessingpipeline/ShufflePipe.java new file mode 100644 index 00000000000..a7b267145fa --- /dev/null +++ b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/preprocessingpipeline/ShufflePipe.java @@ -0,0 +1,23 @@ +package io.openems.edge.predictor.lstmmodel.preprocessingpipeline; + +import io.openems.edge.predictor.lstmmodel.preprocessing.Shuffle; + +public class ShufflePipe implements Stage { + + @Override + public Object execute(Object input) { + if (!(input instanceof double[][][] data)) { + throw new IllegalArgumentException("Input must be a 3-dimensional double array."); + } + + double[][] trainData = data[0]; + double[] targetData = data[1][0]; + + Shuffle shuffle = new Shuffle(trainData, targetData); + + double[][] shuffledData = shuffle.getData(); + double[] shuffledTarget = shuffle.getTarget(); + + return new double[][][] { shuffledData, { shuffledTarget } }; + } +} diff --git a/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/preprocessingpipeline/Stage.java b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/preprocessingpipeline/Stage.java new file mode 100644 index 00000000000..dfff7203266 --- /dev/null +++ b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/preprocessingpipeline/Stage.java @@ -0,0 +1,11 @@ +package io.openems.edge.predictor.lstmmodel.preprocessingpipeline; + +public interface Stage { + /** + * Executes the stage's processing logic on the provided input. + * + * @param input The input data to be processed by the stage. + * @return The result of the processing, typically of type O. + */ + O execute(final I input); +} diff --git a/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/preprocessingpipeline/TrainandTestSplitPipe.java b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/preprocessingpipeline/TrainandTestSplitPipe.java new file mode 100644 index 00000000000..ffdd7efe4b1 --- /dev/null +++ b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/preprocessingpipeline/TrainandTestSplitPipe.java @@ -0,0 +1,49 @@ +package io.openems.edge.predictor.lstmmodel.preprocessingpipeline; + +import java.util.stream.IntStream; + +import io.openems.edge.predictor.lstmmodel.common.HyperParameters; + +public class TrainandTestSplitPipe implements Stage { + private HyperParameters hyp; + + public TrainandTestSplitPipe(HyperParameters hyperParameters) { + this.hyp = hyperParameters; + } + + /** + * Splits the provided data into training and validation datasets based on the + * configured split percentage. + * + * @param value The array of data to be split. + * @return A 2D array where the first row contains the training data and the + * second row contains the validation data. + */ + @Override + public Object execute(Object value) { + if (value instanceof double[] valueTemp) { + double splitPercentage = this.hyp.getDataSplitTrain(); + int dataSize = valueTemp.length - 1; + + int trainLowerIndex = 0; + int trainUpperIndex = (int) (splitPercentage * dataSize); + + int testLowerIndex = trainUpperIndex; + int testUpperIndex = dataSize + 1; + + double[][] combinedData = { // train data + IntStream.range(trainLowerIndex, trainUpperIndex) // + .mapToDouble(index -> valueTemp[index]) // + .toArray(), // target data + IntStream.range(testLowerIndex, testUpperIndex) // + .mapToDouble(index -> valueTemp[index]) // + .toArray() // + }; + + return combinedData; + } else { + throw new IllegalArgumentException("Input must be an instance of double[]"); + } + + } +} diff --git a/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/train/LstmTrain.java b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/train/LstmTrain.java new file mode 100644 index 00000000000..2d2a67e59f4 --- /dev/null +++ b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/train/LstmTrain.java @@ -0,0 +1,173 @@ +package io.openems.edge.predictor.lstmmodel.train; + +import static io.openems.edge.predictor.lstmmodel.preprocessing.DataModification.constantScaling; +import static io.openems.edge.predictor.lstmmodel.preprocessing.DataModification.removeNegatives; + +import java.time.OffsetDateTime; +import java.time.ZonedDateTime; +import java.time.temporal.ChronoUnit; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.Map; +import java.util.Objects; +import java.util.SortedMap; +import java.util.TreeMap; +import java.util.stream.Collectors; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.google.common.collect.Sets; +import com.google.gson.JsonElement; + +import io.openems.common.exceptions.OpenemsError.OpenemsNamedException; +import io.openems.common.timedata.Resolution; +import io.openems.common.types.ChannelAddress; +import io.openems.edge.predictor.lstmmodel.LstmModel; +import io.openems.edge.predictor.lstmmodel.common.HyperParameters; +import io.openems.edge.predictor.lstmmodel.common.ReadAndSaveModels; +import io.openems.edge.timedata.api.Timedata; + +public class LstmTrain implements Runnable { + + private final Logger log = LoggerFactory.getLogger(LstmTrain.class); + + private final Timedata timedata; + private final ChannelAddress channelAddress; + private final LstmModel parent; + private final long days; + + public LstmTrain(Timedata timedata, ChannelAddress channelAddress, LstmModel parent, long days) { + this.timedata = timedata; + this.channelAddress = channelAddress; + this.parent = parent; + this.days = days; + } + + @Override + public void run() { + + var nowDate = ZonedDateTime.now(); + + var until = nowDate.minusDays(1).withHour(23).withMinute(45).withSecond(0).withNano(0); + var fromDate = until.minusDays(this.days).withHour(0).withMinute(0).withSecond(0).withNano(0); + + SortedMap> querryResult = new TreeMap>(); + HyperParameters hyperParameters = ReadAndSaveModels.read(this.channelAddress.getChannelId()); + + SortedMap> trainMap = new TreeMap<>(); + SortedMap> validateMap = new TreeMap<>(); + + try { + querryResult = this.timedata.queryHistoricData(null, fromDate, until, Sets.newHashSet(this.channelAddress), + new Resolution(hyperParameters.getInterval(), ChronoUnit.MINUTES)); + + int totalItems = querryResult.size(); + int trainSize = (int) (totalItems * 0.66); // 66% train and 33% validation + + int count = 0; + for (Map.Entry> entry : querryResult.entrySet()) { + if (count < trainSize) { + trainMap.put(entry.getKey(), entry.getValue()); + } else { + validateMap.put(entry.getKey(), entry.getValue()); + } + count++; + } + + } catch (OpenemsNamedException e) { + e.printStackTrace(); + } + + // Get the training data + var trainingData = this.getData(trainMap); + + if (this.cannotTrainConditions(trainingData)) { + this.parent._setCannotTrainCondition(true); + this.log.info("Cannot proceed with training: Data is all null or insufficient data."); + return; + } + // Get the training Date + var trainingDate = this.getDate(trainMap); + // Get the training data + var validationData = this.getData(validateMap); + // Get the validationDate + var validationDate = this.getDate(validateMap); + + /** + * TODO Read an save model.adapt method ReadAndSaveModels.adapt(hyperParameters, + * validateBatchData, validateBatchDate); + */ + new TrainAndValidateBatch(// + constantScaling(removeNegatives(trainingData), 1), trainingDate, // + constantScaling(removeNegatives(validationData), 1), validationDate, // + hyperParameters); + + this.parent._setLastTrainedTime(hyperParameters.getLastTrainedDate().toString()); + this.parent._setModelError(Collections.min(hyperParameters.getRmsErrorSeasonality())); + this.parent._setCannotTrainCondition(false); + + } + + /** + * Extracts data values. + * + * @param queryResult The SortedMap queryResult. + * @return An ArrayList of Double values extracted from non-null JsonElement + * values. + */ + public ArrayList getData(SortedMap> queryResult) { + + ArrayList data = new ArrayList<>(); + + queryResult.values().stream()// + .map(SortedMap::values)// + .flatMap(Collection::stream)// + .map(v -> { + if (v.isJsonNull()) { + return null; + } + return v.getAsDouble(); + }).forEach(value -> data.add(value)); + + return data; + } + + /** + * Extracts OffsetDateTime objects from the keys of a SortedMap containing + * ZonedDateTime keys. + * + * @param queryResult The SortedMap containing ZonedDateTime keys and associated + * data. + * @return An ArrayList of OffsetDateTime objects extracted from the + * ZonedDateTime keys. + */ + public ArrayList getDate( + SortedMap> queryResult) { + return queryResult.keySet().stream()// + .map(ZonedDateTime::toOffsetDateTime)// + .collect(Collectors.toCollection(ArrayList::new)); + } + + /** + * Checks if all elements in an ArrayList are null. + * + * @param array The ArrayList to be checked. + * @return true if all elements in the ArrayList are null, false otherwise. + */ + private boolean cannotTrainConditions(ArrayList array) { + if (array.isEmpty()) { + return true; // Cannot train with no data + } + + boolean allNulls = array.stream().allMatch(Objects::isNull); + if (allNulls) { + return true; // Cannot train with all null data + } + + var nonNanCount = array.stream().filter(d -> d != null && !Double.isNaN(d)).count(); + var validProportion = (double) nonNanCount / array.size(); + return validProportion <= 0.5; // Cannot train with 50% or more invalid data + } +} diff --git a/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/train/MakeModel.java b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/train/MakeModel.java new file mode 100644 index 00000000000..5e1fbcec101 --- /dev/null +++ b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/train/MakeModel.java @@ -0,0 +1,185 @@ +package io.openems.edge.predictor.lstmmodel.train; + +import java.time.OffsetDateTime; +import java.util.ArrayList; + +import io.openems.edge.predictor.lstmmodel.common.DynamicItterationValue; +import io.openems.edge.predictor.lstmmodel.common.HyperParameters; +import io.openems.edge.predictor.lstmmodel.preprocessingpipeline.PreprocessingPipeImpl; +import io.openems.edge.predictor.lstmmodel.util.Engine.EngineBuilder; +import static io.openems.edge.predictor.lstmmodel.utilities.UtilityConversion.to1DArray; + +public class MakeModel { + public static final String SEASONALITY = "seasonality"; + public static final String TREND = "trend"; + + /** + * Trains the trend model using the specified data, timestamps, and + * hyperparameters. The training process involves preprocessing the data for + * short-term prediction, generating initial weights, and fitting the model for + * each modified data segment. The trained weights are saved to the model file. + * + * @param data The ArrayList of Double values representing the + * time-series data. + * @param date The ArrayList of OffsetDateTime objects corresponding + * to the timestamps of the data. + * @param hyperParameters The hyperparameters configuration for training the + * trend model. + * @return weightMatrix Trained models. + */ + public synchronized ArrayList>>> trainTrend(ArrayList data, + ArrayList date, HyperParameters hyperParameters) { + + var weightMatrix = new ArrayList>>>(); + var weightTrend = new ArrayList>(); + PreprocessingPipeImpl preProcessing = new PreprocessingPipeImpl(hyperParameters); + preProcessing.setData(to1DArray(data)); + preProcessing.setDates(date); + + var modifiedData = (double[][]) preProcessing// + .interpolate()// + .movingAverage()// + .scale()// + .filterOutliers()// + .modifyForTrendPrediction()// + .execute(); + + for (int i = 0; i < modifiedData.length; i++) { + + weightTrend = (hyperParameters.getCount() == 0) // + ? generateInitialWeightMatrix(hyperParameters.getWindowSizeTrend(), hyperParameters)// + : hyperParameters.getlastModelTrend().get(i); + + preProcessing.setData(modifiedData[i]); + + var preProcessed = (double[][][]) preProcessing// + .groupToStiffedWindow()// + .normalize()// + .shuffle()// + .execute(); + + var model = new EngineBuilder() // + .setInputMatrix(preProcessed[0])// + .setTargetVector(preProcessed[1][0]) // + .build(); + model.fit(hyperParameters.getGdIterration(), weightTrend, hyperParameters); + weightMatrix.add(model.getWeights()); + + } + + return weightMatrix; + } + + /** + * Trains the seasonality model using the specified data, timestamps, and + * hyperparameters. The training process involves preprocessing the data, + * grouping it by hour and minute, and fitting the model for each group. The + * trained weights are saved to the model file. + * + * @param data The ArrayList of Double values representing the + * time-series data. + * @param date The ArrayList of OffsetDateTime objects corresponding + * to the timestamps of the data. + * @param hyperParameters The hyperparameters configuration for training the + * seasonality model. + * @return weightMatrix Trained seasonality models. + */ + + public synchronized ArrayList>>> trainSeasonality(ArrayList data, + ArrayList date, HyperParameters hyperParameters) { + + ArrayList>>> weightMatrix = new ArrayList<>(); + ArrayList> weightSeasonality = new ArrayList<>(); + int windowsSize = hyperParameters.getWindowSizeSeasonality(); + + PreprocessingPipeImpl preprocessing = new PreprocessingPipeImpl(hyperParameters); + + preprocessing.setData(to1DArray(data));// + preprocessing.setDates(date);// + + var dataGroupedByMinute = (double[][][]) preprocessing// + .interpolate()// + .movingAverage()// + .scale()// + .filterOutliers()// + .groupByHoursAndMinutes()// + .execute(); + int k = 0; + + for (int i = 0; i < dataGroupedByMinute.length; i++) { + for (int j = 0; j < dataGroupedByMinute[i].length; j++) { + + hyperParameters.setGdIterration(DynamicItterationValue + .setIteration(hyperParameters.getAllModelErrorSeason(), k, hyperParameters)); + + if (hyperParameters.getCount() == 0) { + weightSeasonality = generateInitialWeightMatrix(windowsSize, hyperParameters); + } else { + + weightSeasonality = hyperParameters.getlastModelSeasonality().get(k); + + } + + preprocessing.setData(dataGroupedByMinute[i][j]); + + var preProcessedSeason = (double[][][]) preprocessing// + //.differencing()// + .groupToWIndowSeasonality()// + .normalize()// + .shuffle()// + .execute(); + + var model = new EngineBuilder()// + .setInputMatrix(preProcessedSeason[0])// + .setTargetVector(preProcessedSeason[1][0])// + .build(); + + model.fit(hyperParameters.getGdIterration(), weightSeasonality, hyperParameters); + weightMatrix.add(model.getWeights()); + + k = k + 1; + } + } + + return weightMatrix; + + } + + /** + * Generates the initial weight matrix for the LSTM model based on the specified + * window size and hyperparameters. + * + * @param windowSize The size of the window for the initial weight matrix. + * @param hyperParameters The hyperparameters used for generating the initial + * weight matrix. + * @return The initial weight matrix as an ArrayList of ArrayList of Double + * values. + */ + public static ArrayList> generateInitialWeightMatrix(int windowSize, + HyperParameters hyperParameters) { + + ArrayList> initialWeight = new ArrayList<>(); + String[] parameterTypes = { "wi", "wo", "wz", "ri", "ro", "rz", "yt", "ct" }; + + for (String type : parameterTypes) { + ArrayList temp = new ArrayList<>(); + for (int i = 1; i <= windowSize; i++) { + double value = switch (type) { + case "wi" -> hyperParameters.getWiInit(); + case "wo" -> hyperParameters.getWoInit(); + case "wz" -> hyperParameters.getWzInit(); + case "ri" -> hyperParameters.getRiInit(); + case "ro" -> hyperParameters.getRoInit(); + case "rz" -> hyperParameters.getRzInit(); + case "yt" -> hyperParameters.getYtInit(); + case "ct" -> hyperParameters.getCtInit(); + default -> throw new IllegalArgumentException("Invalid parameter type"); + }; + temp.add(value); + } + initialWeight.add(temp); + } + return initialWeight; + } + +} diff --git a/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/train/TrainAndValidateBatch.java b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/train/TrainAndValidateBatch.java new file mode 100644 index 00000000000..80f3fab9c84 --- /dev/null +++ b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/train/TrainAndValidateBatch.java @@ -0,0 +1,91 @@ +package io.openems.edge.predictor.lstmmodel.train; + +import java.time.OffsetDateTime; +import java.util.ArrayList; +import java.util.concurrent.CompletableFuture; + +import io.openems.edge.predictor.lstmmodel.common.HyperParameters; +import io.openems.edge.predictor.lstmmodel.common.ReadAndSaveModels; +import io.openems.edge.predictor.lstmmodel.preprocessing.DataModification; +import io.openems.edge.predictor.lstmmodel.validator.ValidationSeasonalityModel; +import io.openems.edge.predictor.lstmmodel.validator.ValidationTrendModel; + +public class TrainAndValidateBatch { + + public TrainAndValidateBatch(// + ArrayList trainData, // + ArrayList trainDate, // + ArrayList validateData, // + ArrayList validateDate, // + HyperParameters hyperParameter) { + + /* + * var checkTrain = trainData.size() / hyperParameter.getBatchSize() + * + * if ( checkTrain <= hyperParameter.getWindowSizeSeasonality() || checkTrain <= + * hyperParameter.getWindowSizeTrend() ) { throw new Exception; } + */ + + var batchedData = DataModification.getDataInBatch(// + trainData, hyperParameter.getBatchSize()); + var batchedDate = DataModification.getDateInBatch(// + trainDate, hyperParameter.getBatchSize()); + + + + for (int epoch = hyperParameter.getEpochTrack(); epoch < hyperParameter.getEpoch(); epoch++) { + + int k = hyperParameter.getCount(); + + for (int batch = hyperParameter.getBatchTrack(); batch < hyperParameter.getBatchSize(); batch++) { + + hyperParameter.setCount(k); + System.out.println("=====> Batch = " + hyperParameter.getBatchTrack() // + + "/" + hyperParameter.getBatchSize()); + System.out.println("=====> Epoch= " + epoch // + + "/" + hyperParameter.getEpoch()); + + MakeModel makeModels = new MakeModel(); + + var trainDataTemp = batchedData.get(batch); + var trainDateTemp = batchedDate.get(batch); + + CompletableFuture firstTaskFuture = CompletableFuture + + // Train the Seasonality model + .supplyAsync(() -> makeModels.trainSeasonality(trainDataTemp, trainDateTemp, hyperParameter)) + + // Validate this Seasonality model + .thenAccept(untestedSeasonalityMoadels -> new ValidationSeasonalityModel().validateSeasonality( + validateData, validateDate, untestedSeasonalityMoadels, hyperParameter)); + + CompletableFuture secondTaskFuture = CompletableFuture + + // Train the trend model + .supplyAsync(() -> makeModels.trainTrend(trainDataTemp, trainDateTemp, hyperParameter)) + + // validate the trend model + .thenAccept(untestedSeasonalityMoadels -> new ValidationTrendModel().validateTrend(validateData, + validateDate, untestedSeasonalityMoadels, hyperParameter)); + + k = k + 1; + try { + CompletableFuture.allOf(firstTaskFuture, secondTaskFuture).get(); + } catch (Exception e) { + e.printStackTrace(); + } + + hyperParameter.setBatchTrack(batch + 1); + hyperParameter.setCount(k); + ReadAndSaveModels.save(hyperParameter); + } + hyperParameter.setBatchTrack(0); + hyperParameter.setEpochTrack(hyperParameter.getEpochTrack() + 1); + hyperParameter.update(); + ReadAndSaveModels.save(hyperParameter); + + } + hyperParameter.setEpochTrack(0); + ReadAndSaveModels.save(hyperParameter); + } +} diff --git a/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/util/AdaptiveLearningRate.java b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/util/AdaptiveLearningRate.java new file mode 100644 index 00000000000..26bbdd8a6e3 --- /dev/null +++ b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/util/AdaptiveLearningRate.java @@ -0,0 +1,47 @@ +package io.openems.edge.predictor.lstmmodel.util; + +import io.openems.edge.predictor.lstmmodel.common.HyperParameters; + +public class AdaptiveLearningRate { + /** + * Adjusts the learning rate based on the given percentage. * the total + * iterations. + * + * @param hyperParameters An instance of class HyperParameter + * @return The adapted learning rate calculated using a cosine annealing + * strategy. + */ + + public double scheduler(HyperParameters hyperParameters) { + var maximum = hyperParameters.getLearningRateUpperLimit(); + var minimum = hyperParameters.getLearningRateLowerLimit(); + var tCurByTmax = (double) hyperParameters.getEpochTrack() / hyperParameters.getEpoch(); + var cosineValue = Math.cos(tCurByTmax * Math.PI); + return (minimum + 0.5 * (maximum - minimum) * (1 + cosineValue)); + } + + /** + * Performs the Adagrad optimization step to adjust the learning rate based on + * the gradient information. + * + * @param globalLearningRate The global learning rate for the optimization + * process. + * @param localLearningRate The local learning rate, which is dynamically + * adjusted during the optimization. + * @param gradient The gradient value computed during the + * optimization. + * @param iteration The iteration number, used to determine if this is + * the first iteration. + * @return The adapted learning rate based on the Adagrad optimization strategy. + */ + public double adagradOptimizer(double globalLearningRate, double localLearningRate, double gradient, + int iteration) { + if (iteration == 0 || localLearningRate == 0 || (globalLearningRate == 0 && gradient == 0)) { + return globalLearningRate; + } + + double adjustedRate = Math.pow(globalLearningRate / localLearningRate, 2) // + + Math.pow(gradient, 2); + return globalLearningRate / Math.sqrt(adjustedRate); + } +} diff --git a/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/util/Cell.java b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/util/Cell.java new file mode 100644 index 00000000000..540c3e09453 --- /dev/null +++ b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/util/Cell.java @@ -0,0 +1,315 @@ +package io.openems.edge.predictor.lstmmodel.util; + +import java.util.Random; + +import io.openems.edge.predictor.lstmmodel.utilities.MathUtils; + +public class Cell { + + private double error; + private double wI; + private double wO; + private double wZ; + + private double rI; + private double rO; + private double rZ; + + private double yT; + private double ytMinusOne; + + private double cT; + private double ctMinusOne; + private double oT; + private double zT; + + private double iT; + private double dlByDy; + private double dlByDo; + private double dlByDc; + private double dlByDi; + private double dlByDz; + private double delI; + private double delO; + private double delZ; + + private double xT; + private double outputDataLoc; + + private double delF; + + public Cell(double xt, double outputData) { + this(xt, outputData, 1, 1, 1, 1, 1, 1, 0); + } + + public Cell(double xt, double outputData, double wI, double wO, double wZ, double rI, double rO, double rZ, + double yT) { + this.dlByDc = 0; + this.error = 0; + this.wI = wI; + this.wO = wO; + this.wZ = wZ; + this.rI = rI; + this.rO = rO; + this.rZ = rZ; + this.cT = 0; + this.oT = 0; + this.zT = 0; + this.yT = 0; + this.ytMinusOne = 0; + this.ctMinusOne = 0; + this.ytMinusOne = this.yT; + this.dlByDy = 0; + this.dlByDo = 0; + this.dlByDc = 0; + this.dlByDi = 0; + this.dlByDz = 0; + this.delI = 0; + this.delO = 0; + this.delZ = 0; + this.iT = 0; + this.xT = xt; + this.outputDataLoc = outputData; + } + + /** + * Forward propagation. + */ + public void forwardPropogation() { + double dropOutProb; + boolean decissionFlag = this.decisionDropout(); + if (decissionFlag) { + dropOutProb = 0.0; + this.iT = MathUtils.sigmoid(this.wI * this.xT + this.rI * this.ytMinusOne); + this.oT = MathUtils.sigmoid(this.wO * this.xT + this.rO * this.ytMinusOne); + this.zT = MathUtils.tanh(this.wZ * this.xT + this.rZ * this.ytMinusOne); + this.cT = this.ctMinusOne + this.iT * this.zT * dropOutProb; + this.yT = this.ytMinusOne * (1 - dropOutProb) + this.oT * MathUtils.tanh(this.cT) * dropOutProb; + this.error = Math.abs(this.yT - this.outputDataLoc) / Math.sqrt(2); + } else { + this.iT = MathUtils.sigmoid(this.wI * this.xT + this.rI * this.ytMinusOne); + this.oT = MathUtils.sigmoid(this.wO * this.xT + this.rO * this.ytMinusOne); + this.zT = MathUtils.tanh(this.wZ * this.xT + this.rZ * this.ytMinusOne); + this.cT = this.ctMinusOne + this.iT * this.zT; + this.yT = this.oT * MathUtils.tanh(this.cT); + this.error = Math.abs(this.yT - this.outputDataLoc) / Math.sqrt(2); + } + } + + /** + * Backward propagation. + */ + public void backwardPropogation() { + this.dlByDy = this.error; + this.dlByDo = this.dlByDy * MathUtils.tanh(this.cT); + this.dlByDc = this.dlByDy * this.oT * MathUtils.tanhDerivative(this.cT) + this.dlByDc; + this.dlByDi = this.dlByDc * this.zT; + this.dlByDz = this.dlByDc * this.iT; + this.delI = this.dlByDi * MathUtils.sigmoidDerivative(this.wI * this.xT + this.rI * this.ytMinusOne); + this.delO = this.dlByDo * MathUtils.sigmoidDerivative(this.wO * this.xT + this.rO * this.ytMinusOne); + this.delZ = this.dlByDz * MathUtils.tanhDerivative(this.wZ * this.xT + this.rZ * this.ytMinusOne); + } + + /** + * Generates a random decision with dropout probability. This method generates a + * random boolean decision with a dropout probability of 10%. It uses a random + * number generator to determine whether the decision is true or false. The + * probability of returning true is 10%, and the probability of returning false + * is 90%. + * + * @return true with a 10% probability, false with a 90% probability. + */ + public boolean decisionDropout() { + Random random = new Random(); + int randomNumber = random.nextInt(10) + 1; + if (randomNumber > 7) { + return true; + } + return false; + } + + public double getError() { + return this.error; + } + + public void setError(double error) { + this.error = error; + } + + public double getWi() { + return this.wI; + } + + public void setWi(double wi) { + this.wI = wi; + } + + public double getWo() { + return this.wO; + } + + public void setWo(double wo) { + this.wO = wo; + } + + public double getWz() { + return this.wZ; + } + + public void setWz(double wz) { + this.wZ = wz; + } + + public double getRi() { + return this.rI; + } + + public void setRi(double ri) { + this.rI = ri; + } + + public double getRo() { + return this.rO; + } + + public void setRo(double ro) { + this.rO = ro; + } + + public double getRz() { + return this.rZ; + } + + public void setRz(double rz) { + this.rZ = rz; + } + + public double getCt() { + return this.cT; + } + + public void setCt(double ct) { + this.cT = ct; + } + + public double getCtMinusOne() { + return this.ctMinusOne; + } + + public void setCtMinusOne(double ct) { + this.ctMinusOne = ct; + } + + public double getYtMinusOne() { + return this.ytMinusOne; + } + + public void setYtMinusOne(double yt) { + this.ytMinusOne = yt; + } + + public double getYt() { + return this.yT; + } + + public void setYt(double yt) { + this.yT = yt; + } + + public void setIt(double iT) { + this.iT = iT; + } + + public double getIt() { + return this.iT; + } + + public double getOt() { + return this.oT; + } + + public double getZt() { + return this.zT; + } + + public void setDlByDy(double dlByDy) { + this.dlByDy = dlByDy; + } + + public double getDlByDy() { + return this.dlByDy; + } + + public void setDlByDo(double dlByDo) { + this.dlByDo = dlByDo; + } + + public double getDlByDo() { + return this.dlByDo; + } + + public void setDlByDc(double dlByDc) { + this.dlByDc = dlByDc; + } + + public double getDlByDc() { + + return this.dlByDc; + } + + public void setDlByDi(double dlByDi) { + this.dlByDi = dlByDi; + } + + public double getDlByDi() { + return this.dlByDi; + } + + public void setDlByDz(double dlByDz) { + this.dlByDz = dlByDz; + } + + public double getDlByDz() { + return this.dlByDz; + } + + public void setDelI(double delI) { + this.delI = delI; + } + + public double getDelI() { + return this.delI; + } + + public void setDelO(double delO) { + this.delO = delO; + } + + public double getDelF() { + return this.delF; + } + + public void setDelF(double delF) { + this.delF = delF; + } + + public double getDelO() { + return this.delO; + } + + public void setDelZ(double delZ) { + this.delZ = delZ; + } + + public double getDelZ() { + return this.delZ; + } + + public void setXt(double xt) { + this.xT = xt; + } + + public double getXt() { + return this.xT; + } +} diff --git a/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/util/Engine.java b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/util/Engine.java new file mode 100644 index 00000000000..fafc9bd4f17 --- /dev/null +++ b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/util/Engine.java @@ -0,0 +1,265 @@ +package io.openems.edge.predictor.lstmmodel.util; + +import static io.openems.edge.predictor.lstmmodel.preprocessing.DataModification.scaleBack; +import static io.openems.edge.predictor.lstmmodel.utilities.MathUtils.sigmoid; +import static io.openems.edge.predictor.lstmmodel.utilities.MathUtils.tanh; +import static io.openems.edge.predictor.lstmmodel.utilities.UtilityConversion.getMinIndex; + +import java.util.ArrayList; + +import io.openems.edge.predictor.lstmmodel.common.DataStatistics; +import io.openems.edge.predictor.lstmmodel.common.HyperParameters; +import io.openems.edge.predictor.lstmmodel.util.Lstm.LstmBuilder; + +public class Engine { + + private double[][] inputMatrix; + private double[] targetVector; + private double[][] validateData; + private double[] validateTarget; + private double learningRate; + + private ArrayList>> weights = new ArrayList>>(); + private ArrayList> finalWeights = new ArrayList>(); + + /** + * This method train the LSTM network. and Update the finalWeight matrix. + * + * @param epochs Number of times the forward and backward propagation. + * @param val are the weights. + * @param hyperParameters An instance of class HyperParameter + * + */ + public void fit(int epochs, ArrayList> val, HyperParameters hyperParameters) { + + var rate = new AdaptiveLearningRate(); + + this.learningRate = rate.scheduler(hyperParameters); + + // First Time default LSTM object + var ls = new LstmBuilder(this.inputMatrix[0], this.targetVector[0])// + .setLearningRate(this.learningRate) // + .setEpoch(epochs)// + .build(); + + ls.initilizeCells(); + + ls.setWi(val); + ls.setWo(val); + ls.setWz(val); + ls.setRi(val); + ls.setRo(val); + ls.setRz(val); + ls.setCt(val); + ls.setYt(val); + + var wieghtMatrix = ls.train(); + + this.weights.add(wieghtMatrix); + + for (int i = 1; i < this.inputMatrix.length; i++) { + + this.learningRate = rate.scheduler(hyperParameters); + // Update the Lstm + ls = new LstmBuilder(this.inputMatrix[i], this.targetVector[i])// + .setLearningRate(this.learningRate) // + .setEpoch(epochs) // + .build(); + + ls.initilizeCells(); + + for (int j = 0; j < ls.cells.size(); j++) { + + ls.cells.get(j).setWi((wieghtMatrix.get(0)).get(j)); + ls.cells.get(j).setWo((wieghtMatrix.get(1)).get(j)); + ls.cells.get(j).setWz((wieghtMatrix.get(2)).get(j)); + ls.cells.get(j).setRi((wieghtMatrix.get(3)).get(j)); + ls.cells.get(j).setRo((wieghtMatrix.get(4)).get(j)); + ls.cells.get(j).setRz((wieghtMatrix.get(5)).get(j)); + ls.cells.get(j).setYtMinusOne(wieghtMatrix.get(6).get(j)); + ls.cells.get(j).setCtMinusOne(wieghtMatrix.get(7).get(j)); + } + + wieghtMatrix = ls.train(); + this.weights.add(wieghtMatrix); + + } + + } + + /** + * Predict using the model and the input data. + * + * @param inputData input data for the prediction. + * @param hyperParameter is the object of class HyperParameter + * @return result + */ + public double[] predict(double[][] inputData, HyperParameters hyperParameter) { + + var result = new double[inputData.length]; + for (int i = 0; i < inputData.length; i++) { + + result[i] = this.singleValuePredict(inputData[i], // + this.finalWeights.get(0), // + this.finalWeights.get(1), // + this.finalWeights.get(2), // + this.finalWeights.get(3), // + this.finalWeights.get(4), // + this.finalWeights.get(5), // + this.finalWeights.get(6), // + this.finalWeights.get(7), // + hyperParameter); + } + return result; + } + + /** + * Validate to get the best model. + * + * @param inputData double array + * @param target double array + * @param val weight matrix + * @param hyperParameter An instance of class HyperParameter + * @return The resulted weight matrix + */ + public double[] validate(double[][] inputData, double[] target, ArrayList> val, + HyperParameters hyperParameter) { + + var result = new double[inputData.length]; + for (int i = 0; i < inputData.length; i++) { + + result[i] = this.singleValuePredict(inputData[i], // + val.get(0), // + val.get(1), // + val.get(2), // + val.get(3), // + val.get(4), // + val.get(5), // + val.get(6), // + val.get(7), // + hyperParameter); + } + + return result; + } + + /** + * Takes in an array of inputData and predicts single value. + * + * @param inputData double array + * @param wi weight wi + * @param wo weight wo + * @param wz weight wz + * @param Ri weight Ri + * @param Ro weight Ro + * @param Rz weight Rz + * @param ctV vector containing cell state + * @param ytV vector containing cell output + * + * @param hyperParameter An instance of class HyperParameter + * @return The predicted single double value + */ + private double singleValuePredict(double[] inputData, // + ArrayList wi, // + ArrayList wo, // + ArrayList wz, // + ArrayList Ri, // + ArrayList Ro, // + ArrayList Rz, // + ArrayList ytV, // + ArrayList ctV, // + HyperParameters hyperParameter) { + + var ct = 0.; + var ctMinusOne = 0.; + var yt = 0.; + var standData = inputData; + + for (int i = 0; i < wi.size(); i++) { + ctMinusOne = ctV.get(i); + double it = sigmoid(wi.get(i) * standData[i] + Ri.get(i) * yt); + double ot = sigmoid(wo.get(i) * standData[i] + Ro.get(i) * yt); + double zt = tanh(wz.get(i) * standData[i] + Rz.get(i) * yt); + ct = ctMinusOne + it * zt; + yt = ot * tanh(ct); + } + return scaleBack(yt, hyperParameter.getScalingMin(), hyperParameter.getScalingMax()); + } + + /** + * Select Best Weight out of all the Weights. + * + * @param wightMatrix All the matrices of the weight. + * @param hyperParameter is the object of class HyperParameter + * @return index index of the best matrix. + */ + public int selectWeight(ArrayList>> wightMatrix, HyperParameters hyperParameter) { + + var rms = new double[wightMatrix.size()]; + + for (int k = 0; k < wightMatrix.size(); k++) { + var val = wightMatrix.get(k); + var pre = this.validate(this.validateData, this.validateTarget, val, hyperParameter); + rms[k] = DataStatistics.computeRms(this.validateTarget, pre); + } + var minInd = getMinIndex(rms); + return minInd; + } + + public Engine(EngineBuilder builder) { + this.inputMatrix = builder.inputMatrix; + this.targetVector = builder.targetVector; + this.validateData = builder.validateData; + this.validateTarget = builder.validateTarget; + + } + + public static class EngineBuilder { + + private double[][] inputMatrix; + private double[] targetVector; + private double[][] validateData; + private double[] validateTarget; + + public EngineBuilder(double[][] inputMatrix, double[] targetVector, double[][] validateData, + double[] validateTarget, int validatorCounter) { + this.inputMatrix = inputMatrix; + this.targetVector = targetVector; + this.validateData = validateData; + this.validateTarget = validateTarget; + + } + + public EngineBuilder() { + } + + public EngineBuilder setInputMatrix(double[][] inputMatrix) { + this.inputMatrix = inputMatrix; + return this; + } + + public EngineBuilder setTargetVector(double[] targetVector) { + this.targetVector = targetVector; + return this; + } + + public EngineBuilder setValidateData(double[][] validateData) { + this.validateData = validateData; + return this; + } + + public EngineBuilder setValidateTarget(double[] validateTarget) { + this.validateTarget = validateTarget; + return this; + } + + public Engine build() { + return new Engine(this); + } + } + + public ArrayList>> getWeights() { + return this.weights; + } + +} diff --git a/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/util/Lstm.java b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/util/Lstm.java new file mode 100644 index 00000000000..8108e6d81ae --- /dev/null +++ b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/util/Lstm.java @@ -0,0 +1,354 @@ +package io.openems.edge.predictor.lstmmodel.util; + +import java.util.ArrayList; +import java.util.Comparator; +import java.util.stream.IntStream; + +public class Lstm { + + private double[] inputData; + private double outputData; + private double derivativeLWrtRi = 0; + private double derivativeLWrtRo = 0; + private double derivativeLWrtRz = 0; + private double derivativeLWrtWi = 0; + private double derivativeLWrtWo = 0; + private double derivativeLWrtWz = 0; + private double learningRate; + private int epoch = 100; + + protected ArrayList cells; + + public Lstm(LstmBuilder builder) { + this.inputData = builder.inputData; + this.outputData = builder.outputData; + this.learningRate = builder.learningRate; + this.epoch = builder.epoch; + } + + /** + * Forward propagation. + */ + public void forwardprop() { + try { + for (int i = 0; i < this.cells.size(); i++) { + this.cells.get(i).forwardPropogation(); + if (i < this.cells.size() - 1) { + this.cells.get(i + 1).setYtMinusOne(this.cells.get(i).getYt()); + this.cells.get(i + 1).setCtMinusOne(this.cells.get(i).getCt()); + this.cells.get(i).setError( + (Math.abs(this.cells.get(i).getYt() - this.cells.get(i + 1).getXt()) / Math.sqrt(2))); + } + } + } catch (IndexOutOfBoundsException e) { + e.printStackTrace(); + } + } + + /** + * Backward propagation. + */ + public void backwardprop() { + + ArrayList gradients= new ArrayList(); + + for (int i = this.cells.size() - 1; i >= 0; i--) { + if (i < this.cells.size() - 1) { + this.cells.get(i).setDlByDc(this.cells.get(i + 1).getDlByDc()); + } + this.cells.get(i).backwardPropogation(); + } + + for (int i = 0; i < this.cells.size(); i++) { + this.derivativeLWrtRi += this.cells.get(i).getYtMinusOne() * this.cells.get(i).getDelI(); + this.derivativeLWrtRo += this.cells.get(i).getYtMinusOne() * this.cells.get(i).getDelO(); + this.derivativeLWrtRz += this.cells.get(i).getYtMinusOne() * this.cells.get(i).getDelZ(); + + this.derivativeLWrtWi += this.cells.get(i).getXt() * this.cells.get(i).getDelI(); + this.derivativeLWrtWo += this.cells.get(i).getXt() * this.cells.get(i).getDelO(); + this.derivativeLWrtWz += this.cells.get(i).getXt() * this.cells.get(i).getDelZ(); + + + +// localLearningRate1 = rate.adagradOptimizer(this.learningRate, localLearningRate1, this.derivativeLWrtWi, i); +// localLearningRate2 = rate.adagradOptimizer(this.learningRate, localLearningRate2, this.derivativeLWrtWo, i); +// localLearningRate3 = rate.adagradOptimizer(this.learningRate, localLearningRate3, this.derivativeLWrtWz, i); +// localLearningRate4 = rate.adagradOptimizer(this.learningRate, localLearningRate4, this.derivativeLWrtRi, i); +// localLearningRate5 = rate.adagradOptimizer(this.learningRate, localLearningRate5, this.derivativeLWrtRo, i); +// localLearningRate6 = rate.adagradOptimizer(this.learningRate, localLearningRate6, this.derivativeLWrtRz, i); +// +// this.cells.get(i).setWi(this.cells.get(i).getWi() - localLearningRate1 * this.derivativeLWrtWi); +// this.cells.get(i).setWo(this.cells.get(i).getWo() - localLearningRate2 * this.derivativeLWrtWo); +// this.cells.get(i).setWz(this.cells.get(i).getWz() - localLearningRate3 * this.derivativeLWrtWz); +// this.cells.get(i).setRi(this.cells.get(i).getRi() - localLearningRate4 * this.derivativeLWrtRi); +// this.cells.get(i).setRo(this.cells.get(i).getRo() - localLearningRate5 * this.derivativeLWrtRo); +// this.cells.get(i).setRz(this.cells.get(i).getRz() - localLearningRate6 * this.derivativeLWrtRz); + } + for(int i = 0;i gradients) { + var rate = new AdaptiveLearningRate(); + + var localLearningRate1 = 0.; + var localLearningRate2 = 0.; + var localLearningRate3 = 0.; + var localLearningRate4 = 0.; + var localLearningRate5 = 0.; + var localLearningRate6 = 0.; + for (int i = 0; i < this.cells.size(); i++) { + + + localLearningRate1 = rate.adagradOptimizer(this.learningRate, localLearningRate1, gradients.get(0), i); + localLearningRate2 = rate.adagradOptimizer(this.learningRate, localLearningRate2, gradients.get(1), i); + localLearningRate3 = rate.adagradOptimizer(this.learningRate, localLearningRate3, gradients.get(2), i); + localLearningRate4 = rate.adagradOptimizer(this.learningRate, localLearningRate4, gradients.get(3), i); + localLearningRate5 = rate.adagradOptimizer(this.learningRate, localLearningRate5, gradients.get(4), i); + localLearningRate6 = rate.adagradOptimizer(this.learningRate, localLearningRate6, gradients.get(5), i); + + this.cells.get(i).setWi(this.cells.get(i).getWi() - localLearningRate1 * gradients.get(0)); + this.cells.get(i).setWo(this.cells.get(i).getWo() - localLearningRate2 * gradients.get(1)); + this.cells.get(i).setWz(this.cells.get(i).getWz() - localLearningRate3 * gradients.get(2)); + this.cells.get(i).setRi(this.cells.get(i).getRi() - localLearningRate4 * gradients.get(3)); + this.cells.get(i).setRo(this.cells.get(i).getRo() - localLearningRate5 * gradients.get(4)); + this.cells.get(i).setRz(this.cells.get(i).getRz() - localLearningRate6 * gradients.get(5)); + } + + } + + /** + * Train to get the weight matrix. + * + * @return weight matrix trained weight matrix + */ + public ArrayList> train() { + + MatrixWeight mW = new MatrixWeight(); + for (int i = 0; i < this.epoch; i++) { + + this.forwardprop(); + this.backwardprop(); + + var wiList = new ArrayList(); + var woList = new ArrayList(); + var wzList = new ArrayList(); + var riList = new ArrayList(); + var roList = new ArrayList(); + var rzList = new ArrayList(); + var ytList = new ArrayList(); + var ctList = new ArrayList(); + + for (int j = 0; j < this.cells.size(); j++) { + wiList.add(this.cells.get(j).getWi()); // + woList.add(this.cells.get(j).getWo()); // + wzList.add(this.cells.get(j).getWz()); // + riList.add(this.cells.get(j).getRi()); // + roList.add(this.cells.get(j).getRo()); // + rzList.add(this.cells.get(j).getRz()); // + ytList.add(this.cells.get(j).getYt()); // + ctList.add(this.cells.get(j).getCt()); // + } + + mW.getErrorList().add(this.cells.get(this.cells.size() - 1).getError()); + mW.getWi().add(wiList); + mW.getWo().add(woList); + mW.getWz().add(wzList); + mW.getRi().add(riList); + mW.getRo().add(roList); + mW.getRz().add(rzList); + mW.getOut().add(ytList); + mW.getCt().add(ctList); + } + + int globalMinimaIndex = findGlobalMinima(mW.getErrorList()); + + var returnArray = new ArrayList>(); + + returnArray.add(mW.getWi().get(globalMinimaIndex)); + returnArray.add(mW.getWo().get(globalMinimaIndex)); + returnArray.add(mW.getWz().get(globalMinimaIndex)); + returnArray.add(mW.getRi().get(globalMinimaIndex)); + returnArray.add(mW.getRo().get(globalMinimaIndex)); + returnArray.add(mW.getRz().get(globalMinimaIndex)); + returnArray.add(mW.getOut().get(globalMinimaIndex)); + returnArray.add(mW.getCt().get(globalMinimaIndex)); + + return returnArray; + } + + /** + * Get the index of the Global minima. element arr.get(index x) is a local + * minimum if it is less than both its neighbors and an arr can have multiple + * local minima. + * + * @param data {@link java.util.ArrayList} of double + * @return index index of the global minima in the data + */ + public static int findGlobalMinima(ArrayList data) { + return IntStream.range(0, data.size())// + .boxed()// + .min(Comparator.comparingDouble(i -> Math.abs(data.get(i))))// + .orElse(-1); + } + + public double[] getInputData() { + return this.inputData; + } + + public double getOutputData() { + return this.outputData; + } + + public double getDerivativeLWrtRi() { + return this.derivativeLWrtRi; + } + + public double getDerivativeLWrtRo() { + return this.derivativeLWrtRo; + } + + public double getDerivativeLWrtRz() { + return this.derivativeLWrtRz; + } + + public double getDerivativeLWrtWi() { + return this.derivativeLWrtWi; + } + + public double getDerivativeLWrtWo() { + return this.derivativeLWrtWo; + } + + public double getDerivativeLWrtWz() { + return this.derivativeLWrtWz; + } + + public double getLearningRate() { + return this.learningRate; + } + + public ArrayList getCells() { + return this.cells; + } + + public synchronized void setWi(ArrayList> val) { + for (int i = 0; i < val.get(0).size(); i++) { + try { + this.cells.get(i).setWi(val.get(0).get(i)); + } catch (ArithmeticException | IndexOutOfBoundsException e) { + e.printStackTrace(); + } + } + } + + public synchronized void setWo(ArrayList> val) { + for (int i = 0; i < val.get(1).size(); i++) { + this.cells.get(i).setWo(val.get(1).get(i)); + } + } + + public synchronized void setWz(ArrayList> val) { + for (int i = 0; i < val.get(2).size(); i++) { + this.cells.get(i).setWz(val.get(2).get(i)); + } + } + + public synchronized void setRi(ArrayList> val) { + for (int i = 0; i < val.get(3).size(); i++) { + this.cells.get(i).setRi(val.get(3).get(i)); + } + } + + public synchronized void setRo(ArrayList> val) { + for (int i = 0; i < val.get(4).size(); i++) { + this.cells.get(i).setRo(val.get(4).get(i)); + } + } + + public synchronized void setRz(ArrayList> val) { + for (int i = 0; i < val.get(5).size(); i++) { + this.cells.get(i).setRz(val.get(5).get(i)); + } + } + + public synchronized void setYt(ArrayList> val) { + for (int i = 0; i < val.get(6).size(); i++) { + this.cells.get(i).setYt(val.get(6).get(i)); + } + } + + public synchronized void setCt(ArrayList> val) { + for (int i = 0; i < val.get(7).size(); i++) { + this.cells.get(i).setCt(val.get(7).get(i)); + } + } + + /** + * Please build the model with input and target. + * + */ + public static class LstmBuilder { + + protected double[] inputData; + protected double outputData; + + protected double learningRate; // + protected int epoch = 100; // + + public LstmBuilder(double[] inputData, double outputData) { + this.inputData = inputData; + this.outputData = outputData; + } + + public LstmBuilder setInputData(double[] inputData) { + this.inputData = inputData; + return this; + } + + public LstmBuilder setOutputData(double outputData) { + this.outputData = outputData; + return this; + } + + public LstmBuilder setLearningRate(double learningRate) { + this.learningRate = learningRate; + return this; + } + + public LstmBuilder setEpoch(int epoch) { + this.epoch = epoch; + return this; + } + + public Lstm build() { + return new Lstm(this); + } + + } + + /** + * Initializes the cell with the default data. + */ + + public synchronized void initilizeCells() { + this.cells = new ArrayList<>(); + for (int i = 0; i < this.inputData.length; i++) { + Cell cell = new Cell(this.inputData[i], this.outputData); + this.cells.add(cell); + } + + } + +} diff --git a/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/util/MatrixWeight.java b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/util/MatrixWeight.java new file mode 100644 index 00000000000..24acdc688d3 --- /dev/null +++ b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/util/MatrixWeight.java @@ -0,0 +1,63 @@ +package io.openems.edge.predictor.lstmmodel.util; + +import java.util.ArrayList; + +public class MatrixWeight { + + private ArrayList> wI = new ArrayList>(); + private ArrayList> wO = new ArrayList>(); + private ArrayList> wZ = new ArrayList>(); + private ArrayList> wF = new ArrayList>(); + private ArrayList> rI = new ArrayList>(); + private ArrayList> rO = new ArrayList>(); + private ArrayList> rZ = new ArrayList>(); + private ArrayList> rF = new ArrayList>(); + + private ArrayList> out = new ArrayList>(); + private ArrayList> cT = new ArrayList>(); + private ArrayList errorList = new ArrayList(); + + public ArrayList> getWi() { + return this.wI; + } + + public ArrayList> getWo() { + return this.wO; + } + + public ArrayList> getWz() { + return this.wZ; + } + + public ArrayList> getRi() { + return this.rI; + } + + public ArrayList> getRo() { + return this.rO; + } + + public ArrayList> getRz() { + return this.rZ; + } + + public ArrayList> getOut() { + return this.out; + } + + public ArrayList> getCt() { + return this.cT; + } + + public ArrayList getErrorList() { + return this.errorList; + } + + public ArrayList> getWf() { + return this.wF; + } + + public ArrayList> getRf() { + return this.rF; + } +} diff --git a/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/utilities/DataUtility.java b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/utilities/DataUtility.java new file mode 100644 index 00000000000..a29c4952888 --- /dev/null +++ b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/utilities/DataUtility.java @@ -0,0 +1,127 @@ +package io.openems.edge.predictor.lstmmodel.utilities; + +import java.time.OffsetDateTime; +import java.time.ZonedDateTime; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Objects; +import java.util.SortedMap; +import java.util.stream.Collectors; + +import com.google.gson.JsonElement; + +import io.openems.common.types.ChannelAddress; +import io.openems.edge.predictor.lstmmodel.common.HyperParameters; + +public class DataUtility { + + /** + * Extracts data values. + * + * @param queryResult The SortedMap queryResult. + * @return An ArrayList of Double values extracted from non-null JsonElement + * values. + */ + public static ArrayList getData( + SortedMap> queryResult) { + + ArrayList data = new ArrayList<>(); + + queryResult.values().stream()// + .map(SortedMap::values)// + .flatMap(Collection::stream)// + .map(v -> { + if (v.isJsonNull()) { + return null; + } + return v.getAsDouble(); + }).forEach(value -> data.add(value)); + + // TODO remove this later + if (isAllNulls(data)) { + System.out.println("Data is all null, use a different predictor"); + } + + return data; + } + + /** + * Checks if all elements in an ArrayList are null. + * + * @param array The ArrayList to be checked. + * @return true if all elements in the ArrayList are null, false otherwise. + */ + private static boolean isAllNulls(ArrayList array) { + return array.stream().allMatch(Objects::isNull); + } + + /** + * Combines trend and seasonality predictions into a single list of values. + * + * @param trendPrediction The list of predicted trend values. + * @param seasonalityPrediction The list of predicted seasonality values. + * @return A combined list containing both trend and seasonality predictions. + * + */ + public static ArrayList combine(ArrayList trendPrediction, + ArrayList seasonalityPrediction) { + + for (int l = 0; l < trendPrediction.size(); l++) { + seasonalityPrediction.set(l, trendPrediction.get(l)); + } + return seasonalityPrediction; + } + + /** + * Concatenates two lists of {@code Double} values into a single list. + * + * + * @param list1 the first list of {@code Double} values, may be {@code null} + * @param list2 the second list of {@code Double} values, may be {@code null} + * @return a new {@link ArrayList} containing all elements from both input + * lists, or an empty list if both inputs are {@code null} + */ + public static ArrayList concatenateList(ArrayList list1, ArrayList list2) { + ArrayList result = new ArrayList<>(); + + if (list1 != null) { + result.addAll(list1); + } + + if (list2 != null) { + result.addAll(list2); + } + + return result; + } + + /** + * Extracts OffsetDateTime objects from the keys of a SortedMap containing + * ZonedDateTime keys. + * + * @param queryResult The SortedMap containing ZonedDateTime keys and associated + * data. + * @return An ArrayList of OffsetDateTime objects extracted from the + * ZonedDateTime keys. + */ + public static ArrayList getDate( + SortedMap> queryResult) { + return queryResult.keySet().stream()// + .map(ZonedDateTime::toOffsetDateTime)// + .collect(Collectors.toCollection(ArrayList::new)); + } + + /** + * Get minute. + * + * @param nowDate the now date + * @param hyperParameters the hyperparameter + * @return int minute + */ + public static Integer getMinute(ZonedDateTime nowDate, HyperParameters hyperParameters) { + int interval = hyperParameters.getInterval(); + int minute = nowDate.getMinute(); + return (int) (minute / interval) * interval; + } + +} diff --git a/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/utilities/MathUtils.java b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/utilities/MathUtils.java new file mode 100644 index 00000000000..1a8b95224ff --- /dev/null +++ b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/utilities/MathUtils.java @@ -0,0 +1,45 @@ +package io.openems.edge.predictor.lstmmodel.utilities; + +public class MathUtils { + + /** + * Returns the hyperbolic tangent of a double value. + * + * @param val double value + * @return The hyperbolic tangent of double value + */ + public static double tanh(double val) { + return Math.tanh(val); + } + + /** + * Returns the sigmoid of a double value. + * + * @param val double value + * @return The sigmoid of a double value + */ + public static double sigmoid(double val) { + return 1 / (1 + Math.pow(Math.E, -val)); + } + + /** + * Returns the sigmoid derivative of a double value. + * + * @param val double value + * @return The sigmoid derivative of a double value + */ + public static double sigmoidDerivative(double val) { + return sigmoid(val) * (1 - sigmoid(val)); + } + + /** + * Returns the tanh derivative of a double value. + * + * @param val double value + * @return The tanh derivative of a double value + */ + public static double tanhDerivative(double val) { + return 1 - Math.pow(tanh(val), 2); + } + +} diff --git a/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/utilities/SlidingWindowSpliterator.java b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/utilities/SlidingWindowSpliterator.java new file mode 100644 index 00000000000..0825dbe53e6 --- /dev/null +++ b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/utilities/SlidingWindowSpliterator.java @@ -0,0 +1,79 @@ +package io.openems.edge.predictor.lstmmodel.utilities; + +import java.util.ArrayDeque; +import java.util.Arrays; +import java.util.Collection; +import java.util.Iterator; +import java.util.Objects; +import java.util.Queue; +import java.util.Spliterator; +import java.util.function.Consumer; +import java.util.stream.Stream; +import java.util.stream.StreamSupport; + +public class SlidingWindowSpliterator implements Spliterator> { + + /** + * creates windows. + * + * @param generic data type + * @param stream Collection + * @param windowSize size of the window + * @return result List of List + */ + public static Stream> windowed(Collection stream, int windowSize) { + return StreamSupport.stream(new SlidingWindowSpliterator<>(stream, windowSize), false); + } + + private final Queue buffer; + private final Iterator sourceIterator; + private final int windowSize; + private final int size; + + private SlidingWindowSpliterator(Collection source, int windowSize) { + this.buffer = new ArrayDeque<>(windowSize); + this.sourceIterator = Objects.requireNonNull(source).iterator(); + this.windowSize = windowSize; + this.size = calculateSize(source, windowSize); + } + + @SuppressWarnings("unchecked") + @Override + public boolean tryAdvance(Consumer> action) { + if (this.windowSize < 1) { + return false; + } + + while (this.sourceIterator.hasNext()) { + this.buffer.add(this.sourceIterator.next()); + + if (this.buffer.size() == this.windowSize) { + action.accept(Arrays.stream((T[]) this.buffer.toArray(new Object[0]))); + this.buffer.poll(); + return true; + } + } + + return false; + } + + @Override + public Spliterator> trySplit() { + return null; + } + + @Override + public long estimateSize() { + return this.size; + } + + @Override + public int characteristics() { + return ORDERED | NONNULL | SIZED; + } + + private static int calculateSize(Collection source, int windowSize) { + return source.size() < windowSize ? 0 : source.size() - windowSize + 1; + } + +} diff --git a/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/utilities/UtilityConversion.java b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/utilities/UtilityConversion.java new file mode 100644 index 00000000000..cd72590f730 --- /dev/null +++ b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/utilities/UtilityConversion.java @@ -0,0 +1,229 @@ +package io.openems.edge.predictor.lstmmodel.utilities; + +import java.time.OffsetDateTime; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.DoubleStream; +import java.util.stream.IntStream; + +public class UtilityConversion { + + /** + * Convert {@link java.util.ArrayList} to double[][]. + * + * @param data {@link java.util.ArrayList} double + * @return result converted double [][] + */ + public static double[][] to2DArray(ArrayList> data) { + return data.stream() // + .map(UtilityConversion::to1DArray) // + .toArray(double[][]::new); + } + + /** + * Convert {@link java.util.List} of double. + * + * @param data {@link java.util.List} of Double + * @return result converted double [][] + */ + public static double[][] to2DArray(List> data) { + return data.stream() // + .map(UtilityConversion::to1DArray) // + .toArray(double[][]::new); + } + + /** + * Convert {@link java.util.List} of double to double[]. + * + * @param data {@link java.util.List} of double + * @return result converted double [] + */ + public static double[] to1DArray(List data) { + return data.stream() // + .mapToDouble(d -> { + if (d == null || d.isNaN() || Double.isNaN(d)) { + return Double.NaN; + } + return d.doubleValue(); + }).toArray(); + } + + /** + * Convert {@link java.util.List} of {@link OffsetDateTime} to + * {@link OffsetDateTime}[]. + * + * @param data {@link java.util.List} of {@link OffsetDateTime} + * @return result converted {@link OffsetDateTime} [] + */ + public static OffsetDateTime[] to1DArray(ArrayList data) { + return data.stream().toArray(OffsetDateTime[]::new); + } + + /** + * Converts an ArrayList of Double values to an array of Integer values. + * + * @param data The ArrayList of Double values to be converted. + * @return An array of Integer values where each element represents the + * converted value from the input ArrayList. + */ + public static Integer[] toInteger1DArray(ArrayList data) { + return data.stream() // + .mapToInt(d -> d.intValue())// + .boxed()// + .toArray(Integer[]::new); + } + + /** + * Converts a three-dimensional ArrayList of Double values into a + * three-dimensional array. + * + * @param data The three-dimensional ArrayList to be converted. + * @return A three-dimensional array containing the elements of the input + * ArrayList. + */ + public static double[][][] to3DArray(ArrayList>> data) { + double[][][] returnArray = new double[data.size()][][]; + for (int i = 0; i < data.size(); i++) { + returnArray[i] = to2DArray(data.get(i)); + } + return returnArray; + } + + /** + * Converts a three-dimensional array of Double values into a two-dimensional + * array. This method converts the input three-dimensional array into a + * three-dimensional ArrayList, then into a two-dimensional ArrayList, and + * finally into a two-dimensional array. + * + * @param data The three-dimensional array to be converted.s + * @return A two-dimensional array containing the elements of the input array. + */ + public static double[][] to2DList(double[][][] data) { + return to2DArray(to2DArrayList(to3DArrayList(data))); + } + + /** + * Converts a two-dimensional array of double values to a two-dimensional + * ArrayList of Double values. + * + * @param data The two-dimensional array of double values to be converted. + * @return A two-dimensional ArrayList of Double values representing the + * converted data. + */ + public static ArrayList> to2DArrayList(double[][] data) { + ArrayList> toReturn = new ArrayList>(); + for (int i = 0; i < data.length; i++) { + ArrayList temp = new ArrayList(); + for (int j = 0; j < data[i].length; j++) { + temp.add(data[i][j]); + + } + toReturn.add(temp); + } + return toReturn; + } + + /** + * Converts a three-dimensional ArrayList of Double values into a + * two-dimensional ArrayList. + * + * @param data The three-dimensional ArrayList to be converted. + * @return A two-dimensional ArrayList containing the elements of the input + * ArrayList. + */ + public static ArrayList> to2DArrayList(ArrayList>> data) { + var resized = new ArrayList>(); + + for (int i = 0; i < data.size(); i++) { + for (int j = 0; j < data.get(i).size(); j++) { + resized.add(data.get(i).get(j)); + } + } + return resized; + } + + /** + * Convert double[] to {@link java.util.ArrayList} of Double. + * + * @param toBeConverted array of double + * @return result converted Array list + */ + public static ArrayList to1DArrayList(double[] toBeConverted) { + + return DoubleStream.of(toBeConverted) // + .boxed() // + .collect(Collectors.toCollection(ArrayList::new)); + } + + /** + * Convert double[] to {@link java.util.ArrayList} of OffsetDateTime. + * + * @param toBeConverted array of OffsetDateTime + * @return result converted Array list + */ + public static ArrayList to1DArrayList(OffsetDateTime[] toBeConverted) { + return Arrays.stream(toBeConverted).collect(Collectors.toCollection(ArrayList::new)); + } + + /** + * convert 2DArrayList To 1DArray. + * + * @param data the data + * @return converted the converted + */ + public static ArrayList to1DArrayList(ArrayList> data) { + return (ArrayList) data.stream()// + .flatMap(Collection::stream)// + .collect(Collectors.toList()); + } + + /** + * Convert {@link java.util.List} of Integer to {@link java.util.List} of + * Double. + * + * @param toBeConverted the {@link java.util.List} of Integer + * @return result {@link java.util.List} of Double. + */ + public static List toBoxed1DList(List toBeConverted) { + return toBeConverted.stream() // + .mapToDouble(i -> i == null ? null : i) // + .boxed() // + .collect(Collectors.toList()); + } + + /** + * Converts a three-dimensional array into a three-dimensional ArrayList of + * Double values. + * + * @param data The three-dimensional array to be converted. + * @return A three-dimensional ArrayList containing the elements of the input + * array. + */ + public static ArrayList>> to3DArrayList(double[][][] data) { + var returnArray = new ArrayList>>(); + for (int i = 0; i < data.length; i++) { + returnArray.add(to2DArrayList(data[i])); + } + return returnArray; + } + + /** + * Get the index of the Min element in an array. + * + * @param arr double array. + * @return iMin index of the min element in an array. + */ + public static int getMinIndex(double[] arr) { + if (arr == null || arr.length == 0) { + throw new IllegalArgumentException("Array must not be empty or null"); + } + + return IntStream.range(0, arr.length)// + .boxed()// + .min((i, j) -> Double.compare(arr[i], arr[j]))// + .orElseThrow(); + } +} diff --git a/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/validator/ValidationSeasonalityModel.java b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/validator/ValidationSeasonalityModel.java new file mode 100644 index 00000000000..98fd48f7ccf --- /dev/null +++ b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/validator/ValidationSeasonalityModel.java @@ -0,0 +1,161 @@ +package io.openems.edge.predictor.lstmmodel.validator; + +import java.io.File; +import java.time.OffsetDateTime; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import io.openems.edge.predictor.lstmmodel.LstmPredictor; +import io.openems.edge.predictor.lstmmodel.common.DataStatistics; +import io.openems.edge.predictor.lstmmodel.common.HyperParameters; +import static io.openems.edge.predictor.lstmmodel.performance.PerformanceMatrix.rmsError; +import static io.openems.edge.predictor.lstmmodel.performance.PerformanceMatrix.accuracy; +import io.openems.edge.predictor.lstmmodel.preprocessing.DataModification; +import io.openems.edge.predictor.lstmmodel.preprocessingpipeline.PreprocessingPipeImpl; +import io.openems.edge.predictor.lstmmodel.utilities.UtilityConversion; + +public class ValidationSeasonalityModel { + + public static final String SEASONALITY = "seasonality"; + + /** + * Validate the Seasonality. + * + * @param values the values + * @param dates the dates + * @param untestedSeasonalityWeight Models to validate. + * @param hyperParameters the hyperParameters + */ + + public void validateSeasonality(ArrayList values, ArrayList dates, + ArrayList>>> untestedSeasonalityWeight, + HyperParameters hyperParameters) { + + ArrayList> rmsTemp2 = new ArrayList>(); + + PreprocessingPipeImpl preProcessing = new PreprocessingPipeImpl(hyperParameters); + double[][][] dataGroupedByMinute = (double[][][]) preProcessing.setData(UtilityConversion.to1DArray(values))// + .setDates(dates)// + .interpolate()// + .movingAverage()// + .scale()// + .filterOutliers()// + .groupByHoursAndMinutes()// + .execute(); + + ArrayList>>> allModels = DataModification + .reshape((DataModification.flattern4dto3d(untestedSeasonalityWeight)), hyperParameters); + + for (int h = 0; h < allModels.size(); h++) { + ArrayList rmsTemp1 = new ArrayList(); + int k = 0; + for (int i = 0; i < dataGroupedByMinute.length; i++) { + for (int j = 0; j < dataGroupedByMinute[i].length; j++) { + + double[][][] intermediate = (double[][][]) preProcessing.setData(dataGroupedByMinute[i][j])// + // .differencing()// + .groupToWIndowSeasonality()// + .execute(); + + double[][][] preProcessed = (double[][][]) preProcessing.setData(intermediate)// + .normalize()// + .shuffle()// + .execute(); + + ArrayList> val = allModels.get(h).get(k); + + double[] result = (double[]) preProcessing// + .setData(UtilityConversion + .to1DArray(LstmPredictor.predictPre(preProcessed[0], val, hyperParameters)))// + .setMean(DataStatistics.getMean(intermediate[0]))// + .setStandardDeviation(DataStatistics.getStandardDeviation(intermediate[0]))// + .reverseNormalize()// + .reverseScale()// + .execute(); + + double rms = rmsError(// + (double[]) preProcessing// + .setData(intermediate[1][0])// + .reverseScale()// + .execute(), + result) // + * // + (1 - accuracy((double[]) preProcessing.setData(intermediate[1][0])// + .reverseScale()// + .execute(), result, 0.01)); + + rmsTemp1.add(rms); + k = k + 1; + + } + + } + rmsTemp2.add(rmsTemp1); + } + List> optInd = findOptimumIndex(rmsTemp2, SEASONALITY, hyperParameters); + + DataModification.updateModel(allModels, optInd, + Integer.toString(hyperParameters.getCount()) + hyperParameters.getModelName() + SEASONALITY, + SEASONALITY, hyperParameters); + } + + /** + * Find the indices of the minimum values in each column of a 2D matrix. This + * method takes a 2D matrix represented as a List of Lists and finds the row + * indices of the minimum values in each column. The result is returned as a + * List of Lists, where each inner list contains two integers: the row index and + * column index of the minimum value. + * + * @param matrix A 2D matrix represented as a List of Lists of doubles. + * @param variable the variable + * @param hyperParameters the hyperParameters + * @return A List of Lists containing the row and column indices of the minimum + * values in each column. If the input matrix is empty, an empty list is + * returned. + */ + + public static List> findOptimumIndex(ArrayList> matrix, String variable, + HyperParameters hyperParameters) { + List> minimumIndices = new ArrayList<>(); + + if (matrix.isEmpty() || matrix.get(0).isEmpty()) { + return minimumIndices; // Empty matrix, return empty list + } + + int numColumns = matrix.get(0).size(); + + for (int col = 0; col < numColumns; col++) { + double min = matrix.get(0).get(col); + List minIndices = new ArrayList<>(Arrays.asList(0, col)); + + for (int row = 0; row < matrix.size(); row++) { + double value = matrix.get(row).get(col); + + if (value < min) { + min = value; + minIndices.set(0, row); + } + } + + minimumIndices.add(minIndices); + } + + ArrayList err = new ArrayList(); + for (int i = 0; i < minimumIndices.size(); i++) { + + err.add(matrix.get(minimumIndices.get(i).get(0)).get(minimumIndices.get(i).get(1))); + } + hyperParameters.setAllModelErrorSeason(err); + double errVal = DataStatistics.getStandardDeviation(err, hyperParameters.getTargetError()); + hyperParameters.setRmsErrorSeasonality(errVal); + System.out.println("=====> Average RMS error for " + variable + " = " + errVal); + return minimumIndices; + } + + @FunctionalInterface + interface FilePathGenerator { + String generatePath(File file, String fileName); + } + +} diff --git a/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/validator/ValidationTrendModel.java b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/validator/ValidationTrendModel.java new file mode 100644 index 00000000000..20087ac3966 --- /dev/null +++ b/io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/validator/ValidationTrendModel.java @@ -0,0 +1,157 @@ +package io.openems.edge.predictor.lstmmodel.validator; + +import java.time.OffsetDateTime; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import io.openems.edge.predictor.lstmmodel.LstmPredictor; +import io.openems.edge.predictor.lstmmodel.common.DataStatistics; +import io.openems.edge.predictor.lstmmodel.common.HyperParameters; +import io.openems.edge.predictor.lstmmodel.performance.PerformanceMatrix; +import io.openems.edge.predictor.lstmmodel.preprocessing.DataModification; +import io.openems.edge.predictor.lstmmodel.preprocessingpipeline.PreprocessingPipeImpl; +import io.openems.edge.predictor.lstmmodel.utilities.UtilityConversion; + +public class ValidationTrendModel { + public static final String TREND = "trend"; + + /** + * validate Trend. + * + * @param values the value + * @param dates the date + * @param untestedTrendWeights Untested Models. + * @param hyperParameters the hyperParam + */ + + public void validateTrend(ArrayList values, ArrayList dates, + ArrayList>>> untestedTrendWeights, HyperParameters hyperParameters) { + + ArrayList>>> allModels = DataModification + .reshape((DataModification.flattern4dto3d(untestedTrendWeights)), hyperParameters); + + ArrayList> rmsErrors = this.validateModels(// + values, // + dates, // + allModels, // + hyperParameters); + + List> optInd = findOptimumIndex(rmsErrors, TREND, hyperParameters); + + this.updateModels(// + allModels, // + optInd, + Integer.toString(hyperParameters.getCount()) + hyperParameters.getModelName() + TREND, // + TREND, + hyperParameters); + } + + /** + * Find the indices of the maximum values in each column of a 2D matrix. This + * method takes a 2D matrix represented as a List of Lists and finds the row + * indices of the maximum values in each column. The result is returned as a + * List of Lists, where each inner list contains two integers: the row index and + * column index of the maximum value. + * + * @param matrix A 2D matrix represented as a List of Lists of doubles. + * @param var the var + * @param hyperParameters the hyperParam + * @return A List of Lists containing the row and column indices of the maximum + * values in each column. If the input matrix is empty, an empty list is + * returned. + */ + + public static List> findOptimumIndex(ArrayList> matrix, String var, + HyperParameters hyperParameters) { + List> minimumIndices = new ArrayList<>(); + + if (matrix.isEmpty() || matrix.get(0).isEmpty()) { + return minimumIndices; // Empty matrix, return empty list + } + + int numColumns = matrix.get(0).size(); + + for (int col = 0; col < numColumns; col++) { + double min = matrix.get(0).get(col); + List minIndices = new ArrayList<>(Arrays.asList(0, col)); + + for (int row = 0; row < matrix.size(); row++) { + double value = matrix.get(row).get(col); + + if (value < min) { + min = value; + minIndices.set(0, row); + } + } + + minimumIndices.add(minIndices); + } + + ArrayList err = new ArrayList(); + for (int i = 0; i < minimumIndices.size(); i++) { + err.add(matrix.get(minimumIndices.get(i).get(0)).get(minimumIndices.get(i).get(1))); + } + hyperParameters.setAllModelErrorTrend(err); + double errVal = DataStatistics.getStandardDeviation(err, hyperParameters.getTargetError()); + hyperParameters.setRmsErrorTrend(errVal); + System.out.println("=====> Average RMS error for " + var + " = " + errVal); + return minimumIndices; + } + + private ArrayList> validateModels(ArrayList value, ArrayList dates, + ArrayList>>> allModels, HyperParameters hyperParameters) { + + PreprocessingPipeImpl validateTrendPreProcess = new PreprocessingPipeImpl(hyperParameters); + double[][] modifiedData = (double[][]) validateTrendPreProcess// + .setData(UtilityConversion.to1DArray(value))// + .setDates(dates)// + .interpolate()// + .movingAverage()// + .scale()// + .filterOutliers()// + .modifyForTrendPrediction()// + .execute(); + + ArrayList> rmsTemp2 = new ArrayList<>(); + + for (ArrayList>> modelsForData : allModels) { + ArrayList rmsTemp1 = new ArrayList<>(); + + for (int j = 0; j < modifiedData.length; j++) { + double[][][] intermediate = (double[][][]) validateTrendPreProcess.setData(modifiedData[j])// + // .differencing()// + .groupToStiffedWindow()// + .execute(); + + double[][][] preprocessed = (double[][][]) validateTrendPreProcess.setData(intermediate)// + .normalize()// + .shuffle()// + .execute(); + + double[] result = (double[]) validateTrendPreProcess// + .setData(UtilityConversion.to1DArray( + LstmPredictor.predictPre(preprocessed[0], modelsForData.get(j), hyperParameters)))// + .setMean(DataStatistics.getMean(intermediate[0]))// + .setStandardDeviation(DataStatistics.getStandardDeviation(intermediate[0]))// + .reverseNormalize()// + .reverseScale()// + .execute(); + + double rms = PerformanceMatrix.rmsError( + (double[]) validateTrendPreProcess.setData(intermediate[1][0]).reverseScale().execute(), + result) * (1 - PerformanceMatrix.accuracy((double[]) validateTrendPreProcess.setData(intermediate[1][0]).reverseScale().execute(), + result, 0.01)); + rmsTemp1.add(rms); + } + rmsTemp2.add(rmsTemp1); + } + return rmsTemp2; + } + + private void updateModels(ArrayList>>> allModels, List> optInd, + String modelFileName, String modelType, HyperParameters hyperParameters) { + DataModification.updateModel(allModels, optInd, modelFileName, modelType, hyperParameters); + } + +} diff --git a/io.openems.edge.predictor.lstmmodel/test/.gitignore b/io.openems.edge.predictor.lstmmodel/test/.gitignore new file mode 100644 index 00000000000..e69de29bb2d diff --git a/io.openems.edge.predictor.lstmmodel/test/io/openems/edge/predictor/lstmmodel/Data.java b/io.openems.edge.predictor.lstmmodel/test/io/openems/edge/predictor/lstmmodel/Data.java new file mode 100644 index 00000000000..f0ef3173560 --- /dev/null +++ b/io.openems.edge.predictor.lstmmodel/test/io/openems/edge/predictor/lstmmodel/Data.java @@ -0,0 +1,249 @@ +package io.openems.edge.predictor.lstmmodel; + +public class Data { + + public static final Integer[] data = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // + + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // + + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // + + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // + + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // + + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // + + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // + + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // + + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 19628148, 6550910, 22317828, 15206562, 9058611, + 23406944, 9219693, 15800971, 21296530, 4527141, 22562770, 15267745, 9018811, 23595324, 8430786, 16389406, + 21525472, 4077697, 22768848, 15157240, 9138508, 23968964, 7363085, 17083467, 21832143, 3594137, 22989236, + 15085879, 9228392, 24124509, 6571137, 17642223, 21475288, 4499115, 22445357, 14785938, 9713856, 23915723, + 6550867, 17656891, 20921635, 5474489, 22062513, 14542193, 10049813, 23690578, 6654455, 17689559, 20435467, + 6028892, 21951157, 14613696, 10015734, 23385682, 7012689, 17595474, 20197424, 6435626, 21802029, 14528000, + 10124220, 23141960, 2640169, 8624727, 15593695, 21078270, 6648533, 20702052, 16174167, 8652929, 23103674, + 9415632, 15315525, 21073269, 6429999, 20912249, 16072814, // + + 8688754, 23011505, 9745347, 15174361, 21202293, 6012314, 21216329, 15949181, 8593743, 9973991, 23510081, + 4963040, 19950775, 19491470, 5166975, 23717579, 12303753, 11982384, 23957607, 3778278, 20679632, 19257248, + 5348607, 23838658, 12263030, 12045896, 23524143, 4315132, 20480132, 18894586, 6167012, 23353918, 11973767, + 12412515, 23142252, 5014560, 20084238, 18574591, 6843968, 23042868, 11588716, 12864209, 22745707, 5830336, + 19668607, 18191762, 7343509, 22899931, 11226024, 13150314, 22541918, 6444932, 19260087, 18102400, 7389853, + 22935268, 10986745, 13260801, 22431435, 7086697, 18887108, 18044029, 7377111, 23035500, 11036566, 13216566, + 22184679, 7455357, 18808808, 18012193, 7241047, 23093461, 11080178, 13203141, 22075695, 7330308, 19025666, + 18174654, 6971185, 23062072, 11262408, 13158150, 22087116, 7025189, 19303211, 18293406, 6804257, 23044594, + 11373599, 13171686, 22254260, 6480438, 19696518, 18338323, 6581887, 23151076, // + + 11604769, 12952137, 22589192, 5622771, 20203554, 18302906, 6441386, 23357620, 11445514, 13091251, 22956998, + 4673724, 20807399, 18288487, 6105346, 23848082, 11367278, 13054746, 23273602, 3760297, 21381618, 18225835, + 5981923, 24207759, 11145130, 13123652, 23265447, 3621941, 21545214, 17901138, 6758239, 23756140, 10548255, + 13757071, 22886070, 4390518, 21041361, 17532483, 7454236, 23428798, 10129041, 14229790, 22436513, 5324878, + 20568731, 17178127, 7998347, 23282483, 9856018, 14433316, 22268708, 5949598, 20130590, 17129247, 8022482, + 23263787, 9602355, 14605403, 21927623, 6715756, 19772138, 16903530, 8167990, 23303579, 9620705, 14640926, + 21673839, 6883218, 19889343, 17009546, 7990983, 23226453, 9858403, 14569637, 21493377, 6926327, 20014532, + 17145773, 7809541, 23072085, 10109997, 14492824, 21526944, 6685882, 20217357, 17121589, 7727077, 23114353, + 10420701, 14239556, 21705368, 6126628, 20583521, 17070504, 7684599, 23136513, // + + 10611913, 14154820, 22024580, 5428125, 20992158, 17062487, 7446859, 23470283, 10582385, 14141467, 22284292, + 4557025, 21599418, 16885615, 7338373, 23841840, 10565609, 14034098, 22524607, 3819144, 22069946, 17094514, + 7115064, 24051681, 9727720, 14636115, 22991141, 3044869, 22416925, 16925266, 7464993, 24025258, 8997311, + 15272431, 22354895, 4227531, 21784802, 16444165, 8257695, 23734762, 8626851, 15632950, 22143188, 4954302, + 21267194, 16244320, 8589471, 23587662, 8393340, 15865528, 21702258, 5925147, 20787704, 16105981, 8753773, + 23559390, 8347191, 15864910, 21366536, 6330591, 20781495, 16012461, 8822725, 23382794, 8488262, 15917033, + 20935476, 6680033, 20800007, 16061273, 8776216, 23166338, 8785728, 15844450, 20941593, 6636992, 20860341, + 15951301, 8846804, 23106750, 9177339, 15578916, 20943560, 6438064, 21033893, 15783067, 8924339, 22983467, + 9605892, 15331428, 21078934, 6019274, 21317309, 15728719, 8783992, 23180832, // + + 9713243, 15249007, 21279408, 5367330, 21768778, 15580151, 8696040, 23417895, 9887160, 15097248, 21460661, + 4764798, 22218240, 15720414, 8487344, 23605753, 9446081, 15407128, 21751813, 4180490, 22494522, 15815748, + 8439888, 23815168, 8522399, 16095019, 21935882, 3721703, 22741700, 15606351, 8664988, 24207759, 7370746, + 16889912, 21944239, 3906499, 22590875, 15419180, 9032978, 23986993, 6997075, 17218623, 21527687, 4848544, + 22031346, 15128377, 9466351, 23820788, 7001541, 17206217, 21004085, 5708207, 21741097, 14890121, 9803186, + 23603337, 7024611, 17341331, 20510146, 6244005, 21716133, 14890115, 9786500, 23306787, 7363558, 17284785, + 20320208, 6511591, 21608014, 14750756, 9930128, 23124925, 7913388, 16942690, 20203263, 6587292, 21624962, + 14513392, 10113276, 22919554, 8398547, 16678506, 20181005, 6516502, 21762500, 14295383, 10193554, 22936572, + 8698312, 16537405, 20296021, 6041988, 22100770, 14183007, 10083567, 23105447, // + + 8989410, 16271687, 20399219, 5615836, 22402447, 14305604, 9902154, 23119856, 8923746, 16371915, 20571323, + 5160131, 22645294, 14358318, 9888208, 23240084, 8265843, 16909934, 20805330, 4710098, 22820691, 14360568, + 9931777, 23539247, 7375623, 9415822, 22307506, 15560049, 8671193, 23573577, 9112068, 15751951, 21660721, + 4082015, 22642615, 15402282, 8842632, 23818278, 8039991, 16575794, 21745508, 3806327, 22913415, 15130961, + 9141456, 24229806, 6818899, 17440294, 21759736, 3957445, 22736323, 14901893, 9568919, 24021685, 6471824, + 17780626, 21042461, 5155479, 22294767, 14334638, 10243117, 23786048, 6294118, 18010859, 20419313, 5826465, + 22215726, 14341022, 10276711, 23433900, 6612880, 18009829, 19996368, 6366523, 22052626, 14094256, 10512193, + 23152641, 7019419, 17898658, 19751155, 6615911, 22068261, 13659363, 10943383, 22892049, 7609590, 17637458, + 19507319, 6792320, 22143856, 13279319, 11195651, 22809641, 7998949, 17453715, // + + 19491914, 6485574, 22473502, 12939263, 11339674, 22836780, 8388122, 17241483, 19445575, 6201345, 22766294, + 12881923, 11373551, 22762814, 8080644, 17647664, 19537048, 5905348, 22931837, 12797754, 11564040, 22736543, + 7535507, 18215295, 19550015, 5732163, 22978501, 12702490, 11702830, 22951044, 6636079, 18872397, 19686218, + 5394917, 23188512, 12718043, 11684092, 23347746, 5486549, 19652009, 19587479, 5265496, 23508795, 12366861, + 12010603, 23807618, 4155776, 20472076, 19484893, 4953684, 23976940, 12168409, 12097796, 23668636, 4048813, + 20641313, 18826537, 6032773, 23556207, 11790335, 12600559, 23219029, 4655481, 20475133, 18392416, 6829019, + 23237060, 11159992, 13233021, 22704457, 5504649, 20077577, 17803590, 7545229, 23082117, 10665191, 13669123, + 22413994, 6213296, 19705135, 17678742, 7647986, 23130765, 10269325, 13953565, 22181609, 6892237, 19343592, + 17440090, 7731004, 23264261, 10163309, 14101344, 21889465, 7023527, 19556940, // + + 17361822, 7695152, 23189274, 10150679, 14282087, 21611223, 6991054, 19858916, 17332204, 7643778, 23066653, + 10275184, 14319879, 21627003, 6630320, 20202211, 17215008, 7625807, 23111607, 10505851, 14214845, 21830942, + 5959063, 20687659, 16974341, 7694019, 23208585, 10415890, 14335779, 22019651, 5135061, 21303623, 16776551, + 7543203, 23599951, 10487839, 14225557, 22214630, 4365688, 21876947, 16744734, 7479200, 23818705, 9988650, + 14615526, 22463795, 3626725, 22357371, 16724667, 7538762, 24106610, 8821290, 15516785, 22500738, 3454023, + 22466628, 16044371, 8388949, 23997234, 7947773, 16293538, 22026503, 4430789, 21968253, 15687162, 8972812, + 23832462, 7550030, 16706012, 21471067, 5444039, 21499281, 15367490, 9361479, 23699015, 7411397, 16862884, + 20969087, 5993594, 21507869, 15116464, 9616895, 23434183, 7508189, 17030382, 20426235, 6503849, 21547151, + 15059192, 9671367, 23138310, 7804793, 16991347, 20251935, 6628005, 21579210, // + + 14656789, 9984782, 22985062, 8236105, 16782028, 20165913, 6512453, 21780210, 14212013, 10297958, 22920498, + 8518032, 16683932, 20152224, 6152185, 22153479, 13915924, 10344687, 23061539, 8820218, 16522693, 20140536, + 5759600, 22556871, 13838808, 10368950, 23018087, 8559741, 16890487, 20254554, 5347371, 22787393, 13751540, + 10525072, 23064180, 7856609, 17554570, 20300641, 5118343, 22908214, 13603239, 10736493, 23356200, 6794841, + 18308454, 20525840, 4672346, 23187248, 13609829, 10748145, 23773670, 5514352, 19145320, 20475073, 4490305, + 23522685, 13058746, 11228823, 24110847, 4459320, 19869016, 19849358, 5115280, 23451562, 12779960, 11578253, + 23632326, 4796332, 19842826, 19187801, 6102427, 23125288, 12634340, 11828408, 23198848, 5345817, 19700246, + 18824251, 6737926, 22873861, 12037827, 12409570, 22814959, 6063019, 19345395, 18476572, 7197820, 22768287, + 11529247, 12896938, 22576794, 6692266, 19042234, 18276846, 7288199, 22904648, // + + 11141858, 13137570, 22401204, 7292607, 18730801, 18097388, 7279018, 23100747, 11030513, 13177245, 22183872, + 7350550, 18926763, 18097829, 7116866, 23103889, 10983888, 13362456, 21948964, 7214260, 19312614, 18045112, + 7073893, 23037177, 11027992, 13480188, 22038919, 6691771, 19728807, 18023490, 6935789, 23135144, 11185965, + 13392257, 22219528, 5940309, 20299028, 17790274, 6956401, 23242452, 11070140, 13566321, 22526643, 4974710, + 20957420, 17625559, 6757956, 23692292, 10958044, 13635770, 22785082, 4046866, 21625268, 17475702, 6732056, + 24046416, 10450486, 13963008, 23011575, 3246213, 22235904, 17422398, 6909740, 24127021, 9440164, 14839700, + 22639874, 3890974, 21883594, 16710762, 7984298, 23793381, 8857051, 15456540, 22169292, 4854947, 21368201, + 16397896, 8449589, 23598584, 8388962, 15876931, 21776886, 5699911, 20961300, 16036047, 8799304, 23626879, + 8210654, 16017666, 21336027, 6214812, 20910345, 15887038, 8955933, 23411848, // + + 8221308, 16235322, 20814755, 6615974, 21030755, 15725976, 9093706, 23155009, 8475146, 16256652, 20658188, + 6652752, 21126902, 15457722, 9247262, 23073144, 8825443, 16066263, 20602648, 6426678, 21430883, 15143338, + 9482476, 22984419, 9029727, 16066843, 20644480, 5970987, 21820003, 14827391, 9507534, 23192936, 9256136, + 15900757, 20743238, 5435837, 22278134, 14712817, 9540192, 23242194, 9098919, 16087278, 20859194, 4954511, + 22632370, 14677293, 9611700, 23344214, 8279809, 16837444, 20966693, 4590983, 22822408, 14377899, 9953260, + 23704513, 7075774, 17745043, 21096861, 4234971, 23070139, 14203610, 10057726, 24109993, 5693141, 18681318, + 21025666, 4262303, 23173846, 13850170, 10487687, 24071718, 5320218, 18927447, 20363907, 5177843, 22928321, + 13526949, 10946371, 23664238, 5483270, 19054689, 19622280, 6090506, 22748322, 13435280, 11095644, 23209325, + 5963149, 18964576, 19224673, 6710180, 22539937, 12885926, 11644548, 22931363, // + + 6550727, 18691528, 19021441, 6924526, 22512365, 12333234, 12100024, 22681976, 7094395, 18461092, 18733849, + 7047052, 22653940, 11977218, 12326272, 22637775, 7589822, 18214544, 18635699, 6919708, 22901763, 11783024, + 12440462, 22386064, 7758725, 18322875, 18587777, 6764170, 23010928, 11711592, 12615519, 22277026, 7419076, + 18792359, 18614537, 6595115, 23005535, 11666795, 12784985, 22366400, 6794275, 19289719, 18662622, 6387813, + 23116233, 11766734, 12733555, 22593784, 5906160, 19958795, 18517767, 6349274, 23256681, 11679895, 12904180, + 22978760, 4788741, 20707805, 18292821, 6174101, 23794420, 11312192, 13128047, 23215355, 3795653, 21481949, + 18120240, 6132456, 24208590, 10852530, 13431073, 23245858, 3446308, 21713892, 17753336, 6864070, 23841354, + 10045453, 14248223, 22698325, 4372338, 21262372, 17148096, 7774773, 23565387, 9500381, 14839456, 22277969, + 5258493, 20837354, 16841397, 8169576, 23446618, 9012608, 15210199, 21902217, // + + 6093800, 20440510, 16523055, 8454472, 23499292, 8890729, 15362883, 21514259, 6481556, 20465357, 16341039, + 8577290, 23332686, 8821754, 15592311, 21052378, 6736444, 20672449, 16239111, 8664187, 23106530, 8993414, + 15744851, 20931383, 6644621, 8658437, 18678106, 19513676, 6577549, 22352850, 13307070, 11220911, 23058413, + 6722916, 18356704, 19290179, 6861505, 22307151, 12956656, 11594109, 22704260, 7365840, 18044670, 19153685, + 6939335, 22387956, 12654702, 11695187, 22763499, 7659613, 17888308, 19120689, 6661917, 22648852, 12391634, + 11833808, 22632820, 8090628, 17726750, 18984666, 6550826, 22923766, 12305494, 11981807, 22545599, 7813823, + 18071344, 19131109, 6206572, 22994555, 12433015, 11971871, 22564600, 7268448, 18601799, 19237693, 5935895, + 23016018, 12334842, 12102054, 22799422, 6400726, 19215368, 19280092, 5703876, 23197162, 12396553, 12050686, + 23194007, 5334410, 19887100, 19164018, 5576837, 23511667, 12065924, 12306782, // + + 23622868, 4080603, 20712579, 19081559, 5228071, 24115611, 11874820, 12368914, 23599995, 3798160, 20981386, + 18679483, 6070552, 23690130, 11504268, 12862353, 23149031, 4501087, 20687059, 18180858, 6941855, 23316438, + 10885051, 13483119, 22652728, 5372541, 20234862, 17655542, 7623619, 23151661, 10473286, 13878866, 22420918, + 6098943, 19852138, 17555250, 7739421, 23197412, 10142209, 14126564, 22161262, 6786968, 19497045, 17392925, + 7796711, 23270766, 10041770, 14184307, 21857793, 6958854, 19598870, 17308433, 7745974, 23203566, 10095393, + 14325029, 21527484, 6963176, 19924856, 17202788, 7758382, 23077156, 10173572, 14459591, 21585809, 6671118, + 20180439, 17166377, 7720317, 23138522, 10424859, 14292828, 21713744, 6101203, 20626549, 16979851, 7722878, + 23157107, 10489919, 14314920, 21946507, 5324051, 21188151, 16830889, 7583891, 23527869, 10442655, 14319637, + 22162271, 4488870, 21779424, 16635980, 7571778, 23790018, 10211336, 14448356, // + + 22388820, 3782216, 22272059, 16707843, 7553882, 24046779, 9025574, 15343164, 22596728, 3260710, 22599398, + 16330045, 8055647, 24046000, 8225177, 16063547, 22174203, 4229892, 21984116, 15911054, 8715461, 23789002, + 7784590, 16436907, 21721867, 5185669, 21530754, 15619104, 9105793, 23699385, 7652702, 16570783, 21163955, + 5993314, 21291221, 15387257, 9396874, 23556216, 7707035, 16702941, 20742721, 6388842, 21301023, 15404540, + 9375147, 23284104, 7939228, 16725232, 20574625, 6552148, 21332800, 15120388, 9592830, 23074682, 8318688, + 16520514, 20347739, 6550180, 21517598, 14723604, 9901095, 22876358, 8783418, 16381588, 20385837, 6300682, + 21787449, 14556545, 9890382, 23061978, 8876251, 16320341, 20476953, 5771963, 22223396, 14271993, 9941271, + 23166087, 9114773, 16175465, 20439516, 5393255, 22567798, 14192929, 10075120, 23190789, 8472506, 16767177, + 20645769, 4895467, 22797989, 14255205, 10079508, 23359944, 7604190, 17494690, // + + 20835887, 4568692, 22955800, 14037796, 10295790, 23733943, 6464204, 18239932, 21009725, 4130314, 23255522, + 13997646, 10243844, 24207759, 5168454, 19039304, 20656968, 4624942, 23167800, 13426687, 9852496, 1083474, + 14401511, 12504518, 11760470, 22876805, 6882334, 18701825, 19591030, 5553179, 23075691, 12729128, 11637580, + 23195636, 5935042, 19295823, 19676761, 5290348, 23283992, 12795610, 11576564, 23740208, 4809831, 19909917, + 19786417, 4858785, 23746737, 12496082, 11789655, 23946807, 4045253, 20428663, 19288874, 5502069, 23634895, + 12464669, 11913246, 23450203, 4553808, 20282631, 18940760, 6284196, 23205745, 12059152, 12381645, 23092658, + 5274243, 19878808, 18560891, 6918818, 22953031, 11665044, 12761457, 22710132, 6063331, 19423310, 18253005, + 7351379, 22856544, 11399823, 12970589, 22552648, 6701483, 19037485, 18276846, 7288199, 22881013, 11194948, + 13045473, 22437471, 7277522, 18668126, 18194881, 7223503, 23002224, 11272126, // + + 12980994, 22235027, 7570737, 18648389, 18199546, 7079575, 23092576, 11344158, 12950794, 22148536, 7380508, + 18903079, 18374635, 6792264, 23036140, 11483778, 12941615, 22271619, 6982379, 19206892, 18511435, 6564548, + 23054622, 11648403, 12844267, 22451945, 6332730, 19643018, 18571145, 6384535, 23167304, 11776199, 12759213, + 22783475, 5467491, 20164550, 18518505, 6198329, 23459839, 11616870, 12863411, 23207627, 4412746, 20825569, + 18551736, 5816456, 23925069, 11577616, 12768346, 23509420, 3509631, 21396465, 18479896, 5900829, 24054208, + 11274147, 13072790, 23230871, 3883109, 21238146, 18014281, 6835378, 23588712, 10671534, 13698814, 22781058, + 4767364, 20777898, 17639724, 7457491, 23318302, 10228797, 14105805, 22458203, 5565319, 20305145, 17287345, + 7893967, 23234634, 9918417, 14354779, 22237385, 6273898, 19896154, 17237794, 7938721, 23265383, 9830296, + 14422620, 21909513, 6904759, 19639056, 17060244, 8059794, 23258187, 9884250, // + + 14441784, 21664245, 6998144, 19794057, 17206641, 7803948, 23190880, 10132825, 14321214, 21601674, 6903723, + 19967525, 17282574, 7672093, 23068438, 10289538, 14310632, 21648568, 6520263, 20246686, 17281895, 7593662, + 23100491, 10639937, 14090027, 21884332, 5916088, 20659926, 17146478, 7541935, 23212881, 10609063, 14112918, + 22137518, 5165440, 21112559, 17140767, 7262037, 23593894, 10627470, 14029963, 22449381, 4387880, 21620224, + 17036516, 7188465, 23901148, 10511501, 14014660, 22650195, 3635721, 22129677, 17115191, 7095639, 24162604, + 9521030, 14863812, 22852060, 3359458, 22230710, 16805215, 7694364, 23915937, 8906682, 15361011, 22338476, + 4415153, 21636633, 16337341, 8463659, 23657957, 8619140, 15646278, 22052127, 5162130, 21158554, 16237858, + 8629665, 23547994, 8395596, 15812515, 21657131, 6048249, 20712160, 16036762, 8843879, 23524539, 8450055, + 15797754, 21313416, 6413058, 20689043, 16060796, 8787268, 23328585, 8686991, // + + 15759635, 20993045, 6725825, 20714472, 16162157, 8711698, 23131537, 9014324, 15638737, 20976081, 6673640, + 20780032, 16087399, 8685277, 23096319, 9379459, 15378352, 21026225, 6385770, 21014008, 15944553, 8767459, + 22985980, 9735366, 15207478, 21206809, 5878367, 21330340, 15833491, 8636802, 23238723, 9797233, 15117025, + 21394824, 5187445, 21834285, 15672434, 8602894, 23475313, 9965772, 14986841, 21575828, 4587918, 22295084, + 15844962, 8363823, 23635333, 9337350, 15442834, 21920695, 3963374, 22542040, 15864507, 8379274, 23906847, + 8361422, 16203943, 21999170, 3632656, 22778624, 15630018, 8617641, 24198347, 7377978, 16845211, 21942632, + 4046633, 22412224, 15343078, 9158513, 23920045, 7123167, 17134635, 21451973, 5069842, 21894060, 15161509, + 9457643, 23796364, 7073572, 17134186, 7133380, 17802216, 7582668, 23030633, 10698918, 13627584, 22449998, + 6367520, 19483090, 17734427, 7613118, 23067972, 10491469, 13734358, 22250894, // + + 7020022, 19162364, 17712121, 7540578, 23162818, 10578078, 13629680, 22051954, 7229406, 19166676, 17728778, + 7437856, 23159761, 10685192, 13642402, 0, 22020842, 5485775, 20930006, 17107835, 7399238, 23432151, + 10689871, 14031720, 22285281, 4605999, 21524237, 16915342, 7318512, 23788841, 10709114, 13897782, 22553064, + 3853756, 22041180, 17155563, 7053101, 24064130, 9831320, 14543299, 22982963, 3064699, 22446253, 17014905, + 7360553, 24058692, 9066570, 15191437, 22443435, 4173088, 21785026, 16523924, 8201314, 23739668, 8686494, + 15574890, 22140380, 4949215, 21274849, 16230939, 8591944, 23592633, 8404004, 15856085, 21718829, 5887075, + 20807869, 16123129, 8733755, 23558215, 8400923, 15822116, 21400148, 6329817, 20717703, 16067168, 8790213, + 23412045, 8557373, 15873198, 20997005, 3278352, 5604555, 18603203, 19601020, 6519954, 22317594, 13443922, + 11111186, 23074649, 6756632, 18240756, 19429273, 6790209, 22220665, 8437208 // + + }; + + public static final Integer[] predictedData = { 8931905, 12081183, 13140283, 7890941, 14283475, 10869967, 9203245, + 14450631, 8856705, 12104258, 13130344, 7731838, 14527199, 10794645, 9064588, 14608937, 8823195, 12172205, + 13051236, 7545362, 14801416, 10708873, 9027609, 14723334, 8754347, 12263946, 12909274, 7489182, 14961291, + 10577290, 9215333, 14666870, 8623358, 12483200, 12643272, 7623844, 14964723, 10380271, 9467619, 14572030, + 8515648, 12750934, 12343975, 7781289, 15019471, 10155035, 9670343, 14638097, 8395602, 12894626, 12140855, + 7848232, 15058938, 10126444, 9702186, 14646087, 8281407, 12949820, 12205576, 7800028, 15005580, 10290326, + 9666610, 14525533, 8189920, 13065421, 12270269, 7888022, 14853165, 10294449, 9823395, 14312157, 8205314, + 13121095, 12306997, 8021604, 14667519, 10359409, 9873441, 14147271, 8213434, 13151785, 12394432, 8070486, + 14604112, 10274089, 9947941, 14142882, 8263493, 13041683, 12486538, 7995124, 14638527, 10254334, 9974679, + 14108152 }; + +} diff --git a/io.openems.edge.predictor.lstmmodel/test/io/openems/edge/predictor/lstmmodel/LstmModelImplTest.java b/io.openems.edge.predictor.lstmmodel/test/io/openems/edge/predictor/lstmmodel/LstmModelImplTest.java new file mode 100644 index 00000000000..2aa05c32ca1 --- /dev/null +++ b/io.openems.edge.predictor.lstmmodel/test/io/openems/edge/predictor/lstmmodel/LstmModelImplTest.java @@ -0,0 +1,50 @@ +package io.openems.edge.predictor.lstmmodel; + +import java.time.Instant; +import java.time.ZoneId; +import java.time.ZoneOffset; +import java.time.ZonedDateTime; + +import org.junit.Test; + +import io.openems.common.test.TimeLeapClock; +import io.openems.common.types.ChannelAddress; +import io.openems.edge.common.test.ComponentTest; +import io.openems.edge.common.test.DummyComponentManager; +import io.openems.edge.predictor.api.prediction.LogVerbosity; +import io.openems.edge.timedata.test.DummyTimedata; + +public class LstmModelImplTest { + + private static final String TIMEDATA_ID = "timedata0"; + private static final String PREDICTOR_ID = "predictor0"; + + private static final ChannelAddress METER1_ACTIVE_POWER = new ChannelAddress("meter1", "ActivePower"); + + @Test + public void test() throws Exception { + final var clock = new TimeLeapClock(Instant.ofEpochSecond(1577836800) /* starts at 1. January 2020 00:00:00 */, + ZoneOffset.UTC); + + var values = Data.data; + var timedata = new DummyTimedata(TIMEDATA_ID); + var start = ZonedDateTime.of(2019, 12, 1, 0, 0, 0, 0, ZoneId.of("UTC")); + + for (var i = 0; i < values.length; i++) { + timedata.add(start.plusMinutes(i * 15), METER1_ACTIVE_POWER, values[i]); + } + + var sut = new LstmModelImpl(); + + new ComponentTest(sut) // + .addReference("timedata", timedata) // + .addReference("componentManager", new DummyComponentManager(clock)) // + .activate(MyConfig.create() // + .setId(PREDICTOR_ID) // + .setLogVerbosity(LogVerbosity.NONE) // + .setChannelAddress(METER1_ACTIVE_POWER.toString())// + .build()); + + } + +} diff --git a/io.openems.edge.predictor.lstmmodel/test/io/openems/edge/predictor/lstmmodel/LstmPredictorTest.java b/io.openems.edge.predictor.lstmmodel/test/io/openems/edge/predictor/lstmmodel/LstmPredictorTest.java new file mode 100644 index 00000000000..03a62f2cb06 --- /dev/null +++ b/io.openems.edge.predictor.lstmmodel/test/io/openems/edge/predictor/lstmmodel/LstmPredictorTest.java @@ -0,0 +1,190 @@ +package io.openems.edge.predictor.lstmmodel; + +import static org.junit.Assert.assertEquals; + +import java.util.ArrayList; +import java.util.Arrays; + +import org.junit.Test; + +import io.openems.edge.predictor.lstmmodel.common.HyperParameters; +import io.openems.edge.predictor.lstmmodel.preprocessing.DataModification; + +/** + * Unit test class for the LstmPredictor model. + * + *

+ * This class tests the behavior of the LstmPredictor for different types of + * input patterns such as impulse, step, ramp, and exponential inputs. It + * validates that the predictions generated by the LSTM model match the expected + * outputs within a defined margin of error. + *

+ * + *

+ * The model utilizes pre-defined hyperparameters and scaling factors to process + * the input data. The test cases ensure that the scaling applied before the + * prediction and scaled-back result after the prediction are consistent with + * the hyperparameters. + *

+ * + *

+ * The model trend data used in the test is also related to the specific scaling + * factors defined in {@link HyperParameters}, and the tests confirm the proper + * interaction between the input data and the model's trend. + *

+ * + * + * + *

+ * The tests use the {@link DataModification#scale()} method to scale the input + * data before prediction and {@link DataModification#scaleBack()} to reverse + * the scaling after the prediction. The scaling ranges are provided by the + * hyperparameters through {@link HyperParameters#getScalingMin()} and + * {@link HyperParameters#getScalingMax()}. + *

+ */ +public class LstmPredictorTest { + private static HyperParameters hyperParameters = new HyperParameters(); + + private static ArrayList> modelTrend = new ArrayList<>(Arrays.asList( + createList(0.30000000000000004, -0.10191832534531027, -0.19262844428679757, 0.016925024201681654), + createList(-0.7999999999999999, -0.3142909416393413, -0.3341676120993015, -0.09089772222510135), + createList(-0.4999999999999999, 0.051555896559209405, -0.11477687998526631, 0.10826117268571883), + createList(-0.6, -1.449260711226437, -1.6789748520719996, -1.6707673970279129), + createList(1.9000000000000004, 2.0276163313785935, 2.0457575003167086, 1.716902676376759), + createList(0.09999999999999995, -0.40632251238009526, -0.2902480457595551, -0.21870167929155354), + createList(-0.08825349909436375, -0.10024408682002699, -0.0891597522413061, -0.11174726093461877), + createList(-0.2529282639641216, -0.24738024250988547, -0.18556978270548535, -0.2302537524898713)// + )); + + private static ArrayList createList(Double... values) { + return new ArrayList<>(Arrays.asList(values)); + } + + @Test + public void predictTest() { + + /* + * IMPULSE RESPONSE : impulses are the sudden change in consumption for a very + * short period of time + * + * Example : When someone runs the electric drilling machine + * + * When the change magnitude of data last indexed data is very high compared to + * other data in an array, model identifies it as an impulse. + * + * Model will make a prediction negating the drastic change + * + */ + double result; + + var impulseSimulation = new ArrayList<>(createList(50.0, 55.0, 55.0, 150.0)); + + result = LstmPredictor.predict(// + DataModification.scale(impulseSimulation, // + hyperParameters.getScalingMin(), hyperParameters.getScalingMax()), // + modelTrend.get(0), modelTrend.get(1), modelTrend.get(2), modelTrend.get(3), modelTrend.get(4), + modelTrend.get(5), modelTrend.get(7), modelTrend.get(6), hyperParameters); + + result = DataModification.scaleBack(result, // + hyperParameters.getScalingMin(), // + hyperParameters.getScalingMax()); + assertEquals(result, -4379.836081864531, 0.001); + } + + @Test + public void predictTest1() { + double result; + // STEP RESPONSE : Example: plugging in EV for charging + var stepSimulation1 = new ArrayList<>(createList(55.0, 45.0, 150.0, 150.0)); + + result = LstmPredictor.predict(// + DataModification.scale(stepSimulation1, // + hyperParameters.getScalingMin(), hyperParameters.getScalingMax()), + modelTrend.get(0), modelTrend.get(1), modelTrend.get(2), modelTrend.get(3), modelTrend.get(4), + modelTrend.get(5), modelTrend.get(7), modelTrend.get(6), hyperParameters); + + result = DataModification.scaleBack(result, // + hyperParameters.getScalingMin(), // + hyperParameters.getScalingMax()); + + assertEquals(result, -4382.945607343116, 0.001); + } + + @Test + public void predictTest2() { + double result; + + var stepSimulation2 = new ArrayList<>(createList(45.0, 150.0, 150.0, 150.0)); + result = LstmPredictor.predict(// + DataModification.scale(stepSimulation2, // + hyperParameters.getScalingMin(), hyperParameters.getScalingMax()), + modelTrend.get(0), modelTrend.get(1), modelTrend.get(2), modelTrend.get(3), modelTrend.get(4), + modelTrend.get(5), modelTrend.get(7), modelTrend.get(6), hyperParameters); + result = DataModification.scaleBack(result, // + hyperParameters.getScalingMin(), // + hyperParameters.getScalingMax()); + + assertEquals(result, -4380.577836686382, 0.0001); + + } + + @Test + public void predictTest3() { + double result; + + var stepSimulation3 = new ArrayList<>(createList(150.0, 150.0, 150.0, 150.0)); + result = LstmPredictor.predict(// + DataModification.scale(stepSimulation3, hyperParameters.getScalingMin(), + hyperParameters.getScalingMax()), + modelTrend.get(0), modelTrend.get(1), modelTrend.get(2), modelTrend.get(3), modelTrend.get(4), + modelTrend.get(5), modelTrend.get(7), modelTrend.get(6), hyperParameters); + result = DataModification.scaleBack(result, // + hyperParameters.getScalingMin(), // + hyperParameters.getScalingMax()); + + assertEquals(result, -4391.590202508077, 0.001); + + } + + @Test + public void predictTest4() { + double result; + + // RESPONSE TO RAMP INPUT + var rampInput = new ArrayList<>(createList(100.0, 200.0, 400.0, 800.0)); + result = LstmPredictor.predict(// + DataModification.scale(rampInput, // + hyperParameters.getScalingMin(), // + hyperParameters.getScalingMax()), + modelTrend.get(0), modelTrend.get(1), modelTrend.get(2), modelTrend.get(3), modelTrend.get(4), + modelTrend.get(5), modelTrend.get(7), modelTrend.get(6), hyperParameters); + result = DataModification.scaleBack(result, // + hyperParameters.getScalingMin(), // + hyperParameters.getScalingMax()); + + assertEquals(result, -4376.960088551304, 0.001); + + } + + @Test + public void predictTest5() { + double result; + + // RESPONSE to exponential input + var expInput = new ArrayList<>(createList(20.0, 400.0, 160000.0, 3200000000.0)); + result = LstmPredictor.predict(// + DataModification.scale(expInput, // + hyperParameters.getScalingMin(), // + hyperParameters.getScalingMax()), + modelTrend.get(0), modelTrend.get(1), modelTrend.get(2), modelTrend.get(3), modelTrend.get(4), + modelTrend.get(5), modelTrend.get(7), modelTrend.get(6), hyperParameters); + result = DataModification.scaleBack(result, // + hyperParameters.getScalingMin(), // + hyperParameters.getScalingMax()); + + assertEquals(result, -6666.666666666666, 0.001); + + } + +} diff --git a/io.openems.edge.predictor.lstmmodel/test/io/openems/edge/predictor/lstmmodel/MyConfig.java b/io.openems.edge.predictor.lstmmodel/test/io/openems/edge/predictor/lstmmodel/MyConfig.java new file mode 100644 index 00000000000..ab7bc65293e --- /dev/null +++ b/io.openems.edge.predictor.lstmmodel/test/io/openems/edge/predictor/lstmmodel/MyConfig.java @@ -0,0 +1,63 @@ +package io.openems.edge.predictor.lstmmodel; + +import io.openems.common.test.AbstractComponentConfig; +import io.openems.edge.predictor.api.prediction.LogVerbosity; + +@SuppressWarnings("all") +public class MyConfig extends AbstractComponentConfig implements Config { + + protected static class Builder { + private String id; + private String channelAddress; + private LogVerbosity logVerbosity; + + private Builder() { + } + + public Builder setId(String id) { + this.id = id; + return this; + } + + public Builder setChannelAddress(String channelAddress) { + this.channelAddress = channelAddress; + return this; + } + + public Builder setLogVerbosity(LogVerbosity logVerbosity) { + this.logVerbosity = logVerbosity; + return this; + } + + public MyConfig build() { + return new MyConfig(this); + } + } + + /** + * Create a Config builder. + * + * @return a {@link Builder} + */ + public static Builder create() { + return new Builder(); + } + + private final Builder builder; + + private MyConfig(Builder builder) { + super(Config.class, builder.id); + this.builder = builder; + } + + @Override + public String channelAddress() { + return this.builder.channelAddress; + } + + @Override + public LogVerbosity logVerbosity() { + return this.builder.logVerbosity; + } + +} \ No newline at end of file diff --git a/io.openems.edge.predictor.lstmmodel/test/io/openems/edge/predictor/lstmmodel/StandAlonePredictorTest.java b/io.openems.edge.predictor.lstmmodel/test/io/openems/edge/predictor/lstmmodel/StandAlonePredictorTest.java new file mode 100644 index 00000000000..27b8303cbb4 --- /dev/null +++ b/io.openems.edge.predictor.lstmmodel/test/io/openems/edge/predictor/lstmmodel/StandAlonePredictorTest.java @@ -0,0 +1,482 @@ +package io.openems.edge.predictor.lstmmodel; + +import java.time.OffsetDateTime; +import java.time.ZoneId; +import java.time.ZonedDateTime; +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import java.util.stream.Stream; + +//import org.junit.Test; + +//import org.junit.Test; + +import io.openems.edge.predictor.lstmmodel.common.HyperParameters; +import io.openems.edge.predictor.lstmmodel.common.ReadAndSaveModels; +import io.openems.edge.predictor.lstmmodel.common.ReadCsv; +import io.openems.edge.predictor.lstmmodel.interpolation.InterpolationManager; +import io.openems.edge.predictor.lstmmodel.performance.PerformanceMatrix; +import io.openems.edge.predictor.lstmmodel.preprocessing.DataModification; +import static io.openems.edge.predictor.lstmmodel.utilities.UtilityConversion.to1DArrayList; + +/** + * This test class is intended for local testing and is not executed during the + * build process. To run the JUnit test cases, please uncomment the relevant + * annotations. Ensure that the necessary data and model files are accessible in + * the specified path before executing the tests. + */ +public class StandAlonePredictorTest { + + public static final String CSV = "1.csv"; + public static final ZonedDateTime GLOBAL_DATE = ZonedDateTime.of(2022, 6, 16, 0, 0, 0, 0, + ZonedDateTime.now().getZone()); + public static final HyperParameters HYPER_PARAMETERS = ReadAndSaveModels.read("ConsumptionActivePower"); + + /** + * Prediction testing. + */ + //@Test + public void predictionTest() { + + ArrayList> predictedSeasonality = new ArrayList>(); + ArrayList predictedTrend = new ArrayList(); + HYPER_PARAMETERS.printHyperParameters(); + + int predictionCount = 1; + for (int i = 0; i < predictionCount; i++) { + ZonedDateTime nowDate = GLOBAL_DATE.plusHours(24 * i); + var tempPredicted = this.predictSeasonality(HYPER_PARAMETERS, nowDate, CSV); + predictedSeasonality.add(tempPredicted); + } + + for (int i = 0; i < predictionCount; i++) { + ZonedDateTime nowDate = GLOBAL_DATE.plusHours(24 * i); + predictedTrend = this.predictTrendOneDay(HYPER_PARAMETERS, nowDate, CSV); + } + + var pre = to1DArrayList(predictedSeasonality); + + var until = GLOBAL_DATE.withMinute(getMinute(GLOBAL_DATE, HYPER_PARAMETERS)).withSecond(0).withNano(0); + var targetFrom = until.plusMinutes(HYPER_PARAMETERS.getInterval()); + var targetTo = targetFrom.plusHours(24 * predictionCount); + + var target = this.getTargetData(targetFrom, targetTo, CSV, HYPER_PARAMETERS); + var rmsSeasonality = PerformanceMatrix.rmsError(target, pre); + var rmsTrend = PerformanceMatrix.rmsError(target, predictedTrend); + + StringBuilder sb = new StringBuilder(); + String format = "%-25s %s%n"; + + sb.append(String.format(format, "Target:", DataModification.constantScaling(target, 1))) + .append(String.format(format, "PredictedSeasonality:", DataModification.constantScaling(pre, 1))) + .append(String.format(format, "Target (raw):", target)) + .append(String.format(format, "Predicted trend:", DataModification.constantScaling(predictedTrend, 1))) + .append(String.format(format, "RMS Trend:", rmsTrend)) + .append(String.format(format, "RMS Seasonality:", rmsSeasonality)) + .append(String.format(format, "Accuracy Trend:", + PerformanceMatrix.accuracy(target, predictedTrend, 0.15))) + .append(String.format(format, "Accuracy Seasonality:", PerformanceMatrix.accuracy(target, pre, 0.15))); + + System.out.println(sb.toString()); + } + + // @Test + protected void predictionTestMultivarient() { + ArrayList> predictedSeasonality = new ArrayList>(); + ArrayList predictedTrend = new ArrayList(); + HYPER_PARAMETERS.printHyperParameters(); + + int predictionFor = 1; + for (int i = 0; i < predictionFor; i++) { + ZonedDateTime nowDate = GLOBAL_DATE.plusHours(24 * i); + var tempPredicted = this.predictSeasonalityMultivarent(HYPER_PARAMETERS, nowDate, CSV); + predictedSeasonality.add(tempPredicted); + } + + for (int i = 0; i < predictionFor; i++) { + ZonedDateTime nowDate = GLOBAL_DATE.plusHours(24 * i); + predictedTrend = this.predictTrendOneDayMultivarent(HYPER_PARAMETERS, nowDate, CSV); + + } + var pre = to1DArrayList(predictedSeasonality); + var until = GLOBAL_DATE.withMinute(getMinute(GLOBAL_DATE, HYPER_PARAMETERS)).withSecond(0).withNano(0); + var targetFrom = until.plusMinutes(HYPER_PARAMETERS.getInterval()); + var targetTo = targetFrom.plusHours(24 * predictionFor); + + // changing target data for reference + var target = this.getTargetData(targetFrom, targetTo, CSV, HYPER_PARAMETERS); + var ref = this.getTargetRefrence(targetFrom, targetTo); + + var trend = DataModification.elementWiseDiv(predictedTrend, ref); + var rmsSeasonality = PerformanceMatrix.rmsError(target, pre); + var rmsTrend = PerformanceMatrix.rmsError(target, trend); + + var sb = new StringBuilder(); + String format = "%-25s %s%n"; + + sb.append(String.format(format, "Target:", DataModification.constantScaling(target, 1))) + .append(String.format(format, "PredictedSeasonality:", DataModification.constantScaling(pre, 1))) + .append(String.format(format, "Target (raw):", target)) + .append(String.format(format, "Predicted trend:", DataModification.constantScaling(trend, 1))) + .append(String.format(format, "RMS Trend:", rmsTrend)) + .append(String.format(format, "RMS Seasonality:", rmsSeasonality)) + .append(String.format(format, "Accuracy trend:", PerformanceMatrix.accuracy(target, trend, 0.15))) + .append(String.format(format, "Accuracy seasonality:", PerformanceMatrix.accuracy(target, pre, 0.15))); + + System.out.println(sb.toString()); + } + + /** + * Doing what it suppose to do. + * + * @param hyperParameters the Hyperparam + * @param nowDate nowDate + * @param csvFileName csvFileName + * @return predicted the predicted + */ + public ArrayList predictSeasonality(HyperParameters hyperParameters, ZonedDateTime nowDate, + String csvFileName) { + + var until = GLOBAL_DATE.withMinute(getMinute(GLOBAL_DATE, HYPER_PARAMETERS)).withSecond(0).withNano(0); + var windowSize = hyperParameters.getWindowSizeSeasonality(); + + nowDate = nowDate.plusMinutes(hyperParameters.getInterval()); + + var temp = until.minusDays(windowSize); + var fromDate = temp.withMinute(getMinute(nowDate, hyperParameters)).withSecond(0).withNano(0); + + final var data = this.queryData(fromDate, until, csvFileName); + final var date = this.queryDate(fromDate, until, csvFileName); + + var targetFrom = until.plusMinutes(hyperParameters.getInterval()); + + ArrayList predicted = LstmPredictor.getArranged( + LstmPredictor.getIndex(targetFrom.getHour(), targetFrom.getMinute(), hyperParameters), + LstmPredictor.predictSeasonality(data, date, hyperParameters)); + return predicted; + } + + /** + * Doing what it suppose to do. + * + * @param hyperParameters the Hyperparam + * @param nowDate nowDate + * @param csvFileName csvFileName + * @return predicted the predicted + */ + public ArrayList predictSeasonalityMultivarent(HyperParameters hyperParameters, ZonedDateTime nowDate, + String csvFileName) { + + var until = nowDate.withMinute(getMinute(nowDate, hyperParameters)).withSecond(0).withNano(0); + + var windowSize = hyperParameters.getWindowSizeSeasonality(); + + nowDate = nowDate.plusMinutes(hyperParameters.getInterval()); + + var temp = until.minusDays(windowSize); + var fromDate = temp.withMinute(getMinute(nowDate, hyperParameters)).withSecond(0).withNano(0); + + final var data = this.queryData(fromDate, until, csvFileName); + final var date = this.queryDate(fromDate, until, csvFileName); + + var refdata = this.generateRefrence(date); + var toPredictData = DataModification.elementWiseMultiplication(refdata, data); + + var targetFrom = until.plusMinutes(hyperParameters.getInterval()); + + var predicted = LstmPredictor.getArranged( + LstmPredictor.getIndex(targetFrom.getHour(), targetFrom.getMinute(), hyperParameters), + LstmPredictor.predictSeasonality(toPredictData, date, hyperParameters)); + + // postprocess + var targetRef = this.getTargetRefrence(fromDate, until); + return DataModification.elementWiseDiv(predicted, targetRef); + } + + /** + * Gives prediction for seasoality. + * + * @param hyperParameters the Hyperparam + * @param nowDate nowDate + * @param csvFileName csvFileName + * @return predicted the predicted + */ + public ArrayList predictTrendOneDay(HyperParameters hyperParameters, ZonedDateTime nowDate, + String csvFileName) { + ArrayList predicted = new ArrayList(); + for (int i = 0; i < 60 / hyperParameters.getInterval() * 24; i++) { + var nowDateTemp = nowDate.plusMinutes(i * hyperParameters.getInterval()); + var until = nowDateTemp.withMinute(getMinute(nowDateTemp, hyperParameters)).withSecond(0).withNano(0); + var forTrendPrediction = this.queryData( + until.minusMinutes(hyperParameters.getInterval() * hyperParameters.getWindowSizeTrend()), until, + csvFileName); + var dateForTrend = this.queryDate( + until.minusMinutes(hyperParameters.getInterval() * hyperParameters.getWindowSizeTrend()), until, + csvFileName); + predicted.add(LstmPredictor.predictTrend(forTrendPrediction, dateForTrend, until, hyperParameters).get(0)); + } + return predicted; + } + + /** + * Predicts the Trend. + * + * @param hyperParameters the {@link HyperParameters} + * @param nowDate the {@link ZonedDateTime} for now + * @param csvFileName the csv file name + * @return the trend + */ + public ArrayList predictTrendOneDayMultivarent(HyperParameters hyperParameters, ZonedDateTime nowDate, + String csvFileName) { + + ArrayList predicted = new ArrayList(); + for (int i = 0; i < 60 / hyperParameters.getInterval() * 24; i++) { + + var nowDateTemp = nowDate.plusMinutes(i * hyperParameters.getInterval()); + var until = nowDateTemp.withMinute(getMinute(nowDateTemp, hyperParameters)).withSecond(0).withNano(0); + var forTrendPrediction = this.queryData( + until.minusMinutes(hyperParameters.getInterval() * hyperParameters.getWindowSizeTrend()), until, + csvFileName); + var dateForTrend = this.queryDate( + until.minusMinutes(hyperParameters.getInterval() * hyperParameters.getWindowSizeTrend()), until, + csvFileName); + + // modification for multivariant + var tempData = this.generateRefrence(dateForTrend); + + tempData = DataModification.elementWiseMultiplication(forTrendPrediction, tempData); + predicted.add(LstmPredictor.predictTrend(tempData, dateForTrend, until, hyperParameters).get(0)); + } + return predicted; + } + + /** + * Gives target data to compare. + * + * @param from the From + * @param to the to + * @param csvfileName the csvfileName + * @param hyperParameter the hyperParameter + * @return the target data + */ + public ArrayList getTargetData(ZonedDateTime from, ZonedDateTime to, String csvfileName, + HyperParameters hyperParameter) { + InterpolationManager obj = new InterpolationManager(this.queryData(from, to, csvfileName), hyperParameter); + return obj.getInterpolatedData(); + + } + + /** + * Generates a Reference from {@link OffsetDateTime}s. + * + * @param dates the {@link OffsetDateTime}s + * @return the reference + */ + public ArrayList generateRefrence(ArrayList dates) { + // one hour = 360/24 degree + // one minute = 360/(24*60) degree + Objects.requireNonNull(dates, "Date list must not be null"); + + return dates.stream().map(date -> { + double hourAngle = date.getHour() * 15.0; // 360/24 = 15 degrees per hour + double minuteAngle = date.getMinute() * 0.25; // 360/(24*60) = 0.25 degrees per minute + double totalAngle = hourAngle + minuteAngle; + + return 1.5 + Math.cos(Math.toRadians(totalAngle)); + }).collect(Collectors.toCollection(ArrayList::new)); + } + + /** + * Generates a Reference from {@link ZonedDateTime}s. + * + * @param dates the {@link ZonedDateTime}s + * @return the reference + */ + public static ArrayList generateReference(ArrayList dates) { + + Objects.requireNonNull(dates, "Date list must not be null"); + + return dates.stream().map(date -> { + double hourAngle = date.getHour() * 15.0; // 360/24 = 15 degrees per hour + double minuteAngle = date.getMinute() * 0.25; // 360/(24*60) = 0.25 degrees per minute + double totalAngle = hourAngle + minuteAngle; + + return 1.5 + Math.cos(Math.toRadians(totalAngle)); + }).collect(Collectors.toCollection(ArrayList::new)); + + } + + ArrayList getTargetRefrence(ZonedDateTime from, ZonedDateTime to) { + + int interval = 5; + int hour = 24; + int dataLen = (60 / interval) * hour; + + List dates = Stream.iterate(from, date -> date.plusMinutes(interval))// + .limit(dataLen)// + .collect(Collectors.toList()); + + return generateReference(new ArrayList<>(dates)); + } + + /** + * Gets the rounded minute value for the provided ZonedDateTime. + * + * @param nowDate The ZonedDateTime for which to determine the rounded + * minute + * @param hyperParameters is the object of class HyperParameters value. + * @return The rounded minute value (0, 15, 30, or 45) based on the minute + * component of the input time. + */ + public static int getMinute(ZonedDateTime nowDate, final HyperParameters hyperParameters) { + Objects.requireNonNull(nowDate, "DateTime must not be null"); + Objects.requireNonNull(hyperParameters, "HyperParameters must not be null"); + + final int interval = hyperParameters.getInterval(); + + if (interval <= 0) { + throw new IllegalArgumentException("Interval must be positive"); + } + if (60 % interval != 0) { + throw new IllegalArgumentException(String.format("Interval %d must be a factor of 60", interval)); + } + + return (nowDate.getMinute() / interval) * interval; + } + + /** + * Queries data from a CSV file for a specified time range and returns the + * relevant data points. + * + * @param fromDate The start date and time for data retrieval. + * @param untilDate The end date and time for data retrieval. + * @param path The file path to the CSV data file. + * @return An ArrayList of data points that fall within the specified time + * range. + */ + public ArrayList queryData(ZonedDateTime fromDate, ZonedDateTime untilDate, String path) { + String dataPath = path; + ReadCsv csv = new ReadCsv(dataPath); + ArrayList data = csv.getData(); + ArrayList dates = csv.getDates(); + ArrayList toReturn = new ArrayList(); + int from = this.getindexOfDate(this.toOffsetDateTime(fromDate), dates); + int till = this.getindexOfDate(this.toOffsetDateTime(untilDate), dates); + toReturn = this.getData(from, till, data); + return toReturn; + } + + /** + * Queries and retrieves a list of OffsetDateTime values from a CSV file that + * fall within a specified time range. + * + * @param fromDate The start date and time for data retrieval. + * @param untilDate The end date and time for data retrieval. + * @param path The file path to the CSV data file. + * @return An ArrayList of OffsetDateTime values that correspond to the + * specified time range. + */ + + public ArrayList queryDate(ZonedDateTime fromDate, ZonedDateTime untilDate, String path) { + String dataPath = path; + ReadCsv csv = new ReadCsv(dataPath); + ArrayList dates = csv.getDates(); + ArrayList toReturn = new ArrayList(); + int from = this.getindexOfDate(this.toOffsetDateTime(fromDate), dates); + int till = this.getindexOfDate(this.toOffsetDateTime(untilDate), dates); + toReturn = this.getDate(from, till, dates); + return toReturn; + } + + /** + * Converts an OffsetDateTime to a ZonedDateTime, retaining the date and time + * components. + * + * @param offsetDateTime The OffsetDateTime to convert to ZonedDateTime. + * @return The converted ZonedDateTime with the same date and time components. + */ + + public ZonedDateTime toZonedDateTime(OffsetDateTime offsetDateTime) { + + Objects.requireNonNull(offsetDateTime, "OffsetDateTime must not be null"); + + return offsetDateTime.atZoneSameInstant(ZoneId.systemDefault()).withSecond(0).withNano(0); + } + + /** + * Converts a ZonedDateTime to an OffsetDateTime, retaining the date, time, and + * offset components. + * + * @param time The ZonedDateTime to convert to OffsetDateTime. + * @return The converted OffsetDateTime with the same date, time, and offset + * components. + */ + + public OffsetDateTime toOffsetDateTime(ZonedDateTime time) { + Objects.requireNonNull(time, "ZonedDateTime must not be null"); + return time.toOffsetDateTime().withSecond(0).withNano(0); + } + + /** + * Find the index of a specific OffsetDateTime within an ArrayList of + * OffsetDateTime values. + * + * @param targetDate The OffsetDateTime to search for within the ArrayList. + * @param dates An ArrayList of OffsetDateTime values to search in. + * @return The index of the specified date in the ArrayList if found; otherwise, + * null. + */ + + public Integer getindexOfDate(OffsetDateTime targetDate, ArrayList dates) { + Objects.requireNonNull(targetDate, "Target date must not be null"); + Objects.requireNonNull(dates, "Date list must not be null"); + return IntStream.range(0, dates.size()).boxed().filter(i -> targetDate.isEqual(dates.get(i))).findFirst().get(); + } + + /** + * Retrieves a subset of data from an ArrayList of Double values based on + * specified indices. + * + * @param fromIndex The starting index (inclusive) for data retrieval. + * @param toIndex The ending index (exclusive) for data retrieval. + * @param data An ArrayList of Double values containing the data. + * @return A new ArrayList containing the subset of data from the specified + * range of indices. + */ + public ArrayList getData(Integer fromIndex, Integer toIndex, ArrayList data) { + if (fromIndex < 0 || toIndex > data.size()) { + throw new IllegalArgumentException("Indices out of bounds. Valid range is 0 to " + data.size()); + } + if (fromIndex > toIndex) { + throw new IllegalArgumentException("fromIndex must be less than or equal to toIndex"); + } + + return data.stream().skip(fromIndex).limit(toIndex - fromIndex) + .collect(Collectors.toCollection(ArrayList::new)); + } + + /** + * Retrieves a subset of OffsetDateTime values from an ArrayList based on + * specified indices. + * + * @param fromIndex The starting index (inclusive) for date retrieval. + * @param toIndex The ending index (exclusive) for date retrieval. + * @param date An ArrayList of OffsetDateTime values containing the dates. + * @return A new ArrayList containing the subset of OffsetDateTime values from + * the specified range of indices. + */ + public ArrayList getDate(Integer fromIndex, Integer toIndex, ArrayList date) { + if (fromIndex < 0 || toIndex > date.size()) { + throw new IllegalArgumentException("Indices out of bounds. Valid range is 0 to " + date.size()); + } + if (fromIndex > toIndex) { + throw new IllegalArgumentException("fromIndex must be less than or equal to toIndex"); + } + + return date.stream().skip(fromIndex).limit(toIndex - fromIndex) + .collect(Collectors.toCollection(ArrayList::new)); + } +} diff --git a/io.openems.edge.predictor.lstmmodel/test/io/openems/edge/predictor/lstmmodel/common/DataModificationTest.java b/io.openems.edge.predictor.lstmmodel/test/io/openems/edge/predictor/lstmmodel/common/DataModificationTest.java new file mode 100644 index 00000000000..658f5af68f2 --- /dev/null +++ b/io.openems.edge.predictor.lstmmodel/test/io/openems/edge/predictor/lstmmodel/common/DataModificationTest.java @@ -0,0 +1,170 @@ +package io.openems.edge.predictor.lstmmodel.common; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; + +import java.time.OffsetDateTime; +import java.time.ZoneOffset; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import org.junit.Test; + +import io.openems.edge.predictor.lstmmodel.preprocessing.DataModification; + +public class DataModificationTest { + + @Test + public void testGroupDataByHourAndMinute() { + ArrayList testData = new ArrayList<>(Arrays.asList(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0)); + ArrayList testDate = new ArrayList<>(Arrays.asList(OffsetDateTime.parse("2022-01-01T10:15:30Z"), + OffsetDateTime.parse("2022-01-01T11:30:45Z"), OffsetDateTime.parse("2022-01-01T10:45:00Z"), + OffsetDateTime.parse("2022-01-01T11:15:00Z"), OffsetDateTime.parse("2022-01-01T10:30:00Z"), + OffsetDateTime.parse("2022-01-01T10:15:30Z"), OffsetDateTime.parse("2022-01-01T11:30:45Z"), + OffsetDateTime.parse("2022-01-01T10:45:00Z"), OffsetDateTime.parse("2022-01-01T11:15:00Z"), + OffsetDateTime.parse("2022-01-01T10:30:00Z"))); + + ArrayList>> result = DataModification.groupDataByHourAndMinute(testData, testDate); + + assertEquals(2, result.size()); + } + + @Test + public void testCombinedArray() { + ArrayList> testData1 = new ArrayList<>( + Arrays.asList(new ArrayList<>(Arrays.asList(1.0, 2.0, 3.0)), + new ArrayList<>(Arrays.asList(4.0, 5.0, 6.0)), new ArrayList<>(Arrays.asList(7.0, 8.0, 9.0)))); + ArrayList expectedResult1 = new ArrayList<>(Arrays.asList(1.0, 4.0, 7.0, 2.0, 5.0, 8.0, 3.0, 6.0, 9.0)); + assertEquals(expectedResult1, DataModification.combinedArray(testData1)); + + ArrayList> testData2 = new ArrayList<>(); + ArrayList expectedResult2 = new ArrayList<>(); + assertEquals(expectedResult2, DataModification.combinedArray(testData2)); + } + + @Test + public void testModifyFortrendPrediction() { + ArrayList testData = new ArrayList<>(List.of(1.0, 2.0, 3.0, 4.0, 5.0)); + ArrayList testDates = new ArrayList<>(List.of(OffsetDateTime.parse("2022-01-01T12:30:00Z"), + OffsetDateTime.parse("2022-01-01T12:45:00Z"), OffsetDateTime.parse("2022-01-01T13:00:00Z"), + OffsetDateTime.parse("2022-01-01T13:15:00Z"), OffsetDateTime.parse("2022-01-01T13:30:00Z"))); + HyperParameters testHyperParameters = new HyperParameters(); + + ArrayList> result = DataModification.modifyFortrendPrediction(testData, testDates, + testHyperParameters); + + assertNotNull(result); + assertEquals(5, result.size()); + } + + @Test + public void testScale() { + ArrayList testData = new ArrayList<>(); + testData.add(10.0); + testData.add(20.0); + testData.add(30.0); + + ArrayList scaledData = DataModification.scale(testData, 10.0, 30.0); + + assertEquals(0.2, scaledData.get(0), 0.0001); + assertEquals(0.5, scaledData.get(1), 0.0001); + assertEquals(0.8, scaledData.get(2), 0.0001); + } + + @Test + public void testScaleBack() { + double scaledValue = 0.5; + double minOriginal = 10.0; + double maxOriginal = 30.0; + + double originalValue = DataModification.scaleBack(scaledValue, minOriginal, maxOriginal); + assertEquals(20.0, originalValue, 0.0001); + } + + @Test + public void groupByTest() { + HyperParameters hyperParameters = new HyperParameters(); + ArrayList data = new ArrayList(); + ArrayList date = new ArrayList(); + int interval = hyperParameters.getInterval(); + int forDays = 2; + int itter = forDays * 24 * 60 / interval; + // generating data + OffsetDateTime startingDate = OffsetDateTime.of(2023, 1, 1, 0, 0, 0, 0, ZoneOffset.ofHours(1)); + for (int i = 0; i < itter; i++) { + date.add(startingDate.plusMinutes(i * interval)); + data.add(i + 0.00); + + } + + ArrayList>> groupedData = DataModification.groupDataByHourAndMinute(data, date); + for (int i = 0; i < groupedData.size(); i++) { + for (int j = 0; j < groupedData.get(i).get(j).size(); j++) { + assertEquals(groupedData.get(i).get(j).size(), forDays); + } + } + + } + + @Test + public void getDataInBatchTest() { + ArrayList data = new ArrayList<>(Arrays.asList(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0)); + int numberOfGroups = 2; + ArrayList> result = DataModification.getDataInBatch(data, numberOfGroups); + assertEquals(result.size(), numberOfGroups); + int i = 0; + for (ArrayList outerVal : result) { + for (double val : outerVal) { + assertEquals(val, data.get(i), 0.00001); + i++; + } + } + } + + @Test + public void getDateInBatchTest() { + + ArrayList dateList = new ArrayList(); + HyperParameters hyperParameters = new HyperParameters(); + OffsetDateTime startingDate = OffsetDateTime.of(2023, 1, 1, 0, 0, 0, 0, ZoneOffset.ofHours(1)); + int numberOfGroups = 2; + int j = 0; + + // populating Date list + + for (int i = 0; i < 10; i++) { + dateList.add(startingDate.plusMinutes(i * hyperParameters.getInterval())); + } + + ArrayList> result = DataModification.getDateInBatch(dateList, numberOfGroups); + assertEquals(result.size(), numberOfGroups); + + for (ArrayList outerVal : result) { + for (OffsetDateTime val : outerVal) { + assertEquals(val, dateList.get(j)); + j++; + } + } + } + + @Test + public void removeNegatives() { + + ArrayList inputList = new ArrayList<>(Arrays.asList(5.0, -3.0, 2.0, -7.5)); + ArrayList expectedList = new ArrayList<>(Arrays.asList(5.0, 0.0, 2.0, 0.0)); + + ArrayList resultList = DataModification.removeNegatives(inputList); + assertEquals(expectedList, resultList); + } + + @Test + public void constantScalingTest() { + + ArrayList inputData = new ArrayList<>(Arrays.asList(1.0, 2.0, 3.0)); + double scalingFactor = 2.0; + ArrayList expectedOutput = new ArrayList<>(Arrays.asList(2.0, 4.0, 6.0)); + ArrayList actualOutput = DataModification.constantScaling(inputData, scalingFactor); + assertEquals(expectedOutput, actualOutput); + } +} diff --git a/io.openems.edge.predictor.lstmmodel/test/io/openems/edge/predictor/lstmmodel/common/DataStatisticsTest.java b/io.openems.edge.predictor.lstmmodel/test/io/openems/edge/predictor/lstmmodel/common/DataStatisticsTest.java new file mode 100644 index 00000000000..9fa82b84649 --- /dev/null +++ b/io.openems.edge.predictor.lstmmodel/test/io/openems/edge/predictor/lstmmodel/common/DataStatisticsTest.java @@ -0,0 +1,58 @@ +package io.openems.edge.predictor.lstmmodel.common; + +import static org.junit.Assert.assertEquals; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import org.junit.Test; + +public class DataStatisticsTest { + + public static final List DATALIST = Arrays.asList(1.0, 2.0, 3.0, 4.0, 5.0); + public static final ArrayList DATA = new ArrayList<>(DATALIST); + public static final ArrayList EMPTYDATA = new ArrayList<>(); + + @Test + public void testGetMean() { + double result = DataStatistics.getMean(DATA); + assertEquals(3.0, result, 0.0001); + } + + @Test + public void testGetMeanEmptyList() { + double result = DataStatistics.getMean((ArrayList) EMPTYDATA); + assertEquals(0.0, result, 0.0001); + } + + @Test + public void testGetStandardDeviation() { + double result = DataStatistics.getStandardDeviation(DATA); + assertEquals(1.41421, result, 0.0001); + } + + @Test + public void testGetStandardDeviationEmptyList() { + assertEquals(Double.NaN, DataStatistics.getStandardDeviation(EMPTYDATA), 0.0001); + } + + @Test + public void testGetStanderDeviation() { + double result = DataStatistics.getStandardDeviation((ArrayList) DATA); + assertEquals(1.41421, result, 0.0001); + } + + @Test + public void testGetStanderDeviationEmptyList() { + assertEquals(Double.NaN, DataStatistics.getStandardDeviation((ArrayList) EMPTYDATA), 0.0001); + } + + @Test + public void testComputeRms() { + double[] original = { 1.0, 2.0, 3.0, 4.0, 5.0 }; + double[] computed = { 1.1, 2.2, 3.1, 4.2, 5.1 }; + double expectedRms = 0.1483239; + assertEquals(expectedRms, DataStatistics.computeRms(original, computed), 0.0001); + } +} diff --git a/io.openems.edge.predictor.lstmmodel/test/io/openems/edge/predictor/lstmmodel/common/ReadAndSaveObjectTest.java b/io.openems.edge.predictor.lstmmodel/test/io/openems/edge/predictor/lstmmodel/common/ReadAndSaveObjectTest.java new file mode 100644 index 00000000000..123d1135eb8 --- /dev/null +++ b/io.openems.edge.predictor.lstmmodel/test/io/openems/edge/predictor/lstmmodel/common/ReadAndSaveObjectTest.java @@ -0,0 +1,70 @@ +package io.openems.edge.predictor.lstmmodel.common; + +import static org.junit.Assert.assertEquals; + +import java.io.File; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Paths; + +//import org.junit.Test; + +import io.openems.common.OpenemsConstants; + +/** + * This class contains test methods for saving and reading objects using Gson. + * Uncomment the @Test annotations to run the tests locally. These tests use the + * HyperParameters class and involve saving objects to JSON files and reading + * them back for validation. + */ +public class ReadAndSaveObjectTest { + + /** + * Test method for saving an object to a file using Gson serialization. + * Uncomment the @Test annotation to run the test locally. + */ + // @Test + public void saveObjectGsonTest() { + HyperParameters hyperParameters = new HyperParameters(); + hyperParameters.setModelName("testGson"); + hyperParameters.setCount(30); + hyperParameters.setRmsErrorTrend(0.1234); + hyperParameters.setRmsErrorTrend(0.4567); + ReadAndSaveModels.save(hyperParameters); + } + + /** + * Test method for reading a JSON object from a file. Uncomment the @Test + * annotation to run the test locally. + */ + // @Test + public void readObjectGson() { + HyperParameters hyperParameters = new HyperParameters(); + hyperParameters.setCount(30); + hyperParameters.setModelName("Consumption"); + + HyperParameters hyper = ReadAndSaveModels.read(hyperParameters.getModelName()); + assertEquals(hyper.getCount(), hyperParameters.getCount()); + + // deleting the hyperparametes + try { + Files.delete(Paths.get(this.getModelPath(hyperParameters.getModelName() + "fenHp.fems"))); + } catch (IOException e) { + e.printStackTrace(); + } + + } + + /** + * Gets the absolute path for a model file based on a given suffix. The path is + * constructed within the OpenEMS data directory under the "models" + * subdirectory. + * + * @param suffix The suffix to be appended to the model file path. + * @return The absolute path for the model file. + */ + public String getModelPath(String suffix) { + File file = Paths.get(OpenemsConstants.getOpenemsDataDir()).toFile(); + return file.getAbsolutePath() + File.separator + "models" + File.separator + suffix; + } +} \ No newline at end of file diff --git a/io.openems.edge.predictor.lstmmodel/test/io/openems/edge/predictor/lstmmodel/interpolation/CubicalInterpolationTest.java b/io.openems.edge.predictor.lstmmodel/test/io/openems/edge/predictor/lstmmodel/interpolation/CubicalInterpolationTest.java new file mode 100644 index 00000000000..803616ccbea --- /dev/null +++ b/io.openems.edge.predictor.lstmmodel/test/io/openems/edge/predictor/lstmmodel/interpolation/CubicalInterpolationTest.java @@ -0,0 +1,50 @@ +package io.openems.edge.predictor.lstmmodel.interpolation; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +import java.util.ArrayList; +import java.util.Arrays; + +import org.junit.Test; + +public class CubicalInterpolationTest { + + @Test + public void testCanInterpolate() { + ArrayList validData = new ArrayList<>(Arrays.asList(1.0, 2.0, 3.0, 4.0, null, Double.NaN)); + CubicalInterpolation inter = new CubicalInterpolation(validData); + + assertTrue(inter.canInterpolate()); + + ArrayList invalidData = new ArrayList<>(Arrays.asList(1.0, null, 3.0, Double.NaN)); + inter.setData(invalidData); + assertFalse(inter.canInterpolate()); + + ArrayList exactlyFourData = new ArrayList<>(Arrays.asList(1.0, 2.0, null, 4.0, 5.0)); + inter.setData(exactlyFourData); + assertTrue(inter.canInterpolate()); + + ArrayList allNullOrNaNData = new ArrayList<>(Arrays.asList(null, Double.NaN, null, Double.NaN)); + inter.setData(allNullOrNaNData); + assertFalse(inter.canInterpolate()); + + ArrayList emptyData = new ArrayList<>(); + inter.setData(emptyData); + assertFalse(inter.canInterpolate()); + } + + @Test + + public void testInterpolate() { + + ArrayList validData = new ArrayList<>(Arrays.asList(2.0, 4.0, Double.NaN, 8.0)); + ArrayList expectedResult = new ArrayList<>(Arrays.asList(2.0, 4.0, 6.0, 8.0)); + + CubicalInterpolation inter = new CubicalInterpolation(validData); + + ArrayList interpolatedData = inter.compute(); + assertEquals(interpolatedData, expectedResult); + } +} diff --git a/io.openems.edge.predictor.lstmmodel/test/io/openems/edge/predictor/lstmmodel/interpolation/InterpolationMangerTest.java b/io.openems.edge.predictor.lstmmodel/test/io/openems/edge/predictor/lstmmodel/interpolation/InterpolationMangerTest.java new file mode 100644 index 00000000000..5ab943bd3c6 --- /dev/null +++ b/io.openems.edge.predictor.lstmmodel/test/io/openems/edge/predictor/lstmmodel/interpolation/InterpolationMangerTest.java @@ -0,0 +1,74 @@ +package io.openems.edge.predictor.lstmmodel.interpolation; + +import static org.junit.Assert.assertEquals; + +import java.util.ArrayList; +import java.util.Arrays; + +import org.junit.Test; + +import io.openems.edge.predictor.lstmmodel.common.HyperParameters; + +public class InterpolationMangerTest { + + private static HyperParameters hyperParameters = new HyperParameters(); + + @Test + public void calculateMeanShouldReturnNaNForEmptyList() { + ArrayList emptyList = new ArrayList<>(); + double result = InterpolationManager.calculateMean(emptyList); + assertEquals(Double.NaN, result, 0.0001); + } + + @Test + public void calculateMean_shouldReturnMeanWithoutNaN() { + ArrayList dataList = new ArrayList<>(Arrays.asList(1.0, 2.0, Double.NaN, 4.0, 5.0)); + double result = InterpolationManager.calculateMean(dataList); + assertEquals(3.0, result, 0.0001); + } + + @Test + public void testGroup() { + + ArrayList testData = new ArrayList<>(Arrays.asList(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 3.0, 4.0)); + int group = 3; + ArrayList> expectedGroupedData = new ArrayList<>(Arrays.asList(// + new ArrayList<>(Arrays.asList(1.0, 2.0, 3.0)), // + new ArrayList<>(Arrays.asList(4.0, 5.0, 6.0)), // + new ArrayList<>(Arrays.asList(3.0, 4.0))// + )); + + ArrayList> result = InterpolationManager.group(testData, group); + assertEquals(expectedGroupedData, result); + } + + @Test + public void testUnGroup() { + + ArrayList> groupedData = new ArrayList<>(); + groupedData.add(new ArrayList<>(Arrays.asList(1.0, 2.0, 3.0))); + groupedData.add(new ArrayList<>(Arrays.asList(4.0, 5.0, 6.0))); + + ArrayList expectedResult = new ArrayList<>(Arrays.asList(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)); + ArrayList result = InterpolationManager.unGroup(groupedData); + assertEquals(expectedResult, result); + } + + @Test + public void testInterpolationManagerCaseLinear() { + ArrayList data = new ArrayList<>(Arrays.asList(1.0, null, 3.0, Double.NaN, 5.0)); + ArrayList expectedData = new ArrayList<>(Arrays.asList(1.0, 2.0, 3.0, 4.0, 5.0)); + InterpolationManager interpolationManager = new InterpolationManager(data, hyperParameters); + assertEquals(interpolationManager.getInterpolatedData(), expectedData); + } + + @Test + public void testInterPolationManagerCaseCubical() { + ArrayList data = new ArrayList<>( + Arrays.asList(1.0, null, 3.0, Double.NaN, 5.0, 6.0, null, 7.0, 8.0, null, Double.NaN, 9.0)); + ArrayList expectedData = new ArrayList<>(Arrays.asList(1.0, 2.0092714608433737, 3.0, 3.9721856174698793, + 5.0, 6.0, 6.485598644578313, 7.0, 8.0, 8.671770414993308, 8.937416331994646, 9.0)); + InterpolationManager interpolationManager = new InterpolationManager(data, hyperParameters); + assertEquals(interpolationManager.getInterpolatedData(), expectedData); + } +} diff --git a/io.openems.edge.predictor.lstmmodel/test/io/openems/edge/predictor/lstmmodel/interpolation/LinearInterpolationTest.java b/io.openems.edge.predictor.lstmmodel/test/io/openems/edge/predictor/lstmmodel/interpolation/LinearInterpolationTest.java new file mode 100644 index 00000000000..0d03ce04ec6 --- /dev/null +++ b/io.openems.edge.predictor.lstmmodel/test/io/openems/edge/predictor/lstmmodel/interpolation/LinearInterpolationTest.java @@ -0,0 +1,59 @@ +package io.openems.edge.predictor.lstmmodel.interpolation; + +import static org.junit.Assert.assertEquals; + +import java.util.ArrayList; +import java.util.Arrays; + +import org.junit.Test; + +public class LinearInterpolationTest { + + @Test + public void determineInterpolatingPointsTest() { + + ArrayList data = new ArrayList<>(Arrays.asList(1.0, 2.0, 3.0, 4.0, Double.NaN, 6.0, 7.0, 8.0, + Double.NaN, Double.NaN, 11.0, 12.0, 13.0, Double.NaN, Double.NaN, Double.NaN, 17.0, Double.NaN, 19.0)); + + ArrayList> expectedResults = new ArrayList<>( + Arrays.asList(new ArrayList<>(Arrays.asList(3, 5)), new ArrayList<>(Arrays.asList(7, 10)), + new ArrayList<>(Arrays.asList(12, 16)), new ArrayList<>(Arrays.asList(16, 18)))); + + ArrayList> result = LinearInterpolation.determineInterpolatingPoints(data); + assertEquals(result, expectedResults); + } + + @Test + public void computeInterpolationTest() { + + ArrayList data = new ArrayList<>(Arrays.asList(1.0, 2.0, 3.0, 4.0, Double.NaN, 6.0, 7.0, 8.0, + Double.NaN, Double.NaN, 11.0, 12.0, 13.0, Double.NaN, Double.NaN, Double.NaN, 17.0, Double.NaN, 19.0)); + ArrayList expectedResult = new ArrayList<>(Arrays.asList(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, + 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0)); + ArrayList results = LinearInterpolation.interpolate(data); + assertEquals(results, expectedResult); + } + + @Test + public void combineTest() { + + ArrayList data = new ArrayList<>(Arrays.asList(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, + 12.0, 13.0, Double.NaN, Double.NaN, Double.NaN, 17.0, 18.0, 19.0)); + ArrayList interpoltedValue = new ArrayList<>(Arrays.asList(14.0, 15.0, 16.0)); + ArrayList expectedResult = new ArrayList<>(Arrays.asList(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, + 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0)); + ArrayList result = LinearInterpolation.combine(data, interpoltedValue, 12, 16); + assertEquals(result, expectedResult); + } + + @Test + public void computeInterPolation() { + int xval1 = 12; + int xValue2 = 16; + double yvalue1 = 13; + double yvalue2 = 17; + ArrayList expectedResult = new ArrayList<>(Arrays.asList(14.0, 15.0, 16.0)); + ArrayList result = LinearInterpolation.computeInterpolation(xval1, xValue2, yvalue1, yvalue2); + assertEquals(result, expectedResult); + } +} diff --git a/io.openems.edge.predictor.lstmmodel/test/io/openems/edge/predictor/lstmmodel/performance/PerformanceMatrixTest.java b/io.openems.edge.predictor.lstmmodel/test/io/openems/edge/predictor/lstmmodel/performance/PerformanceMatrixTest.java new file mode 100644 index 00000000000..6ef42237c01 --- /dev/null +++ b/io.openems.edge.predictor.lstmmodel/test/io/openems/edge/predictor/lstmmodel/performance/PerformanceMatrixTest.java @@ -0,0 +1,85 @@ +package io.openems.edge.predictor.lstmmodel.performance; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; + +import java.util.ArrayList; +import java.util.List; + +import org.junit.Test; + +public class PerformanceMatrixTest { + + public static final double delta = 0.0001; + + @Test + public void meanAbsoluteErrorTest() { + ArrayList target = new ArrayList<>(List.of(1.0, 2.0, 3.0)); + ArrayList predicted = new ArrayList<>(List.of(2.0, 2.5, 2.8)); + + double result = PerformanceMatrix.meanAbsoluteError(target, predicted); + assertEquals(0.5666, result, delta); + } + + @Test + public void meanAbsoluteErrorTestWithException() { + ArrayList target = new ArrayList<>(List.of(1.0, 2.0, 3.0)); + ArrayList predicted = new ArrayList<>(List.of(2.0, 2.5)); + assertThrows(IllegalArgumentException.class, () -> PerformanceMatrix.meanAbsoluteError(target, predicted)); + } + + @Test + public void rmsErrorTest() { + ArrayList target = new ArrayList<>(List.of(1.0, 2.0, 3.0)); + ArrayList predicted = new ArrayList<>(List.of(2.0, 2.5, 2.8)); + + double expectedRmsError = 0.6557; + double result = PerformanceMatrix.rmsError(target, predicted); + assertEquals(expectedRmsError, result, delta); + } + + @Test + public void rmsErrorWithException() { + ArrayList target = new ArrayList<>(List.of(1.0, 2.0, 3.0)); + ArrayList predicted = new ArrayList<>(List.of(2.0, 2.5)); + + assertThrows(IllegalArgumentException.class, () -> PerformanceMatrix.rmsError(target, predicted)); + } + + @Test + public void meanSquaredErrorTest() { + ArrayList target = new ArrayList<>(List.of(1.0, 2.0, 3.0)); + ArrayList predicted = new ArrayList<>(List.of(2.0, 2.5, 2.8)); + + double expectedMse = 0.43; + double result = PerformanceMatrix.meanSquaredError(target, predicted); + assertEquals(expectedMse, result, delta); + } + + @Test + public void meanSquaredErrorException() { + ArrayList target = new ArrayList<>(List.of(1.0, 2.0, 3.0)); + ArrayList predicted = new ArrayList<>(List.of(2.0, 2.5)); + + assertThrows(IllegalArgumentException.class, () -> PerformanceMatrix.meanSquaredError(target, predicted)); + } + + @Test + public void accuracyTest() { + ArrayList target = new ArrayList<>(List.of(1.0, 2.0, 3.0)); + ArrayList predicted = new ArrayList<>(List.of(1.2, 2.3, 3.2)); + + double allowedPercentage = 0.1; + double expectedAccuracy = 0.3333; + assertEquals(expectedAccuracy, PerformanceMatrix.accuracy(target, predicted, allowedPercentage), delta); + } + + @Test + public void accuracyTestWithEmptyList() { + ArrayList target = new ArrayList<>(); + ArrayList predicted = new ArrayList<>(); + + double allowedPercentage = 0.1; + assertEquals(Double.NaN, PerformanceMatrix.accuracy(target, predicted, allowedPercentage), delta); + } +} \ No newline at end of file diff --git a/io.openems.edge.predictor.lstmmodel/test/io/openems/edge/predictor/lstmmodel/preprocessing/CombineFeatureTest.java b/io.openems.edge.predictor.lstmmodel/test/io/openems/edge/predictor/lstmmodel/preprocessing/CombineFeatureTest.java new file mode 100644 index 00000000000..d82856a15e3 --- /dev/null +++ b/io.openems.edge.predictor.lstmmodel/test/io/openems/edge/predictor/lstmmodel/preprocessing/CombineFeatureTest.java @@ -0,0 +1,29 @@ +package io.openems.edge.predictor.lstmmodel.preprocessing; + +import static org.junit.Assert.assertTrue; + +import java.util.Arrays; + +import org.junit.Test; + +public class CombineFeatureTest { + + @Test + public void multiplication() { + double[] featureA = { 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0 }; + double[] featureB = { 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0 }; + double[] expected = { 1.0, 4.0, 9.0, 16.0, 25.0, 36.0, 49.0, 64.0, 81.0, 100.0, 121.0, 144.0 }; + + assertTrue(Arrays.equals(DataModification.elementWiseMultiplication(featureA, featureB), expected)); + + } + + @Test + public void divisioTest() { + double[] featureA = { 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0 }; + double[] featureB = { 1.0, 2.0, 3.0, 0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0 }; + double[] expected = { 1.0, 1.0, 1.0, 4.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 }; + + assertTrue(Arrays.equals(DataModification.elementWiseDiv(featureA, featureB), expected)); + } +} diff --git a/io.openems.edge.predictor.lstmmodel/test/io/openems/edge/predictor/lstmmodel/preprocessing/GroupByTest.java b/io.openems.edge.predictor.lstmmodel/test/io/openems/edge/predictor/lstmmodel/preprocessing/GroupByTest.java new file mode 100644 index 00000000000..6984c1e378b --- /dev/null +++ b/io.openems.edge.predictor.lstmmodel/test/io/openems/edge/predictor/lstmmodel/preprocessing/GroupByTest.java @@ -0,0 +1,196 @@ +/* + * package io.openems.edge.predictor.lstm.preprocessing; + * + * import static org.junit.Assert.assertEquals; + * + * import java.time.OffsetDateTime; import java.util.ArrayList; import + * java.util.Arrays; + * + * import org.junit.Test; + * + * public class GroupByTest { + * + * ArrayList testDatesxx = new ArrayList<>(Arrays.asList(// + * OffsetDateTime.parse("2023-01-01T00:00:00Z"), + * OffsetDateTime.parse("2023-01-01T00:05:00Z"), + * OffsetDateTime.parse("2023-01-01T00:10:00Z"), + * OffsetDateTime.parse("2023-01-01T00:15:00Z"), + * OffsetDateTime.parse("2023-01-01T00:20:00Z"), + * OffsetDateTime.parse("2023-01-01T00:25:00Z"), + * OffsetDateTime.parse("2023-01-01T00:30:00Z"), + * OffsetDateTime.parse("2023-01-01T00:35:00Z"), + * OffsetDateTime.parse("2023-01-01T00:40:00Z"), + * OffsetDateTime.parse("2023-01-01T00:45:00Z"), + * OffsetDateTime.parse("2023-01-01T00:50:00Z"), + * OffsetDateTime.parse("2023-01-01T00:55:00Z"), + * OffsetDateTime.parse("2023-01-01T01:00:00Z"), + * OffsetDateTime.parse("2023-01-01T01:05:00Z"), + * OffsetDateTime.parse("2023-01-01T01:10:00Z"), + * OffsetDateTime.parse("2023-01-01T01:15:00Z"), + * OffsetDateTime.parse("2023-01-01T01:20:00Z"), + * OffsetDateTime.parse("2023-01-01T01:25:00Z"), + * OffsetDateTime.parse("2023-01-01T01:30:00Z"), + * OffsetDateTime.parse("2023-01-01T01:35:00Z"), + * OffsetDateTime.parse("2023-01-01T01:40:00Z"), + * OffsetDateTime.parse("2023-01-01T01:45:00Z"), + * OffsetDateTime.parse("2023-01-01T01:50:00Z"), + * OffsetDateTime.parse("2023-01-01T01:55:00Z"), + * OffsetDateTime.parse("2023-01-01T02:00:00Z"), + * OffsetDateTime.parse("2023-01-01T02:05:00Z"), + * OffsetDateTime.parse("2023-01-01T02:10:00Z"), + * OffsetDateTime.parse("2023-01-01T02:15:00Z"), + * OffsetDateTime.parse("2023-01-01T02:20:00Z"), + * OffsetDateTime.parse("2023-01-01T02:25:00Z"), + * OffsetDateTime.parse("2023-01-01T02:30:00Z"), + * OffsetDateTime.parse("2023-01-01T02:35:00Z"), + * OffsetDateTime.parse("2023-01-01T02:40:00Z"), + * OffsetDateTime.parse("2023-01-01T02:45:00Z"), + * OffsetDateTime.parse("2023-01-01T02:50:00Z"), + * OffsetDateTime.parse("2023-01-01T02:55:00Z"), + * OffsetDateTime.parse("2023-01-01T03:00:00Z"), + * OffsetDateTime.parse("2023-01-01T03:05:00Z"), + * OffsetDateTime.parse("2023-01-01T03:10:00Z"), + * OffsetDateTime.parse("2023-01-01T03:15:00Z"), + * OffsetDateTime.parse("2023-01-01T03:20:00Z"), + * OffsetDateTime.parse("2023-01-01T03:25:00Z"), + * OffsetDateTime.parse("2023-01-01T03:30:00Z"), + * OffsetDateTime.parse("2023-01-01T03:35:00Z"), + * OffsetDateTime.parse("2023-01-01T03:40:00Z"), + * OffsetDateTime.parse("2023-01-01T03:45:00Z"), + * OffsetDateTime.parse("2023-01-01T03:50:00Z"), + * OffsetDateTime.parse("2023-01-01T03:55:00Z"), + * OffsetDateTime.parse("2023-01-01T04:00:00Z"), + * OffsetDateTime.parse("2023-01-01T04:05:00Z"), + * OffsetDateTime.parse("2023-01-01T04:10:00Z"), + * OffsetDateTime.parse("2023-01-01T04:15:00Z"), + * OffsetDateTime.parse("2023-01-01T04:20:00Z"), + * OffsetDateTime.parse("2023-01-01T04:25:00Z"), + * OffsetDateTime.parse("2023-01-01T04:30:00Z"), + * OffsetDateTime.parse("2023-01-01T04:35:00Z"), + * OffsetDateTime.parse("2023-01-01T04:40:00Z"), + * OffsetDateTime.parse("2023-01-01T04:45:00Z"), + * OffsetDateTime.parse("2023-01-01T04:50:00Z"), + * OffsetDateTime.parse("2023-01-01T04:55:00Z"), + * OffsetDateTime.parse("2023-01-01T05:00:00Z"), + * OffsetDateTime.parse("2023-01-01T05:05:00Z"), + * OffsetDateTime.parse("2023-01-01T05:10:00Z"), + * OffsetDateTime.parse("2023-01-01T05:15:00Z"), + * OffsetDateTime.parse("2023-01-01T05:20:00Z"), + * OffsetDateTime.parse("2023-01-01T05:25:00Z"), + * OffsetDateTime.parse("2023-01-01T05:30:00Z"), + * OffsetDateTime.parse("2023-01-01T05:35:00Z"), + * OffsetDateTime.parse("2023-01-01T05:40:00Z"), + * OffsetDateTime.parse("2023-01-01T05:45:00Z"), + * OffsetDateTime.parse("2023-01-01T05:50:00Z"), + * OffsetDateTime.parse("2023-01-01T05:55:00Z"), + * OffsetDateTime.parse("2023-01-01T06:00:00Z"), + * OffsetDateTime.parse("2023-01-01T06:05:00Z"), + * OffsetDateTime.parse("2023-01-01T06:10:00Z"), + * OffsetDateTime.parse("2023-01-01T06:15:00Z"), + * OffsetDateTime.parse("2023-01-01T06:20:00Z"), + * OffsetDateTime.parse("2023-01-01T06:25:00Z"), + * OffsetDateTime.parse("2023-01-01T06:30:00Z"), + * OffsetDateTime.parse("2023-01-01T06:35:00Z"), + * OffsetDateTime.parse("2023-01-01T06:40:00Z"), + * OffsetDateTime.parse("2023-01-01T06:45:00Z"), + * OffsetDateTime.parse("2023-01-01T06:50:00Z"), + * OffsetDateTime.parse("2023-01-01T06:55:00Z"), + * OffsetDateTime.parse("2023-01-01T07:00:00Z"), + * OffsetDateTime.parse("2023-01-01T07:05:00Z"), + * OffsetDateTime.parse("2023-01-01T07:10:00Z"), + * OffsetDateTime.parse("2023-01-01T07:15:00Z"), + * OffsetDateTime.parse("2023-01-01T07:20:00Z"), + * OffsetDateTime.parse("2023-01-01T07:25:00Z"), + * OffsetDateTime.parse("2023-01-01T07:30:00Z"), + * OffsetDateTime.parse("2023-01-01T07:35:00Z"), + * OffsetDateTime.parse("2023-01-01T07:40:00Z"), + * OffsetDateTime.parse("2023-01-01T07:45:00Z"), + * OffsetDateTime.parse("2023-01-01T07:50:00Z"), + * OffsetDateTime.parse("2023-01-01T07:55:00Z"), + * OffsetDateTime.parse("2023-01-01T08:00:00Z"), + * OffsetDateTime.parse("2023-01-01T08:05:00Z"), + * OffsetDateTime.parse("2023-01-01T08:10:00Z"), + * OffsetDateTime.parse("2023-01-01T08:15:00Z"), + * + * OffsetDateTime.parse("2023-11-24T00:00:00+02:00"), + * OffsetDateTime.parse("2023-11-24T00:05:00+02:00"), + * OffsetDateTime.parse("2023-11-24T00:10:00+02:00"), + * OffsetDateTime.parse("2023-11-24T00:15:00+02:00"), + * OffsetDateTime.parse("2023-11-24T00:20:00+02:00"), + * OffsetDateTime.parse("2023-11-24T00:25:00+02:00"), + * OffsetDateTime.parse("2023-11-24T00:30:00+02:00"), + * OffsetDateTime.parse("2023-11-24T00:35:00+02:00"), + * OffsetDateTime.parse("2023-11-24T00:40:00+02:00"), + * OffsetDateTime.parse("2023-11-24T00:45:00+02:00"), + * OffsetDateTime.parse("2023-11-24T00:50:00+02:00"), + * OffsetDateTime.parse("2023-11-24T00:55:00+02:00") + * + * )); + * + * + * private ArrayList testDataxx = new ArrayList<>(Arrays.asList(421.0, + * 408.0, 360.0, 357.0, 330.0, 329.0, 330.0, 376.0, 356.0, 334.0, 352.0, 319.0, + * 247.0, 185.0, 174.0, 226.0, 317.0, 303.0, 299.0, 368.0, 345.0, 309.0, 302.0, + * 374.0, 366.0, 343.0, 334.0, 340.0, 348.0, 306.0, 306.0, 370.0, 335.0, 283.0, + * 283.0, 278.0, 299.0, 250.0, 244.0, 311.0, 290.0, 280.0, 282.0, 324.0, 380.0, + * 380.0, 372.0, 379.0, 306.0, 296.0, 312.0, 363.0, 367.0, 334.0, 309.0, 312.0, + * 308.0, 667.0, 386.0, 364.0, 336.0, 312.0, 310.0, 343.0, 317.0, 406.0, 396.0, + * 371.0, 357.0, 363.0, 318.0, 304.0, 302.0, 343.0, 327.0, 292.0, 283.0, 272.0, + * 262.0, 311.0, 331.0, 381.0, 401.0, 421.0, 474.0, 463.0, 426.0, 379.0, 801.0, + * 511.0, 453.0, 351.0, 415.0, 476.0, 508.0, 451.0, 435.0, 421.0, 466.0, 599.0, + * 421.0, 408.0, 360.0, 357.0, 330.0, 329.0, 330.0, 376.0, 356.0, 334.0, 352.0, + * 319.0)); + * + * + * + * @Test public void testGroupByHour() { + * + * System.out.println(testDatesxx.size()); + * System.out.println(testDataxx.size()); GroupBy groupBy = new + * GroupBy(testDataxx, testDatesxx); + * + * groupBy.hour(); + * + * assertEquals(2, groupBy.getGroupedDataByHour().size()); assertEquals(2, + * groupBy.getGroupedDateByHour().size()); + * + * assertEquals(Arrays.asList(1.0, 3.0, 5.0), + * groupBy.getGroupedDataByHour().get(0)); + * assertEquals(Arrays.asList(OffsetDateTime.parse("2022-01-01T10:15:30Z"), + * OffsetDateTime.parse("2022-01-01T10:45:00Z"), // + * OffsetDateTime.parse("2022-01-01T10:30:00Z")), + * groupBy.getGroupedDateByHour().get(0)); + * + * assertEquals(Arrays.asList(2.0, 4.0), groupBy.getGroupedDataByHour().get(1)); + * assertEquals(Arrays.asList(OffsetDateTime.parse("2022-01-01T11:30:45Z"), + * OffsetDateTime.parse("2022-01-01T11:15:00Z")), + * groupBy.getGroupedDateByHour().get(1)); } + * + * + * + * @Test public void testGroupByMinute() { + * + * + * GroupBy groupBy = new GroupBy(testDataxx, testDatesxx); groupBy.minute(); + * + * assertEquals(3, groupBy.getGroupedDataByMinute().size()); assertEquals(3, + * groupBy.getGroupedDateByMinute().size()); + * + * assertEquals(Arrays.asList(1.0, 4.0), + * groupBy.getGroupedDataByMinute().get(0)); assertEquals( + * Arrays.asList(OffsetDateTime.parse("2022-01-01T10:15:30Z"), + * OffsetDateTime.parse("2022-01-01T11:15:00Z")), // + * groupBy.getGroupedDateByMinute().get(0)); + * + * assertEquals(Arrays.asList(2.0, 5.0), + * groupBy.getGroupedDataByMinute().get(1)); + * assertEquals(Arrays.asList(OffsetDateTime.parse("2022-01-01T11:30:45Z"), // + * OffsetDateTime.parse("2022-01-01T10:30Z")// + * + * ), groupBy.getGroupedDateByMinute().get(1)); + * + * } + * + * + * } + */ \ No newline at end of file diff --git a/io.openems.edge.predictor.lstmmodel/test/io/openems/edge/predictor/lstmmodel/preprocessing/GroupToStiffWindowlTest.java b/io.openems.edge.predictor.lstmmodel/test/io/openems/edge/predictor/lstmmodel/preprocessing/GroupToStiffWindowlTest.java new file mode 100644 index 00000000000..833cc3a831a --- /dev/null +++ b/io.openems.edge.predictor.lstmmodel/test/io/openems/edge/predictor/lstmmodel/preprocessing/GroupToStiffWindowlTest.java @@ -0,0 +1,75 @@ +package io.openems.edge.predictor.lstmmodel.preprocessing; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; + +import java.util.ArrayList; +import java.util.Arrays; + +import org.junit.Test; + +import io.openems.edge.predictor.lstmmodel.preprocessingpipeline.GroupToStiffWindowPipe; + +public class GroupToStiffWindowlTest { + @Test + public void testGroupToStiffedWindow() { + ArrayList inputValues = new ArrayList<>(Arrays.asList(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)); + int windowSize = 2; + + double[][] result = GroupToStiffWindowPipe.groupToStiffedWindow(inputValues, windowSize); + + double[][] expected = { { 1.0, 2.0 }, { 4.0, 5.0 } }; + assertArrayEquals("Windowing is incorrect", expected, result); + } + + @Test + public void testGroupToStiffedWindow1() { + ArrayList inputValues = new ArrayList<>(Arrays.asList(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0)); + int windowSize = 2; + + double[][] result = GroupToStiffWindowPipe.groupToStiffedWindow(inputValues, windowSize); + // double[][] resultX = + // GroupToStiffWindowPipe.groupToStiffedWindowX(inputValues, windowSize); + + double[][] expected = { { 1.0, 2.0 }, { 4.0, 5.0 }, { 7.0, 8.0 } }; + assertArrayEquals("Windowing is incorrect", expected, result); + } + + @Test + public void testGroupToStiffedWindowWithInvalidSize() { + ArrayList inputValues = new ArrayList<>(Arrays.asList(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)); + int windowSize = 7; + + IllegalArgumentException exception = assertThrows(// + IllegalArgumentException.class, () -> { + GroupToStiffWindowPipe.groupToStiffedWindow(inputValues, windowSize); + }); + + assertEquals("Invalid window size", exception.getMessage()); + } + + @Test + public void testGroupToStiffedTarget() { + ArrayList inputValues = new ArrayList<>(Arrays.asList(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)); + int windowSize = 2; + + double[] result = GroupToStiffWindowPipe.groupToStiffedTarget(inputValues, windowSize); + + double[] expected = { 3.0, 6.0 }; + assertArrayEquals(expected, result, 0.001); + } + + @Test + public void testGroupToStiffedTargetWithInvalidSize() { + ArrayList inputValues = new ArrayList<>(Arrays.asList(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)); + int windowSize = 7; + + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> { + GroupToStiffWindowPipe.groupToStiffedTarget(inputValues, windowSize); + }); + + assertEquals("Invalid window size", exception.getMessage()); + } + +} diff --git a/io.openems.edge.predictor.lstmmodel/test/io/openems/edge/predictor/lstmmodel/preprocessing/MovingAverageTest.java b/io.openems.edge.predictor.lstmmodel/test/io/openems/edge/predictor/lstmmodel/preprocessing/MovingAverageTest.java new file mode 100644 index 00000000000..87c38123aac --- /dev/null +++ b/io.openems.edge.predictor.lstmmodel/test/io/openems/edge/predictor/lstmmodel/preprocessing/MovingAverageTest.java @@ -0,0 +1,38 @@ +package io.openems.edge.predictor.lstmmodel.preprocessing; + +import org.junit.Test; + +import io.openems.edge.predictor.lstmmodel.utilities.UtilityConversion; + +public class MovingAverageTest { + + @Test + public void test() { + + double[] data = { 393.0, 555.0, 482.0, 297.0, 317.0, 162.0, 157.0, 208.0, 243.0, 312.0, 377.0, 393.0, 308.0, + 287.0, 229.0, 226.0, 221.0, 259.0, 277.0, 284.0, 265.0, 250.0, 218.0, 151.0, 155.0, 214.0, 184.0, 148.0, + 221.0, 249.0, 290.0, 199.0, 240.0, 264.0, 193.0, 176.0, 147.0, 232.0, 275.0, 269.0, 319.0, 247.0, 230.0, + 225.0, 228.0, 221.0, 238.0, 321.0, 325.0, 228.0, 221.0, 198.0, 227.0, 278.0, 288.0, 338.0, 304.0, 307.0, + 238.0, 195.0, 153.0, 166.0, 205.0, 263.0, 157.0, 190.0, 280.0, 275.0, 240.0, 288.0, 306.0, 285.0, 281.0, + 273.0, 285.0, 346.0, 374.0, 338.0, 334.0, 288.0, 221.0, 168.0, 160.0, 161.0, 221.0, 279.0, 257.0, 324.0, + 178.0, 167.0, 192.0, 204.0, 188.0, 174.0, 233.0, 217.0, 194.0, 203.0, 294.0, 450.0, 583.0, 1835.0, + 2672.0, 2652.0, 3003.0, 2690.0, 2747.0, 2880.0, 2761.0, 2826.0, 2881.0, 2926.0, 2971.0, 2951.0, 3342.0, + 3592.0, 3367.0, 3271.0, 3323.0, 3446.0, 5255.0, 5354.0, 5467.0, 5127.0, 3525.0, 3206.0, 3106.0, 2961.0, + 3085.0, 3116.0, 2466.0, 2771.0, 2817.0, 3004.0, 5151.0, 5386.0, 5333.0, 5471.0, 5172.0, 5177.0, 5248.0, + 5259.0, 5255.0, 5631.0, 5410.0, 5409.0, 5415.0, 5350.0, 5782.0, 6611.0, 6545.0, 6145.0, 5406.0, 5192.0, + 3745.0, 4630.0, 4149.0, 4106.0, 5432.0, 7020.0, 4766.0, 4047.0, 3858.0, 3689.0, 2505.0, 1773.0, 1257.0, + 1018.0, 1000.0, 1422.0, 960.0, 707.0, 765.0, 797.0, 641.0, 699.0, 693.0, 681.0, 627.0, 581.0, 605.0, + 608.0, 572.0, 648.0, 637.0, 711.0, 695.0, 736.0, 705.0, 696.0, 710.0, 725.0, 662.0, 630.0, 627.0, 638.0, + 608.0, 518.0, 448.0, 417.0, 353.0, 394.0, 353.0, 397.0, 421.0, 413.0, 409.0, 374.0, 353.0, 356.0, 392.0, + 419.0, 419.0, 345.0, 410.0, 342.0, 339.0, 343.0, 242.0, 230.0, 226.0, 274.0, 367.0, 325.0, 276.0, 249.0, + 281.0, 313.0, 247.0, 230.0, 233.0, 208.0, 286.0, 231.0, 208.0, 210.0, 281.0, 346.0, 324.0, 674.0, 469.0, + 283.0, 369.0, 340.0, 325.0, 379.0, 353.0, 351.0, 416.0, 432.0, 576.0, 655.0, 689.0, 734.0, 662.0, 696.0, + 686.0, 694.0, 668.0, 589.0, 489.0, 453.0, 453.0, 454.0, 494.0, 440.0, 435.0, 463.0, 490.0, 460.0, 478.0, + 434.0, 506.0, 489.0, 545.0, 622.0, 696.0, 613.0, 562.0, 521.0, 551.0, 492.0, 475.0, 507.0, 489.0, 488.0, + 471.0, 401.0 }; + + System.out.println("Converted = " + UtilityConversion.to1DArrayList(MovingAverage.movingAverage(data))); + System.out.println("orginal = " + UtilityConversion.to1DArrayList(data)); + } + +} diff --git a/io.openems.edge.predictor.lstmmodel/test/io/openems/edge/predictor/lstmmodel/preprocessing/NormalizePipelineTest.java b/io.openems.edge.predictor.lstmmodel/test/io/openems/edge/predictor/lstmmodel/preprocessing/NormalizePipelineTest.java new file mode 100644 index 00000000000..bc4ba008bab --- /dev/null +++ b/io.openems.edge.predictor.lstmmodel/test/io/openems/edge/predictor/lstmmodel/preprocessing/NormalizePipelineTest.java @@ -0,0 +1,27 @@ +package io.openems.edge.predictor.lstmmodel.preprocessing; + +import org.junit.Test; + +import io.openems.edge.predictor.lstmmodel.common.HyperParameters; +import io.openems.edge.predictor.lstmmodel.preprocessingpipeline.NormalizePipe; + +public class NormalizePipelineTest { + + @Test + public void test() { + double[] data = { 1.0, 2.0 }; + double[][] data2D = { { 1.0, 2.0 }, { 3.0, 4.0 } }; + + var hyperParameters = new HyperParameters(); + var np = new NormalizePipe(hyperParameters); + np.execute(data); + var result = (double[][]) np.execute(data2D); + System.out.print(result[0][0]); + + } + + + +} + + diff --git a/io.openems.edge.predictor.lstmmodel/test/io/openems/edge/predictor/lstmmodel/preprocessing/OutliersDetectionTest.java b/io.openems.edge.predictor.lstmmodel/test/io/openems/edge/predictor/lstmmodel/preprocessing/OutliersDetectionTest.java new file mode 100644 index 00000000000..8fb21ad7f8f --- /dev/null +++ b/io.openems.edge.predictor.lstmmodel/test/io/openems/edge/predictor/lstmmodel/preprocessing/OutliersDetectionTest.java @@ -0,0 +1,20 @@ +package io.openems.edge.predictor.lstmmodel.preprocessing; + +import org.junit.Test; + +import io.openems.edge.predictor.lstmmodel.common.HyperParameters; +import io.openems.edge.predictor.lstmmodel.utilities.UtilityConversion; + +public class OutliersDetectionTest { + + @Test + public void test() { + HyperParameters hyperParameters = new HyperParameters(); + double[] data = { 1, 2, 4, 8, 1000, 4000, 9, 8, 7, 2, 3, 8, 7, 7, 9, 7 }; + var dataScaled = DataModification.scale(data, hyperParameters.getScalingMin(), hyperParameters.getScalingMax()); + var dataWithoutOutliers = FilterOutliers.filterOutlier(dataScaled); + System.out.println(UtilityConversion.to1DArrayList(DataModification.scaleBack(dataWithoutOutliers, + hyperParameters.getScalingMin(), hyperParameters.getScalingMax()))); + } + +} diff --git a/io.openems.edge.predictor.lstmmodel/test/io/openems/edge/predictor/lstmmodel/preprocessing/PreprocessingPipe.java b/io.openems.edge.predictor.lstmmodel/test/io/openems/edge/predictor/lstmmodel/preprocessing/PreprocessingPipe.java new file mode 100644 index 00000000000..c7af31fb23a --- /dev/null +++ b/io.openems.edge.predictor.lstmmodel/test/io/openems/edge/predictor/lstmmodel/preprocessing/PreprocessingPipe.java @@ -0,0 +1,71 @@ +//package io.openems.edge.predictor.lstm.preprocessing; +// +//import static org.junit.Assert.assertEquals; +// +//import java.util.ArrayList; +// +//import org.junit.Test; +// +//import io.openems.edge.predictor.lstm.common.HyperParameters; +//import io.openems.edge.predictor.lstm.preprocessingpipeline.PreprocessingPipeImpl; +//import io.openems.edge.predictor.lstm.preprocessingpipeline.ScalingPipe; +//import io.openems.edge.predictor.lstm.preprocessingpipeline.TrainandTestSplitPipe; +//import io.openems.edge.predictor.lstm.utilities.UtilityConversion; +// +//public class PreprocessingPipe { +// +// @Test +// public void trainandTestSplitPipeTest() { +// double[] data = { 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0 }; +// double[] res1 = { 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, }; +// double[] res2 = { 7.0, 8.0, 9.0, 10.0 }; +// HyperParameters hyp = new HyperParameters(); +// hyp.setDatasplitTrain(.7); +// TrainandTestSplitPipe ttsp = new TrainandTestSplitPipe(hyp); +// assertEquals(ttsp.execute(data)[0], res1); +// assertEquals(ttsp.execute(data)[1], res2); +// +// } +// +// @Test +// +// public void scalingPipeTest() { +// double[] data = { 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0 }; +// +// double[] result = { 0.2, 0.2666666666666667, 0.33333333333333337, 0.4, 0.46666666666666673, 0.5333333333333334, +// 0.6000000000000001, 0.6666666666666667, 0.7333333333333334, 0.8 }; +// +// HyperParameters hyp = new HyperParameters(); +// hyp.setScalingMax(10.0); +// hyp.setScalingMin(1); +// ScalingPipe sp = new ScalingPipe(hyp); +// +// } +// +// @Test +// public void preprocessingPiplineTest() { +// double[] data = { 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, +// 18.0 }; +// HyperParameters hyp = new HyperParameters(); +// hyp.setScalingMax(100); +// hyp.setScalingMin(1); +// hyp.setDatasplitTrain(.7); +// hyp.setWindowSizeTrend(2); +// hyp.setWindowSizeSeasonality(7); +// +// PreprocessingPipeImpl ppimplObj1 = new PreprocessingPipeImpl(hyp); +// ArrayList> temp = UtilityConversion +// .convert2DArrayTo2DArrayList((double[][]) ppimplObj1.setData(data).scale().trainTestSplit().execute()); +// // System.out.println(temp); +// +// PreprocessingPipeImpl ppimplObj2 = new PreprocessingPipeImpl(hyp); +// double[][][] normalized = (double[][][]) ppimplObj2.setData(data).groupToStiffedWindow().execute(); +// ArrayList> normalizedData = UtilityConversion.convert2DArrayTo2DArrayList(normalized[0]); +// ArrayList target = UtilityConversion.convert1DArrayTo1DArrayList(normalized[1][0]); +// +// System.out.println("Target = " + target); +// System.out.println("data = " + normalizedData); +// +// } +// +//} diff --git a/io.openems.edge.predictor.lstmmodel/test/io/openems/edge/predictor/lstmmodel/preprocessing/ShuffleTest.java b/io.openems.edge.predictor.lstmmodel/test/io/openems/edge/predictor/lstmmodel/preprocessing/ShuffleTest.java new file mode 100644 index 00000000000..38bb43f0885 --- /dev/null +++ b/io.openems.edge.predictor.lstmmodel/test/io/openems/edge/predictor/lstmmodel/preprocessing/ShuffleTest.java @@ -0,0 +1,35 @@ +package io.openems.edge.predictor.lstmmodel.preprocessing; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotEquals; + +import org.junit.Test; + +public class ShuffleTest { + + @Test + public void testShuffle() { + double[][] originalData = { // + { 1.0, 2.0, 3.0 }, // + { 4.0, 5.0, 6.0 }, // + { 7.0, 8.0, 9.0 }, // + { 10.0, 11.0, 12.0 }, // + { 13.0, 14.0, 15.0 }// + }; + + double[] originalTarget = { 10.0, 20.0, 30.0, 40.0, 50.0 }; + + Shuffle shuffle = new Shuffle(originalData, originalTarget); + + double[][] shuffledData = shuffle.getData(); + double[] shuffledTarget = shuffle.getTarget(); + + assertNotEquals(originalData, shuffledData); + assertNotEquals(originalTarget, shuffledTarget); + + assertEquals(originalData.length, shuffledData.length); + assertEquals(originalTarget.length, shuffledTarget.length); + + } + +} diff --git a/io.openems.edge.predictor.lstmmodel/test/io/openems/edge/predictor/lstmmodel/train/BatchImplementationTest.java b/io.openems.edge.predictor.lstmmodel/test/io/openems/edge/predictor/lstmmodel/train/BatchImplementationTest.java new file mode 100644 index 00000000000..59a66021824 --- /dev/null +++ b/io.openems.edge.predictor.lstmmodel/test/io/openems/edge/predictor/lstmmodel/train/BatchImplementationTest.java @@ -0,0 +1,153 @@ +package io.openems.edge.predictor.lstmmodel.train; + +import java.time.OffsetDateTime; +import java.util.ArrayList; + +import org.junit.Test; + +import io.openems.edge.predictor.lstmmodel.common.HyperParameters; +import io.openems.edge.predictor.lstmmodel.common.ReadAndSaveModels; +import io.openems.edge.predictor.lstmmodel.common.ReadCsv; +import io.openems.edge.predictor.lstmmodel.preprocessing.DataModification; + +public class BatchImplementationTest { + /** + * Batch testing. + */ + @Test + public void trainInBatchtest() { + + HyperParameters hyperParameters; + String modelName = "ConsumptionActivePower"; + + hyperParameters = ReadAndSaveModels.read(modelName); + + int check = hyperParameters.getOuterLoopCount(); + + for (int i = check; i <= 25; i++) { + + hyperParameters.setOuterLoopCount(i); + + final String pathTrain = Integer.toString(4) + ".csv"; + final String pathValidate = Integer.toString(4) + ".csv"; + System.out.println(""); + + hyperParameters.printHyperParameters(); + hyperParameters.setLearningRateLowerLimit(0.00001); + hyperParameters.setLearningRateUpperLimit(0.001); + + System.out.println(""); + + System.out.println(pathTrain); + System.out.println(pathValidate); + + ReadCsv obj1 = new ReadCsv(pathTrain); + final ReadCsv obj2 = new ReadCsv(pathValidate); + + var validateBatchData = DataModification.getDataInBatch(obj2.getData(), 6).get(1); + var validateBatchDate = DataModification.getDateInBatch(obj2.getDates(), 6).get(1); + + // ReadAndSaveModels.adapt(hyperParameters, validateBatchData, + // validateBatchDate); + + new TrainAndValidateBatch( + DataModification.constantScaling(DataModification.removeNegatives(obj1.getData()), 1), + obj1.getDates(), + DataModification.constantScaling(DataModification.removeNegatives(validateBatchData), 1), + validateBatchDate, hyperParameters); + + hyperParameters.setEpochTrack(0); + hyperParameters.setBatchTrack(0); + hyperParameters.setOuterLoopCount(hyperParameters.getOuterLoopCount() + 1); + ReadAndSaveModels.save(hyperParameters); + + } + } + + // @Test + protected void trainInBatchtestMultivarient() { + + HyperParameters hyperParameters; + String modelName = "ConsumptionActivePower"; + + hyperParameters = ReadAndSaveModels.read(modelName); + + int check = hyperParameters.getOuterLoopCount(); + + for (int i = check; i <= 25; i++) { + + hyperParameters.setOuterLoopCount(i); + + final String pathTrain = Integer.toString(i + 4) + ".csv"; + final String pathValidate = Integer.toString(i + 4) + ".csv"; + System.out.println(""); + + hyperParameters.printHyperParameters(); + hyperParameters.setLearningRateLowerLimit(0.00001); + hyperParameters.setLearningRateUpperLimit(0.001); + + System.out.println(""); + + System.out.println(pathTrain); + System.out.println(pathValidate); + + ReadCsv obj1 = new ReadCsv(pathTrain); + final ReadCsv obj2 = new ReadCsv(pathValidate); + + var trainingref = this.generateRefrence(obj1.getDates()); + var validationref = this.generateRefrence(obj2.getDates()); + + var trainingData = DataModification.elementWiseMultiplication(trainingref, obj1.getData()); + var validationData = DataModification.elementWiseMultiplication(validationref, obj2.getData()); + + var validateBatchData = DataModification.getDataInBatch(validationData, 6).get(1); + var validateBatchDate = DataModification.getDateInBatch(obj2.getDates(), 6).get(1); + + // ReadAndSaveModels.adapt(hyperParameters, validateBatchData, + // validateBatchDate); + + new TrainAndValidateBatch(trainingData, obj1.getDates(), validateBatchData, validateBatchDate, + hyperParameters); + + hyperParameters.setEpochTrack(0); + hyperParameters.setBatchTrack(0); + hyperParameters.setOuterLoopCount(hyperParameters.getOuterLoopCount() + 1); + ReadAndSaveModels.save(hyperParameters); + + } + + } + + /** + * Generates a list of reference values based on the provided list of + * OffsetDateTime objects. Each reference value is calculated using the cosine + * of the angle corresponding to the time of day represented by each + * OffsetDateTime. The formula used is: - One hour corresponds to 360/24 + * degrees. - One minute corresponds to 360/(24*60) degrees. + * + * @param date an ArrayList of OffsetDateTime objects representing the date and + * time. + * @return an ArrayList of Double values representing the generated reference + * values. + */ + public ArrayList generateRefrence(ArrayList date) { + ArrayList data = new ArrayList(); + + for (int i = 0; i < date.size(); i++) { + // Extract the hour and minute from the current OffsetDateTime. + int hour = date.get(i).getHour(); + int minute = date.get(i).getMinute(); + + // Calculate the degree values for the hour and minute. + double deg = 360.0 * hour / 24.0; + double degDec = 360.0 * minute / (24.0 * 60.0); + double angle = deg + degDec; + + // Calculate the cosine of the angle in radians and add 1.5 to the result. + double addVal = Math.cos(Math.toRadians(angle)); + data.add(1.5 + addVal); + } + return data; + } + +} \ No newline at end of file diff --git a/io.openems.edge.predictor.lstmmodel/test/io/openems/edge/predictor/lstmmodel/train/MakeModelImplementationTest.java b/io.openems.edge.predictor.lstmmodel/test/io/openems/edge/predictor/lstmmodel/train/MakeModelImplementationTest.java new file mode 100644 index 00000000000..3b72896221d --- /dev/null +++ b/io.openems.edge.predictor.lstmmodel/test/io/openems/edge/predictor/lstmmodel/train/MakeModelImplementationTest.java @@ -0,0 +1,59 @@ +//package io.openems.edge.predictor.lstm.train; +// +//import java.io.IOException; +// +//import org.junit.Test; +// +//import io.openems.edge.predictor.lstm.common.DeletModels; +//import io.openems.edge.predictor.lstm.common.GetObject; +//import io.openems.edge.predictor.lstm.common.HyperParameters; +//import io.openems.edge.predictor.lstm.common.ReadCsv; +//import io.openems.edge.predictor.lstm.common.SaveObject; +// +//public class MakeModelImplementationTest { +// +// @Test +// public static void itter() { +// HyperParameters hyperParameters; +// String modelName = "Consumption"; +// try { +// hyperParameters = (HyperParameters) GetObject.get(modelName); +// } catch (ClassNotFoundException | IOException e) { +// // TODO Auto-generated catch block +// System.out.println("Creating new hyperparameter object"); +// hyperParameters = HyperParameters.getInstance(); +// hyperParameters.setModelName(modelName); +// } +// +// // checking if the training has been completed in previous batch +// +// if (hyperParameters.getEpochTrack() == hyperParameters.getEpoch()) { +// // hyperParameters.setEpochTrack(0); +// +// } +// +// int check = hyperParameters.getOuterLoopCount(); +// for (int i = check; i <= 8; i++) { +// hyperParameters.setOuterLoopCount(i); +// System.out.println("Batch:" + i + "/" + 28); +// System.out.println("count :" + hyperParameters.getCount()); +// System.out.println("Rms Error of all train for the trend is = " + hyperParameters.getRmsErrorTrend()); +// System.out.println( +// "Rms Error of all train for the seasonality is = " + hyperParameters.getRmsErrorSeasonality()); +// +// String pathTrain = Integer.toString(i + 1) + ".csv"; +// String pathValidate = Integer.toString(9) + ".csv"; +// +// ReadCsv obj1 = new ReadCsv(pathTrain); +// final ReadCsv obj2 = new ReadCsv(pathValidate); +// +// new TrainAndValidate(obj1.getData(), obj1.getDates(), obj2.getData(), obj2.getDates(), hyperParameters); +// +// hyperParameters.setEpochTrack(0); +// hyperParameters.setOuterLoopCount(hyperParameters.getOuterLoopCount() + 1); +// SaveObject.save(hyperParameters); +// DeletModels.delet(hyperParameters); +// } +// +// } +//} diff --git a/io.openems.edge.predictor.lstmmodel/test/io/openems/edge/predictor/lstmmodel/train/MakeModelTest.java b/io.openems.edge.predictor.lstmmodel/test/io/openems/edge/predictor/lstmmodel/train/MakeModelTest.java new file mode 100644 index 00000000000..d0afe698400 --- /dev/null +++ b/io.openems.edge.predictor.lstmmodel/test/io/openems/edge/predictor/lstmmodel/train/MakeModelTest.java @@ -0,0 +1,35 @@ +package io.openems.edge.predictor.lstmmodel.train; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; + +import java.util.ArrayList; + +import org.junit.Test; + +import io.openems.edge.predictor.lstmmodel.common.HyperParameters; + +public class MakeModelTest { + + @Test + public void testGenerateInitialWeightMatrixOriginal() { + // Result should be + // [ + // [1.0, 1.0, 1.0], + // [1.0, 1.0, 1.0], + // [1.0, 1.0, 1.0], + // [-1.0, -1.0, -1.0], + // [-1.0, -1.0, -1.0], + // [-1.0, -1.0, -1.0], + // [0.0, 0.0, 0.0], + // [0.0, 0.0, 0.0] + // ] + + int windowSize = 3; + ArrayList> result = MakeModel.generateInitialWeightMatrix(windowSize, new HyperParameters()); + + assertNotNull(result); + assertEquals(8, result.size()); + assertEquals(windowSize, result.get(0).size()); + } +} diff --git a/io.openems.edge.predictor.lstmmodel/test/io/openems/edge/predictor/lstmmodel/util/AdaptiveLearningRateTest.java b/io.openems.edge.predictor.lstmmodel/test/io/openems/edge/predictor/lstmmodel/util/AdaptiveLearningRateTest.java new file mode 100644 index 00000000000..ee11467ea36 --- /dev/null +++ b/io.openems.edge.predictor.lstmmodel/test/io/openems/edge/predictor/lstmmodel/util/AdaptiveLearningRateTest.java @@ -0,0 +1,100 @@ +package io.openems.edge.predictor.lstmmodel.util; + +import static org.junit.Assert.assertEquals; + +import org.junit.Test; + +import io.openems.edge.predictor.lstmmodel.common.HyperParameters; + +public class AdaptiveLearningRateTest { + + @Test + public void scheduleTest() { + AdaptiveLearningRate obj = new AdaptiveLearningRate(); + HyperParameters hyperParameter = new HyperParameters(); + + hyperParameter.setLearningRateUpperLimit(0.2); + hyperParameter.setLearningRateLowerLimit(0.05); + hyperParameter.setEpochTrack(10); + + double perc = (double) hyperParameter.getEpochTrack() / hyperParameter.getEpoch(); + double lr = obj.scheduler(hyperParameter); + double val = hyperParameter.getLearningRateLowerLimit() + + 0.5 * (hyperParameter.getLearningRateUpperLimit() - hyperParameter.getLearningRateLowerLimit()) + * (1 + Math.cos(perc * Math.PI)); + assertEquals(lr, val, 0.0001); + + } + + @Test + + public void adagradOptimizerTest() { + AdaptiveLearningRate obj = new AdaptiveLearningRate(); + double globalLearningRate = 0.001; + double localLearningRate = 0.1; + + // Test case; i = 0, gradient =0 + int i = 0; + double gradient = 0.0; + double lr = obj.adagradOptimizer(globalLearningRate, localLearningRate, gradient, i); + assertEquals(lr, globalLearningRate, 0.0001); + + // Test Case i>0 gradient =! 0 + i = 0; + gradient = 10; + lr = obj.adagradOptimizer(globalLearningRate, localLearningRate, gradient, i); + // double expected = globalLearningRate / gradient; + assertEquals(lr, globalLearningRate, 0.0001); + + // Test Case 3 i > 0 gradient = 0 , local learning rate = 0 + i = 1; + gradient = 0; + localLearningRate = 0; + lr = obj.adagradOptimizer(globalLearningRate, localLearningRate, gradient, i); + assertEquals(lr, globalLearningRate, 0.0001); + + // Test Case 3 i > 0 gradient = 0 , local learning rate =! 0 + + i = 1; + gradient = 0; + localLearningRate = 0.001; + lr = obj.adagradOptimizer(globalLearningRate, localLearningRate, gradient, i); + double temp1 = globalLearningRate / localLearningRate; + double temp2 = Math.pow(temp1, 2); + double temp3 = temp2 + Math.pow(gradient, 2); + double expected = globalLearningRate / Math.pow(temp3, 0.5); + assertEquals(lr, expected, 0.0001); + + } + + @Test + // To test if the learning rate decreases with epoch + + public void adagradTestWithScheduler() { + HyperParameters hyperParameters = new HyperParameters(); + AdaptiveLearningRate obj = new AdaptiveLearningRate(); + double localLearningRate = 0.1; + double globalLearningRate = 0.1; + double gradient = 1; + double previousLocalLearningRate = localLearningRate; + double previousGlobalLearningRate = globalLearningRate; + hyperParameters.setLearningRateUpperLimit(0.1); + hyperParameters.setLearningRateLowerLimit(0.0005); + for (int i = 1; i < hyperParameters.getEpoch(); i++) { + hyperParameters.setEpochTrack(i); + localLearningRate = obj.scheduler(hyperParameters); + // System.out.println("Local Learning Rate: " + localLearningRate); + gradient = gradient + gradient * (hyperParameters.getEpochTrack() / hyperParameters.getEpoch()); + localLearningRate = obj.adagradOptimizer(globalLearningRate, localLearningRate, gradient, i); + globalLearningRate = localLearningRate; + assert (previousLocalLearningRate > localLearningRate); + assert (previousGlobalLearningRate > globalLearningRate); + + previousLocalLearningRate = localLearningRate; + previousGlobalLearningRate = globalLearningRate; + + } + + } + +} diff --git a/io.openems.edge.predictor.lstmmodel/test/io/openems/edge/predictor/lstmmodel/util/CellTest.java b/io.openems.edge.predictor.lstmmodel/test/io/openems/edge/predictor/lstmmodel/util/CellTest.java new file mode 100644 index 00000000000..2be5b9bbb32 --- /dev/null +++ b/io.openems.edge.predictor.lstmmodel/test/io/openems/edge/predictor/lstmmodel/util/CellTest.java @@ -0,0 +1,150 @@ +package io.openems.edge.predictor.lstmmodel.util; + +import static org.junit.Assert.assertEquals; + +//import org.junit.Test; + +import io.openems.edge.predictor.lstmmodel.utilities.MathUtils; + +public class CellTest { + private double wi = 1; + private double wo = 1; + private double wz = 1; + private double ri = 1; + private double ro = 1; + private double rz = 1; + private double yt = 1; + private double input = 1; + private double target = 1; + private double ctMinusOne = 1; + private double ytMinusOne = 1; + + private double dropOutProb = 0.02; // do not change this + private Cell obj = new Cell(this.input, this.target, this.wi, this.wo, this.wz, this.ri, this.ro, this.rz, this.yt); + + /** + * forwardPropagationTest. + */ + // @Test + public void forwardPropagationTest() { + this.obj.setCtMinusOne(this.ctMinusOne); + this.obj.setYtMinusOne(this.ytMinusOne); + this.obj.forwardPropogation(); + + double itExpected = MathUtils.sigmoid(this.wi * this.input + this.ri * this.ytMinusOne); + double otExpected = MathUtils.sigmoid(this.wo * this.input + this.ro * this.ytMinusOne); + double ztExpected = MathUtils.tanh(this.wz * this.input + this.rz * this.ytMinusOne); + + // computation without drop out + + double ctExpectedWithoutDropout = this.ctMinusOne + itExpected * ztExpected; + double ytExpectedWithoutDropout = otExpected * MathUtils.tanh(ctExpectedWithoutDropout); + final double errorExpectedWithoutDropout = ytExpectedWithoutDropout - this.target; + + // computation with drop out + double ctExpectedWithDropout = this.ctMinusOne + itExpected * ztExpected * this.dropOutProb; + double ytExpectedWithDropout = this.ytMinusOne * (1 - this.dropOutProb) + + otExpected * MathUtils.tanh(ctExpectedWithDropout) * (this.dropOutProb); + final double errorExpectedWithDropout = ytExpectedWithDropout - this.target; + + assertEquals(itExpected, this.obj.getIt(), 0.00001); + assertEquals(otExpected, this.obj.getOt(), 0.00001); + assertEquals(ztExpected, this.obj.getZt(), 0.00001); + + boolean matchCheckcT = ctExpectedWithoutDropout == this.obj.getCt() + || ctExpectedWithDropout == this.obj.getCt(); + boolean matchCheckyT = ytExpectedWithoutDropout == this.obj.getYt() + || ytExpectedWithDropout == this.obj.getYt(); + boolean matchCheckerror = errorExpectedWithoutDropout == this.obj.getError() + || errorExpectedWithDropout == this.obj.getError(); + assert (matchCheckcT && matchCheckyT && matchCheckerror); + + } + + /** + * Backpropagation. + */ + // @Test + public void backwardPropogationTest() { + + this.obj.setCtMinusOne(this.ctMinusOne); + this.obj.setYtMinusOne(this.ytMinusOne); + this.obj.setDlByDc(0.0); + + this.obj.forwardPropogation(); + this.obj.backwardPropogation(); + // common calculation + double itExpected = MathUtils.sigmoid(this.wi * this.input + this.ri * this.ytMinusOne); + double otExpected = MathUtils.sigmoid(this.wo * this.input + this.ro * this.ytMinusOne); + double ztExpected = MathUtils.tanh(this.wz * this.input + this.rz * this.ytMinusOne); + + // without drop out + // Forward propagation + double ctExpectedWithoutDropout = this.ctMinusOne + itExpected * ztExpected; + double ytExpectedWithoutDropout = otExpected * MathUtils.tanh(ctExpectedWithoutDropout); + double errorExpectedWithoutDropout = ytExpectedWithoutDropout - this.target; + // backward propagation + double dlByDyWithoutDropout = errorExpectedWithoutDropout; + + double dlByDoWithoutDropout = dlByDyWithoutDropout * MathUtils.tanh(ctExpectedWithoutDropout); + double dlByDcWithoutDropout = 0; + double dlByDcWithDropout = 0; + + dlByDcWithoutDropout = dlByDyWithoutDropout * otExpected * MathUtils.tanhDerivative(ctExpectedWithoutDropout) + + dlByDcWithoutDropout; + + double dlByDiWithoutDropout = dlByDcWithoutDropout * ztExpected; + double dlByDzWithoutDropout = dlByDcWithoutDropout * itExpected; + double delIWithoutDropout = dlByDiWithoutDropout + * MathUtils.sigmoidDerivative(this.wi * this.input + this.ri * this.ytMinusOne); + double delOWithoutDropout = dlByDoWithoutDropout + * MathUtils.sigmoidDerivative(this.wo * this.input + this.ro * this.ytMinusOne); + double delZWithoutDropout = dlByDzWithoutDropout + * MathUtils.tanhDerivative(this.wz * this.input + this.rz * this.ytMinusOne); + + // computation with drop out + // Forwar Propagatin + + double ctExpectedWithDropout = this.ctMinusOne + itExpected * ztExpected * this.dropOutProb; + double ytExpectedWithDropout = this.ytMinusOne * (1 - this.dropOutProb) + + otExpected * MathUtils.tanh(ctExpectedWithDropout) * (this.dropOutProb); + double errorExpectedWithDropout = ytExpectedWithDropout - this.target; + // backward propagation + double dlByDyWithDropout = errorExpectedWithDropout; + + double dlByDoWithDropout = dlByDyWithDropout * MathUtils.tanh(ctExpectedWithDropout); + + dlByDcWithDropout = dlByDyWithDropout * otExpected * MathUtils.tanhDerivative(ctExpectedWithDropout) + + dlByDcWithDropout; + + double dlByDiWithDropout = dlByDcWithDropout * ztExpected; + double dlByDzWithDropout = dlByDcWithDropout * itExpected; + double delIWithDropout = dlByDiWithDropout + * MathUtils.sigmoidDerivative(this.wi * this.input + this.ri * this.ytMinusOne); + double delOWithDropout = dlByDoWithDropout + * MathUtils.sigmoidDerivative(this.wo * this.input + this.ro * this.ytMinusOne); + double delZWithDropout = dlByDzWithDropout + * MathUtils.tanhDerivative(this.wz * this.input + this.rz * this.ytMinusOne); + + boolean dlByDydecission = this.obj.getDlByDy() == dlByDyWithDropout + || this.obj.getDlByDy() == dlByDyWithoutDropout; + boolean dlByDodecission = this.obj.getDlByDo() == dlByDoWithDropout + || this.obj.getDlByDo() == dlByDoWithoutDropout; + + boolean dlByDcdecission = this.obj.getDlByDc() == dlByDcWithDropout + || this.obj.getDlByDc() == dlByDcWithoutDropout; + boolean dlByDidecission = this.obj.getDlByDi() == dlByDiWithDropout + || this.obj.getDlByDi() == dlByDiWithoutDropout; + boolean dlByDzdecission = this.obj.getDlByDz() == dlByDzWithDropout + || this.obj.getDlByDz() == dlByDzWithoutDropout; + + boolean delIdecission = this.obj.getDelI() == delIWithDropout || this.obj.getDelI() == delIWithoutDropout; + boolean delOdecission = this.obj.getDelO() == delOWithDropout || this.obj.getDelO() == delOWithoutDropout; + boolean delZdecission = this.obj.getDelZ() == delZWithDropout || this.obj.getDelZ() == delZWithoutDropout; + + assert (dlByDydecission && dlByDodecission && dlByDcdecission && dlByDidecission && dlByDzdecission + && delIdecission && delOdecission && delZdecission); + + } + +} diff --git a/io.openems.edge.predictor.lstmmodel/test/io/openems/edge/predictor/lstmmodel/util/LstmTest.java b/io.openems.edge.predictor.lstmmodel/test/io/openems/edge/predictor/lstmmodel/util/LstmTest.java new file mode 100644 index 00000000000..0f9d160596d --- /dev/null +++ b/io.openems.edge.predictor.lstmmodel/test/io/openems/edge/predictor/lstmmodel/util/LstmTest.java @@ -0,0 +1,38 @@ +package io.openems.edge.predictor.lstmmodel.util; + +import static org.junit.Assert.assertEquals; + +import java.util.ArrayList; +import java.util.Arrays; + +import org.junit.Test; + +import io.openems.edge.predictor.lstmmodel.utilities.UtilityConversion; + +public class LstmTest { + @Test + public void findGlobalMinima() { + + double[] errorList = { 1, 2, 3, 4, 5, 6, 3, 4, 0, 0.05, -5, 10 }; + int val = Lstm.findGlobalMinima(UtilityConversion.to1DArrayList(errorList)); + int expectedIndex = 8; // + assertEquals(val, expectedIndex); + + ArrayList testData1 = new ArrayList<>(Arrays.asList(2.5, -3.7, 1.8, -4.2, 5.1)); + int expectedIndex1 = 2; // + int actualIndex1 = Lstm.findGlobalMinima(testData1); + assertEquals(expectedIndex1, actualIndex1); + + ArrayList testData2 = new ArrayList<>(Arrays.asList(0.0, 0.0, 0.0, 0.0)); + int expectedIndex2 = 0; // + int actualIndex2 = Lstm.findGlobalMinima(testData2); + assertEquals(expectedIndex2, actualIndex2); + + ArrayList testData3 = new ArrayList<>(Arrays.asList(5.5, -2.2, 3.3, -4.4)); + int expectedIndex3 = 1; + int actualIndex3 = Lstm.findGlobalMinima(testData3); + assertEquals(expectedIndex3, actualIndex3); + + } + +} diff --git a/io.openems.edge.predictor.lstmmodel/test/io/openems/edge/predictor/lstmmodel/utillities/DataUtilityTest.java b/io.openems.edge.predictor.lstmmodel/test/io/openems/edge/predictor/lstmmodel/utillities/DataUtilityTest.java new file mode 100644 index 00000000000..2e60ee54077 --- /dev/null +++ b/io.openems.edge.predictor.lstmmodel/test/io/openems/edge/predictor/lstmmodel/utillities/DataUtilityTest.java @@ -0,0 +1,49 @@ +package io.openems.edge.predictor.lstmmodel.utillities; + +import static org.junit.Assert.assertEquals; + +import java.time.ZoneId; +import java.time.ZonedDateTime; +import java.util.ArrayList; +import java.util.Arrays; + +import org.junit.Test; + +import io.openems.edge.predictor.lstmmodel.common.HyperParameters; +import io.openems.edge.predictor.lstmmodel.utilities.DataUtility; + +public class DataUtilityTest { + + private static HyperParameters hyperParameters = new HyperParameters(); + + @Test + public void test() { + + var dateTime1 = ZonedDateTime.of(2022, 1, 1, 12, 4, 0, 0, ZoneId.systemDefault()); + var res = DataUtility.getMinute(dateTime1, hyperParameters).intValue(); + assertEquals(0, res); + + var dateTime2 = ZonedDateTime.of(2022, 1, 1, 12, 8, 0, 0, ZoneId.systemDefault()); + var res1 = DataUtility.getMinute(dateTime2, hyperParameters).intValue(); + assertEquals(5, res1); + + var dateTime3 = ZonedDateTime.of(2022, 1, 1, 12, 36, 0, 0, ZoneId.systemDefault()); + var res2 = DataUtility.getMinute(dateTime3, hyperParameters).intValue(); + + assertEquals(35, res2); + + var dateTime4 = ZonedDateTime.of(2022, 1, 1, 12, 50, 0, 0, ZoneId.systemDefault()); + var res3 = DataUtility.getMinute(dateTime4, hyperParameters).intValue(); + assertEquals(50, res3); + } + + @Test + public void testCombine() { + var testData = new ArrayList<>(Arrays.asList(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0)); + var testData1 = new ArrayList<>(Arrays.asList(100.0, 200.0, 300.0, 400.0)); + var res = DataUtility.combine(testData1, testData); + var expectedResult = new ArrayList<>(Arrays.asList(100.0, 200.0, 300.0, 400.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0)); + assertEquals(expectedResult, res); + } + +} diff --git a/io.openems.edge.predictor.lstmmodel/test/io/openems/edge/predictor/lstmmodel/utillities/UtilityConversionTest.java b/io.openems.edge.predictor.lstmmodel/test/io/openems/edge/predictor/lstmmodel/utillities/UtilityConversionTest.java new file mode 100644 index 00000000000..82a6177f900 --- /dev/null +++ b/io.openems.edge.predictor.lstmmodel/test/io/openems/edge/predictor/lstmmodel/utillities/UtilityConversionTest.java @@ -0,0 +1,30 @@ +package io.openems.edge.predictor.lstmmodel.utillities; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; + +import org.junit.Test; + +import io.openems.edge.predictor.lstmmodel.utilities.UtilityConversion; + +public class UtilityConversionTest { + + @Test + public void testGetMinIndex() { + double[] arr = { 3.5, 2.0, 5.1, 1.2, 4.8 }; + assertEquals(3, UtilityConversion.getMinIndex(arr)); + } + + @Test + public void testGetMinIndexEmptyArray() { + double[] arr = {}; + assertThrows(IllegalArgumentException.class, () -> UtilityConversion.getMinIndex(arr)); + } + + @Test + public void testGetMinIndexNullArray() { + double[] arr = null; + assertThrows(IllegalArgumentException.class, () -> UtilityConversion.getMinIndex(arr)); + } + +} diff --git a/io.openems.edge.predictor.lstmmodel/test/io/openems/edge/predictor/lstmmodel/validation/FindOptimumIndexTest.java b/io.openems.edge.predictor.lstmmodel/test/io/openems/edge/predictor/lstmmodel/validation/FindOptimumIndexTest.java new file mode 100644 index 00000000000..0b14ff59263 --- /dev/null +++ b/io.openems.edge.predictor.lstmmodel/test/io/openems/edge/predictor/lstmmodel/validation/FindOptimumIndexTest.java @@ -0,0 +1,32 @@ +package io.openems.edge.predictor.lstmmodel.validation; + +import static org.junit.Assert.assertEquals; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import io.openems.edge.predictor.lstmmodel.common.HyperParameters; +import io.openems.edge.predictor.lstmmodel.validator.ValidationSeasonalityModel; + +public class FindOptimumIndexTest { + + /** + * testFindOptimumIndex. + */ + // @Test + public void testFindOptimumIndex() { + ArrayList> matrix = new ArrayList<>(// + Arrays.asList(// + new ArrayList<>(Arrays.asList(1.0, 2.0, 7.0)), // + new ArrayList<>(Arrays.asList(4.0, 5.0, 8.0)), // + new ArrayList<>(Arrays.asList(7.0, 8.0, 6.0))// + )// + ); + + List> result = ValidationSeasonalityModel.findOptimumIndex(matrix, "Test", new HyperParameters()); + + assertEquals(Arrays.asList(Arrays.asList(2, 0), Arrays.asList(2, 1), Arrays.asList(1, 2)), result); + } + +}