Skip to content

Commit 238249f

Browse files
authored
Merge pull request #11 from malaval/plugin-v2
Plugin v2
2 parents 933542a + b7de0ce commit 238249f

21 files changed

+1614
-905
lines changed

LICENSE

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
1+
Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
22

33
Permission is hereby granted, free of charge, to any person obtaining a copy of this
44
software and associated documentation files (the "Software"), to deal in the Software

README.md

+566-49
Large diffs are not rendered by default.

common.py

+299
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,299 @@
1+
import json
2+
import logging
3+
import os
4+
import re
5+
import subprocess
6+
import sys
7+
8+
import boto3
9+
10+
11+
dir_path = os.path.dirname(os.path.realpath(__file__)) # Folder where resides the Python files
12+
13+
logger = None # Global variable for the logging.Logger object
14+
config = None # Global variable for the config parameters
15+
partitions = None # Global variable that stores partitions details
16+
17+
18+
# Create and return a logging.Logger object
19+
# - scriptname: name of the module
20+
# - levelname: log level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
21+
# - filename: location of the log file
22+
def get_logger(scriptname, levelname, filename):
23+
24+
logger = logging.getLogger(scriptname)
25+
26+
# Update log level
27+
log_levels = {
28+
'DEBUG': logging.DEBUG,
29+
'INFO': logging.INFO,
30+
'WARNING': logging.WARNING,
31+
'ERROR': logging.ERROR,
32+
'CRITICAL': logging.CRITICAL
33+
}
34+
logger.setLevel(log_levels.get(levelname, logging.DEBUG))
35+
36+
# Create a console handler
37+
sh = logging.StreamHandler()
38+
sh_formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
39+
sh.setFormatter(sh_formatter)
40+
logger.addHandler(sh)
41+
42+
# Create a file handler
43+
fh = logging.FileHandler(filename)
44+
fh_formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
45+
fh.setFormatter(fh_formatter)
46+
logger.addHandler(fh)
47+
48+
return logger
49+
50+
51+
# Validate the structure of the config.json file content
52+
# - data: dict loaded from config.json
53+
def validate_config(data):
54+
55+
assert 'LogLevel' in data, 'Missing "LogLevel" in root'
56+
assert data['LogLevel'] in ('CRITICAL', 'ERROR', 'WARNING', 'INFO', 'DEBUG'), 'root["LogLevel"] is an invalid value'
57+
58+
assert 'LogFileName' in data, 'Missing "LogFileName" in root'
59+
60+
assert 'SlurmBinPath' in data, 'Missing "SlurmBinPath" in root'
61+
62+
assert 'SlurmConf' in data, 'Missing "SlurmConf" in root'
63+
slurm_conf = data['SlurmConf']
64+
assert isinstance(slurm_conf, dict), 'root["SlurmConf"] is not a dict'
65+
66+
assert 'PrivateData' in slurm_conf, 'Missing "PrivateData" in root["SlurmConf"]'
67+
assert 'ResumeProgram' in slurm_conf, 'Missing "ResumeProgram" in root["SlurmConf"]'
68+
assert 'SuspendProgram' in slurm_conf, 'Missing "SuspendProgram" in root["SlurmConf"]'
69+
assert 'ResumeRate' in slurm_conf, 'Missing "ResumeRate" in root["SlurmConf"]'
70+
assert 'SuspendRate' in slurm_conf, 'Missing "SuspendRate" in root["SlurmConf"]'
71+
assert 'ResumeTimeout' in slurm_conf, 'Missing "ResumeTimeout" in root["SlurmConf"]'
72+
assert 'SuspendTime' in slurm_conf, 'Missing "SuspendTime" in root["SlurmConf"]'
73+
assert 'TreeWidth' in slurm_conf, 'Missing "TreeWidth" in root["SlurmConf"]'
74+
75+
76+
# Validate the structure of the partitions.json file content
77+
# - data: dict loaded from partitions.json
78+
def validate_partitions(data):
79+
80+
assert 'Partitions' in data, 'Missing "Partitions" in root'
81+
assert isinstance(data['Partitions'], list), 'root["Partitions"] is not an array'
82+
83+
for i_partition, partition in enumerate(data['Partitions']):
84+
assert 'PartitionName' in partition, 'Missing "PartitionName" in root["Partitions"][%s]' %i_partition
85+
assert re.match('^[a-zA-Z0-9_]+$', partition['PartitionName']), 'root["Partitions"][%s]["PartitionName"] does not match ^[a-zA-Z0-9-]+$' %i_partition
86+
87+
assert 'NodeGroups' in partition, 'Missing "NodeGroups" in root["Partitions"][%s]' %i_partition
88+
assert isinstance(partition['NodeGroups'], list), 'root["Partitions"][%s]["NodeGroups"] is not an array' %i_partition
89+
90+
for i_nodegroup, nodegroup in enumerate(partition['NodeGroups']):
91+
assert 'NodeGroupName' in nodegroup, 'Missing "NodeGroupName" in root["Partitions"][%s]["NodeGroups"][%s]' %(i_partition, i_nodegroup)
92+
assert re.match('^[a-zA-Z0-9_]+$', nodegroup['NodeGroupName']), 'root["Partitions"][%s]["NodeGroups"][%s]["NodeGroupName"] does not match ^[a-zA-Z0-9-]+$' %(i_partition, i_nodegroup)
93+
94+
assert 'MaxNodes' in nodegroup, 'Missing "MaxNodes" in root["Partitions"][%s]["NodeGroups"][%s]' %(i_partition, i_nodegroup)
95+
assert isinstance(nodegroup['MaxNodes'], int), 'root["Partitions"][%s]["NodeGroups"][%s]["MaxNodes"] is not a number' %(i_partition, i_nodegroup)
96+
97+
assert 'Region' in nodegroup, 'Missing "Region" in root["Partitions"][%s]["NodeGroups"][%s]' %(i_partition, i_nodegroup)
98+
99+
assert 'SlurmSpecifications' in nodegroup, 'Missing "SlurmSpecifications" in root["Partitions"][%s]["NodeGroups"][%s]' %(i_partition, i_nodegroup)
100+
assert isinstance(nodegroup['SlurmSpecifications'], dict), 'root["Partitions"][%s]["NodeGroups"][%s]["SlurmSpecifications"] is not a dict' %(i_partition, i_nodegroup)
101+
102+
assert 'PurchasingOption' in nodegroup, 'Missing "PurchasingOption" in root["Partitions"][%s]["NodeGroups"][%s]' %(i_partition, i_nodegroup)
103+
assert nodegroup['PurchasingOption'] in ('spot', 'on-demand'), 'root["Partitions"][%s]["NodeGroups"][%s]["PurchasingOption"] must be spot or on-demand' %(i_partition, i_nodegroup)
104+
105+
assert 'LaunchTemplateSpecification' in nodegroup, 'Missing "LaunchTemplateSpecification" in root["Partitions"][%s]["NodeGroups"][%s]' %(i_partition, i_nodegroup)
106+
assert isinstance(nodegroup['LaunchTemplateSpecification'], dict), 'root["Partitions"][%s]["NodeGroups"][%s]["LaunchTemplateSpecification"] is not a dict' %(i_partition, i_nodegroup)
107+
108+
assert 'LaunchTemplateOverrides' in nodegroup, 'Missing "LaunchTemplateOverrides" in root["Partitions"][%s]["NodeGroups"][%s]' %(i_partition, i_nodegroup)
109+
assert isinstance(nodegroup['LaunchTemplateOverrides'], list), 'root["Partitions"][%s]["NodeGroups"][%s]["LaunchTemplateOverrides"] is not a dict' %(i_partition, i_nodegroup)
110+
111+
assert 'SubnetIds' in nodegroup, 'Missing "SubnetIds" in root["Partitions"][%s]["NodeGroups"][%s]' %(i_partition, i_nodegroup)
112+
assert isinstance(nodegroup['SubnetIds'], list), 'root["Partitions"][%s]["NodeGroups"][%s]["SubnetIds"] is not a dict' %(i_partition, i_nodegroup)
113+
114+
115+
# Create and return logger, config, and partitions variables
116+
def get_common(scriptname):
117+
118+
global logger
119+
global config
120+
global partitions
121+
122+
# Load configuration parameters from ./config.json and merge with default values
123+
try:
124+
config_filename = '%s/config.json' %dir_path
125+
with open(config_filename, 'r') as f:
126+
config = json.load(f)
127+
except Exception as e:
128+
config = {'JsonLoadError': str(e)}
129+
130+
# Populate default values if unspecified
131+
if not 'LogFileName' in config:
132+
config['LogFileName'] = '%s/aws_plugin.log' %dir_path
133+
if not 'LogLevel' in config:
134+
config['LogLevel'] = 'DEBUG'
135+
136+
# Make sure that SlurmBinPath ends with a /
137+
if 'SlurmBinPath' in config and not config['SlurmBinPath'].endswith('/'):
138+
config['SlurmBinPath'] += '/'
139+
140+
# Create a logger
141+
logger = get_logger(scriptname, config['LogLevel'], config['LogFileName'])
142+
logger.debug('Config: %s' %json.dumps(config, indent=4))
143+
144+
# Validate the structure of config.json
145+
if 'JsonLoadError' in config:
146+
logger.critical('Failed to load %s - %s' %(config['LogFileName'], config['JsonLoadError']))
147+
sys.exit(1)
148+
try:
149+
validate_config(config)
150+
except Exception as e:
151+
logger.critical('File config.json is invalid - %s' %e)
152+
sys.exit(1)
153+
154+
# Load partitions details from ./partitions.json
155+
partitions_filename = '%s/partitions.json' %dir_path
156+
try:
157+
with open(partitions_filename, 'r') as f:
158+
partitions_json = json.load(f)
159+
except Exception as e:
160+
logger.critical('Failed to load %s - %s' %(partitions_filename, e))
161+
sys.exit(1)
162+
163+
# Validate the structure of partitions.json
164+
try:
165+
validate_partitions(partitions_json)
166+
except Exception as e:
167+
logger.critical('File partition.json is invalid - %s' %e)
168+
sys.exit(1)
169+
finally:
170+
partitions = partitions_json['Partitions']
171+
logger.debug('Partitions: %s' %json.dumps(partitions_json, indent=4))
172+
173+
return logger, config, partitions
174+
175+
176+
# Return the name of a node [partition_name]-[nodegroup_name][id]
177+
# - partition: can either be a string, or a dict with dict['PartitionName'] = partition_name
178+
# - nodegroup: can either be a string, or a dict with dict['NodeGroupName'] = nodegroup_name
179+
# - id: optional id
180+
def get_node_name(partition, nodegroup, node_id=''):
181+
182+
if isinstance(partition, dict):
183+
partition_name = partition['PartitionName']
184+
else:
185+
partition_name = partition
186+
187+
if isinstance(nodegroup, dict):
188+
nodegroup_name = nodegroup['NodeGroupName']
189+
else:
190+
nodegroup_name = nodegroup
191+
192+
return '%s-%s%s' %(partition_name, nodegroup_name, node_id)
193+
194+
195+
# Return the name of a node [partition_name]-[nodegroup_name][id]
196+
# - partition: can either be a string, or a dict with dict['PartitionName'] = partition_name
197+
# - nodegroup: can either be a string, or a dict with dict['NodeGroupName'] = nodegroup_name
198+
# - nb_nodes: optional number of nodes
199+
def get_node_range(partition, nodegroup, nb_nodes=None):
200+
201+
if nb_nodes is None:
202+
nb_nodes = nodegroup['MaxNodes']
203+
204+
if nb_nodes > 1:
205+
return '%s[0-%s]' %(get_node_name(partition, nodegroup), nb_nodes-1)
206+
else:
207+
return '%s0' %(get_node_name(partition, nodegroup))
208+
209+
210+
# Run scontrol and return output
211+
# - command: name of the command such as scontrol
212+
# - arguments: array
213+
def run_scommand(command, arguments):
214+
215+
scommand_path = '%s%s' %(config['SlurmBinPath'], command)
216+
cmd = [scommand_path] + arguments
217+
logger.debug('Command %s: %s' %(command, ' '.join(cmd)))
218+
proc = subprocess.Popen(cmd, stdout=subprocess.PIPE)
219+
lines = proc.communicate()[0].splitlines()
220+
return [line.decode() for line in lines]
221+
222+
223+
# Use 'scontrol show hostnames' to expand the hostlist and return a list of node names
224+
# - hostlist: argument passed to SuspendProgram or ResumeProgram
225+
def expand_hostlist(hostlist):
226+
227+
try:
228+
arguments = ['show', 'hostnames', hostlist]
229+
return run_scommand('scontrol', arguments)
230+
except Exception as e:
231+
logger.critical('Failed to expand hostlist - %s' %e)
232+
sys.exit(1)
233+
234+
235+
# Take a list of node names in input and return a dict with result[partition_name][nodegroup_name] = list of node ids
236+
def parse_node_names(node_names):
237+
result = {}
238+
for node_name in node_names:
239+
240+
# For each node: extract partition name, node group name and node id
241+
pattern = '^([a-zA-Z0-9_]+)-([a-zA-Z0-9_]+)([0-9]+)$'
242+
match = re.match(pattern, node_name)
243+
if match:
244+
partition_name, nodegroup_name, node_id = match.groups()
245+
246+
# Add to result
247+
if not partition_name in result:
248+
result[partition_name] = {}
249+
if not nodegroup_name in result[partition_name]:
250+
result[partition_name][nodegroup_name] = []
251+
result[partition_name][nodegroup_name].append(node_id)
252+
253+
return result
254+
255+
256+
# Return a pointer in partitions to a specific partition and node group
257+
def get_partition_nodegroup(partition_name, nodegroup_name):
258+
259+
for partition in partitions:
260+
if partition['PartitionName'] == partition_name:
261+
for nodegroup in partition['NodeGroups']:
262+
if nodegroup['NodeGroupName'] == nodegroup_name:
263+
return nodegroup
264+
265+
# Return None if it does not exist
266+
return None
267+
268+
269+
# Use 'scontrol update node' to update nodes
270+
def update_node(node_name, parameters):
271+
272+
parameters_split = parameters.split(' ')
273+
arguments = ['update', 'nodename=%s' %node_name] + parameters_split
274+
run_scommand('scontrol', arguments)
275+
276+
277+
# Call sinfo and return node status for a list of nodes
278+
def get_node_state(hostlist):
279+
280+
try:
281+
cmd = [scontrol_path, '-n', ','.join(hostlist), '-N', '-o', '"%N %t"']
282+
return run_scommand('sinfo', arguments)
283+
except Exception as e:
284+
logger.critical('Failed to retrieve node state - %s' %e)
285+
sys.exit(1)
286+
287+
288+
# Return boto3 client
289+
def get_ec2_client(nodegroup):
290+
291+
if 'ProfileName' in nodegroup:
292+
try:
293+
session = boto3.session.Session(region_name=nodegroup['Region'], profile_name=nodegroup['ProfileName'])
294+
return session.client('ec2')
295+
except Exception as e:
296+
logger.critical('Failed to create a EC2 client - %s' %e)
297+
sys.exit(1)
298+
else:
299+
return boto3.client('ec2', region_name=nodegroup['Region'])

