Skip to content

Commit 378f4f3

Browse files
authored
Add native "Stealth" infotext support (#279)
1 parent 9277598 commit 378f4f3

File tree

6 files changed

+286
-36
lines changed

6 files changed

+286
-36
lines changed

modules/gradio_extensons.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
import gradio as gr
2+
from gradio import processing_utils
3+
import warnings
4+
import PIL.ImageOps
25

36
from modules import scripts, ui_tempdir, patches
47

@@ -74,10 +77,75 @@ def Blocks_get_config_file(self, *args, **kwargs):
7477
return config
7578

7679

80+
def Image_upload_handler(self, x):
81+
"""Handles conversion of uploaded images to RGB"""
82+
if isinstance(x, dict) and 'image' in x:
83+
output_image = x['image'].convert('RGB')
84+
return output_image
85+
return x
86+
87+
def Image_custom_preprocess(self, x):
88+
"""Custom preprocessing for images with masks"""
89+
if x is None:
90+
return x
91+
92+
mask = ""
93+
if self.tool == "sketch" and self.source in ["upload", "webcam"]:
94+
if isinstance(x, dict):
95+
x, mask = x["image"], x["mask"]
96+
97+
if not isinstance(x, str):
98+
return x
99+
100+
im = processing_utils.decode_base64_to_image(x)
101+
102+
with warnings.catch_warnings():
103+
warnings.simplefilter("ignore")
104+
im = im.convert(self.image_mode)
105+
106+
if self.shape is not None:
107+
im = processing_utils.resize_and_crop(im, self.shape)
108+
109+
if self.invert_colors:
110+
im = PIL.ImageOps.invert(im)
111+
112+
if (self.source == "webcam"
113+
and self.mirror_webcam is True
114+
and self.tool != "color-sketch"):
115+
im = PIL.ImageOps.mirror(im)
116+
117+
if self.tool == "sketch" and self.source in ["upload", "webcam"]:
118+
mask_im = None
119+
if mask is not None:
120+
mask_im = processing_utils.decode_base64_to_image(mask)
121+
return {
122+
"image": self._format_image(im),
123+
"mask": self._format_image(mask_im)
124+
}
125+
126+
return self._format_image(im)
127+
128+
def Image_init_extension(self, *args, **kwargs):
129+
"""Extended initialization for Image components"""
130+
res = original_Image_init(self, *args, **kwargs)
131+
132+
# Only apply to inpaint with mask component for now
133+
if getattr(self, 'elem_id', None) == 'img2maskimg':
134+
self.upload(
135+
fn=Image_upload_handler.__get__(self, gr.Image),
136+
inputs=self,
137+
outputs=self
138+
)
139+
self.preprocess = Image_custom_preprocess.__get__(self, gr.Image)
140+
141+
return res
142+
143+
77144
original_IOComponent_init = patches.patch(__name__, obj=gr.components.IOComponent, field="__init__", replacement=IOComponent_init)
78145
original_Block_get_config = patches.patch(__name__, obj=gr.blocks.Block, field="get_config", replacement=Block_get_config)
79146
original_BlockContext_init = patches.patch(__name__, obj=gr.blocks.BlockContext, field="__init__", replacement=BlockContext_init)
80147
original_Blocks_get_config_file = patches.patch(__name__, obj=gr.blocks.Blocks, field="get_config_file", replacement=Blocks_get_config_file)
148+
original_Image_init = patches.patch(__name__, obj=gr.components.Image, field="__init__", replacement=Image_init_extension)
81149

82150

83151
ui_tempdir.install_ui_tempdir_override()

modules/images.py

Lines changed: 50 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import json
2222
import hashlib
2323

24-
from modules import sd_samplers, shared, script_callbacks, errors
24+
from modules import sd_samplers, shared, script_callbacks, errors, stealth_infotext
2525
from modules.paths_internal import roboto_ttf_file
2626
from modules.shared import opts
2727

@@ -276,6 +276,9 @@ def resize_image(resize_mode, im, width, height, upscaler_name=None, force_RGBA=
276276
upscaler_name: The name of the upscaler to use. If not provided, defaults to opts.upscaler_for_img2img.
277277
"""
278278

279+
if not force_RGBA and im.mode == 'RGBA':
280+
im = im.convert('RGB')
281+
279282
upscaler_name = upscaler_name or opts.upscaler_for_img2img
280283

281284
def resize(im, w, h):
@@ -715,6 +718,9 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
715718
pnginfo[pnginfo_section_name] = info
716719

717720
params = script_callbacks.ImageSaveParams(image, p, fullfn, pnginfo)
721+
if opts.enable_pnginfo:
722+
stealth_infotext.add_stealth_pnginfo(params)
723+
718724
script_callbacks.before_image_saved_callback(params)
719725

720726
image = params.image
@@ -787,44 +793,53 @@ def _atomically_save_image(image_to_save, filename_without_extension, extension)
787793

788794

789795
def read_info_from_image(image: Image.Image) -> tuple[str | None, dict]:
790-
items = (image.info or {}).copy()
796+
"""Read generation info from an image, checking standard metadata first, then stealth info if needed."""
791797

792-
geninfo = items.pop('parameters', None)
798+
def read_standard():
799+
items = (image.info or {}).copy()
793800

794-
if "exif" in items:
795-
exif_data = items["exif"]
796-
try:
797-
exif = piexif.load(exif_data)
798-
except OSError:
799-
# memory / exif was not valid so piexif tried to read from a file
800-
exif = None
801-
exif_comment = (exif or {}).get("Exif", {}).get(piexif.ExifIFD.UserComment, b'')
802-
try:
803-
exif_comment = piexif.helper.UserComment.load(exif_comment)
804-
except ValueError:
805-
exif_comment = exif_comment.decode('utf8', errors="ignore")
806-
807-
if exif_comment:
808-
geninfo = exif_comment
809-
elif "comment" in items: # for gif
810-
if isinstance(items["comment"], bytes):
811-
geninfo = items["comment"].decode('utf8', errors="ignore")
812-
else:
813-
geninfo = items["comment"]
801+
geninfo = items.pop('parameters', None)
802+
803+
if "exif" in items:
804+
exif_data = items["exif"]
805+
try:
806+
exif = piexif.load(exif_data)
807+
except OSError:
808+
# memory / exif was not valid so piexif tried to read from a file
809+
exif = None
810+
exif_comment = (exif or {}).get("Exif", {}).get(piexif.ExifIFD.UserComment, b'')
811+
try:
812+
exif_comment = piexif.helper.UserComment.load(exif_comment)
813+
except ValueError:
814+
exif_comment = exif_comment.decode('utf8', errors="ignore")
815+
816+
if exif_comment:
817+
geninfo = exif_comment
818+
elif "comment" in items: # for gif
819+
if isinstance(items["comment"], bytes):
820+
geninfo = items["comment"].decode('utf8', errors="ignore")
821+
else:
822+
geninfo = items["comment"]
814823

815-
for field in IGNORED_INFO_KEYS:
816-
items.pop(field, None)
824+
for field in IGNORED_INFO_KEYS:
825+
items.pop(field, None)
817826

818-
if items.get("Software", None) == "NovelAI":
819-
try:
820-
json_info = json.loads(items["Comment"])
821-
sampler = sd_samplers.samplers_map.get(json_info["sampler"], "Euler a")
822-
823-
geninfo = f"""{items["Description"]}
824-
Negative prompt: {json_info["uc"]}
825-
Steps: {json_info["steps"]}, Sampler: {sampler}, CFG scale: {json_info["scale"]}, Seed: {json_info["seed"]}, Size: {image.width}x{image.height}, Clip skip: 2, ENSD: 31337"""
826-
except Exception:
827-
errors.report("Error parsing NovelAI image generation parameters", exc_info=True)
827+
if items.get("Software", None) == "NovelAI":
828+
try:
829+
json_info = json.loads(items["Comment"])
830+
sampler = sd_samplers.samplers_map.get(json_info["sampler"], "Euler a")
831+
832+
geninfo = f"""{items["Description"]}
833+
Negative prompt: {json_info["uc"]}
834+
Steps: {json_info["steps"]}, Sampler: {sampler}, CFG scale: {json_info["scale"]}, Seed: {json_info["seed"]}, Size: {image.width}x{image.height}, Clip skip: 2, ENSD: 31337"""
835+
except Exception:
836+
errors.report("Error parsing NovelAI image generation parameters", exc_info=True)
837+
838+
return geninfo, items
839+
840+
geninfo, items = read_standard()
841+
if geninfo is None:
842+
geninfo = stealth_infotext.read_info_from_image_stealth(image)
828843

829844
return geninfo, items
830845

modules/infotext_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,8 +188,12 @@ def connect_paste_params_buttons():
188188
def send_image_and_dimensions(x):
189189
if isinstance(x, Image.Image):
190190
img = x
191+
if img.mode == 'RGBA':
192+
img = img.convert('RGB')
191193
else:
192194
img = image_from_url_text(x)
195+
if img is not None and img.mode == 'RGBA':
196+
img = img.convert('RGB')
193197

194198
if shared.opts.send_size and isinstance(img, Image.Image):
195199
w = img.width

modules/shared_options.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -431,6 +431,7 @@
431431
It is displayed in UI below the image. To use infotext, paste it into the prompt and click the ↙️ paste button.
432432
"""),
433433
"enable_pnginfo": OptionInfo(True, "Write infotext to metadata of the generated image"),
434+
"stealth_pnginfo_option": OptionInfo("Alpha", "Stealth infotext mode", gr.Radio, {"choices": ["Alpha", "RGB", "None"]}).info("Ignored if infotext is disabled"),
434435
"save_txt": OptionInfo(False, "Create a text file with infotext next to every generated image"),
435436

436437
"add_model_name_to_info": OptionInfo(True, "Add model name to infotext"),

modules/stealth_infotext.py

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
import gzip
2+
3+
from modules.script_callbacks import ImageSaveParams
4+
from modules import shared
5+
6+
7+
def add_stealth_pnginfo(params: ImageSaveParams):
8+
stealth_pnginfo_option = shared.opts.data.get('stealth_pnginfo_option', 'Alpha')
9+
if not stealth_pnginfo_option or stealth_pnginfo_option == 'None':
10+
return
11+
if not params.filename.endswith('.png') or params.pnginfo is None:
12+
return
13+
if 'parameters' not in params.pnginfo:
14+
return
15+
add_data(params, str(stealth_pnginfo_option), True)
16+
17+
def prepare_data(params, mode='Alpha', compressed=True):
18+
signature = f"stealth_{'png' if mode == 'alpha' else 'rgb'}{'info' if not compressed else 'comp'}"
19+
binary_signature = ''.join(format(byte, '08b') for byte in signature.encode('utf-8'))
20+
param = params.encode('utf-8') if not compressed else gzip.compress(bytes(params, 'utf-8'))
21+
binary_param = ''.join(format(byte, '08b') for byte in param)
22+
binary_param_len = format(len(binary_param), '032b')
23+
return binary_signature + binary_param_len + binary_param
24+
25+
def add_data(params, mode='Alpha', compressed=True):
26+
binary_data = prepare_data(params.pnginfo['parameters'], mode, compressed)
27+
if mode == 'Alpha':
28+
params.image.putalpha(255)
29+
width, height = params.image.size
30+
pixels = params.image.load()
31+
index = 0
32+
end_write = False
33+
for x in range(width):
34+
for y in range(height):
35+
if index >= len(binary_data):
36+
end_write = True
37+
break
38+
values = pixels[x, y]
39+
if mode == 'Alpha':
40+
r, g, b, a = values
41+
else:
42+
r, g, b = values
43+
if mode == 'Alpha':
44+
a = (a & ~1) | int(binary_data[index])
45+
index += 1
46+
else:
47+
r = (r & ~1) | int(binary_data[index])
48+
if index + 1 < len(binary_data):
49+
g = (g & ~1) | int(binary_data[index + 1])
50+
if index + 2 < len(binary_data):
51+
b = (b & ~1) | int(binary_data[index + 2])
52+
index += 3
53+
pixels[x, y] = (r, g, b, a) if mode == 'Alpha' else (r, g, b)
54+
if end_write:
55+
break
56+
57+
def read_info_from_image_stealth(image):
58+
width, height = image.size
59+
pixels = image.load()
60+
61+
has_alpha = True if image.mode == 'RGBA' else False
62+
mode = None
63+
compressed = False
64+
binary_data = ''
65+
buffer_a = ''
66+
buffer_rgb = ''
67+
index_a = 0
68+
index_rgb = 0
69+
sig_confirmed = False
70+
confirming_signature = True
71+
reading_param_len = False
72+
reading_param = False
73+
read_end = False
74+
for x in range(width):
75+
for y in range(height):
76+
if has_alpha:
77+
r, g, b, a = pixels[x, y]
78+
buffer_a += str(a & 1)
79+
index_a += 1
80+
else:
81+
r, g, b = pixels[x, y]
82+
buffer_rgb += str(r & 1)
83+
buffer_rgb += str(g & 1)
84+
buffer_rgb += str(b & 1)
85+
index_rgb += 3
86+
if confirming_signature:
87+
if index_a == len('stealth_pnginfo') * 8:
88+
decoded_sig = bytearray(int(buffer_a[i:i + 8], 2) for i in
89+
range(0, len(buffer_a), 8)).decode('utf-8', errors='ignore')
90+
if decoded_sig in {'stealth_pnginfo', 'stealth_pngcomp'}:
91+
confirming_signature = False
92+
sig_confirmed = True
93+
reading_param_len = True
94+
mode = 'alpha'
95+
if decoded_sig == 'stealth_pngcomp':
96+
compressed = True
97+
buffer_a = ''
98+
index_a = 0
99+
else:
100+
read_end = True
101+
break
102+
elif index_rgb == len('stealth_pnginfo') * 8:
103+
decoded_sig = bytearray(int(buffer_rgb[i:i + 8], 2) for i in
104+
range(0, len(buffer_rgb), 8)).decode('utf-8', errors='ignore')
105+
if decoded_sig in {'stealth_rgbinfo', 'stealth_rgbcomp'}:
106+
confirming_signature = False
107+
sig_confirmed = True
108+
reading_param_len = True
109+
mode = 'rgb'
110+
if decoded_sig == 'stealth_rgbcomp':
111+
compressed = True
112+
buffer_rgb = ''
113+
index_rgb = 0
114+
elif reading_param_len:
115+
if mode == 'alpha':
116+
if index_a == 32:
117+
param_len = int(buffer_a, 2)
118+
reading_param_len = False
119+
reading_param = True
120+
buffer_a = ''
121+
index_a = 0
122+
else:
123+
if index_rgb == 33:
124+
pop = buffer_rgb[-1]
125+
buffer_rgb = buffer_rgb[:-1]
126+
param_len = int(buffer_rgb, 2)
127+
reading_param_len = False
128+
reading_param = True
129+
buffer_rgb = pop
130+
index_rgb = 1
131+
elif reading_param:
132+
if mode == 'alpha':
133+
if index_a == param_len:
134+
binary_data = buffer_a
135+
read_end = True
136+
break
137+
else:
138+
if index_rgb >= param_len:
139+
diff = param_len - index_rgb
140+
if diff < 0:
141+
buffer_rgb = buffer_rgb[:diff]
142+
binary_data = buffer_rgb
143+
read_end = True
144+
break
145+
else:
146+
# impossible
147+
read_end = True
148+
break
149+
if read_end:
150+
break
151+
if sig_confirmed and binary_data != '':
152+
# Convert binary string to UTF-8 encoded text
153+
byte_data = bytearray(int(binary_data[i:i + 8], 2) for i in range(0, len(binary_data), 8))
154+
try:
155+
if compressed:
156+
decoded_data = gzip.decompress(bytes(byte_data)).decode('utf-8')
157+
else:
158+
decoded_data = byte_data.decode('utf-8', errors='ignore')
159+
geninfo = decoded_data
160+
except:
161+
pass
162+
return geninfo

modules/ui.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -881,7 +881,7 @@ def select_img2img_tab(tab):
881881
with gr.Blocks(analytics_enabled=False) as pnginfo_interface:
882882
with ResizeHandleRow(equal_height=False):
883883
with gr.Column(variant='panel'):
884-
image = gr.Image(elem_id="pnginfo_image", label="Source", source="upload", interactive=True, type="pil")
884+
image = gr.Image(elem_id="pnginfo_image", label="Source", source="upload", interactive=True, type="pil", image_mode="RGBA")
885885

886886
with gr.Column(variant='panel'):
887887
html = gr.HTML()

0 commit comments

Comments
 (0)