Analyzing How Model Improves

In this tutorial, we'll use AlexNet as example, We can load AlexNet from torchvision

By:

  • Xiaochen Zhang
  • Lai Wei
from torchvision.models.alexnet import AlexNet
import torch

Our example model

model = AlexNet()

Sample data

Create a sample data, something like 2 normalized images in a batch, size 224,224

samp = (torch.rand(2,3,224,224)-1)*2
samp.mean(), samp.std()
(tensor(-1.0012), tensor(0.5775))

Torch Ember

The essence of torch ember, is to place trackers within modules.

It will decorate the forward function to achieve following purpose

  • What variables come in/out of the module
  • The happening sequence, containing relationships between sub-modules
  • The statistics we want for further analysis, eg.
    • Min, Max, Mean, Std, zero coverage of input / outpout tensors
    • Min, Max, Mean, Std, zero coverage of model weights at this iteration
    • Min, Max, Mean, Std, zero coverage of model weight gradients at this iteration

class moduleTrack[source]

moduleTrack(module, name=None, root_module=False)

get_stats[source]

get_stats(tensor)

The default statistic method, it will capture shape of the tensor mean, std, max, min of the tensor this will return a dictionary

class torchEmber[source]

torchEmber(model, verbose=False)

Track a model !!

Start tracking a model

te = torchEmber(model,verbose=True)
start analyzing model
[ARMING][START]2020-03-22 20:00:26
[BUILD FORWARD][model(AlexNet)]2020-03-22 20:00:26
[BUILD FORWARD][model(AlexNet).features(Sequential)]2020-03-22 20:00:26
[BUILD FORWARD][model(AlexNet).features(Sequential).0(Conv2d)]2020-03-22 20:00:26
[BUILD FORWARD][model(AlexNet).features(Sequential).1(ReLU)]2020-03-22 20:00:26
[BUILD FORWARD][model(AlexNet).features(Sequential).2(MaxPool2d)]2020-03-22 20:00:26
[BUILD FORWARD][model(AlexNet).features(Sequential).3(Conv2d)]2020-03-22 20:00:26
[BUILD FORWARD][model(AlexNet).features(Sequential).4(ReLU)]2020-03-22 20:00:26
[BUILD FORWARD][model(AlexNet).features(Sequential).5(MaxPool2d)]2020-03-22 20:00:26
[BUILD FORWARD][model(AlexNet).features(Sequential).6(Conv2d)]2020-03-22 20:00:26
[BUILD FORWARD][model(AlexNet).features(Sequential).7(ReLU)]2020-03-22 20:00:26
[BUILD FORWARD][model(AlexNet).features(Sequential).8(Conv2d)]2020-03-22 20:00:26
[BUILD FORWARD][model(AlexNet).features(Sequential).9(ReLU)]2020-03-22 20:00:26
[BUILD FORWARD][model(AlexNet).features(Sequential).10(Conv2d)]2020-03-22 20:00:26
[BUILD FORWARD][model(AlexNet).features(Sequential).11(ReLU)]2020-03-22 20:00:26
[BUILD FORWARD][model(AlexNet).features(Sequential).12(MaxPool2d)]2020-03-22 20:00:26
[BUILD FORWARD][model(AlexNet).avgpool(AdaptiveAvgPool2d)]2020-03-22 20:00:26
[BUILD FORWARD][model(AlexNet).classifier(Sequential)]2020-03-22 20:00:26
[BUILD FORWARD][model(AlexNet).classifier(Sequential).0(Dropout)]2020-03-22 20:00:26
[BUILD FORWARD][model(AlexNet).classifier(Sequential).1(Linear)]2020-03-22 20:00:26
[BUILD FORWARD][model(AlexNet).classifier(Sequential).2(ReLU)]2020-03-22 20:00:26
[BUILD FORWARD][model(AlexNet).classifier(Sequential).3(Dropout)]2020-03-22 20:00:26
[BUILD FORWARD][model(AlexNet).classifier(Sequential).4(Linear)]2020-03-22 20:00:26
[BUILD FORWARD][model(AlexNet).classifier(Sequential).5(ReLU)]2020-03-22 20:00:26
[BUILD FORWARD][model(AlexNet).classifier(Sequential).6(Linear)]2020-03-22 20:00:26
[ARMING][SUCCESS]2020-03-22 20:00:26
[INFO][20200322_200026]Creating meta data

Remove the trackers we placed

model = model.disarm()
[DISARM][model(AlexNet)]2020-03-21 14:53:38
[DISARM][model(AlexNet).features(Sequential)]2020-03-21 14:53:38
[DISARM][model(AlexNet).features(Sequential).0(Conv2d)]2020-03-21 14:53:38
[DISARM][model(AlexNet).features(Sequential).1(ReLU)]2020-03-21 14:53:38
[DISARM][model(AlexNet).features(Sequential).2(MaxPool2d)]2020-03-21 14:53:38
[DISARM][model(AlexNet).features(Sequential).3(Conv2d)]2020-03-21 14:53:38
[DISARM][model(AlexNet).features(Sequential).4(ReLU)]2020-03-21 14:53:38
[DISARM][model(AlexNet).features(Sequential).5(MaxPool2d)]2020-03-21 14:53:38
[DISARM][model(AlexNet).features(Sequential).6(Conv2d)]2020-03-21 14:53:38
[DISARM][model(AlexNet).features(Sequential).7(ReLU)]2020-03-21 14:53:38
[DISARM][model(AlexNet).features(Sequential).8(Conv2d)]2020-03-21 14:53:38
[DISARM][model(AlexNet).features(Sequential).9(ReLU)]2020-03-21 14:53:38
[DISARM][model(AlexNet).features(Sequential).10(Conv2d)]2020-03-21 14:53:38
[DISARM][model(AlexNet).features(Sequential).11(ReLU)]2020-03-21 14:53:38
[DISARM][model(AlexNet).features(Sequential).12(MaxPool2d)]2020-03-21 14:53:38
[DISARM][model(AlexNet).avgpool(AdaptiveAvgPool2d)]2020-03-21 14:53:38
[DISARM][model(AlexNet).classifier(Sequential)]2020-03-21 14:53:38
[DISARM][model(AlexNet).classifier(Sequential).0(Dropout)]2020-03-21 14:53:38
[DISARM][model(AlexNet).classifier(Sequential).1(Linear)]2020-03-21 14:53:38
[DISARM][model(AlexNet).classifier(Sequential).2(ReLU)]2020-03-21 14:53:38
[DISARM][model(AlexNet).classifier(Sequential).3(Dropout)]2020-03-21 14:53:38
[DISARM][model(AlexNet).classifier(Sequential).4(Linear)]2020-03-21 14:53:38
[DISARM][model(AlexNet).classifier(Sequential).5(ReLU)]2020-03-21 14:53:38
[DISARM][model(AlexNet).classifier(Sequential).6(Linear)]2020-03-21 14:53:38
[DISARM][DONE]2020-03-21 14:53:38

Or like this

