Skip to content

Commit 1ba444f

Browse files
Add CSV export to debug_analysis (#1253)
* save as csv * fix flake8 error * default to json export --------- Co-authored-by: Lucas Wilkinson <[email protected]>
1 parent 404e8a4 commit 1ba444f

File tree

1 file changed

+41
-5
lines changed

1 file changed

+41
-5
lines changed

Diff for: src/deepsparse/debug_analysis.py

+41-5
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ def parse_args():
214214
parser.add_argument(
215215
"-x",
216216
"--export_path",
217-
help="Store results into a JSON file",
217+
help="Store results into a JSON or CSV file",
218218
type=str,
219219
default=None,
220220
)
@@ -419,10 +419,46 @@ def main():
419419
print(construct_layer_statistics(result))
420420

421421
if args.export_path:
422-
# Export results
423-
print("Saving analysis results to JSON file at {}".format(args.export_path))
424-
with open(args.export_path, "w") as out:
425-
json.dump(result, out, indent=2)
422+
if ".csv" in args.export_path:
423+
top_level_items_skip = ["iteration_times", "layer_info"]
424+
top_level_items_dict = {
425+
k: v for k, v in result.items() if k not in top_level_items_skip
426+
}
427+
428+
def construct_csv_layer_info(li):
429+
def flatten(parent_k, sub_d):
430+
return {f"{parent_k}_{k}": v for k, v in sub_d.items()}
431+
432+
csv_li = {}
433+
for k, v in li.items():
434+
if k not in ["sub_layer_info"]:
435+
csv_li.update({k: v} if type(v) is not dict else flatten(k, v))
436+
return csv_li
437+
438+
csv_layer_infos = [
439+
{
440+
**top_level_items_dict,
441+
**construct_csv_layer_info(li),
442+
}
443+
for li in result["layer_info"]
444+
]
445+
446+
# Export results
447+
import csv
448+
449+
print("Saving analysis results to CSV file at {}".format(args.export_path))
450+
with open(args.export_path, "w") as out:
451+
writer = csv.DictWriter(
452+
out, fieldnames=csv_layer_infos[0].keys(), extrasaction="ignore"
453+
)
454+
writer.writeheader()
455+
for data in csv_layer_infos:
456+
writer.writerow(data)
457+
else:
458+
# Export results
459+
print("Saving analysis results to JSON file at {}".format(args.export_path))
460+
with open(args.export_path, "w") as out:
461+
json.dump(result, out, indent=2)
426462

427463

428464
if __name__ == "__main__":

0 commit comments

Comments
 (0)