Skip to content

Implementation of linear models (OLS/LASSO/Ridge) in base PyTorch

License

Notifications You must be signed in to change notification settings

njelicic/pytorch-linear-models

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

24 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

pytorch-linear-models

Implementation of linear models (OLS/LASSO/Ridge) in base PyTorch (so no torch.nn). This repo tries to follow the sklearn API for easy integration with existing projects.

Usage: Regression

from regression import LinearRegression

clf = LinearRegression(penalty=None)  # Penalty can be one of: None for  OLS, 'l1' for LASSO or 'l2' for Ridge

clf.fit(X_train,y_train)              # Fit the model like any sklearn model

clf.plot_history()                    # Plot loss over time

clf.predict(X_test)                   # Make predictions on new data

Usage: Classification

from classification import LogisticRegression

clf = LogisticRegression(penalty=None)# Penalty can be one of: None for  OLS, 'l1' for LASSO or 'l2' for Ridge

clf.fit(X_train,y_train)              # Fit the model like any sklearn model

clf.plot_history()                    # Plot loss over time

clf.predict(X_test)                   # Make predictions on new data

Requirements:

  • torch==1.7.0
  • numpy==1.18.5
  • seaborn==0.10.0

About

Implementation of linear models (OLS/LASSO/Ridge) in base PyTorch

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages