|
428 | 428 | ],
|
429 | 429 | "source": [
|
430 | 430 | "import torch\n",
|
431 |
| - "import torchvision\n", |
432 | 431 | "\n",
|
433 | 432 | "torch.hub._validate_not_a_forked_repo=lambda a,b,c: True\n",
|
434 | 433 | "\n",
|
|
558 | 557 | "from PIL import Image\n",
|
559 | 558 | "from torchvision import transforms\n",
|
560 | 559 | "import matplotlib.pyplot as plt\n",
|
561 |
| - "import json \n", |
| 560 | + "import json\n", |
562 | 561 | "\n",
|
563 | 562 | "fig, axes = plt.subplots(nrows=2, ncols=2)\n",
|
564 | 563 | "\n",
|
|
571 | 570 | " transforms.ToTensor(),\n",
|
572 | 571 | " transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n",
|
573 | 572 | " ])\n",
|
574 |
| - " input_tensor = preprocess(img) \n", |
| 573 | + " input_tensor = preprocess(img)\n", |
575 | 574 | " plt.subplot(2,2,i+1)\n",
|
576 | 575 | " plt.imshow(img)\n",
|
577 | 576 | " plt.axis('off')\n",
|
578 | 577 | "\n",
|
579 |
| - "# loading labels \n", |
580 |
| - "with open(\"./data/imagenet_class_index.json\") as json_file: \n", |
| 578 | + "# loading labels\n", |
| 579 | + "with open(\"./data/imagenet_class_index.json\") as json_file:\n", |
581 | 580 | " d = json.load(json_file)"
|
582 | 581 | ]
|
583 | 582 | },
|
|
614 | 613 | " preprocess = rn50_preprocess()\n",
|
615 | 614 | " input_tensor = preprocess(img)\n",
|
616 | 615 | " input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model\n",
|
617 |
| - " \n", |
| 616 | + "\n", |
618 | 617 | " # move the input and model to GPU for speed if available\n",
|
619 | 618 | " if torch.cuda.is_available():\n",
|
620 | 619 | " input_batch = input_batch.to('cuda')\n",
|
|
624 | 623 | " output = model(input_batch)\n",
|
625 | 624 | " # Tensor of shape 1000, with confidence scores over Imagenet's 1000 classes\n",
|
626 | 625 | " sm_output = torch.nn.functional.softmax(output[0], dim=0)\n",
|
627 |
| - " \n", |
| 626 | + "\n", |
628 | 627 | " ind = torch.argmax(sm_output)\n",
|
629 | 628 | " return d[str(ind.item())], sm_output[ind] #([predicted class, description], probability)\n",
|
630 | 629 | "\n",
|
|
633 | 632 | " input_data = input_data.to(\"cuda\")\n",
|
634 | 633 | " if dtype=='fp16':\n",
|
635 | 634 | " input_data = input_data.half()\n",
|
636 |
| - " \n", |
| 635 | + "\n", |
637 | 636 | " print(\"Warm up ...\")\n",
|
638 | 637 | " with torch.no_grad():\n",
|
639 | 638 | " for _ in range(nwarmup):\n",
|
|
695 | 694 | "for i in range(4):\n",
|
696 | 695 | " img_path = './data/img%d.JPG'%i\n",
|
697 | 696 | " img = Image.open(img_path)\n",
|
698 |
| - " \n", |
| 697 | + "\n", |
699 | 698 | " pred, prob = predict(img_path, resnet50_model)\n",
|
700 | 699 | " print('{} - Predicted: {}, Probablility: {}'.format(img_path, pred, prob))\n",
|
701 | 700 | "\n",
|
702 | 701 | " plt.subplot(2,2,i+1)\n",
|
703 |
| - " plt.imshow(img);\n", |
704 |
| - " plt.axis('off');\n", |
| 702 | + " plt.imshow(img)\n", |
| 703 | + " plt.axis('off')\n", |
705 | 704 | " plt.title(pred[1])"
|
706 | 705 | ]
|
707 | 706 | },
|
|
0 commit comments