Skip to content

Commit

Permalink
second_attempt
Browse files Browse the repository at this point in the history
  • Loading branch information
abhiyagupta committed Sep 19, 2024
1 parent d57aa25 commit 7bdf55b
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 14 deletions.
17 changes: 16 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,17 @@
data
*.zip
*.zip
mnist/MNIST/raw/t10k-images-idx3-ubyte
.gitignore
mnist/MNIST/raw/t10k-images-idx3-ubyte.gz
.gitignore
mnist/MNIST/raw/t10k-labels-idx1-ubyte
.gitignore
mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz
.gitignore
mnist/MNIST/raw/train-images-idx3-ubyte
.gitignore
mnist/MNIST/raw/train-images-idx3-ubyte.gz
.gitignore
mnist/MNIST/raw/train-labels-idx1-ubyte
.gitignore
mnist/MNIST/raw/train-labels-idx1-ubyte.gz
23 changes: 14 additions & 9 deletions eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,20 @@
from model import Net


def test_epoch(model, data_loader):
# write code to test this epoch
def test_epoch(model, device, data_loader):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in data_loader:
output = model(data.to(device))
test_loss += F.nll_loss(output, target.to(device), reduction='sum').item() # sum up batch loss
pred = output.max(1)[1] # get the index of the max log-probability
correct += pred.eq(target.to(device)).sum().item()
test_loss += F.nll_loss(output, target.to(device), reduction='sum').item()
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.to(device).view_as(pred)).sum().item()

test_loss /= len(data_loader.dataset)
accuracy = 100.0 * correct / len(data_loader.dataset)
out = {"Test loss": test_loss, "Accuracy": accuracy}
test_loss /= len(data_loader.dataset)
accuracy = 100. * correct / len(data_loader.dataset)
out = {'Test loss': test_loss, 'Accuracy': accuracy}
print(out)
return out

Expand All @@ -45,7 +44,13 @@ def main():
)

args, unknown = parser.parse_known_args()
use_cuda = torch.cuda.is_available()
torch.manual_seed(args.seed)

device = torch.device("cuda" if use_cuda else "cpu")
kwargs = {'batch_size': args.test_batch_size}
if use_cuda:
kwargs.update({'num_workers': 1, 'pin_memory': True, 'shuffle': True},)

kwargs = {
"batch_size": args.test_batch_size,
Expand All @@ -67,7 +72,7 @@ def main():
model.load_state_dict(torch.load(checkpoint_path))

# Run the test epoch and collect evaluation results
eval_results = test_epoch(model, test_loader)
eval_results = test_epoch(model, device, test_loader)

# # Save evaluation results to a JSON file
# results_path = Path(args.save_dir) / "eval_results.json"
Expand Down
22 changes: 19 additions & 3 deletions infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,30 @@ def infer(model, dataset, save_dir, num_samples=5):
results_dir.mkdir(parents=True, exist_ok=True)

indices = random.sample(range(len(dataset)), num_samples)
for idx in indices:
image, _ = dataset[idx]

for idx, i in enumerate(indices):
image, _ = dataset[i]
with torch.no_grad():
output = model(image.unsqueeze(0))
pred = output.argmax(dim=1, keepdim=True).item()

img = Image.fromarray(image.squeeze().numpy() * 255).convert("L")
img.save(results_dir / f"{pred}.png")
# Ensure unique filenames
filename = f"{pred}_{idx}.png"
img.save(results_dir / filename)

print(f"Saved {num_samples} inference result images.")

# for idx in indices:
# image, _ = dataset[idx]

# with torch.no_grad():
# output = model(image.unsqueeze(0))
# pred = output.argmax(dim=1, keepdim=True).item()

# img = Image.fromarray(image.squeeze().numpy() * 255).convert("L")
# img.save(results_dir / f"{pred}_{i}.png")
# i +=1


def main():
Expand Down
6 changes: 5 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from torch.utils.data import DataLoader
import torch.multiprocessing as mp
from model import Net
import time

def train_epoch(epoch, args, model, train_loader, optimizer):
model.train()
Expand Down Expand Up @@ -48,7 +49,7 @@ def main():
parser.add_argument("--dry-run", action="store_true", default=False, help="quickly check a single pass")
parser.add_argument("--save_dir", default="/opt/mount", help="checkpoint will be saved in this directory")
parser.add_argument('--epochs', type=int, default=1, metavar='N', help='number of epochs to train (default: 1)')
parser.add_argument('--save_model', action='store_true', default=False, help='save the trained model to state_dict')
parser.add_argument('--save_model', action='store_true', default=True, help='save the trained model to state_dict')

args = parser.parse_args()
torch.manual_seed(args.seed)
Expand Down Expand Up @@ -78,5 +79,8 @@ def main():
os.makedirs(os.path.join(args.save_dir, "model"), exist_ok=True)
torch.save(model.state_dict(), os.path.join(args.save_dir, "model", "mnist_cnn.pt"))

# time.sleep(10000)
#/opt/mount/model/mnist.pt

if __name__ == "__main__":
main()

0 comments on commit 7bdf55b

Please sign in to comment.