Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor load_ and save_ methods #33

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

julia_version = "1.10.5"
manifest_format = "2.0"
project_hash = "1f67e044e69ce7594dfe4916317f0142d499c0dc"
project_hash = "4fa1b4dfd1f1d683237a41089874685acce20b29"

[[deps.ADTypes]]
git-tree-sha1 = "99a6f5d0ce1c7c6afdb759daa30226f71c54f6b0"
Expand Down
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
FLoops = "cc61a311-1640-44b5-9fba-1b764f453329"
FileIO = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Expand Down
5 changes: 4 additions & 1 deletion src/ISOKANN.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ using SparseArrays: sparse
import Distances
import ProgressMeter
import ChainRulesCore
import Chemfiles
import Flux
import StatsBase, Zygote, Optimisers, Flux, JLD2
import LsqFit
Expand Down Expand Up @@ -74,7 +75,7 @@ export SimulationData
export addcoords
export getxs, getys
export exit_rates
export load_trajectory, save_trajectory
export load_coords, save_coords
export atom_indices
export localpdistinds, pdists, restricted_localpdistinds
export data_from_trajectory, mergedata
Expand Down Expand Up @@ -120,4 +121,6 @@ include("reactionpath.jl")
include("makie.jl")
include("bonito.jl")

include("fileio.jl")

end
130 changes: 130 additions & 0 deletions src/fileio.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@

"""
save_iso(path::String, iso::Iso)

Save the complete Iso object to a JLD2 file """
save_iso(path::String, iso::Iso) = JLD2.save(path, "iso", cpu(iso))

"""
load(path::String, iso::Iso)

Load the Iso object from a JLD2 file
Note that it will be loaded to the CPU, even if it was saved on the GPU.
"""
function load_iso(path::String)
iso = JLD2.load(path, "iso")
if CUDA.has_cuda()
return gpu(iso)
else
return iso
end
end

load_coords(filename, steps) = load_coords_chemfiles(filename, steps)
save_coords(filename, coords, topology) = save_coords_chemfiles(filename, coords, topology=topology)

@deprecate readchemfile load_coords_chemfiles
@deprecate writechemfile(filename, data::Array{<:Any,2}; source) save_coords_chemfiles(filename, coords, topology=source)
@deprecate load_trajectory(filename; top, stride, atom_indices) load_coords_mdcoords(filename, topology=top; stride, atom_indices)

### Chemfiles

import Chemfiles

function load_coords_chemfiles(topology::String, frames=:)
traj = Chemfiles.Trajectory(topology, 'r')
try
load_coords_chemfiles(traj, frames)
finally
close(traj)
end
end

function load_coords_chemfile(traj::Chemfiles.Trajectory, frames)
frame = Chemfiles.read_step(traj, 0)
xs = Array{Float32}(undef, length(Chemfiles.positions(frame)), length(frames))
read = fill(length(frames), false)
for i in frames
Chemfiles.read_step!(traj, i - 1, frame)
try
xs[:, i] .= Chemfiles.positions(frame).coords |> vec
read[i] = true
catch
end

end
xs = xs[:, read]
xs ./= 10 # convert from Angstrom to nm
return xs
end

load_coords_chemfiles(traj::Chemfiles.Trajectory, frames::Colon=:) =
load_coords_chemfiles(traj::Chemfiles.Trajectory, Base.OneTo(length(traj)))

load_coords_chemfiles(traj::Chemfiles.Trajectory, frame::Int) =
load_coords_chemfiles(traj, frame:frame) |> vec

function save_coords_chemfiles(filename, coords::Array{<:Any,2}; topology)
coords = cpu(coords)
trajectory = Chemfiles.Trajectory(topology, 'r')
try
frame = Chemfiles.read(trajectory)
trajectory = Chemfiles.Trajectory(filename, 'w', uppercase(split(filename, ".")[end]))
for i in 1:size(coords, 2)
Chemfiles.positions(frame) .= reshape(coords[:, i], 3, :) .* 10 # convert from nm to Angstrom
write(trajectory, frame)
end
finally
close(trajectory)
end
end

## MDTraj

"""
load_coords_mdcoords(filename; topology=nothing, kwargs...)

wrapper around Python's `mdtraj.load()`.
Returns a (3 * natom, nframes) shaped array.
"""
function load_coords_mdcoords(filename; topology::Union{Nothing,String}=nothing, stride=nothing, atom_indices=nothing)
mdtraj = pyimport_conda("mdtraj", "mdtraj", "conda-forge")

