Skip to content

Commit

Permalink
Save progress
Browse files Browse the repository at this point in the history
  • Loading branch information
luciaquirke committed Dec 9, 2024
1 parent 999ba7c commit ead5f68
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 46 deletions.
56 changes: 31 additions & 25 deletions experiments/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from torchvision.transforms.v2.functional import to_tensor
from concept_erasure.quadratic import QuadraticFitter
from concept_erasure.leace import LeaceFitter
# from concept_erasure.alf_qleace import AlfQLeaceFitter
from concept_erasure.alf_qleace import AlfQLeaceFitter
from torch import Tensor
from torchvision.datasets import CIFAR10
from tqdm.auto import tqdm
Expand Down Expand Up @@ -101,28 +101,6 @@ def get_cifar10(normalize: bool = False):
X_test: Tensor = torch.stack(list(map(to_tensor, test_images))).to(device)
Y_test = torch.tensor(test_labels).to(device)

if normalize:
X_flat = X.reshape(X.shape[0], -1)

mean = X_flat.mean(dim=0, keepdim=True)
X_centered = X_flat - mean

cov = (X_centered.T @ X_centered) / (X_centered.shape[0] - 1)

scaling = torch.sqrt(torch.diagonal(cov))
scaling = torch.where(scaling > 0, scaling, torch.ones_like(scaling))

def normalize_data(data: Tensor) -> Tensor:
data_flat = data.reshape(data.shape[0], -1)
data_centered = data_flat - mean
data_normalized = data_centered / scaling
return data_normalized.reshape(data.shape)

X = normalize_data(X)
X_train = normalize_data(X_train)
X_val = normalize_data(X_val)
X_test = normalize_data(X_test)

return X_train, Y_train, X_val, Y_val, X_test, Y_test, k, X, Y


Expand Down Expand Up @@ -161,11 +139,35 @@ def normalize_data(data: Tensor) -> Tensor:
state_path.parent.mkdir(exist_ok=True)
state = {} if not state_path.exists() else torch.load(state_path)

def normalize(X, X_train, X_val, X_test):
X_flat = X.reshape(X.shape[0], -1)

mean = X_flat.mean(dim=0, keepdim=True)
X_centered = X_flat - mean

cov = (X_centered.T @ X_centered) / (X_centered.shape[0] - 1)

scaling = torch.sqrt(torch.diagonal(cov))
scaling = torch.where(scaling > 0, scaling, torch.ones_like(scaling))

def normalize_data(data: Tensor) -> Tensor:
data_flat = data.reshape(data.shape[0], -1)
data_centered = data_flat - mean
data_normalized = data_centered / scaling
return data_normalized.reshape(data.shape)

X = normalize_data(X)
X_train = normalize_data(X_train)
X_val = normalize_data(X_val)
X_test = normalize_data(X_test)

return X, X_train, X_val, X_test

if args.eraser != "control" and (args.eraser not in state or args.nocache):
cls = {
"leace": LeaceFitter,
"qleace": QuadraticFitter,
# "qleace2": AlfQLeaceFitter,
"qleace2": AlfQLeaceFitter,
}[args.eraser]

fitter = cls(
Expand Down Expand Up @@ -251,6 +253,10 @@ def erase(x: Tensor, y: Tensor, eraser):
else none_transform
)

# TODO Lucia normalize eraserd data - currently only supports control run
if args.normalize:
X, X_train, X_val, X_test = normalize(X, X_train, X_val, X_test)

