Skip to content

Commit

Permalink
Merge pull request #16 from PolicyEngine/reweight-function
Browse files Browse the repository at this point in the history
Added a reweight function and testing
  • Loading branch information
nikhilwoodruff authored Jul 23, 2024
2 parents 92bc9de + 7543011 commit ba9870c
Show file tree
Hide file tree
Showing 10 changed files with 156 additions and 59 deletions.
21 changes: 0 additions & 21 deletions .github/workflows/pr.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,27 +11,6 @@ jobs:
uses: "lgeiger/black-action@master"
with:
args: ". -l 79 --check"
check-version:
name: Check version
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
with:
fetch-depth: 0
repository: ${{ github.event.pull_request.head.repo.full_name }}
ref: ${{ github.event.pull_request.head.ref }}
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: 3.9
- name: Build changelog
run: pip install yaml-changelog>=0.1.7 && make changelog
- name: Make scripts executable
run: chmod -R +x .github/
- name: Preview changelog update
run: ".github/get-changelog-diff.sh"
- name: Check version number has been properly updated
run: .github/is-version-number-acceptable.sh
Test:
runs-on: ${{ matrix.os }}
continue-on-error: true
Expand Down
31 changes: 0 additions & 31 deletions .github/workflows/push.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,37 +14,6 @@ jobs:
uses: "lgeiger/black-action@master"
with:
args: ". -l 79 --check"
versioning:
name: Update versioning
if: |
(github.repository == 'PolicyEngine/reweight')
&& !(github.event.head_commit.message == 'Update reweight')
runs-on: ubuntu-latest
steps:
- name: Checkout repo
uses: actions/checkout@v3
with:
repository: ${{ github.event.pull_request.head.repo.full_name }}
ref: ${{ github.event.pull_request.head.ref }}
token: ${{ secrets.POLICYENGINE_GITHUB }}
- name: Setup Python
uses: actions/setup-python@v4
with:
python-version: 3.9
- name: Build changelog
run: pip install yaml-changelog && make changelog
- name: Make scripts executable
run: chmod -R +x .github/
- name: Preview changelog update
run: ".github/get-changelog-diff.sh"
- name: Update changelog
uses: EndBug/add-and-commit@v9
with:
add: "."
committer_name: Github Actions[bot]
author_name: Github Actions[bot]
message: Update reweight
github_token: ${{ secrets.POLICYENGINE_GITHUB }}
Test:
runs-on: ${{ matrix.os }}
if: |
Expand Down
4 changes: 2 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,6 @@ reweight/tests/__pycache__
#############################
docs/_build

# Pycache folder #
# Pycache folders #
##################
reweight/__pycache__
**/__pycache__/
1 change: 1 addition & 0 deletions changelog_entry.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

3 changes: 2 additions & 1 deletion reweight/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
__version__ = "0.1.0"
__version__ = "0.3.0"
from .logic.reweight import reweight
1 change: 1 addition & 0 deletions reweight/logic/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .reweight import reweight
72 changes: 72 additions & 0 deletions reweight/logic/reweight.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import pandas as pd
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter


def reweight(
initial_weights,
estimate_matrix,
target_names,
target_values,
epochs=1000,
epoch_step=100,
):
"""
Main reweighting function, suitable for PolicyEngine UK use (PolicyEngine US use and testing TK)
To avoid the need for equivalisation factors, use relative error:
|predicted - actual|/actual
Parameters:
household_weights (torch.Tensor): The initial weights given to survey data, which are to be
adjusted by this function.
estimate_matrix (torch.Tensor): A large matrix of estimates, obtained from e.g. a PolicyEngine
Microsimulation instance.
target_names (iterable): The names of a set of target statistics treated as ground truth.
target_values (torch.Tensor): The values of these target statistics.
epochs: The number of iterations that the optimization loop should run for.
epoch_step: The interval at which to print the loss during the optimization loop.
Returns:
final_weights: a reweighted set of household weights, obtained through an optimization process
over mean squared errors with respect to the target values.
"""
# Initialize a TensorBoard writer
writer = SummaryWriter()

