Skip to content

Commit daa10f3

Browse files
Merge pull request #14 from CodeWithKyrian/add-zero-shot-object-detection-pipeline
Add Zero Shot Object Detection Pipeline and OwlVit models
2 parents f794e9e + e4abee3 commit daa10f3

21 files changed

+303
-8
lines changed

bin/transformers

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ $application = new Application();
1212
try {
1313
$application->setName('Transformers PHP CLI');
1414

15-
$application->add(new Codewithkyrian\Transformers\Commands\InitCommand());
15+
// $application->add(new Codewithkyrian\Transformers\Commands\InitCommand());
1616
$application->add(new Codewithkyrian\Transformers\Commands\DownloadModelCommand());
1717

1818
$application->run();

examples/images/astronaut.png

415 KB
Loading

examples/images/beach.png

182 KB
Loading

examples/pipelines/image-to-text.php

+2-2
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@
1313
//$captioner = pipeline('image-to-text', 'Xenova/vit-gpt2-image-captioning');
1414
$captioner = pipeline('image-to-text', 'Xenova/trocr-small-handwritten');
1515

16-
$streamer = StdOutStreamer::make($captioner->tokenizer);
16+
//$streamer = StdOutStreamer::make($captioner->tokenizer);
1717

18-
//$url = __DIR__. '/../images/cats.jpg';
18+
$url = __DIR__. '/../images/beach.png';
1919
//$url = __DIR__. '/../images/handwriting.jpg';
2020
//$url = __DIR__. '/../images/handwriting3.png';
2121
$url = __DIR__. '/../images/handwriting4.jpeg';
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
<?php
2+
3+
declare(strict_types=1);
4+
5+
namespace Codewithkyrian\Transformers\Pipelines;
6+
7+
use function Codewithkyrian\Transformers\Utils\memoryUsage;
8+
use function Codewithkyrian\Transformers\Utils\timeUsage;
9+
10+
require_once './bootstrap.php';
11+
12+
ini_set('memory_limit', '-1');
13+
14+
$detector = pipeline('zero-shot-object-detection', 'Xenova/owlvit-base-patch32');
15+
16+
$url = __DIR__. '/../images/astronaut.png';
17+
$candidateLabels = ['human face', 'rocket', 'helmet', 'american flag'];
18+
19+
$url = __DIR__. '/../images/beach.png';
20+
$candidateLabels = ['hat', 'book', 'sunglasses', 'camera'];
21+
22+
$output = $detector($url, $candidateLabels, topK: 4, threshold: 0.05);
23+
24+
dd($output, timeUsage(), memoryUsage());
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
<?php
2+
3+
declare(strict_types=1);
4+
5+
6+
namespace Codewithkyrian\Transformers\FeatureExtractors;
7+
8+
use Codewithkyrian\Transformers\Models\Output\ObjectDetectionOutput;
9+
use Codewithkyrian\Transformers\Processors\Processor;
10+
11+
class OwlViTFeatureExtractor extends ImageFeatureExtractor
12+
{
13+
/**
14+
* Post-processes the outputs of the model (for object detection).
15+
* @param ObjectDetectionOutput $outputs The outputs of the model that must be post-processed
16+
* @param float $threshold The threshold to use for the scores.
17+
* @param array|null $targetSizes The sizes of the original images.
18+
* @param bool $isZeroShot Whether zero-shot object detection was performed.
19+
* @return array An array of objects containing the post-processed outputs.
20+
*/
21+
public function postProcessObjectDetection(ObjectDetectionOutput $outputs, float $threshold = 0.5, ?array $targetSizes = null, bool $isZeroShot = false): array
22+
{
23+
return Processor::postProcessObjectDetection($outputs, $threshold, $targetSizes, $isZeroShot);
24+
}
25+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
<?php
2+
3+
declare(strict_types=1);
4+
5+
6+
namespace Codewithkyrian\Transformers\FeatureExtractors;
7+
8+
class Owlv2ImageProcessor extends OwlViTFeatureExtractor
9+
{
10+
11+
}

src/Models/Auto/AutoModel.php

+5-1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ class AutoModel extends PretrainedMixin
2121

2222
'detr' => \Codewithkyrian\Transformers\Models\Pretrained\DETRModel::class,
2323
'yolos' => \Codewithkyrian\Transformers\Models\Pretrained\YOLOSModel::class,
24+
'owlvit' => \Codewithkyrian\Transformers\Models\Pretrained\OwlVitModel::class,
25+
'owlv2' => \Codewithkyrian\Transformers\Models\Pretrained\OwlV2Model::class,
2426
];
2527

2628
const ENCODER_DECODER_MODEL_MAPPING = [
@@ -48,7 +50,9 @@ class AutoModel extends PretrainedMixin
4850
AutoModelForMaskedLM::MODEL_CLASS_MAPPING,
4951
AutoModelForQuestionAnswering::MODEL_CLASS_MAPPING,
5052
AutoModelForImageClassification::MODEL_CLASS_MAPPING,
51-
AutoModelForVision2Seq::MODEL_CLASS_MAPPING
53+
AutoModelForVision2Seq::MODEL_CLASS_MAPPING,
54+
AutoModelForObjectDetection::MODEL_CLASS_MAPPING,
55+
AutoModelForZeroShotObjectDetection::MODEL_CLASS_MAPPING,
5256
];
5357

5458

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
<?php
2+
3+
declare(strict_types=1);
4+
5+
6+
namespace Codewithkyrian\Transformers\Models\Auto;
7+
8+
class AutoModelForZeroShotObjectDetection extends PretrainedMixin
9+
{
10+
const MODEL_CLASS_MAPPING = [
11+
'owlvit' => \Codewithkyrian\Transformers\Models\Pretrained\OwlViTForObjectDetection::class,
12+
'owlv2' => \Codewithkyrian\Transformers\Models\Pretrained\Owlv2ForObjectDetection::class,
13+
];
14+
15+
const MODEL_CLASS_MAPPINGS = [
16+
self::MODEL_CLASS_MAPPING,
17+
];
18+
19+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
<?php
2+
3+
declare(strict_types=1);
4+
5+
6+
namespace Codewithkyrian\Transformers\Models\Pretrained;
7+
8+
use Codewithkyrian\Transformers\Models\Output\ObjectDetectionOutput;
9+
10+
class OwlViTForObjectDetection extends OwlViTPretrainedModel
11+
{
12+
public function __invoke(array $modelInputs): ObjectDetectionOutput
13+
{
14+
return ObjectDetectionOutput::fromOutput(parent::__invoke($modelInputs));
15+
}
16+
}

src/Models/Pretrained/OwlViTModel.php

+11
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
<?php
2+
3+
declare(strict_types=1);
4+
5+
6+
namespace Codewithkyrian\Transformers\Models\Pretrained;
7+
8+
class OwlViTModel extends OwlViTPretrainedModel
9+
{
10+
11+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
<?php
2+
3+
declare(strict_types=1);
4+
5+
6+
namespace Codewithkyrian\Transformers\Models\Pretrained;
7+
8+
class OwlViTPretrainedModel extends PretrainedModel
9+
{
10+
11+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
<?php
2+
3+
declare(strict_types=1);
4+
5+
6+
namespace Codewithkyrian\Transformers\Models\Pretrained;
7+
8+
use Codewithkyrian\Transformers\Models\Output\ObjectDetectionOutput;
9+
10+
class Owlv2ForObjectDetection extends Owlv2PretrainedModel
11+
{
12+
public function __invoke(array $modelInputs): ObjectDetectionOutput
13+
{
14+
return ObjectDetectionOutput::fromOutput(parent::__invoke($modelInputs));
15+
}
16+
}

src/Models/Pretrained/Owlv2Model.php

+11
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
<?php
2+
3+
declare(strict_types=1);
4+
5+
6+
namespace Codewithkyrian\Transformers\Models\Pretrained;
7+
8+
class Owlv2Model extends Owlv2PretrainedModel
9+
{
10+
11+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
<?php
2+
3+
declare(strict_types=1);
4+
5+
6+
namespace Codewithkyrian\Transformers\Models\Pretrained;
7+
8+
class Owlv2PretrainedModel extends PretrainedModel
9+
{
10+
11+
}

src/Pipelines/Task.php

+12-2
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
use Codewithkyrian\Transformers\Models\Auto\AutoModelForSequenceClassification;
1515
use Codewithkyrian\Transformers\Models\Auto\AutoModelForTokenClassification;
1616
use Codewithkyrian\Transformers\Models\Auto\AutoModelForVision2Seq;
17+
use Codewithkyrian\Transformers\Models\Auto\AutoModelForZeroShotObjectDetection;
1718
use Codewithkyrian\Transformers\Models\Pretrained\PretrainedModel;
1819
use Codewithkyrian\Transformers\PretrainedTokenizers\AutoTokenizer;
1920
use Codewithkyrian\Transformers\PretrainedTokenizers\PretrainedTokenizer;
@@ -43,6 +44,7 @@ enum Task: string
4344
case ZeroShotImageClassification = 'zero-shot-image-classification';
4445

4546
case ObjectDetection = 'object-detection';
47+
case ZeroShotObjectDetection = 'zero-shot-object-detection';
4648

4749

4850
public function pipeline(PretrainedModel $model, ?PretrainedTokenizer $tokenizer, ?Processor $processor): Pipeline
@@ -78,6 +80,8 @@ public function pipeline(PretrainedModel $model, ?PretrainedTokenizer $tokenizer
7880
self::ZeroShotImageClassification => new ZeroShotImageClassificationPipeline($this, $model, $tokenizer, $processor),
7981

8082
self::ObjectDetection => new ObjectDetectionPipeline($this, $model, $tokenizer, $processor),
83+
84+
self::ZeroShotObjectDetection => new ZeroShotObjectDetectionPipeline($this, $model, $tokenizer, $processor),
8185
};
8286
}
8387

@@ -112,6 +116,8 @@ public function defaultModelName(): string
112116
self::ZeroShotImageClassification => 'Xenova/clip-vit-base-patch32', // Original: 'openai/clip-vit-base-patch32'
113117

114118
self::ObjectDetection => 'Xenova/detr-resnet-50', // Original: 'facebook/detr-resnet-50',
119+
120+
self::ZeroShotObjectDetection => 'Xenova/owlvit-base-patch32', // Original: 'google/owlvit-base-patch32',
115121
};
116122
}
117123

@@ -153,6 +159,8 @@ public function autoModel(
153159
self::ZeroShotImageClassification => AutoModel::fromPretrained($modelNameOrPath, $quantized, $config, $cacheDir, $revision, $modelFilename, $output),
154160

155161
self::ObjectDetection => AutoModelForObjectDetection::fromPretrained($modelNameOrPath, $quantized, $config, $cacheDir, $revision, $modelFilename, $output),
162+
163+
self::ZeroShotObjectDetection => AutoModelForZeroShotObjectDetection::fromPretrained($modelNameOrPath, $quantized, $config, $cacheDir, $revision, $modelFilename, $output),
156164
};
157165
}
158166

@@ -185,7 +193,8 @@ public function autoTokenizer(
185193
self::TokenClassification,
186194
self::Ner,
187195
self::ImageToText,
188-
self::ZeroShotImageClassification => AutoTokenizer::fromPretrained($modelNameOrPath, $quantized, $config, $cacheDir, $revision, null, $output),
196+
self::ZeroShotImageClassification,
197+
self::ZeroShotObjectDetection => AutoTokenizer::fromPretrained($modelNameOrPath, $quantized, $config, $cacheDir, $revision, null, $output),
189198
};
190199
}
191200