te.disarm()
[DISARM][model(AlexNet)]2020-03-21 14:53:41
[DISARM][model(AlexNet).features(Sequential)]2020-03-21 14:53:41
[DISARM][model(AlexNet).features(Sequential).0(Conv2d)]2020-03-21 14:53:41
[DISARM][model(AlexNet).features(Sequential).1(ReLU)]2020-03-21 14:53:41
[DISARM][model(AlexNet).features(Sequential).2(MaxPool2d)]2020-03-21 14:53:41
[DISARM][model(AlexNet).features(Sequential).3(Conv2d)]2020-03-21 14:53:41
[DISARM][model(AlexNet).features(Sequential).4(ReLU)]2020-03-21 14:53:41
[DISARM][model(AlexNet).features(Sequential).5(MaxPool2d)]2020-03-21 14:53:41
[DISARM][model(AlexNet).features(Sequential).6(Conv2d)]2020-03-21 14:53:41
[DISARM][model(AlexNet).features(Sequential).7(ReLU)]2020-03-21 14:53:41
[DISARM][model(AlexNet).features(Sequential).8(Conv2d)]2020-03-21 14:53:41
[DISARM][model(AlexNet).features(Sequential).9(ReLU)]2020-03-21 14:53:41
[DISARM][model(AlexNet).features(Sequential).10(Conv2d)]2020-03-21 14:53:41
[DISARM][model(AlexNet).features(Sequential).11(ReLU)]2020-03-21 14:53:41
[DISARM][model(AlexNet).features(Sequential).12(MaxPool2d)]2020-03-21 14:53:41
[DISARM][model(AlexNet).avgpool(AdaptiveAvgPool2d)]2020-03-21 14:53:41
[DISARM][model(AlexNet).classifier(Sequential)]2020-03-21 14:53:41
[DISARM][model(AlexNet).classifier(Sequential).0(Dropout)]2020-03-21 14:53:41
[DISARM][model(AlexNet).classifier(Sequential).1(Linear)]2020-03-21 14:53:41
[DISARM][model(AlexNet).classifier(Sequential).2(ReLU)]2020-03-21 14:53:41
[DISARM][model(AlexNet).classifier(Sequential).3(Dropout)]2020-03-21 14:53:41
[DISARM][model(AlexNet).classifier(Sequential).4(Linear)]2020-03-21 14:53:41
[DISARM][model(AlexNet).classifier(Sequential).5(ReLU)]2020-03-21 14:53:41
[DISARM][model(AlexNet).classifier(Sequential).6(Linear)]2020-03-21 14:53:41
[DISARM][DONE]2020-03-21 14:53:41

Okay, refresh the tracker

te.rearm()
[DISARM][model(AlexNet)]2020-03-21 14:53:42
[DISARM][model(AlexNet).features(Sequential)]2020-03-21 14:53:42
[DISARM][model(AlexNet).features(Sequential).0(Conv2d)]2020-03-21 14:53:42
[DISARM][model(AlexNet).features(Sequential).1(ReLU)]2020-03-21 14:53:42
[DISARM][model(AlexNet).features(Sequential).2(MaxPool2d)]2020-03-21 14:53:42
[DISARM][model(AlexNet).features(Sequential).3(Conv2d)]2020-03-21 14:53:42
[DISARM][model(AlexNet).features(Sequential).4(ReLU)]2020-03-21 14:53:42
[DISARM][model(AlexNet).features(Sequential).5(MaxPool2d)]2020-03-21 14:53:42
[DISARM][model(AlexNet).features(Sequential).6(Conv2d)]2020-03-21 14:53:42
[DISARM][model(AlexNet).features(Sequential).7(ReLU)]2020-03-21 14:53:42
[DISARM][model(AlexNet).features(Sequential).8(Conv2d)]2020-03-21 14:53:42
[DISARM][model(AlexNet).features(Sequential).9(ReLU)]2020-03-21 14:53:42
[DISARM][model(AlexNet).features(Sequential).10(Conv2d)]2020-03-21 14:53:42
[DISARM][model(AlexNet).features(Sequential).11(ReLU)]2020-03-21 14:53:42
[DISARM][model(AlexNet).features(Sequential).12(MaxPool2d)]2020-03-21 14:53:42
[DISARM][model(AlexNet).avgpool(AdaptiveAvgPool2d)]2020-03-21 14:53:42
[DISARM][model(AlexNet).classifier(Sequential)]2020-03-21 14:53:42
[DISARM][model(AlexNet).classifier(Sequential).0(Dropout)]2020-03-21 14:53:42
[DISARM][model(AlexNet).classifier(Sequential).1(Linear)]2020-03-21 14:53:42
[DISARM][model(AlexNet).classifier(Sequential).2(ReLU)]2020-03-21 14:53:42
[DISARM][model(AlexNet).classifier(Sequential).3(Dropout)]2020-03-21 14:53:42
[DISARM][model(AlexNet).classifier(Sequential).4(Linear)]2020-03-21 14:53:42
[DISARM][model(AlexNet).classifier(Sequential).5(ReLU)]2020-03-21 14:53:42
[DISARM][model(AlexNet).classifier(Sequential).6(Linear)]2020-03-21 14:53:42
[DISARM][DONE]2020-03-21 14:53:42
[ARMING][START]2020-03-21 14:53:42
[BUILD FORWARD][model(AlexNet)]2020-03-21 14:53:42
[BUILD FORWARD][model(AlexNet).features(Sequential)]2020-03-21 14:53:42
[BUILD FORWARD][model(AlexNet).features(Sequential).0(Conv2d)]2020-03-21 14:53:42
[BUILD FORWARD][model(AlexNet).features(Sequential).1(ReLU)]2020-03-21 14:53:42
[BUILD FORWARD][model(AlexNet).features(Sequential).2(MaxPool2d)]2020-03-21 14:53:42
[BUILD FORWARD][model(AlexNet).features(Sequential).3(Conv2d)]2020-03-21 14:53:42
[BUILD FORWARD][model(AlexNet).features(Sequential).4(ReLU)]2020-03-21 14:53:42
[BUILD FORWARD][model(AlexNet).features(Sequential).5(MaxPool2d)]2020-03-21 14:53:42
[BUILD FORWARD][model(AlexNet).features(Sequential).6(Conv2d)]2020-03-21 14:53:42
[BUILD FORWARD][model(AlexNet).features(Sequential).7(ReLU)]2020-03-21 14:53:42
[BUILD FORWARD][model(AlexNet).features(Sequential).8(Conv2d)]2020-03-21 14:53:42
[BUILD FORWARD][model(AlexNet).features(Sequential).9(ReLU)]2020-03-21 14:53:42
[BUILD FORWARD][model(AlexNet).features(Sequential).10(Conv2d)]2020-03-21 14:53:42
[BUILD FORWARD][model(AlexNet).features(Sequential).11(ReLU)]2020-03-21 14:53:42
[BUILD FORWARD][model(AlexNet).features(Sequential).12(MaxPool2d)]2020-03-21 14:53:42
[BUILD FORWARD][model(AlexNet).avgpool(AdaptiveAvgPool2d)]2020-03-21 14:53:42
[BUILD FORWARD][model(AlexNet).classifier(Sequential)]2020-03-21 14:53:42
[BUILD FORWARD][model(AlexNet).classifier(Sequential).0(Dropout)]2020-03-21 14:53:42
[BUILD FORWARD][model(AlexNet).classifier(Sequential).1(Linear)]2020-03-21 14:53:42
[BUILD FORWARD][model(AlexNet).classifier(Sequential).2(ReLU)]2020-03-21 14:53:42
[BUILD FORWARD][model(AlexNet).classifier(Sequential).3(Dropout)]2020-03-21 14:53:42
[BUILD FORWARD][model(AlexNet).classifier(Sequential).4(Linear)]2020-03-21 14:53:42
[BUILD FORWARD][model(AlexNet).classifier(Sequential).5(ReLU)]2020-03-21 14:53:42
[BUILD FORWARD][model(AlexNet).classifier(Sequential).6(Linear)]2020-03-21 14:53:42
[ARMING][SUCCESS]2020-03-21 14:53:42

