Skip to content

Commit

Permalink
Change int the batch size
Browse files Browse the repository at this point in the history
  • Loading branch information
pooran-c committed Oct 7, 2024
1 parent 3eade01 commit 1883875
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@
import java.time.temporal.ChronoUnit;
import java.util.ArrayList;
import java.util.SortedMap;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
Expand Down Expand Up @@ -129,34 +127,17 @@ protected Prediction createNewPrediction(ChannelAddress channelAddress) {
var hyperParameters = ReadAndSaveModels.read(channelAddress.getChannelId());
var nowDate = ZonedDateTime.now();

Check warning on line 128 in io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/LstmModelImpl.java

View check run for this annotation

Codecov / codecov/patch

io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/LstmModelImpl.java#L127-L128

Added lines #L127 - L128 were not covered by tests

var seasonalityPredictionFuture = CompletableFuture.supplyAsync(() -> {
return this.predictSeasonality(channelAddress, nowDate, hyperParameters);
});

var trendPredictionFuture = CompletableFuture.supplyAsync(() -> {
return this.predictTrend(channelAddress, nowDate, hyperParameters);
});

/*
* Combine predictions only after both seasonalityPredictionFuture and
* trendPredictionFuture are complete
*/
var predicted = seasonalityPredictionFuture.thenCombine(trendPredictionFuture, (seasonality, trend) -> {
return combine(seasonality, trend);
});
var seasonalityPrediction = this.predictSeasonality(channelAddress, nowDate, hyperParameters);
var trendPrediction = this.predictTrend(channelAddress, nowDate, hyperParameters);

Check warning on line 131 in io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/LstmModelImpl.java

View check run for this annotation

Codecov / codecov/patch

io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/LstmModelImpl.java#L130-L131

Added lines #L130 - L131 were not covered by tests

var predicted = combine(trendPrediction, seasonalityPrediction);
var till = nowDate.withMinute(getMinute(nowDate, hyperParameters)).withSecond(0).withNano(0);

Check warning on line 134 in io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/LstmModelImpl.java

View check run for this annotation

Codecov / codecov/patch

io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/LstmModelImpl.java#L133-L134

Added lines #L133 - L134 were not covered by tests

try {
return Prediction.from(//
Prediction.getValueRange(this.sum, channelAddress), //
Interval.DUODCIMUS, //
till, //
DOUBLELIST_TO_INTARRAY.apply(predicted.get()));
} catch (InterruptedException | ExecutionException e) {
e.printStackTrace();
return null;
}
return Prediction.from(//
Prediction.getValueRange(this.sum, channelAddress), //

Check warning on line 137 in io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/LstmModelImpl.java

View check run for this annotation

Codecov / codecov/patch

io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/LstmModelImpl.java#L136-L137

Added lines #L136 - L137 were not covered by tests
Interval.DUODCIMUS, //
till, //
DOUBLELIST_TO_INTARRAY.apply(predicted));

Check warning on line 140 in io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/LstmModelImpl.java

View check run for this annotation

Codecov / codecov/patch

io.openems.edge.predictor.lstmmodel/src/io/openems/edge/predictor/lstmmodel/LstmModelImpl.java#L140

Added line #L140 was not covered by tests
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ public OffsetDateTime getLastTrainedDate() {
* of range errors during training.</li>
* </ul>
*/
private int batchSize = 10;
private int batchSize = 1;

/**
* Counter for tracking batches.
Expand Down

0 comments on commit 1883875

Please sign in to comment.