@@ -202,7 +211,8 @@ public function autoProcessor(
202211
self::ImageToText,
203212
self::ImageClassification,
204213
self::ZeroShotImageClassification,
205-
self::ObjectDetection => AutoProcessor::fromPretrained($modelNameOrPath, $config, $cacheDir, $revision, $output),
214+
self::ObjectDetection,
215+
self::ZeroShotObjectDetection => AutoProcessor::fromPretrained($modelNameOrPath, $config, $cacheDir, $revision, $output),
206216

207217

208218
self::SentimentAnalysis,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
<?php
2+
3+
declare(strict_types=1);
4+
5+
6+
namespace Codewithkyrian\Transformers\Pipelines;
7+
8+
use Codewithkyrian\Transformers\Models\Output\ObjectDetectionOutput;
9+
use Codewithkyrian\Transformers\Utils\Tensor;
10+
use function Codewithkyrian\Transformers\Utils\getBoundingBox;
11+
use function Codewithkyrian\Transformers\Utils\prepareImages;
12+
13+
/**
14+
* Zero-shot object detection pipeline. This pipeline predicts bounding boxes of
15+
* objects when you provide an image and a set of `candidate_labels`.
16+
*
17+
* **Example:** Zero-shot object detection w/ `Xenova/owlvit-base-patch32`.
18+
* ```php
19+
* $detector = pipeline('zero-shot-object-detection', 'Xenova/owlvit-base-patch32');
20+
* $url = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/astronaut.png';
21+
* $candidateLabels = ['human face', 'rocket', 'helmet', 'american flag'];
22+
* $output = $detector($url, $candidateLabels);
23+
* // [
24+
* // [
25+
* // score: 0.24392342567443848,
26+
* // label: 'human face',
27+
* // box: { xmin: 180, ymin: 67, xmax: 274, ymax: 175 }
28+
* // ],
29+
* // ...
30+
* // ]
31+
* ```
32+
*
33+
* **Example:** Zero-shot object detection w/ `Xenova/owlvit-base-patch32` (returning top 4 matches and setting a threshold).
34+
* ```javascript
35+
* $detector = pipeline('zero-shot-object-detection', 'Xenova/owlvit-base-patch32');
36+
* $url = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/beach.png';
37+
* $candidateLabels = ['hat', 'book', 'sunglasses', 'camera'];
38+
* $output = $detector($url, $candidateLabels, topK : 4, threshold : 0.05);
39+
* // [
40+
* // [
41+
* // score: 0.1606510728597641,
42+
* // label: 'sunglasses',
43+
* // box: { xmin: 347, ymin: 229, xmax: 429, ymax: 264 }
44+
* // ],
45+
* // ...
46+
* // ]
47+
* ```
48+
*/
49+
class ZeroShotObjectDetectionPipeline extends Pipeline
50+
{
51+
52+
public function __invoke(array|string $inputs, ...$args): array
53+
{
54+
$candidateLabels = $args[0];
55+
$threshold = $args['threshold'] ?? 0.1;
56+
$topK = $args['topK'] ?? null;
57+
$percentage = $args['percentage'] ?? false;
58+
59+
$isBatched = is_array($inputs);
60+
61+
$preparedImages = prepareImages($inputs);
62+
63+
// Run tokenization
64+
$textInputs = $this->tokenizer->tokenize($candidateLabels, padding: true, truncation: true);
65+
66+
// Run processor
67+
$modelInputs = ($this->processor)($preparedImages);
68+
69+
$toReturn = [];
70+
foreach ($preparedImages as $i => $image) {
71+
$imageSize = $percentage ? null : [[$image->height(), $image->width()]];
72+
$pixelValues = $modelInputs['pixel_values'][$i];
73+
74+
$pixelValues = Tensor::fromNdArray($pixelValues)->unsqueeze(0);
75+
76+
// Run model with both text and pixel inputs
77+
/** @var ObjectDetectionOutput $output */
78+
$output = $this->model->__invoke(array_merge($textInputs, ['pixel_values' => $pixelValues]));
79+
80+
// Perform post-processing
81+
$processed = $this->processor->featureExtractor->postProcessObjectDetection($output, $threshold, $imageSize, true)[0];
82+
83+
$result = [];
84+
85+
foreach ($processed['boxes'] as $j => $box) {
86+
$result[] = [
87+
'score' => $processed['scores'][$j],
88+
'label' => $candidateLabels[$processed['classes'][$j]],
89+
'box' => getBoundingBox($box, !$percentage),
90+
];
91+
}
92+
// Sort by score
93+
usort($result, fn($a, $b) => $b['score'] <=> $a['score']);
94+
95+
if ($topK !== null) {
96+
$result = array_slice($result, 0, $topK);
97+
}
98+
99+
$toReturn[] = $result;
100+
}
101+
102+
return $isBatched ? $toReturn : $toReturn[0];
103+
}
104+
}

src/Processors/OwlViTProcessor.php

+11
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
<?php
2+
3+
declare(strict_types=1);
4+
5+
6+
namespace Codewithkyrian\Transformers\Processors;
7+
8+
class OwlViTProcessor extends Processor
9+
{
10+
11+
}

0 commit comments

Comments
 (0)