Run forward pass for 3 iterations, nothing strange happend

# change log file location to init-00_phase-train_epoch-{epoch}.log
te.mark(phase="train") 
for epoch in range(2):
    # change log file location to init-00_phase-train_epoch-{epoch}.log
    te.mark(epoch=epoch) 
    for batch in range(3):
        te.add_extra(n_batch=batch)
        # forward
        model(samp)
        # track model's weight
        te.record_weight()
        
        
# change log file location to init-00_phase-valid_epoch-{epoch}.log
te.mark(phase="valid") 
for epoch in range(2):
    # change log file location to init-00_phase-valid_epoch-{epoch}.log
    te.mark(epoch=epoch) 
    for batch in range(2):
        te.add_extra(n_batch=batch)
        model(samp)
te.after_train()
!ls -l ~/.torchember/log/
!ls -l ~/.torchember/log/AlexNet_20200320_225811_latest
-rw-r--r--  1 salvor  staff  14896 Mar 20 22:58 /Users/salvor/.torchember/log/AlexNet_20200320_225811_latest
!ls -l ~/.torchember/log/AlexNet_20200321_144938
total 328
-rw-r--r--  1 salvor  staff  52419 Mar 21 14:49 init-00_phase-train_epoch-0.log
-rw-r--r--  1 salvor  staff  52418 Mar 21 14:49 init-00_phase-train_epoch-1.log
-rw-r--r--  1 salvor  staff  14813 Mar 21 14:49 init-00_phase-valid_epoch-0.log
-rw-r--r--  1 salvor  staff  43349 Mar 21 14:51 init-00_phase-valid_epoch-1.log

Track weight/grad separately instead of using log_model()

Track weight for the model (on this time spot)

te.record_weight()

Track gradients for the model

model(samp).mean().backward()
te.record_grad()

As you can see here, for conv layer,

  • grad_0 is for the 1st weight grad tensor(weight),
  • grad_1 is for the 2nd(bias)

Module tree json

This file will be stored at $HOME/.torchember/data/structure_<modelname>_<date>_<time>.json

te.mod_tree()
{'name': 'model(AlexNet)',
 'short': 'model(AlexNet)',
 'children': [{'name': 'model(AlexNet).features(Sequential)',
   'short': 'features(Sequential)',
   'children': [{'name': 'model(AlexNet).features(Sequential).0(Conv2d)',
     'short': '0(Conv2d)'},
    {'name': 'model(AlexNet).features(Sequential).1(ReLU)',
     'short': '1(ReLU)'},
    {'name': 'model(AlexNet).features(Sequential).2(MaxPool2d)',
     'short': '2(MaxPool2d)'},
    {'name': 'model(AlexNet).features(Sequential).3(Conv2d)',
     'short': '3(Conv2d)'},
    {'name': 'model(AlexNet).features(Sequential).4(ReLU)',
     'short': '4(ReLU)'},
    {'name': 'model(AlexNet).features(Sequential).5(MaxPool2d)',
     'short': '5(MaxPool2d)'},
    {'name': 'model(AlexNet).features(Sequential).6(Conv2d)',
     'short': '6(Conv2d)'},
    {'name': 'model(AlexNet).features(Sequential).7(ReLU)',
     'short': '7(ReLU)'},
    {'name': 'model(AlexNet).features(Sequential).8(Conv2d)',
     'short': '8(Conv2d)'},
    {'name': 'model(AlexNet).features(Sequential).9(ReLU)',
     'short': '9(ReLU)'},
    {'name': 'model(AlexNet).features(Sequential).10(Conv2d)',
     'short': '10(Conv2d)'},
    {'name': 'model(AlexNet).features(Sequential).11(ReLU)',
     'short': '11(ReLU)'},
    {'name': 'model(AlexNet).features(Sequential).12(MaxPool2d)',
     'short': '12(MaxPool2d)'}]},
  {'name': 'model(AlexNet).avgpool(AdaptiveAvgPool2d)',
   'short': 'avgpool(AdaptiveAvgPool2d)'},
  {'name': 'model(AlexNet).classifier(Sequential)',
   'short': 'classifier(Sequential)',
   'children': [{'name': 'model(AlexNet).classifier(Sequential).0(Dropout)',
     'short': '0(Dropout)'},
    {'name': 'model(AlexNet).classifier(Sequential).1(Linear)',
     'short': '1(Linear)'},
    {'name': 'model(AlexNet).classifier(Sequential).2(ReLU)',
     'short': '2(ReLU)'},
    {'name': 'model(AlexNet).classifier(Sequential).3(Dropout)',
     'short': '3(Dropout)'},
    {'name': 'model(AlexNet).classifier(Sequential).4(Linear)',
     'short': '4(Linear)'},
    {'name': 'model(AlexNet).classifier(Sequential).5(ReLU)',
     'short': '5(ReLU)'},
    {'name': 'model(AlexNet).classifier(Sequential).6(Linear)',
     'short': '6(Linear)'}]}]}
te.mt_log
['enter model(AlexNet)',
 'enter model(AlexNet).features(Sequential)',
 'enter model(AlexNet).features(Sequential).0(Conv2d)',
 'exit model(AlexNet).features(Sequential).0(Conv2d)',
 'enter model(AlexNet).features(Sequential).1(ReLU)',
 'exit model(AlexNet).features(Sequential).1(ReLU)',
 'enter model(AlexNet).features(Sequential).2(MaxPool2d)',
 'exit model(AlexNet).features(Sequential).2(MaxPool2d)',
 'enter model(AlexNet).features(Sequential).3(Conv2d)',
 'exit model(AlexNet).features(Sequential).3(Conv2d)',
 'enter model(AlexNet).features(Sequential).4(ReLU)',
 'exit model(AlexNet).features(Sequential).4(ReLU)',
 'enter model(AlexNet).features(Sequential).5(MaxPool2d)',
 'exit model(AlexNet).features(Sequential).5(MaxPool2d)',
 'enter model(AlexNet).features(Sequential).6(Conv2d)',
 'exit model(AlexNet).features(Sequential).6(Conv2d)',
 'enter model(AlexNet).features(Sequential).7(ReLU)',
 'exit model(AlexNet).features(Sequential).7(ReLU)',
 'enter model(AlexNet).features(Sequential).8(Conv2d)',
 'exit model(AlexNet).features(Sequential).8(Conv2d)',
 'enter model(AlexNet).features(Sequential).9(ReLU)',
 'exit model(AlexNet).features(Sequential).9(ReLU)',
 'enter model(AlexNet).features(Sequential).10(Conv2d)',
 'exit model(AlexNet).features(Sequential).10(Conv2d)',
 'enter model(AlexNet).features(Sequential).11(ReLU)',
 'exit model(AlexNet).features(Sequential).11(ReLU)',
 'enter model(AlexNet).features(Sequential).12(MaxPool2d)',
 'exit model(AlexNet).features(Sequential).12(MaxPool2d)',
 'exit model(AlexNet).features(Sequential)',
 'enter model(AlexNet).avgpool(AdaptiveAvgPool2d)',
 'exit model(AlexNet).avgpool(AdaptiveAvgPool2d)',
 'enter model(AlexNet).classifier(Sequential)',
 'enter model(AlexNet).classifier(Sequential).0(Dropout)',
 'exit model(AlexNet).classifier(Sequential).0(Dropout)',
 'enter model(AlexNet).classifier(Sequential).1(Linear)',
 'exit model(AlexNet).classifier(Sequential).1(Linear)',
 'enter model(AlexNet).classifier(Sequential).2(ReLU)',
 'exit model(AlexNet).classifier(Sequential).2(ReLU)',
 'enter model(AlexNet).classifier(Sequential).3(Dropout)',
 'exit model(AlexNet).classifier(Sequential).3(Dropout)',
 'enter model(AlexNet).classifier(Sequential).4(Linear)',
 'exit model(AlexNet).classifier(Sequential).4(Linear)',
 'enter model(AlexNet).classifier(Sequential).5(ReLU)',
 'exit model(AlexNet).classifier(Sequential).5(ReLU)',
 'enter model(AlexNet).classifier(Sequential).6(Linear)',
 'exit model(AlexNet).classifier(Sequential).6(Linear)',
 'exit model(AlexNet).classifier(Sequential)',
 'exit model(AlexNet)']