if isnothing(topology)
if filename[end-2:end] == "pdb"
topology = filename
else
error("must supply topology file (.pdb) to the topology argument")
end
end

if !isnothing(atom_indices)
atom_indices = atom_indices .- 1
end

traj = mdtraj.load(filename; topology, stride, atom_indices)
xs = permutedims(PyArray(py"$traj.xyz"o), (3, 2, 1))
xs = reshape(xs, :, size(xs, 3))
return xs::Matrix{Float32}
end

"""
save_coords_mdcoords(filename, coords::AbstractMatrix; topology::String)

save the trajectory given in `coords` to `filename` with the topology provided by the file `topology`
"""
function save_coords_mdcoords(filename, coords::AbstractMatrix, topology::String)
coords = cpu(coords)
mdtraj = pyimport_conda("mdtraj", "mdtraj", "conda-forge")
traj = mdtraj.load(topology, stride=-1)
xyz = reshape(coords, 3, :, size(coords, 2))
traj = mdtraj.Trajectory(PyReverseDims(xyz), traj.topology)
traj.save(filename)
end

function atom_indices(filename::String, selector::String)
mdtraj = pyimport_conda("mdtraj", "mdtraj", "conda-forge")
traj = mdtraj.load(filename, stride=-1)
inds = traj.topology.select(selector) .+ 1
return inds::Vector{Int}
end
8 changes: 4 additions & 4 deletions src/iso.jl
Original file line number Diff line number Diff line change
Expand Up @@ -257,13 +257,13 @@ Save the coordinates of the specified observation indices from the data of of `i

Save the coordinates of the specified matrix of coordinates to a file, using the molecule in `iso` as a template.
"""
function savecoords(path::String, iso::Iso, inds=1:numobs(iso.data))
function save_coords(path::String, iso::Iso, inds=1:numobs(iso.data))
coords = getcoords(iso.data)[:,inds]
savecoords(path, iso, coords)
save_coords(path, coords, iso)
end

function savecoords(path::String, iso::Iso, coords::AbstractMatrix)
savecoords(path, iso.data.sim, coords)
function save_coords(path::String, coords::AbstractMatrix, iso::Iso)
OpenMM.save_coords_openmm(path, coords, iso.data.sim)
end

"""
Expand Down
51 changes: 50 additions & 1 deletion src/makie.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,55 @@ function onlychanges(a::Observable)
end
return o
end
onlychanges(x) = x

observe(x) = x isa Observable ? x : Observable(x)

function visgradients(iso::Iso, x=getcoords(iso.data))
dx = mapreduce(hcat, eachcol(x)) do c
ISOKANN.dchidx(iso, c)
end

plotmol(x, iso.data.sim.pysim, grad=dx, showatoms=false)
end



function plotmol(c, pysim, color=1; grad=nothing, kwargs...)
c = observe(c)
color = observe(color)

fig = Figure()
frameselector = SliderGrid(fig[1, 1],
(label="Frame", range=@lift(1:size($c, 2)), startvalue=1))

i = frameselector.sliders[1].value
col = @lift $c[:, $i]


ax = LScene(fig[2, 1], show_axis=false)
plotmol!(ax, col, pysim, color; kwargs...)

if !isnothing(grad)
grad = @lift($(observe(grad))[:, $i])
plotgrad!(ax, @lift(reshape($col, 3, :)), @lift(reshape($grad, 3, :)),
arrowsize=0.01, lengthscale=0.2, linecolor=:red, linewidth=0.005)
end

return fig
end

function plotgrad!(ax, c::Observable{T}, dc::Observable{T}; kwargs...) where {T<:AbstractMatrix}

x = @lift vec($c[1, :])
y = @lift vec($c[2, :])
z = @lift vec($c[3, :])
u = @lift vec($dc[1, :])
v = @lift vec($dc[2, :])
w = @lift vec($dc[3, :])

arrows!(ax, x, y, z, u, v, w; kwargs...)
end

function plotmol!(ax, c, pysim, color; showbonds=true, showatoms=true, showbackbone=true, alpha=1.0, linewidth=4)
z = zeros(3, 0)
Expand Down Expand Up @@ -54,11 +103,11 @@ function plotmol!(ax, c, pysim, color; showbonds=true, showatoms=true, showbackb

color = onlychanges(color)


meshscatter!(ax, onlychanges(a), markersize=0.1, color=@lift($color .* ones(size($a, 2))), colorrange=(0.0, 1.0), colormap=:roma,)
lines!(ax, onlychanges(p); linewidth, color=color, colorrange=(0.0, 1.0), colormap=:roma,)
linesegments!(ax, onlychanges(b); color=color, colorrange=(0.0, 1.0), colormap=:roma, alpha)


ax
end

Expand Down
89 changes: 0 additions & 89 deletions src/molutils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -109,95 +109,6 @@ function split_first_dimension(A, d)
reshape(A, (d, div(s1, d), s2...))
end

"""
load_trajectory(filename; top=nothing, kwargs...)

wrapper around Python's `mdtraj.load()`.
Returns a (3 * natom, nframes) shaped array.
"""
function load_trajectory(filename; top::Union{Nothing,String}=nothing, stride=nothing, atom_indices=nothing)
mdtraj = pyimport_conda("mdtraj", "mdtraj", "conda-forge")

if isnothing(top)
if filename[end-2:end] == "pdb"
top = filename
else
error("must supply topology file (.pdb) to the top argument")
end
end

if !isnothing(atom_indices)
atom_indices = atom_indices .- 1
end

traj = mdtraj.load(filename; top, stride, atom_indices)
xs = permutedims(PyArray(py"$traj.xyz"o), (3, 2, 1))
xs = reshape(xs, :, size(xs, 3))
return xs::Matrix{Float32}
end

"""
save_trajectory(filename, coords::AbstractMatrix; top::String)

save the trajectory given in `coords` to `filename` with the topology provided by the file `top`
"""
function save_trajectory(filename, coords::AbstractMatrix; top::String)
mdtraj = pyimport_conda("mdtraj", "mdtraj", "conda-forge")
traj = mdtraj.load(top, stride=-1)
xyz = reshape(coords, 3, :, size(coords, 2))
traj = mdtraj.Trajectory(PyReverseDims(xyz), traj.topology)
traj.save(filename)
end

function atom_indices(filename::String, selector::String)
mdtraj = pyimport_conda("mdtraj", "mdtraj", "conda-forge")
traj = mdtraj.load(filename, stride=-1)
inds = traj.top.select(selector) .+ 1
return inds::Vector{Int}
end

import Chemfiles

function readchemfile(source::String, steps=:)
traj = Chemfiles.Trajectory(source, 'r')
try
readchemfile(traj, steps)
finally
close(traj)
end
end

function readchemfile(traj::Chemfiles.Trajectory, frames)
frame = Chemfiles.read_step(traj, 0)
xs = Array{Float32}(undef, length(Chemfiles.positions(frame)), length(frames))
for (i, s) in enumerate(frames)
Chemfiles.read_step!(traj, s - 1, frame)
xs[:, i] .= Chemfiles.positions(frame).data |> vec
end
xs ./= 10 # convert from Angstrom to nm
return xs
end

readchemfile(traj::Chemfiles.Trajectory, frames::Colon=:) =
readchemfile(traj::Chemfiles.Trajectory, Base.OneTo(length(traj)))

readchemfile(traj::Chemfiles.Trajectory, frame::Int) =
readchemfile(traj, frame:frame) |> vec

function writechemfile(filename, data::Array{<:Any,2}; source)
trajectory = Chemfiles.Trajectory(source, 'r')
try
frame = Chemfiles.read(trajectory)
trajectory = Chemfiles.Trajectory(filename, 'w', uppercase(split(filename, ".")[end]))
for i in 1:size(data, 2)
Chemfiles.positions(frame) .= reshape(data[:, i], 3, :) .* 10 # convert from nm to Angstrom
write(trajectory, frame)
end
finally
close(trajectory)
end
end

mutable struct LazyTrajectory <: AbstractMatrix{Float32}
path::String
traj::Chemfiles.Trajectory
Expand Down
3 changes: 2 additions & 1 deletion src/simulators/openmm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -284,8 +284,9 @@ function potential(sim::OpenMMSimulation, x)
v = v.value_in_unit(v.unit)
end

@deprecate savecoords(path, sim::OpenMMSimulation, coords) save_traj_openmm(path, coords, sim)

function savecoords(path, sim::OpenMMSimulation, coords::AbstractArray{T}) where {T}
function save_coords_openmm(path, coords::AbstractArray{T}, sim::OpenMMSimulation) where {T}
coords = ISOKANN.cpu(coords)
s = sim.pysim
p = py"pdbfile.PDBFile"
Expand Down