Skip to content

Commit

Permalink
All out on pylint; CI; requirements.txt (#56)
Browse files Browse the repository at this point in the history
* pylint passes at 100%
* Factoring out loss functions
* Adding GH Actions CI
* Adding requirements.txt and removing environment.yml
* Renaming split_label to partition_label
  • Loading branch information
matsen authored Jun 4, 2020
1 parent 3615d6e commit 64df674
Show file tree
Hide file tree
Showing 11 changed files with 208 additions and 229 deletions.
44 changes: 44 additions & 0 deletions .github/workflows/build-and-test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
name: build and test

on:
push:
branches: [ master ]
pull_request:
branches: [ master ]

jobs:
build:
if: "!contains(github.event.commits[0].message, '[skip ci]')"
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.7]

steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements.txt
- name: Install torchdms
run: |
pip install .
- name: Test
run: |
make test
- name: Lint with flake8
run: |
# stop the build if there are Python syntax errors or undefined names
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
- name: Lint with pylint
run: |
pylint **/*.py
- name: Check format with black
run: |
black --check torchdms
15 changes: 6 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,25 +1,22 @@
# torchdms

![build and test](https://github.com/matsengrp/torchdms/workflows/build%20and%20test/badge.svg)
[![Docker Repository on Quay](https://quay.io/repository/matsengrp/torchdms/status "Docker Repository on Quay")](https://quay.io/repository/matsengrp/torchdms)
[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)


## What is this?

Pytorch - Deep Mutational Scanning (`torchdms`) is a small Python package made to train neural networks on amino-acid substitution data, predicting some chosen functional score(s).
PyTorch - Deep Mutational Scanning (`torchdms`) is a Python package made to train neural networks on amino-acid substitution data, predicting some chosen functional score(s).
We use the binary encoding of variants using [BinaryMap Object](https://jbloomlab.github.io/dms_variants/dms_variants.binarymap.html) as input to feed-forward networks.


## How do I install it?

To install the API and command-line scripts at the moment, it suggested you clone the repository, create a conda environment from `environment.yaml`, and run the tests to make sure everything is working properly.

git clone [email protected]:matsengrp/torchdms.git
conda env create -f environment.yaml
conda activate dms
pytest

Install with `pip install -e .`
cd torchdms
pip install -r requirements.txt
pip install .
make test


## CLI
Expand Down
22 changes: 0 additions & 22 deletions environment.yml

This file was deleted.

11 changes: 11 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
black
click
click-config-file
dms_variants
docformatter
flake8
matplotlib
pylint
pytest
scipy
torch==1.4.0
70 changes: 38 additions & 32 deletions torchdms/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,37 @@ def __init__(
]
self.val_loss_record = sys.float_info.max

def loss_of_targets_and_prediction(
self, loss_fn, targets, predictions, per_target_loss_decay
):
"""Return loss on the valid predictions, i.e. the ones that are not
NaN."""
valid_target_indices = torch.isfinite(targets)
valid_targets = targets[valid_target_indices].to(self.device)
valid_predict = predictions[valid_target_indices].to(self.device)
return loss_fn(valid_targets, valid_predict, per_target_loss_decay)

def complete_loss(self, loss_fn, targets, predictions, loss_decays):
"""Compute our total (across targets) loss with regularization.
Here we compute loss separately for each target, before summing
the results. This allows for us to take advantage of the samples
which may contain missing information for a subset of the
targets.
"""
per_target_loss = [
self.loss_of_targets_and_prediction(
loss_fn,
targets[:, target_idx],
predictions[:, target_idx],
per_target_loss_decay,
)
for target_idx, per_target_loss_decay in zip(
range(targets.shape[1]), loss_decays
)
]
return sum(per_target_loss) + self.model.regularization_loss()

def train(
self, epoch_count, loss_fn, patience=10, min_lr=1e-5, loss_weight_span=None
):
Expand Down Expand Up @@ -92,34 +123,6 @@ def loss_decays_of_target_extrema(extremum_pairs_across_targets):
scheduler = ReduceLROnPlateau(optimizer, patience=patience, verbose=True)
self.model.to(self.device)

def loss_of_targets_and_prediction(targets, predictions, per_target_loss_decay):
"""Return loss on the valid predictions, i.e. the ones that are not
NaN."""
valid_target_indices = torch.isfinite(targets)
valid_targets = targets[valid_target_indices].to(self.device)
valid_predict = predictions[valid_target_indices].to(self.device)
return loss_fn(valid_targets, valid_predict, per_target_loss_decay)

def complete_loss(targets, predictions, loss_decays):
"""Compute our total (across targets) loss with regularization.
Here we compute loss separately for each target, before
summing the results. This allows for us to take advantage of
the samples which may contain missing information for a
subset of the targets.
"""
per_target_loss = [
loss_of_targets_and_prediction(
targets[:, target_idx],
predictions[:, target_idx],
per_target_loss_decay,
)
for target_idx, per_target_loss_decay in zip(
range(target_count), loss_decays
)
]
return sum(per_target_loss) + self.model.regularization_loss()

def step_model():
per_epoch_loss = 0.0
for _ in range(batch_count):
Expand All @@ -133,8 +136,8 @@ def step_model():
samples = batch["samples"].to(self.device)
predictions = self.model(samples)

loss = complete_loss(
batch["targets"], predictions, per_stratum_loss_decays
loss = self.complete_loss(
loss_fn, batch["targets"], predictions, per_stratum_loss_decays
)
per_batch_loss += loss.item()

Expand All @@ -153,8 +156,11 @@ def step_model():

val_samples = self.val_data.samples.to(self.device)
val_predictions = self.model(val_samples)
val_loss = complete_loss(
self.val_data.targets.to(self.device), val_predictions, val_loss_decay
val_loss = self.complete_loss(
loss_fn,
self.val_data.targets.to(self.device),
val_predictions,
val_loss_decay,
).item()
if val_loss < self.val_loss_record:
print(f"\nvalidation loss record: {val_loss}")
Expand Down
70 changes: 27 additions & 43 deletions torchdms/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def cli(ctx, dry_run):
"appended in_test column.",
)
@click.option(
"--split-by",
"--partition-by",
type=str,
required=False,
default=None,
Expand All @@ -125,7 +125,7 @@ def prep(
per_stratum_variants_for_test,
skip_stratum_if_count_is_smaller_than,
export_dataframe,
split_by,
partition_by,
):
"""Prepare data for training.
Expand All @@ -139,57 +139,41 @@ def prep(
click.echo(f"LOG: Targets: {targets}")
click.echo(f"LOG: Loading substitution data for: {in_path}")
aa_func_scores, wtseq = from_pickle_file(in_path)
click.echo(f"LOG: Successfully loaded data")
click.echo("LOG: Successfully loaded data")

total_variants = len(aa_func_scores.iloc[:, 1])
click.echo(f"LOG: There are {total_variants} total variants in this dataset")

if split_by is None and "library" in aa_func_scores.columns:
if partition_by is None and "library" in aa_func_scores.columns:
click.echo(
f"WARNING: you have a 'library' column but haven't specified a split via '--split-by'"
"WARNING: you have a 'library' column but haven't specified a partition "
"via '--partition-by'"
)

if split_by in aa_func_scores.columns:
for split_label, per_split_label_df in aa_func_scores.groupby(split_by):
click.echo(f"LOG: Partitioning data via '{split_label}'")
test_partition, val_partition, partitioned_train_data = partition(
per_split_label_df.copy(),
per_stratum_variants_for_test,
skip_stratum_if_count_is_smaller_than,
export_dataframe,
split_label,
)

prep_by_stratum_and_export(
test_partition,
val_partition,
partitioned_train_data,
wtseq,
targets,
out_prefix,
str(ctx.params),
split_label,
)

else:
test_partition, val_partition, partitioned_train_data = partition(
aa_func_scores,
def prep_by_stratum_and_export_of_partition_label_and_df(partition_label, df):
split_df = partition(
df,
per_stratum_variants_for_test,
skip_stratum_if_count_is_smaller_than,
export_dataframe,
partition_label,
)

prep_by_stratum_and_export(
test_partition,
val_partition,
partitioned_train_data,
wtseq,
targets,
out_prefix,
str(ctx.params),
None,
split_df, wtseq, targets, out_prefix, str(ctx.params), partition_label,
)

if partition_by in aa_func_scores.columns:
for partition_label, per_partition_label_df in aa_func_scores.groupby(
partition_by
):
click.echo(f"LOG: Partitioning data via '{partition_label}'")
prep_by_stratum_and_export_of_partition_label_and_df(
partition_label, per_partition_label_df.copy()
)
else:
prep_by_stratum_and_export_of_partition_label_and_df(None, aa_func_scores)

click.echo(
"LOG: Successfully finished prep and dumped BinaryMapDataset "
f"object to {out_prefix}"
Expand Down Expand Up @@ -357,7 +341,7 @@ def evaluate(ctx, model_path, data_path, out, device):
click.echo(f"LOG: loading testing data from {data_path}")
data = from_pickle_file(data_path)

click.echo(f"LOG: evaluating test data with given model")
click.echo("LOG: evaluating test data with given model")
evaluation = build_evaluation_dict(model, data.test, device)

click.echo(f"LOG: pickle dump evalution data dictionary to {out}")
Expand Down Expand Up @@ -417,10 +401,10 @@ def scatter(ctx, model_path, data_path, out, device):
click.echo(f"LOG: loading testing data from {data_path}")
data = from_pickle_file(data_path)

click.echo(f"LOG: evaluating test data with given model")
click.echo("LOG: evaluating test data with given model")
evaluation = build_evaluation_dict(model, data.test, device)

click.echo(f"LOG: plotting scatter correlation")
click.echo("LOG: plotting scatter correlation")
plot_test_correlation(evaluation, model, out)

click.echo(f"LOG: scatter plot finished and dumped to {out}")
Expand Down Expand Up @@ -448,7 +432,7 @@ def contour(ctx, model_path, start, end, nticks, out):
if not isinstance(model, VanillaGGE):
raise TypeError("Model must be a VanillaGGE")

click.echo(f"LOG: plotting contour")
click.echo("LOG: plotting contour")
latent_space_contour_plot_2d(model, out, start, end, nticks)

click.echo(f"LOG: Contour finished and dumped to {out}")
Expand All @@ -474,7 +458,7 @@ def beta(ctx, model_path, data_path, out):
f"LOG: loaded data, evaluating beta coeff for wildtype seq: {data.test.wtseq}"
)

click.echo(f"LOG: plotting beta coefficients")
click.echo("LOG: plotting beta coefficients")
beta_coefficients(model, data.test, out)

click.echo(f"LOG: Beta coefficients plotted and dumped to {out}")
Expand Down
Loading

0 comments on commit 64df674

Please sign in to comment.