Check latest tensor stats

te.t.latest_df
shape mean std max min cnt_zero zero_pct module ts ttype tname n_batch
0 [2, 3, 224, 224] -1.001179 0.577458 -0.000004 -1.999994 0 0.000000 model(AlexNet) 2020-03-22 20:00:38 input_dt x 1
1 [2, 3, 224, 224] -1.001179 0.577458 -0.000004 -1.999994 0 0.000000 model(AlexNet).features(Sequential) 2020-03-22 20:00:38 input_dt input 1
2 [2, 3, 224, 224] -1.001179 0.577458 -0.000004 -1.999994 0 0.000000 model(AlexNet).features(Sequential).0(Conv2d) 2020-03-22 20:00:38 input_dt input 1
3 [2, 64, 55, 55] -0.030966 0.705571 2.591979 -2.716868 0 0.000000 model(AlexNet).features(Sequential).0(Conv2d) 2020-03-22 20:00:38 output_dt output 1
4 [2, 64, 55, 55] -0.030966 0.705571 2.591979 -2.716868 0 0.000000 model(AlexNet).features(Sequential).1(ReLU) 2020-03-22 20:00:38 input_dt input 1
5 [2, 64, 55, 55] 0.273783 0.391649 2.591979 0.000000 196719 0.508055 model(AlexNet).features(Sequential).1(ReLU) 2020-03-22 20:00:38 output_dt output 1
6 [2, 64, 55, 55] 0.273783 0.391649 2.591979 0.000000 196719 0.508055 model(AlexNet).features(Sequential).2(MaxPool2d) 2020-03-22 20:00:38 input_dt input 1
7 [2, 64, 27, 27] 0.562680 0.524378 2.591979 0.000000 24868 0.266504 model(AlexNet).features(Sequential).2(MaxPool2d) 2020-03-22 20:00:38 output_dt output 1
8 [2, 64, 27, 27] 0.562680 0.524378 2.591979 0.000000 24868 0.266504 model(AlexNet).features(Sequential).3(Conv2d) 2020-03-22 20:00:38 input_dt input 1
9 [2, 192, 27, 27] 0.006303 0.424667 1.430273 -1.550920 0 0.000000 model(AlexNet).features(Sequential).3(Conv2d) 2020-03-22 20:00:38 output_dt output 1
10 [2, 192, 27, 27] 0.006303 0.424667 1.430273 -1.550920 0 0.000000 model(AlexNet).features(Sequential).4(ReLU) 2020-03-22 20:00:38 input_dt input 1
11 [2, 192, 27, 27] 0.179221 0.242090 1.430273 0.000000 137386 0.490776 model(AlexNet).features(Sequential).4(ReLU) 2020-03-22 20:00:38 output_dt output 1
12 [2, 192, 27, 27] 0.179221 0.242090 1.430273 0.000000 137386 0.490776 model(AlexNet).features(Sequential).5(MaxPool2d) 2020-03-22 20:00:38 input_dt input 1
13 [2, 192, 13, 13] 0.269740 0.293569 1.430273 0.000000 24218 0.373182 model(AlexNet).features(Sequential).5(MaxPool2d) 2020-03-22 20:00:38 output_dt output 1
14 [2, 192, 13, 13] 0.269740 0.293569 1.430273 0.000000 24218 0.373182 model(AlexNet).features(Sequential).6(Conv2d) 2020-03-22 20:00:38 input_dt input 1
15 [2, 384, 13, 13] -0.002234 0.210906 0.704697 -0.743322 0 0.000000 model(AlexNet).features(Sequential).6(Conv2d) 2020-03-22 20:00:38 output_dt output 1
16 [2, 384, 13, 13] -0.002234 0.210906 0.704697 -0.743322 0 0.000000 model(AlexNet).features(Sequential).7(ReLU) 2020-03-22 20:00:38 input_dt input 1
17 [2, 384, 13, 13] 0.082723 0.123925 0.704697 0.000000 66994 0.516164 model(AlexNet).features(Sequential).7(ReLU) 2020-03-22 20:00:38 output_dt output 1
18 [2, 384, 13, 13] 0.082723 0.123925 0.704697 0.000000 66994 0.516164 model(AlexNet).features(Sequential).8(Conv2d) 2020-03-22 20:00:38 input_dt input 1
19 [2, 256, 13, 13] -0.001783 0.082520 0.273276 -0.314945 0 0.000000 model(AlexNet).features(Sequential).8(Conv2d) 2020-03-22 20:00:38 output_dt output 1
20 [2, 256, 13, 13] -0.001783 0.082520 0.273276 -0.314945 0 0.000000 model(AlexNet).features(Sequential).9(ReLU) 2020-03-22 20:00:38 input_dt input 1
21 [2, 256, 13, 13] 0.032140 0.046769 0.273276 0.000000 43549 0.503294 model(AlexNet).features(Sequential).9(ReLU) 2020-03-22 20:00:38 output_dt output 1
22 [2, 256, 13, 13] 0.032140 0.046769 0.273276 0.000000 43549 0.503294 model(AlexNet).features(Sequential).10(Conv2d) 2020-03-22 20:00:38 input_dt input 1
23 [2, 256, 13, 13] 0.001767 0.034403 0.192531 -0.116374 0 0.000000 model(AlexNet).features(Sequential).10(Conv2d) 2020-03-22 20:00:38 output_dt output 1
24 [2, 256, 13, 13] 0.001767 0.034403 0.192531 -0.116374 0 0.000000 model(AlexNet).features(Sequential).11(ReLU) 2020-03-22 20:00:38 input_dt input 1
25 [2, 256, 13, 13] 0.014330 0.020668 0.192531 0.000000 40435 0.467305 model(AlexNet).features(Sequential).11(ReLU) 2020-03-22 20:00:38 output_dt output 1
26 [2, 256, 13, 13] 0.014330 0.020668 0.192531 0.000000 40435 0.467305 model(AlexNet).features(Sequential).12(MaxPool2d) 2020-03-22 20:00:38 input_dt input 1
27 [2, 256, 6, 6] 0.021939 0.024657 0.192531 0.000000 5918 0.321072 model(AlexNet).features(Sequential).12(MaxPool2d) 2020-03-22 20:00:38 output_dt output 1
28 [2, 256, 6, 6] 0.021939 0.024657 0.192531 0.000000 5918 0.321072 model(AlexNet).features(Sequential) 2020-03-22 20:00:38 output_dt output 1
29 [2, 256, 6, 6] 0.021939 0.024657 0.192531 0.000000 5918 0.321072 model(AlexNet).avgpool(AdaptiveAvgPool2d) 2020-03-22 20:00:38 input_dt input 1
30 [2, 256, 6, 6] 0.021939 0.024657 0.192531 0.000000 5918 0.321072 model(AlexNet).avgpool(AdaptiveAvgPool2d) 2020-03-22 20:00:38 output_dt output 1
31 [2, 9216] 0.021939 0.024657 0.192531 0.000000 5918 0.321072 model(AlexNet).classifier(Sequential) 2020-03-22 20:00:38 input_dt input 1
32 [2, 9216] 0.021939 0.024657 0.192531 0.000000 5918 0.321072 model(AlexNet).classifier(Sequential).0(Dropout) 2020-03-22 20:00:38 input_dt input 1
33 [2, 9216] 0.022088 0.041603 0.385062 0.000000 12155 0.659451 model(AlexNet).classifier(Sequential).0(Dropout) 2020-03-22 20:00:38 output_dt output 1
34 [2, 9216] 0.022088 0.041603 0.385062 0.000000 12155 0.659451 model(AlexNet).classifier(Sequential).1(Linear) 2020-03-22 20:00:38 input_dt input 1
35 [2, 4096] 0.000140 0.027931 0.099764 -0.136028 0 0.000000 model(AlexNet).classifier(Sequential).1(Linear) 2020-03-22 20:00:38 output_dt output 1
36 [2, 4096] 0.000140 0.027931 0.099764 -0.136028 0 0.000000 model(AlexNet).classifier(Sequential).2(ReLU) 2020-03-22 20:00:38 input_dt input 1
37 [2, 4096] 0.011170 0.016329 0.099764 0.000000 4077 0.497681 model(AlexNet).classifier(Sequential).2(ReLU) 2020-03-22 20:00:38 output_dt output 1
38 [2, 4096] 0.011170 0.016329 0.099764 0.000000 4077 0.497681 model(AlexNet).classifier(Sequential).3(Dropout) 2020-03-22 20:00:38 input_dt input 1
39 [2, 4096] 0.011146 0.025565 0.197697 0.000000 6121 0.747192 model(AlexNet).classifier(Sequential).3(Dropout) 2020-03-22 20:00:38 output_dt output 1
40 [2, 4096] 0.011146 0.025565 0.197697 0.000000 6121 0.747192 model(AlexNet).classifier(Sequential).4(Linear) 2020-03-22 20:00:38 input_dt input 1
41 [2, 4096] 0.000314 0.018151 0.066405 -0.062668 0 0.000000 model(AlexNet).classifier(Sequential).4(Linear) 2020-03-22 20:00:38 output_dt output 1
42 [2, 4096] 0.000314 0.018151 0.066405 -0.062668 0 0.000000 model(AlexNet).classifier(Sequential).5(ReLU) 2020-03-22 20:00:38 input_dt input 1
43 [2, 4096] 0.007441 0.010761 0.066405 0.000000 4077 0.497681 model(AlexNet).classifier(Sequential).5(ReLU) 2020-03-22 20:00:38 output_dt output 1
44 [2, 4096] 0.007441 0.010761 0.066405 0.000000 4077 0.497681 model(AlexNet).classifier(Sequential).6(Linear) 2020-03-22 20:00:38 input_dt input 1
45 [2, 1000] -0.000065 0.011815 0.033582 -0.034286 0 0.000000 model(AlexNet).classifier(Sequential).6(Linear) 2020-03-22 20:00:38 output_dt output 1
46 [2, 1000] -0.000065 0.011815 0.033582 -0.034286 0 0.000000 model(AlexNet).classifier(Sequential) 2020-03-22 20:00:38 output_dt output 1
47 [2, 1000] -0.000065 0.011815 0.033582 -0.034286 0 0.000000 model(AlexNet) 2020-03-22 20:00:38 output_dt output 1