generate_conf.py

+36
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
#!/usr/bin/python3
2+
import common
3+
4+
5+
logger, config, partitions = common.get_common('generate_conf')
6+
7+
filename = 'slurm.conf.aws'
8+
9+
# This script generates a file to append to slurm.conf
10+
with open(filename, 'w') as f:
11+
12+
# Write Slurm configuration parameters
13+
for item, value in config['SlurmConf'].items():
14+
f.write('%s=%s\n' %(item, value))
15+
f.write('\n')
16+
17+
for partition in partitions:
18+
partition_nodes = ()
19+
20+
for nodegroup in partition['NodeGroups']:
21+
nodes = common.get_node_range(partition, nodegroup)
22+
partition_nodes += nodes,
23+
24+
nodegroup_specs = ()
25+
for key, value in nodegroup['SlurmSpecifications'].items():
26+
nodegroup_specs += '%s=%s' %(key, value),
27+
28+
# Write a line for each node group
29+
line = 'NodeName=%s State=CLOUD %s' %(nodes, ' '.join(nodegroup_specs))
30+
f.write('%s\n' %line)
31+
32+
# Write a line for each partition
33+
line = 'PartitionName=%s Nodes=%s Default=No MaxTime=INFINITE State=UP' %(partition['PartitionName'], ','.join(partition_nodes))
34+
f.write('%s\n\n' %line)
35+
36+
logger.info('Output file: %s' %filename)

gres.conf

-1
This file was deleted.

imgs/slurm-burst.png

-101 KB
Binary file not shown.

imgs/slurm-cf.png

-67.3 KB
Binary file not shown.

imgs/slurm-submit.gif

-24.7 MB
Binary file not shown.

lambda/hpc_worker.zip

-14.7 MB
Binary file not shown.

lambda/test.json

-13
This file was deleted.

0 commit comments

Comments
 (0)