Skip to content

Commit d04d780

Browse files
Add files via upload
1 parent e221515 commit d04d780

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

50 files changed

+5772
-0
lines changed

Diff for: wandb/run-20211016_231600-fn82rqkz/files/code/0.py

+56
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
from torch.nn import *
2+
import torch
3+
import torch.nn as nn
4+
import torchvision
5+
import torchvision.transforms as transforms
6+
import torch.nn.functional as F
7+
import matplotlib.pyplot as plt
8+
import pytorch_lightning as pl
9+
from pytorch_lightning import Trainer
10+
from tqdm import tqdm
11+
from torch.utils.data import DataLoader,Dataset
12+
import wandb,os
13+
PROJECT_NAME = 'Intel-Image-Classification-Learning-PyTorch-Lightning'
14+
criterion = MSELoss()
15+
class Model(Module):
16+
def __init__(self):
17+
self.activation = ReLU()
18+
self.linear1 = Linear(3*5*5,256)
19+
self.linear2 = Linear(256,512)
20+
self.linear3 = Linear(512,1024)
21+
self.linear4 = Linear(1024,len(labels))
22+
23+
def training_step(self,batch,batch_idx):
24+
images,labels = batch
25+
preds = self.activation(self.linear1(images))
26+
preds = self.activation(self.linear2(preds))
27+
preds = self.activation(self.linear3(preds))
28+
preds = self.activation(self.linear4(preds))
29+
loss = criterion(preds,labels)
30+
wandb.log({'Loss':loss.item()})
31+
return {'train_loss':loss.item()}
32+
33+
def validation_step(self,batch,batch_idx):
34+
images,labels = batch
35+
preds = self.activation(self.linear1(images))
36+
preds = self.activation(self.linear2(preds))
37+
preds = self.activation(self.linear3(preds))
38+
preds = self.activation(self.linear4(preds))
39+
loss = criterion(preds,labels)
40+
wandb.log({'Val Loss':loss.item()})
41+
return {'val_loss':loss.item()}
42+
43+
def train_dataloader(self):
44+
dataset = torchvision.datasets.MNIST('./data/',download=True,)
45+
data_loader = DataLoader(dataset,batch_size=32,shuffle=True)
46+
return data_loader
47+
48+
def val_dataloader(self):
49+
dataset = torchvision.datasets.MNIST('./data/',download=True,train=False)
50+
data_loader = DataLoader(dataset,batch_size=32,shuffle=True)
51+
return data_loader
52+
53+
wandb.init(project=PROJECT_NAME,name='baseline')
54+
trainer = Trainer()
55+
trainer.fit(Model())
56+
wandb.finish()

0 commit comments

Comments
 (0)