Redifine what you want to record

For the default statistic function, you can keep track shape, mean, std, max,min of a tensor.

The afore-mentioned tensor could mean all of the following

  • module input tensors
  • module output tensors
  • module weight
  • gradient of module weight

If you have more interesting metrics to follow, you can redifine the statistic tracking function

Redifine the weight tensor/ weight grad tensor statitic function

@te.set_metric("weight")
def weight_stats(tensor):
    return {"num":tensor.numel(),"row_max":list(row.max().item() for row in tensor)}

Redifine the input or output statitic function

@te.set_metric("in")
def input_stats(tensor):
    return {"num":tensor.numel(),"row_min":list(row.min().item() for row in tensor)}

@te.set_metric("out")
def output_stats(tensor):
    return {"num":tensor.numel(),"row_min":list(row.min().item() for row in tensor)}

Let's give 1 forward pass again

model(samp)
tensor([[ 0.0014,  0.0031,  0.0140,  ...,  0.0199, -0.0074, -0.0077],
        [ 0.0081, -0.0028,  0.0206,  ...,  0.0053, -0.0158, -0.0009]],
       grad_fn=<AddmmBackward>)

The latest stats changed

te.t.latest_df
shape mean std max min cnt_zero zero_pct module ts ttype tname num row_min
0 [64, 3, 11, 11] 0.000199 0.030295 0.052484 -0.052486 0.0 0.0 0.000000 model(AlexNet).features(Sequential).0(Conv2d) 2020-03-21 14:54:21 grad grad_0 NaN NaN
1 [64] 0.005669 0.031589 0.052216 -0.051510 0.0 0.0 0.000000 model(AlexNet).features(Sequential).0(Conv2d) 2020-03-21 14:54:21 grad grad_1 NaN NaN
2 [192, 64, 5, 5] 0.000022 0.014421 0.025000 -0.025000 0.0 0.0 model(AlexNet).features(Sequential).3(Conv2d) 2020-03-21 14:54:21 grad grad_0 NaN NaN
3 [192] -0.001904 0.014245 0.023686 -0.024896 0.0 0.0 model(AlexNet).features(Sequential).3(Conv2d) 2020-03-21 14:54:21 grad grad_1 NaN NaN
4 [384, 192, 3, 3] 0.000005 0.013876 0.024056 -0.024056 0.0 0.0 model(AlexNet).features(Sequential).6(Conv2d) 2020-03-21 14:54:21 grad grad_0 NaN NaN
... ... ... ... ... ... ... ... ... ... ... ... ... ...
59 NaN NaN NaN NaN NaN NaN NaN model(AlexNet).classifier(Sequential).5(ReLU) 2020-03-21 14:54:44 output output 8192.0 [0.0, 0.0]
60 NaN NaN NaN NaN NaN NaN NaN model(AlexNet).classifier(Sequential).6(Linear) 2020-03-21 14:54:44 input input 8192.0 [0.0, 0.0]
61 NaN NaN NaN NaN NaN NaN NaN model(AlexNet).classifier(Sequential).6(Linear) 2020-03-21 14:54:44 output output 2000.0 [-0.035990722477436066, -0.03543027490377426]
62 NaN NaN NaN NaN NaN NaN NaN model(AlexNet).classifier(Sequential) 2020-03-21 14:54:44 output output 2000.0 [-0.035990722477436066, -0.03543027490377426]
63 NaN NaN NaN NaN NaN NaN NaN model(AlexNet) 2020-03-21 14:54:44 output output 2000.0 [-0.035990722477436066, -0.03543027490377426]

64 rows × 13 columns

Placing tracker on variables

To be experimented here