# Create a Torch tensor of log weights
log_weights = torch.log(initial_weights)
log_weights.requires_grad_()

# estimate_matrix (cross) exp(log_weights) = target_values

optimizer = torch.optim.Adam([log_weights])

# Training loop
for epoch in range(epochs):

# Estimate the targets
targets_estimate = torch.exp(log_weights) @ estimate_matrix
# Calculate the loss
loss = torch.mean(
((targets_estimate - target_values) / target_values) ** 2
)

writer.add_scalar("Loss/train", loss, epoch)

optimizer.zero_grad()

# Perform backpropagation
loss.backward()

# Update weights
optimizer.step()

# Print loss whenever the epoch number, when one-indexed, is divisible by epoch_step
if (epoch + 1) % epoch_step == 0:
print(f"Epoch {epoch+1}, Loss: {loss.item()}")

writer.flush()

return torch.exp(log_weights.detach())
28 changes: 28 additions & 0 deletions reweight/tests/test_uk_prototype.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,31 @@ def test_uk_microsimulation():

# Create a Microsimulation instance
sim = Microsimulation()


def test_uk_prototype():
from policyengine_uk import Microsimulation
from reweight import reweight
import torch

sim = Microsimulation()

from policyengine_uk.data import RawFRS_2021_22

RawFRS_2021_22().download()

from policyengine_uk.data.datasets.frs.calibration.calibrate import (
generate_model_variables,
)

(
household_weights,
weight_adjustment,
values_df,
targets,
targets_array,
equivalisation_factors_array,
) = generate_model_variables("frs_2021", 2025)

sim_matrix = torch.tensor(values_df.to_numpy(), dtype=torch.float32)
reweight(household_weights, sim_matrix, targets, targets_array)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

setup(
name="reweight",
version="0.2.0",
version="0.3.0",
author="PolicyEngine",
author_email="[email protected]",
long_description=readme,
Expand Down
52 changes: 49 additions & 3 deletions test.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,54 @@
"# square error, and then average to get MSE."
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"from reweight.logic import reweight"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 100, Loss: 48.30685043334961\n",
"Epoch 200, Loss: 40.58155059814453\n",
"Epoch 300, Loss: 34.585235595703125\n",
"Epoch 400, Loss: 29.832853317260742\n",
"Epoch 500, Loss: 25.99891471862793\n",
"Epoch 600, Loss: 22.858182907104492\n",
"Epoch 700, Loss: 20.250896453857422\n",
"Epoch 800, Loss: 18.061073303222656\n",
"Epoch 900, Loss: 16.202829360961914\n",
"Epoch 1000, Loss: 14.611446380615234\n"
]
},
{
"data": {
"text/plain": [
"tensor([1120.1953, 89.4442, 3851.2649, ..., 730.9640, 832.0632,\n",
" 4155.2686])"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sim_matrix = torch.tensor(values_df.to_numpy(), dtype=torch.float32)\n",
"\n",
"reweight.reweight(household_weights, sim_matrix, targets, targets_array)"
]
},
{
"cell_type": "code",
"execution_count": 29,
Expand All @@ -86,8 +134,6 @@
" # Initialize a TensorBoard writer\n",
" writer = SummaryWriter()\n",
"\n",
" #TODO: Write stuff here\n",
"\n",
" #Create a Torch tensor of log weights\n",
" log_weights = torch.log(household_weights)\n",
" log_weights.requires_grad_()\n",
Expand Down Expand Up @@ -119,7 +165,7 @@
" # Update weights\n",
" optimizer.step()\n",
"\n",
" # Print loss for every 1000 epochs\n",
" # Print loss for every 100 epochs\n",
" if epoch % 100 == 0:\n",
" print(f\"Epoch {epoch}, Loss: {loss.item()}\")\n",
"\n",
Expand Down

0 comments on commit ba9870c

Please sign in to comment.