@@ -2199,7 +2199,7 @@ function dpnntest.OneHot()
2199
2199
end
2200
2200
end
2201
2201
2202
- function dpnntest .NCE ()
2202
+ function dpnntest .NCE_main ()
2203
2203
local batchsize = 4
2204
2204
local k = 10
2205
2205
local inputsize = 3
@@ -2353,7 +2353,7 @@ function dpnntest.NCE()
2353
2353
local linear = nn .Linear (inputsize , outputsize )
2354
2354
linear .weight :copy (ncem .weight )
2355
2355
linear .bias :copy (ncem .bias )
2356
- local mlp = nn .Sequential ():add (linear ):add (nn .Exp ())
2356
+ local mlp = nn .Sequential ():add (linear ):add (nn .Exp ()): add ( nn . MulConstant ( 1 / ncem . Z [ 1 ]))
2357
2357
mlp :cuda ()
2358
2358
2359
2359
local output2_ = mlp :forward (input )
@@ -2455,6 +2455,8 @@ function dpnntest.NCE_multinomial()
2455
2455
end
2456
2456
2457
2457
function dpnnbigtest .NCE_benchmark ()
2458
+ pcall (function () require ' cunn' end ) -- make sure to import cunn before initializing large tensors, else weird segfault...
2459
+
2458
2460
local nclass = 1000000
2459
2461
local hiddensize = 200
2460
2462
local batchsize = 50
@@ -2483,8 +2485,6 @@ function dpnnbigtest.NCE_benchmark()
2483
2485
sync = function () cutorch .synchronize () end
2484
2486
end
2485
2487
2486
- print (torch .type (nce .unigrams ))
2487
-
2488
2488
local output = nce :forward {input , target }
2489
2489
local loss = crit :forward (output , target )
2490
2490
local gradOutput = crit :backward (output , target )
@@ -2494,8 +2494,8 @@ function dpnnbigtest.NCE_benchmark()
2494
2494
local loss = nll :forward (output , target )
2495
2495
local gradOutput = nll :backward (output , target )
2496
2496
local gradInput = mlp :backward (input , gradOutput )
2497
- sync ()
2498
2497
2498
+ sync ()
2499
2499
local a = torch .Timer ()
2500
2500
for i = 1 ,nloop do
2501
2501
output = nce :forward {input , target }
@@ -2525,7 +2525,6 @@ function dpnnbigtest.NCE_benchmark()
2525
2525
local ncebwd = a :time ().real
2526
2526
2527
2527
-- mlp nll
2528
-
2529
2528
local a = torch .Timer ()
2530
2529
for i = 1 ,nloop do
2531
2530
output = mlp :forward (input )
@@ -2561,6 +2560,38 @@ function dpnnbigtest.NCE_benchmark()
2561
2560
print (" criterion:backward (nce vs nll)" , critbwd , nllbwd )
2562
2561
print (" module:backward (nce vs linear)" , ncebwd , mlpbwd )
2563
2562
print (" total (nce vs linear)" , ncetotal , lintotal , lintotal / ncetotal )
2563
+
2564
+ if not (cunn and cutorch .getDeviceCount () > 1 ) then
2565
+ return
2566
+ end
2567
+
2568
+ nce :multicuda (1 ,2 )
2569
+
2570
+ local output = nce :forward {input , target }
2571
+ local loss = crit :forward (output , target )
2572
+ local gradOutput = crit :backward (output , target )
2573
+ local gradInput = nce :backward ({input , target }, gradOutput )
2574
+ sync ()
2575
+
2576
+ local a = torch .Timer ()
2577
+ for i = 1 ,nloop do
2578
+ output = nce :forward {input , target }
2579
+ end
2580
+ sync ()
2581
+ local ncefwd2 = a :time ().real
2582
+
2583
+ a :reset ()
2584
+ for i = 1 ,nloop do
2585
+ gradInput = nce :backward ({input , target }, gradOutput )
2586
+ end
2587
+ sync ()
2588
+ local ncebwd2 = a :time ().real
2589
+
2590
+ local total1 = ncefwd + ncebwd
2591
+ local total2 = ncefwd2 + ncebwd2
2592
+ print (" module:forward (1 vs 2 gpu)" , ncefwd , ncefwd2 )
2593
+ print (" module:backward (1 vs 2 gpu)" , ncebwd , ncebwd2 )
2594
+ print (" total (1 vs 2 gpu)" , total1 , total2 , total2 / total1 )
2564
2595
end
2565
2596
2566
2597
function dpnntest .NaN ()
@@ -2599,6 +2630,87 @@ function dpnntest.NaN()
2599
2630
mytester :assert (not pcall (function () nan :backward (input , gradOutput ) end ))
2600
2631
end
2601
2632
2633
+ function dpnntest .NCE_multicuda ()
2634
+ if not pcall (function () require ' torchx' end ) then
2635
+ return
2636
+ end
2637
+ if not pcall (function () require ' cunn' end ) then
2638
+ return
2639
+ end
2640
+ if cutorch .getDeviceCount () < 2 then
2641
+ return
2642
+ end
2643
+ assert (torchx .version and torchx .version >= 1 , " Update torchx" )
2644
+
2645
+ local nclass = 1000
2646
+ local hiddensize = 20
2647
+ local batchsize = 5
2648
+ local k = 25
2649
+ local unigrams = torch .Tensor (nclass ):uniform (0 ,1 )
2650
+ local noise = torch .LongTensor (batchsize , k ):random (1 ,nclass )
2651
+
2652
+ local crit = nn .NCECriterion ():cuda ()
2653
+ local crit2 = nn .NCECriterion ():cuda ()
2654
+
2655
+ local nce = nn .NCEModule (hiddensize , nclass , k , unigrams )
2656
+
2657
+ -- make it deterministic
2658
+ nce .noiseSample = function (self , sampleidx , batchsize , k )
2659
+ sampleidx :resize (batchsize , k )
2660
+ sampleidx :copy (noise )
2661
+ return sampleidx
2662
+ end
2663
+
2664
+ local nce2 = nce :clone ()
2665
+ nce2 :cuda ()
2666
+
2667
+ local input = torch .randn (batchsize , hiddensize ):cuda ()
2668
+ local target = torch .LongTensor (batchsize ):random (1 ,nclass ):cuda ()
2669
+
2670
+ nce :multicuda (1 , 2 )
2671
+
2672
+ local output = nce :forward {input , target }
2673
+ local loss = crit :forward (output , target )
2674
+ local gradOutput = crit :backward (output , target )
2675
+ nce :zeroGradParameters ()
2676
+ local gradInput = nce :backward ({input , target }, gradOutput )
2677
+
2678
+ local output2 = nce2 :forward {input , target }
2679
+ local loss2 = crit2 :forward (output2 , target )
2680
+ local gradOutput2 = crit2 :backward (output2 , target )
2681
+ nce2 :zeroGradParameters ()
2682
+ local gradInput2 = nce2 :backward ({input , target }, gradOutput2 )
2683
+
2684
+ mytester :assertTensorEq (output [1 ], output2 [1 ], 0.00001 )
2685
+ mytester :assertTensorEq (output [2 ], output2 [2 ], 0.00001 )
2686
+ mytester :assertTensorEq (output [3 ], output2 [3 ], 0.00001 )
2687
+ mytester :assertTensorEq (output [4 ], output2 [4 ], 0.00001 )
2688
+
2689
+ mytester :assertTensorEq (gradInput [1 ], gradInput2 [1 ], 0.00001 )
2690
+ mytester :assertTensorEq (gradInput [2 ], gradInput2 [2 ], 0.00001 )
2691
+
2692
+
2693
+ nce2 :updateParameters (0.1 )
2694
+ nce :updateParameters (0.1 )
2695
+
2696
+ mytester :assertTensorEq (nce2 .bias , nce .bias , 0.000001 )
2697
+ mytester :assertTensorEq (nce2 .gradBias , nce .gradBias , 0.000001 )
2698
+ mytester :assertTensorEq (nce2 .weight [{{},{1 ,hiddensize / 2 }}]:float (), nce .weight .tensors [1 ]:float (), 0.000001 )
2699
+ mytester :assertTensorEq (nce2 .weight [{{},{1 + (hiddensize / 2 ), hiddensize }}]:float (), nce .weight .tensors [2 ]:float (), 0.000001 )
2700
+ mytester :assertTensorEq (nce2 .gradWeight [{{},{1 ,hiddensize / 2 }}]:float (), nce .gradWeight .tensors [1 ]:float (), 0.000001 )
2701
+ mytester :assertTensorEq (nce2 .gradWeight [{{},{1 + (hiddensize / 2 ), hiddensize }}]:float (), nce .gradWeight .tensors [2 ]:float (), 0.000001 )
2702
+
2703
+ -- test momentum
2704
+ nce2 :updateGradParameters (0.9 )
2705
+ nce :updateGradParameters (0.9 )
2706
+
2707
+ mytester :assertTensorEq (nce2 .gradBias , nce .gradBias , 0.000001 )
2708
+ mytester :assertTensorEq (nce2 .momGradParams [1 ][{{},{1 ,hiddensize / 2 }}]:float (), nce .momGradParams [1 ].tensors [1 ]:float (), 0.000001 )
2709
+ mytester :assertTensorEq (nce2 .momGradParams [1 ][{{},{1 + (hiddensize / 2 ), hiddensize }}]:float (), nce .momGradParams [1 ].tensors [2 ]:float (), 0.000001 )
2710
+ mytester :assertTensorEq (nce2 .gradWeight [{{},{1 ,hiddensize / 2 }}]:float (), nce .gradWeight .tensors [1 ]:float (), 0.000001 )
2711
+ mytester :assertTensorEq (nce2 .gradWeight [{{},{1 + (hiddensize / 2 ), hiddensize }}]:float (), nce .gradWeight .tensors [2 ]:float (), 0.000001 )
2712
+ end
2713
+
2602
2714
function dpnn .test (tests )
2603
2715
mytester = torch .Tester ()
2604
2716
mytester :add (dpnntest )
0 commit comments