w = list(model.features.parameters())[0]
from types import BuiltinMethodType,BuiltinFunctionType
x1 = torch.rand(5,6)
x2 = torch.rand(5,6)
x3 = x1*6+x2
x2.numel()
30
x1.abs_()
tensor([[0.1462, 0.6524, 0.6635, 0.0931, 0.8485, 0.3402],
        [0.6705, 0.0846, 0.6348, 0.3046, 0.7542, 0.6418],
        [0.6934, 0.4078, 0.9792, 0.1871, 0.7833, 0.6145],
        [0.6606, 0.6178, 0.2674, 0.4398, 0.4242, 0.2114],
        [0.9054, 0.9068, 0.6374, 0.8210, 0.7212, 0.4652]])
from types import MethodType
import inspect
def TorchTensorEmber(x):
    class TensorEmber(x.__class__):
        def __init__(self,x):
            self.host_ = x
            attrs = dir(x)
            for attr in attrs:
                self.super_attr(attr)
            
        def super_attr(self,attr):
            if inspect.isbuiltin(getattr(self.host_,attr))==False: return 
            def func(self,*args,**kwargs):
                print(attr)
                return getattr(super(),attr)(*args,**kwargs)
            func.__name__ = attr
            setattr(self,attr, MethodType(func,self))
            return func
            
    return TensorEmber(x)
x2 = TorchTensorEmber(x2)
x2.add(x1)
add
tensor([[1.0196, 1.1548, 1.1521, 0.1822, 1.7265, 0.4464],
        [1.2865, 0.4544, 0.9891, 0.8650, 1.1334, 1.2300],
        [1.3343, 0.8323, 1.9395, 1.1801, 1.5499, 0.7846],
        [1.1385, 1.2144, 0.6191, 0.6455, 0.9545, 0.8413],
        [1.3088, 1.7986, 1.2820, 1.6781, 1.5974, 0.9433]])
x2+x1
tensor([[1.0196, 1.1548, 1.1521, 0.1822, 1.7265, 0.4464],
        [1.2865, 0.4544, 0.9891, 0.8650, 1.1334, 1.2300],
        [1.3343, 0.8323, 1.9395, 1.1801, 1.5499, 0.7846],
        [1.1385, 1.2144, 0.6191, 0.6455, 0.9545, 0.8413],
        [1.3088, 1.7986, 1.2820, 1.6781, 1.5974, 0.9433]])

LSTM

Experiment on LSTM (rich structure of input and output)

from torch import nn

rnn = nn.LSTM(input_size=20, hidden_size=20,batch_first = True)
te_rnn = torchEmber(rnn)
start analyzing model
for i in range(3):
    x = torch.rand(2,10,20)
    h = torch.zeros(1,2,20)
    c = torch.zeros(1,2,20)
    x,(h,c) = rnn(x,(h,c))

Placing tracker on optimizer

To be experimented here

net = AlexNet()
class amr_model():
    def __init__(self, tm):
        self=tm.model
        
    @property    
    def parse_net(module):
        Ps = set(module.parameters())
        for child in module.children():
            Ps-=set(child.parameters())
            parse_net(child)
        return list(Ps)
        setattr(module,"weights_owned",list(Ps))
parse_net(net)
dic={}
i = 0
for m in net.modules():
    dic.update({f'level_{i}_{m.__class__.__name__}':m.weights_owned})
    i+=1
