-
Notifications
You must be signed in to change notification settings - Fork 335
Updating FasterRCNN to use Task API #2012
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
ariG23498
wants to merge
50
commits into
keras-team:master
Choose a base branch
from
ariG23498:aritra/port-rcnn
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from 2 commits
Commits
Show all changes
50 commits
Select commit
Hold shift + click to select a range
9bc256f
chore: initial commit
ariG23498 a8ad7c4
review comments
ariG23498 d523a32
Merge branch 'master' into aritra/port-rcnn
ariG23498 ed3337c
chore: train test step modification
ariG23498 301bb1d
Merge branch 'master' into aritra/port-rcnn
ariG23498 005f70d
review nits
ariG23498 da5a01e
chore: adding test
ariG23498 ea88f2c
Merge branch 'master' into aritra/port-rcnn
ariG23498 5c7048f
Merge branch 'master' into aritra/port-rcnn
ariG23498 ac005b8
chore: reformat compute loss
ariG23498 613e29f
chore: faster rcnn call and predict work
ariG23498 dcb648a
resolved conflicts
ariG23498 5bf2bc9
chore: porting roi align to keras core
ariG23498 7d6ef6f
chore: port roi sampler to keras core
ariG23498 f1e3e17
chore: port rpn label encoder to keras core
ariG23498 6478cbf
chore: adding tests and fix lint
ariG23498 7741edc
fix: lint
ariG23498 13a26e6
chore: adding copyright to faster rcnn presets script
ariG23498 0bc4cfa
Merge branch 'master' into aritra/port-rcnn
ariG23498 3b42ecc
chore: removing tf imports
ariG23498 be9178b
fix imports
ariG23498 c3b0cfa
Merge branch 'master' into aritra/port-rcnn
ariG23498 54fd49c
Merge branch 'master' into aritra/port-rcnn
ariG23498 e59d2b4
fix: style
ariG23498 001162c
chore: making the model functional in init
ariG23498 4889192
Merge branch 'master' into aritra/port-rcnn
ariG23498 4da5ff1
Merge branch 'master' into aritra/port-rcnn
ariG23498 6a51562
Merge branch 'master' into aritra/port-rcnn
ariG23498 36da548
Merge branch 'master' into aritra/port-rcnn
ariG23498 711c031
Merge branch 'master' into aritra/port-rcnn
ariG23498 9aab0e9
chore: adding static image shapes to backbone in tests
ariG23498 49815d1
fix: parameterised input shape in test
ariG23498 6061f01
fix: reshape
ariG23498 ef279a9
fix: format and output dict
ariG23498 134f897
chore: masking sample weights for box labels -1
ariG23498 e190e1b
chore: fixing sample weights and decode predictions
ariG23498 70f205c
Merge branch 'master' into aritra/port-rcnn
ariG23498 821b7aa
chore: porting roi gen to keras 3 ops
ariG23498 324f7fc
Merge branch 'master' into aritra/port-rcnn
ariG23498 9227255
chore: port roi gen to keras 3
ariG23498 345764f
chore: removing asserts for keras 3
ariG23498 3a714e7
Merge branch 'master' into aritra/port-rcnn
ariG23498 9e7eea0
chore: adding faster rcnn to kokoro build script
ariG23498 af47e3f
chore: changing a bunch of things and keeping it commited for reference
ariG23498 fd20746
Merge branch 'master' into aritra/port-rcnn
ariG23498 2f5c0a2
chore: update roi align
ariG23498 9c85dfc
chore: adding init and compute loss
ariG23498 e26a8ef
chore: format
ariG23498 5a1f5a7
chore: demo.py
ariG23498 7d873f6
Merge branch 'master' into aritra/port-rcnn
ariG23498 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
# Copyright 2022 The KerasCV Authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import pytest | ||
import tensorflow as tf | ||
from absl.testing import parameterized | ||
|
||
from keras_cv.backend import keras | ||
from keras_cv.models import ResNet18V2Backbone | ||
from keras_cv.models.object_detection.__test_utils__ import ( | ||
_create_bounding_box_dataset, | ||
) | ||
from keras_cv.models.object_detection.faster_rcnn.faster_rcnn import FasterRCNN | ||
from keras_cv.tests.test_case import TestCase | ||
|
||
|
||
class FasterRCNNTest(TestCase): | ||
# TODO(ianstenbit): Make FasterRCNN support shapes that are not multiples | ||
# of 128, perhaps by adding a flag to the anchor generator for whether to | ||
# include anchors centered outside of the image. (RetinaNet does use those, | ||
# while FasterRCNN doesn't). For more context on why this is the case, see | ||
# https://github.com/keras-team/keras-cv/pull/1882 | ||
@parameterized.parameters( | ||
((2, 640, 384, 3),), | ||
((2, 512, 512, 3),), | ||
((2, 128, 128, 3),), | ||
) | ||
def test_faster_rcnn_infer(self, batch_shape): | ||
model = FasterRCNN( | ||
num_classes=80, | ||
bounding_box_format="xyxy", | ||
backbone=ResNet18V2Backbone(), | ||
) | ||
images = tf.random.normal(batch_shape) | ||
outputs = model(images, training=False) | ||
# 1000 proposals in inference | ||
self.assertAllEqual([2, 1000, 81], outputs[1].shape) | ||
self.assertAllEqual([2, 1000, 4], outputs[0].shape) | ||
|
||
@parameterized.parameters( | ||
((2, 640, 384, 3),), | ||
((2, 512, 512, 3),), | ||
((2, 128, 128, 3),), | ||
) | ||
def test_faster_rcnn_train(self, batch_shape): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. lets add a model.fit() to test training |
||
model = FasterRCNN( | ||
num_classes=80, | ||
bounding_box_format="xyxy", | ||
backbone=ResNet18V2Backbone(), | ||
) | ||
images = tf.random.normal(batch_shape) | ||
outputs = model(images, training=True) | ||
self.assertAllEqual([2, 1000, 81], outputs[1].shape) | ||
self.assertAllEqual([2, 1000, 4], outputs[0].shape) | ||
|
||
def test_invalid_compile(self): | ||
model = FasterRCNN( | ||
num_classes=80, | ||
bounding_box_format="yxyx", | ||
backbone=ResNet18V2Backbone(), | ||
) | ||
with self.assertRaisesRegex(ValueError, "only accepts"): | ||
model.compile(rpn_box_loss="binary_crossentropy") | ||
with self.assertRaisesRegex(ValueError, "only accepts"): | ||
model.compile( | ||
rpn_classification_loss=keras.losses.BinaryCrossentropy( | ||
from_logits=False | ||
) | ||
) | ||
|
||
@pytest.mark.large # Fit is slow, so mark these large. | ||
def test_faster_rcnn_with_dictionary_input_format(self): | ||
faster_rcnn = FasterRCNN( | ||
num_classes=20, | ||
bounding_box_format="xywh", | ||
backbone=ResNet18V2Backbone(), | ||
) | ||
|
||
images, boxes = _create_bounding_box_dataset("xywh") | ||
dataset = tf.data.Dataset.from_tensor_slices( | ||
{"images": images, "bounding_boxes": boxes} | ||
).batch(5, drop_remainder=True) | ||
|
||
faster_rcnn.compile( | ||
optimizer=keras.optimizers.Adam(), | ||
box_loss="Huber", | ||
classification_loss="SparseCategoricalCrossentropy", | ||
rpn_box_loss="Huber", | ||
rpn_classification_loss="BinaryCrossentropy", | ||
) | ||
|
||
faster_rcnn.fit(dataset, epochs=1) | ||
faster_rcnn.evaluate(dataset) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.