Skip to content

Commit 0ff21df

Browse files
add select feature
1 parent 9c1d1c7 commit 0ff21df

File tree

2 files changed

+8
-5
lines changed

2 files changed

+8
-5
lines changed

App.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ def load_settings():
1616
st.markdown('## Auto Tagging for Fashion Retail Using Multi-label Image Classification')
1717

1818
# Sidebar
19+
global sidebar
1920
sidebar = st.sidebar
2021

2122
with open("assets/test.jpg", "rb") as file:
@@ -35,13 +36,15 @@ def main():
3536
if uploaded_file is not None:
3637
image = Image.open(uploaded_file)
3738

39+
num_of_tags = sidebar.radio("Number of Tags", (3, 4, 5), index=2)
40+
3841
col1, col2 = st.columns(2)
3942
with col1:
4043
st.markdown('<p style="text-align: center;">Your Image</p><hr>', unsafe_allow_html=True)
4144
st.image(image, caption="Fashion Image")
4245
with col2:
4346
st.markdown('<p style="text-align: center;">Image Tags</p><hr>', unsafe_allow_html=True)
44-
tags = predictor.predict(image=image)
47+
tags = predictor.predict(image=image, num_of_tags=num_of_tags)
4548
tags = st_tags(label="Tags:", text="Add more", value=tags)
4649

4750
if __name__=="__main__":

utils.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -55,16 +55,16 @@ def load_model(self):
5555
model.load_state_dict(torch.load(f="./model/parameters.pth", map_location=torch.device('cpu')))
5656
self.model = model
5757

58-
def predict(self, image):
58+
def predict(self, image, num_of_tags=5):
5959
image_transformed = self.transform(image)
6060
batch_image = torch.unsqueeze(image_transformed, 0)
6161
self.model.eval()
6262

6363
# Inference
6464
output = self.model(batch_image).detach().numpy()
6565
preds = output[0]
66-
top5 = np.sort(preds)[::-1][min(4, len(preds)-1)]
67-
preds[preds < top5] = 0
68-
preds[preds >= top5] = 1
66+
top_n_tags = np.sort(preds)[::-1][min(num_of_tags - 1, len(preds)-1)]
67+
preds[preds < top_n_tags] = 0
68+
preds[preds >= top_n_tags] = 1
6969
tags = self.binarizer.inverse_transform(np.array([preds]))[0][:]
7070
return tags

0 commit comments

Comments
 (0)