diff --git a/examples/app/components/serve/gradio/app.py b/examples/app/components/serve/gradio/app.py index b0f81bf26c6f9..eef7000305b85 100644 --- a/examples/app/components/serve/gradio/app.py +++ b/examples/app/components/serve/gradio/app.py @@ -14,7 +14,7 @@ class AnimeGANv2UI(ServeGradio): inputs = gr.inputs.Image(type="pil") outputs = gr.outputs.Image(type="pil") - elon = "https://upload.wikimedia.org/wikipedia/commons/3/34/Elon_Musk_Royal_Society_%28crop2%29.jpg" + elon = "https://upload.wikimedia.org/wikipedia/commons/thumb/3/34/Elon_Musk_Royal_Society_%28crop2%29.jpg/330px-Elon_Musk_Royal_Society_%28crop2%29.jpg" img = Image.open(requests.get(elon, stream=True).raw) img.save("elon.jpg") examples = [["elon.jpg"]] diff --git a/src/lightning/app/CHANGELOG.md b/src/lightning/app/CHANGELOG.md index 6018fbf69935b..4ea16d1b3bcd0 100644 --- a/src/lightning/app/CHANGELOG.md +++ b/src/lightning/app/CHANGELOG.md @@ -9,7 +9,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added -- +- Allow customize `gradio` components with lightning colors ([#17054](https://github.com/Lightning-AI/lightning/pull/17054)) ### Changed diff --git a/src/lightning/app/components/serve/gradio_server.py b/src/lightning/app/components/serve/gradio_server.py index 54ad83d0bff1f..dc9f2d7847415 100644 --- a/src/lightning/app/components/serve/gradio_server.py +++ b/src/lightning/app/components/serve/gradio_server.py @@ -24,6 +24,12 @@ import gradio else: gradio = ModuleType("gradio") + gradio.themes = ModuleType("gradio.themes") + + class __DummyBase: + pass + + gradio.themes.Base = __DummyBase class ServeGradio(LightningWork, abc.ABC): @@ -49,11 +55,12 @@ class ServeGradio(LightningWork, abc.ABC): _start_method = "spawn" - def __init__(self, *args: Any, **kwargs: Any): + def __init__(self, *args: Any, theme: Optional[gradio.themes.Base] = None, **kwargs: Any): requires("gradio")(super().__init__(*args, **kwargs)) assert self.inputs assert self.outputs self._model = None + self._theme = theme or ServeGradio.__get_lightning_gradio_theme() self.ready = False @@ -85,6 +92,7 @@ def run(self, *args: Any, **kwargs: Any): examples=self.examples, title=self.title, description=self.description, + theme=self._theme, ).launch( server_name=self.host, server_port=self.port, @@ -93,3 +101,98 @@ def run(self, *args: Any, **kwargs: Any): def configure_layout(self) -> str: return self.url + + @staticmethod + def __get_lightning_gradio_theme(): + return gradio.themes.Default( + primary_hue=gradio.themes.Color( + "#ffffff", + "#e9d5ff", + "#d8b4fe", + "#c084fc", + "#fcfcfc", + "#a855f7", + "#9333ea", + "#8823e1", + "#6b21a8", + "#2c2730", + "#1c1c1c", + ), + secondary_hue=gradio.themes.Color( + "#c3a1e8", + "#e9d5ff", + "#d3bbec", + "#c795f9", + "#9174af", + "#a855f7", + "#9333ea", + "#6700c2", + "#000000", + "#991ef1", + "#33243d", + ), + neutral_hue=gradio.themes.Color( + "#ede9fe", + "#ddd6fe", + "#c4b5fd", + "#a78bfa", + "#fafafa", + "#8b5cf6", + "#7c3aed", + "#6d28d9", + "#6130b0", + "#8a4ce6", + "#3b3348", + ), + ).set( + body_background_fill="*primary_50", + body_background_fill_dark="*primary_950", + body_text_color_dark="*primary_100", + body_text_size="*text_sm", + body_text_color_subdued_dark="*primary_100", + background_fill_primary="*primary_50", + background_fill_primary_dark="*primary_950", + background_fill_secondary="*primary_50", + background_fill_secondary_dark="*primary_950", + border_color_accent="*primary_400", + border_color_accent_dark="*primary_900", + border_color_primary="*primary_600", + border_color_primary_dark="*primary_800", + color_accent="*primary_400", + color_accent_soft="*primary_300", + color_accent_soft_dark="*primary_700", + link_text_color="*primary_500", + link_text_color_dark="*primary_50", + link_text_color_active="*secondary_800", + link_text_color_active_dark="*primary_500", + link_text_color_hover="*primary_400", + link_text_color_hover_dark="*primary_400", + link_text_color_visited="*primary_500", + link_text_color_visited_dark="*secondary_100", + block_background_fill="*primary_50", + block_background_fill_dark="*primary_900", + block_border_color_dark="*primary_800", + checkbox_background_color="*primary_50", + checkbox_background_color_dark="*primary_50", + checkbox_background_color_focus="*primary_100", + checkbox_background_color_focus_dark="*primary_100", + checkbox_background_color_hover="*primary_400", + checkbox_background_color_hover_dark="*primary_500", + checkbox_background_color_selected="*primary_300", + checkbox_background_color_selected_dark="*primary_500", + checkbox_border_color_dark="*primary_200", + checkbox_border_radius="*radius_md", + input_background_fill="*primary_50", + input_background_fill_dark="*primary_900", + input_radius="*radius_xxl", + slider_color="*primary_600", + slider_color_dark="*primary_700", + button_large_radius="*radius_xxl", + button_large_text_size="*text_md", + button_small_radius="*radius_xxl", + button_primary_background_fill_dark="*primary_800", + button_primary_background_fill_hover_dark="*primary_700", + button_primary_border_color_dark="*primary_800", + button_secondary_background_fill="*neutral_200", + button_secondary_background_fill_dark="*primary_600", + ) diff --git a/src/lightning/version.info b/src/lightning/version.info new file mode 100644 index 0000000000000..c175dec033b75 --- /dev/null +++ b/src/lightning/version.info @@ -0,0 +1 @@ +2.1.0dev diff --git a/tests/integrations_app/public/test_gradio.py b/tests/integrations_app/public/test_gradio.py index 1ed4b66e90ad3..7eaf722564543 100644 --- a/tests/integrations_app/public/test_gradio.py +++ b/tests/integrations_app/public/test_gradio.py @@ -28,5 +28,5 @@ def predict(self, *args, **kwargs): assert comp.model == "model" assert comp.predict() == "prediction" gradio_mock.Interface.assert_called_once_with( - fn=ANY, inputs=ANY, outputs=ANY, examples=ANY, title=None, description=None + fn=ANY, inputs=ANY, outputs=ANY, examples=ANY, title=None, description=None, theme=ANY )