dic.keys()
dict_keys(['level_0_AlexNet', 'level_1_Sequential', 'level_2_Conv2d', 'level_3_ReLU', 'level_4_MaxPool2d', 'level_5_Conv2d', 'level_6_ReLU', 'level_7_MaxPool2d', 'level_8_Conv2d', 'level_9_ReLU', 'level_10_Conv2d', 'level_11_ReLU', 'level_12_Conv2d', 'level_13_ReLU', 'level_14_MaxPool2d', 'level_15_AdaptiveAvgPool2d', 'level_16_Sequential', 'level_17_Dropout', 'level_18_Linear', 'level_19_ReLU', 'level_20_Dropout', 'level_21_Linear', 'level_22_ReLU', 'level_23_Linear'])
list(net.modules())[2].weights_owned,list(net.modules())[2].__class__.__name__
([Parameter containing:
  tensor([[[[-1.6327e-02,  1.1162e-02, -2.8405e-02,  ..., -1.9573e-02,
             -3.3804e-02, -4.4825e-02],
            [-1.9221e-03, -4.2759e-02,  4.0985e-03,  ...,  2.0986e-02,
              4.7981e-02,  4.7343e-02],
            [ 3.3014e-03, -3.7866e-02, -2.6860e-02,  ..., -3.3356e-02,
             -3.9897e-02,  4.1397e-03],
            ...,
            [ 2.3858e-02, -3.1886e-02,  1.0154e-02,  ...,  4.4415e-02,
              2.4320e-02,  4.3755e-03],
            [ 2.0721e-02, -5.1307e-02, -2.7957e-02,  ..., -1.1665e-02,
             -3.4113e-03, -1.0425e-02],
            [ 1.0254e-02,  1.2262e-02,  4.2601e-02,  ..., -1.3578e-02,
             -2.5370e-02,  4.4870e-02]],
  
           [[-4.5430e-02,  2.3310e-02,  2.2527e-02,  ...,  3.9389e-02,
              1.9963e-02, -3.9786e-02],
            [-1.8303e-02,  4.1754e-02, -4.3693e-02,  ...,  3.3112e-02,
              4.2441e-02,  6.2147e-03],
            [ 3.7984e-02,  6.0208e-03,  5.0205e-02,  ..., -3.9633e-04,
              4.4289e-02, -2.9487e-02],
            ...,
            [ 2.1567e-02,  5.0117e-02,  3.9698e-02,  ...,  4.2929e-02,
              3.9644e-02, -3.8859e-02],
            [ 3.8407e-02, -2.2986e-02, -3.8284e-04,  ..., -4.5617e-02,
              1.2778e-02, -3.4898e-02],
            [-1.9761e-03, -2.5815e-02, -3.1854e-02,  ..., -1.3739e-02,
              1.5286e-02,  1.3604e-03]],
  
           [[ 1.8892e-02,  4.9961e-02, -2.5605e-02,  ..., -5.9871e-03,
              3.9408e-03, -8.3244e-03],
            [ 4.6460e-03,  1.0863e-02, -3.9094e-02,  ...,  4.0489e-02,
             -5.9131e-03,  2.6736e-02],
            [-4.1610e-02,  4.8054e-02,  4.3286e-02,  ..., -1.4439e-02,
              2.7459e-02, -2.4571e-02],
            ...,
            [-4.4373e-02, -4.4802e-02,  3.6243e-02,  ...,  5.1976e-02,
             -3.6861e-02, -2.0367e-02],
            [ 4.6666e-02, -5.8455e-03, -1.8954e-02,  ..., -9.6904e-03,
             -1.2530e-02, -4.5702e-02],
            [ 4.3149e-02,  7.4010e-03, -3.8539e-02,  ...,  4.6101e-02,
             -8.7086e-03, -4.9541e-03]]],
  
  
          [[[-8.0408e-03, -1.7279e-02,  6.7536e-05,  ..., -2.3744e-02,
             -1.2311e-02,  4.7851e-02],
            [ 2.9197e-02,  1.6269e-02,  1.7624e-02,  ..., -1.2846e-02,
              7.2723e-04, -5.2063e-03],
            [-4.6501e-02, -2.9172e-02,  4.1413e-02,  ...,  3.2842e-02,
              7.6062e-03, -3.4498e-02],
            ...,
            [-3.5862e-02,  4.0498e-02, -5.2404e-02,  ...,  9.7134e-03,
              1.9845e-02,  2.3917e-02],
            [ 1.4307e-03,  1.4190e-02, -1.5650e-02,  ...,  7.7413e-03,
              1.1957e-02,  5.4209e-03],
            [ 4.8249e-02,  2.5972e-02, -1.3158e-02,  ..., -3.9730e-02,
              4.4385e-02, -1.5572e-02]],
  
           [[ 4.9841e-02, -4.0689e-02,  4.2824e-02,  ..., -3.6120e-02,
             -3.3183e-02,  1.9866e-02],
            [ 4.8186e-02, -1.4457e-02,  4.9148e-02,  ..., -4.5886e-02,
             -2.5225e-02, -3.5466e-02],
            [-1.0808e-02,  4.4364e-02, -1.7445e-02,  ...,  1.2834e-02,
              4.2694e-03,  1.9969e-02],
            ...,
            [ 3.2471e-02,  4.5638e-02,  1.5253e-02,  ...,  4.8775e-02,
              1.0568e-03,  1.5566e-02],
            [ 4.6002e-02,  6.0134e-03, -1.5038e-02,  ...,  6.4702e-03,
              1.7146e-03,  3.7487e-02],
            [ 3.8742e-02,  1.7727e-02,  3.6878e-02,  ...,  1.8696e-02,
              1.3658e-03,  1.4744e-02]],
  
           [[ 2.4391e-02,  2.2451e-02, -3.9846e-02,  ..., -1.4928e-02,
             -2.0701e-02, -1.4936e-02],
            [-3.2523e-02, -5.3568e-03,  9.4021e-03,  ..., -2.4463e-02,
             -4.0409e-02,  1.6398e-03],
            [ 2.8257e-02,  2.9855e-02, -4.0830e-02,  ...,  3.1422e-02,
              2.6279e-02, -2.4028e-02],
            ...,
            [ 1.2100e-02, -2.9761e-02,  2.3151e-02,  ...,  2.1168e-02,
              1.8855e-02,  1.5994e-02],
            [ 1.6805e-02, -1.1173e-02,  1.3692e-03,  ..., -4.6544e-02,
              5.3142e-03, -9.1379e-03],
            [-3.3663e-02, -3.5115e-02,  2.6473e-02,  ..., -4.2573e-02,
              1.2499e-02,  4.6951e-02]]],
  
  
          [[[ 1.3344e-02, -3.8256e-02, -2.8484e-02,  ...,  7.1430e-03,
             -3.4000e-02, -4.2351e-02],
            [ 1.0588e-02, -5.0056e-02,  4.5984e-02,  ...,  1.0427e-02,
              1.7341e-02, -2.1410e-02],
            [-6.5497e-03,  4.7154e-03, -2.9450e-02,  ..., -9.9096e-03,
              1.5492e-02, -2.3772e-02],
            ...,
            [ 7.1360e-03, -3.0012e-02,  4.1119e-02,  ..., -3.4346e-02,
             -5.2393e-02, -1.5507e-02],
            [ 4.8934e-04, -5.1632e-02, -4.8675e-02,  ...,  9.3624e-03,
              2.6448e-02,  1.5598e-02],
            [-1.6510e-02, -4.1269e-02, -3.4156e-02,  ..., -9.1055e-03,
              3.1496e-02, -4.0468e-02]],
  
           [[ 4.2905e-02, -3.8384e-02, -4.5313e-04,  ...,  4.7658e-02,
             -4.2466e-02,  3.1180e-02],
            [ 4.9251e-02,  5.0616e-02, -1.6698e-02,  ...,  3.3522e-02,
              4.2196e-03, -4.6683e-02],
            [-2.1000e-02, -8.2991e-03,  6.3627e-03,  ..., -1.9115e-02,
             -1.1024e-02,  1.4501e-03],
            ...,
            [-1.5806e-02,  2.8734e-02, -4.2988e-02,  ...,  2.5495e-02,
              1.3237e-02, -2.2480e-02],
            [-2.0212e-02, -2.5854e-02,  7.2619e-03,  ...,  4.5810e-02,
              3.9343e-02,  3.7333e-02],
            [-2.5210e-02, -4.5021e-02, -3.5563e-02,  ..., -1.0309e-02,
             -5.9101e-03,  4.4068e-02]],
  
           [[-9.3068e-03,  1.9526e-02, -4.6282e-02,  ...,  1.8824e-02,
              2.0614e-02, -3.5049e-02],
            [-3.0746e-02,  3.2606e-02,  2.0441e-02,  ..., -1.1884e-02,
             -3.1358e-02, -4.0053e-03],
            [ 3.5490e-02,  4.2098e-02, -2.9597e-02,  ..., -1.1405e-02,
              3.6126e-02,  2.5460e-02],
            ...,
            [ 1.9456e-03,  1.1244e-02,  1.0540e-02,  ..., -2.3625e-02,
              4.4521e-02, -5.6920e-03],
            [ 4.8741e-02,  4.3648e-02,  4.1463e-02,  ..., -2.1848e-02,
             -7.4721e-03, -4.7804e-02],
            [-6.5712e-03, -4.3555e-02, -8.6372e-03,  ...,  4.9737e-02,
              2.3035e-02, -2.7911e-02]]],
  
  
          ...,
  
  
          [[[ 4.9561e-02, -3.1964e-03,  4.8666e-02,  ...,  9.8746e-03,
             -3.2994e-02, -6.6449e-03],
            [-1.8285e-02, -1.5940e-02, -4.2757e-02,  ...,  2.7441e-02,
              3.7931e-02,  3.5887e-02],
            [ 4.1755e-02,  3.2321e-02, -1.7707e-02,  ...,  3.7276e-02,
             -4.3789e-02,  2.8113e-02],
            ...,
            [ 5.1050e-02, -1.6813e-02, -2.5724e-02,  ...,  2.4554e-02,
             -2.4526e-02, -2.4625e-02],
            [ 2.9677e-02,  1.0142e-02,  1.7238e-02,  ...,  1.1841e-02,
             -4.1555e-02, -3.4420e-02],
            [ 5.1649e-02, -1.0868e-03,  3.1674e-02,  ...,  6.7830e-03,
              3.2504e-02, -2.3810e-02]],
  
           [[ 2.2200e-02, -3.1003e-02, -2.2400e-02,  ..., -9.4456e-03,
             -1.1570e-02, -4.4553e-03],
            [ 3.0843e-02,  3.6150e-02,  3.5662e-02,  ..., -3.1607e-02,
              5.0240e-02, -1.2652e-03],
            [-9.3393e-03,  1.2245e-02, -4.8486e-02,  ...,  2.3058e-02,
             -3.1741e-02,  1.9298e-02],
            ...,
            [ 4.4604e-02, -2.6501e-03, -4.0699e-03,  ..., -4.8826e-02,
              3.9019e-02, -4.9581e-02],
            [-2.6170e-02,  2.7402e-02,  1.2372e-02,  ..., -3.6543e-03,
              4.9457e-02,  3.8373e-02],
            [-1.4711e-02, -3.9906e-03,  1.7440e-02,  ...,  5.7326e-03,
             -2.6910e-02, -2.0366e-02]],
  
           [[ 1.3091e-02,  6.2031e-03, -7.0316e-03,  ...,  3.3780e-03,
              1.7800e-02,  4.0473e-02],
            [ 4.1367e-02,  1.1167e-02,  2.8024e-02,  ..., -2.8959e-02,
              2.9537e-02, -4.3138e-02],
            [ 4.1558e-02, -1.2385e-03, -3.5784e-02,  ..., -3.3275e-02,
              4.4693e-02, -5.1954e-02],
            ...,
            [ 3.1279e-02, -4.1827e-02,  1.8772e-03,  ...,  4.9399e-02,
              4.8427e-02, -7.4510e-03],
            [ 5.0078e-02,  2.9156e-02,  9.5356e-03,  ...,  3.2966e-02,
             -2.7987e-02, -4.1791e-02],
            [ 1.4730e-02,  2.0569e-02,  2.1631e-02,  ..., -3.3185e-02,
              2.6271e-02, -3.9775e-02]]],
  
  
          [[[-1.2064e-02,  1.0438e-02, -9.4433e-03,  ...,  2.9094e-02,
              5.7960e-03, -4.9238e-02],
            [ 6.0487e-03,  7.2256e-03,  3.3526e-02,  ...,  1.0666e-02,
              1.9697e-02, -4.2439e-02],
            [-1.9803e-02, -2.4073e-02, -1.7229e-02,  ..., -2.7883e-02,
              2.2040e-02, -2.1412e-02],
            ...,
            [ 4.0054e-03,  3.8719e-02, -1.9259e-02,  ...,  6.1719e-03,
             -1.1971e-02, -1.1393e-02],
            [ 3.3959e-03,  4.5294e-02,  3.5897e-02,  ...,  4.2523e-03,
              1.2300e-02, -2.3434e-02],
            [ 4.1054e-02,  4.7831e-02,  9.3155e-03,  ...,  2.3216e-02,
             -4.9854e-02, -5.2689e-03]],
  
           [[ 1.3912e-02, -4.4741e-03, -3.4478e-02,  ...,  4.1369e-02,
              4.7080e-02, -3.1518e-02],
            [-7.2721e-03, -4.3725e-02,  5.1467e-02,  ..., -3.3723e-02,
              3.2725e-02, -1.9429e-02],
            [ 5.4907e-03,  3.5316e-03, -1.7232e-03,  ..., -4.7159e-02,
              3.1791e-02,  1.6001e-02],
            ...,
            [ 2.0772e-02,  2.5290e-02, -3.1089e-02,  ...,  1.9733e-02,
             -2.0684e-02,  3.7561e-03],
            [ 3.4498e-03,  4.5156e-03,  5.1393e-02,  ..., -4.7893e-02,
             -4.6476e-02,  3.9432e-02],
            [-4.8598e-02,  1.2690e-02, -4.0148e-02,  ..., -4.1108e-02,
             -2.5094e-02,  3.7583e-02]],
  
           [[ 4.6840e-02,  2.8525e-02, -3.9433e-03,  ..., -1.7610e-02,
             -4.6216e-02,  4.3124e-02],
            [-5.1405e-02, -1.5895e-02,  1.0340e-03,  ..., -3.0895e-03,
              1.4551e-02,  1.9688e-02],
            [-9.2064e-03,  3.2544e-02,  4.3838e-02,  ...,  1.8025e-02,
             -1.1587e-03, -8.4455e-03],
            ...,
            [ 1.8635e-02, -4.7795e-02,  2.4993e-02,  ...,  9.8532e-04,
              2.7919e-02,  5.1903e-02],
            [ 2.1100e-02,  1.9092e-02, -4.3302e-02,  ..., -2.2819e-02,
              1.1515e-02,  4.0907e-02],
            [ 3.9619e-02,  3.0630e-03, -5.0501e-03,  ...,  2.5631e-02,
             -3.5143e-02, -1.3551e-02]]],
  
  
          [[[ 2.9818e-02,  3.4398e-03, -4.3275e-02,  ...,  4.5394e-02,
             -2.6625e-02,  1.1323e-03],
            [-5.1105e-02,  2.8720e-02, -2.2189e-02,  ..., -8.4933e-03,
              1.4328e-02, -5.2264e-04],
            [-1.4094e-02,  4.8694e-02, -4.8202e-02,  ...,  4.9510e-03,
              4.7436e-02, -3.1204e-02],
            ...,
            [ 4.1559e-03,  2.5767e-02,  7.3608e-03,  ...,  5.1680e-03,
              3.3126e-03, -4.7476e-02],
            [ 3.3597e-02, -2.6299e-02,  2.4095e-02,  ..., -3.7296e-02,
             -3.5387e-02, -2.0719e-02],
            [ 2.7173e-03,  5.0040e-02,  3.0475e-02,  ..., -4.7845e-03,
             -1.6346e-02,  3.9454e-02]],
  
           [[ 2.8541e-02, -4.2643e-03, -4.3947e-02,  ..., -2.5512e-02,
             -2.2716e-02,  1.3897e-02],
            [-3.8172e-02, -1.5592e-02, -3.5201e-02,  ...,  3.6502e-02,
             -4.8578e-02,  3.1697e-04],
            [ 5.2122e-02, -3.9069e-02, -1.6774e-02,  ..., -1.0005e-02,
              1.6590e-02, -3.3140e-02],
            ...,
            [-1.6394e-02, -1.9741e-02, -4.0651e-02,  ..., -3.0367e-02,
              2.9636e-02,  3.8146e-02],
            [-4.1929e-03,  2.1044e-02, -1.3101e-02,  ..., -2.5063e-02,
             -3.3674e-02,  2.5253e-02],
            [ 2.7324e-02,  3.5350e-02, -4.4975e-02,  ..., -4.1634e-02,
             -2.9090e-03,  1.0550e-02]],
  
           [[ 3.3394e-02,  3.0214e-02, -3.1053e-02,  ...,  4.3145e-02,
             -3.7860e-02, -3.0976e-02],
            [ 5.4121e-03,  4.2933e-02,  3.9807e-02,  ...,  3.3977e-02,
             -3.5963e-02,  1.9917e-02],
            [ 3.5320e-02, -4.8012e-02,  1.5429e-03,  ..., -2.8106e-02,
             -3.7762e-02,  2.2052e-02],
            ...,
            [ 4.8067e-02, -4.7444e-02,  1.8627e-02,  ..., -3.7919e-02,
              1.8663e-02, -3.6427e-02],
            [-4.3965e-02,  3.4070e-02, -1.7858e-02,  ..., -3.1441e-03,
             -4.5198e-02, -1.0859e-02],
            [ 5.0178e-02, -3.5188e-02,  5.1515e-02,  ...,  3.7071e-02,
             -2.6922e-02, -7.7885e-03]]]], requires_grad=True),
  Parameter containing:
  tensor([ 0.0252, -0.0287,  0.0137,  0.0356,  0.0411,  0.0133, -0.0525,  0.0304,
          -0.0248,  0.0077, -0.0507,  0.0340, -0.0133, -0.0309, -0.0256,  0.0461,
          -0.0134, -0.0271,  0.0416, -0.0015,  0.0206, -0.0290, -0.0284,  0.0104,
          -0.0041, -0.0270, -0.0391,  0.0178,  0.0073, -0.0088,  0.0420,  0.0303,
           0.0270, -0.0516,  0.0154,  0.0333, -0.0039,  0.0077, -0.0396, -0.0333,
          -0.0352,  0.0499,  0.0002,  0.0204,  0.0301, -0.0254,  0.0192, -0.0297,
           0.0301, -0.0097,  0.0062, -0.0397, -0.0210,  0.0056,  0.0242,  0.0510,
           0.0214, -0.0434,  0.0008, -0.0223, -0.0087,  0.0394,  0.0039, -0.0094],
         requires_grad=True)],
 'Conv2d')