base_model = model_cls(
num_classes=k,
num_features=num_features,
Expand Down Expand Up @@ -283,7 +289,7 @@ def erase(x: Tensor, y: Tensor, eraser):

results = []
for seed in range(args.num_seeds):
wandb_name = f'{args.eraser} {args.name} w={args.width} d={args.depth} s={seed} {args.net} act={args.act} lr={args.lr} b1={args.b1} n={args.normalize}'
wandb_name = f'{args.eraser} {args.name} w={args.width} d={args.depth} s={seed} {args.net} act={args.act} lr={args.lr:.3f} b1={args.b1} n={args.normalize} es={args.early_stop_epochs}'

run = (
wandb.init(
Expand Down
105 changes: 85 additions & 20 deletions experiments/polyapprox_mlp.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,105 @@
from pathlib import Path

from plotly.subplots import make_subplots
import plotly.graph_objects as go
import pandas as pd
import torch
from torchvision.datasets import CIFAR10
from torchvision.transforms.v2.functional import to_tensor
from polyapprox.ols import ols
from mdl.mlp_probe import MlpProbe
import lovely_tensors as lt

lt.monkey_patch()

def get_cifar10_mean():
nontest = CIFAR10("/home/lucia/cifar10", download=True)
def plot(ols_results, filename='polyapprox_mlp_fvu'):
# Plot FVU over checkpoints - the final number in each name is the checkpoints
fvu = []
checkpoint = []
eraser = []
for key, value in ols_results.items():
if value.fvu < -0.01:
print(f"{key} has FVU {value.fvu}. Skipping.")
continue

images, labels = zip(*nontest)
X = torch.stack(list(map(to_tensor, images)))
X = X.view(X.shape[0], -1)
fvu.append(value.fvu)
chunks = key[:-4].split("-")
checkpoint.append(int(chunks[-1]))
eraser.append(chunks[0].split(" ")[0])

return X.mean(dim=0)
df = pd.DataFrame({"fvu": fvu, "checkpoint": checkpoint, "eraser": eraser})
df = df.sort_values(by="checkpoint")

fig = make_subplots(rows=len(df.eraser.unique()), cols=1)

# Load each MLP checkpoint ols
for row, eraser in enumerate(df.eraser.unique(), start=1):
df_eraser = df[df.eraser == eraser]
fig.add_trace(go.Scatter(x=df_eraser.checkpoint, y=df_eraser.fvu, mode="lines", name=eraser), row=row, col=1)

fig.update_layout(title="FVU over checkpoints")
fig.write_image(f"{filename}.pdf", format="pdf")

# Load each MLP checkpoint ols
out_path = Path("polyapprox_mlp.pth")
ckpts = list(Path("probe-ckpts").glob("*.pth"))
ols_results = []
ols_results = {} if not out_path.exists() else torch.load(out_path)
plot(ols_results)

for ckpt in ckpts:
if 'normalize' not in ckpt.name:
if "normalize" not in ckpt.name:
continue

if ckpt.name in ols_results:
print(f"Skipping {ckpt.name} because it already exists")
continue

probe = MlpProbe(num_features=32*32*3, num_classes=10, hidden_size=128, num_layers=1)
print(f"Processing {ckpt.name}")

probe = MlpProbe(
num_features=32 * 32 * 3, num_classes=10, hidden_size=128, num_layers=1
)
probe.load_state_dict(torch.load(ckpt))
probe.eval()

ols_results.append(ols(
probe.net[0].weight.data.double().numpy(), probe.net[0].bias.data.double().numpy(),
probe.net[2].weight.data.double().numpy(), probe.net[2].bias.data.double().numpy(),
act="relu", order="quadratic",
return_fvu=True
))
ols_results[ckpt.name] = ols(
probe.net[0].weight.data.double().numpy(),
probe.net[0].bias.data.double().numpy(),
probe.net[2].weight.data.double().numpy(),
probe.net[2].bias.data.double().numpy(),
act="relu",
order="quadratic",
return_fvu=True,
)

torch.save(ols_results, out_path)
plot(ols_results)

# def polyapprox_linear(ckpts):
# linear_results = {}
# for ckpt in ckpts:
# if "normalize" not in ckpt.name:
# continue

# if ckpt.name in ols_results:
# print(f"Skipping {ckpt.name} because it already exists")
# continue

# print(f"Processing {ckpt.name}")

torch.save(ols_results, "polyapprox_mlp.pth")

# probe = MlpProbe(
# num_features=32 * 32 * 3, num_classes=10, hidden_size=128, num_layers=1
# )
# probe.load_state_dict(torch.load(ckpt))
# probe.eval()

# ols_results[ckpt.name] = ols(
# probe.net[0].weight.data.double().numpy(),
# probe.net[0].bias.data.double().numpy(),
# probe.net[2].weight.data.double().numpy(),
# probe.net[2].bias.data.double().numpy(),
# act="relu",
# order="quadratic",
# return_fvu=True,
# )
# print(f"FVU: {ols_results[ckpt.name].fvu}")
# # exit()
# torch.save(linear_results, out_path)
# plot(linear_results, filename="polyapprox_mlp_linear")
4 changes: 3 additions & 1 deletion experiments/sweep_eraser.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,16 +105,18 @@ def parse_args():
'control': 5e-4,
'leace': 5e-4,
'qleace': 5e-4,
'qleace2': 5e-4, # guessing
},
'b1': {
'control': 0.99,
'leace': 0.95,
'qleace': 0.95,
'qleace2': 0.95, # guessing
},
'mup_width': 128,
'mup_depth': 2,
'widths': [64, 128, 256, 512, 1024, 2048],
'depths': [1, 2, 3, 4, 6, 8] # Loses coherence at 16, 1 breaks probe
'depths': [1, 2, 3, 4, 6, 8] # # Loses coherence at 16, 1 breaks probe
},
# 'linear': {
# # Unused
Expand Down

0 comments on commit ead5f68

Please sign in to comment.