|
| 1 | +import json |
| 2 | +import logging |
| 3 | +import os |
| 4 | +import re |
| 5 | +import subprocess |
| 6 | +import sys |
| 7 | + |
| 8 | +import boto3 |
| 9 | + |
| 10 | + |
| 11 | +dir_path = os.path.dirname(os.path.realpath(__file__)) # Folder where resides the Python files |
| 12 | + |
| 13 | +logger = None # Global variable for the logging.Logger object |
| 14 | +config = None # Global variable for the config parameters |
| 15 | +partitions = None # Global variable that stores partitions details |
| 16 | + |
| 17 | + |
| 18 | +# Create and return a logging.Logger object |
| 19 | +# - scriptname: name of the module |
| 20 | +# - levelname: log level (DEBUG, INFO, WARNING, ERROR, CRITICAL) |
| 21 | +# - filename: location of the log file |
| 22 | +def get_logger(scriptname, levelname, filename): |
| 23 | + |
| 24 | + logger = logging.getLogger(scriptname) |
| 25 | + |
| 26 | + # Update log level |
| 27 | + log_levels = { |
| 28 | + 'DEBUG': logging.DEBUG, |
| 29 | + 'INFO': logging.INFO, |
| 30 | + 'WARNING': logging.WARNING, |
| 31 | + 'ERROR': logging.ERROR, |
| 32 | + 'CRITICAL': logging.CRITICAL |
| 33 | + } |
| 34 | + logger.setLevel(log_levels.get(levelname, logging.DEBUG)) |
| 35 | + |
| 36 | + # Create a console handler |
| 37 | + sh = logging.StreamHandler() |
| 38 | + sh_formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') |
| 39 | + sh.setFormatter(sh_formatter) |
| 40 | + logger.addHandler(sh) |
| 41 | + |
| 42 | + # Create a file handler |
| 43 | + fh = logging.FileHandler(filename) |
| 44 | + fh_formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') |
| 45 | + fh.setFormatter(fh_formatter) |
| 46 | + logger.addHandler(fh) |
| 47 | + |
| 48 | + return logger |
| 49 | + |
| 50 | + |
| 51 | +# Validate the structure of the config.json file content |
| 52 | +# - data: dict loaded from config.json |
| 53 | +def validate_config(data): |
| 54 | + |
| 55 | + assert 'LogLevel' in data, 'Missing "LogLevel" in root' |
| 56 | + assert data['LogLevel'] in ('CRITICAL', 'ERROR', 'WARNING', 'INFO', 'DEBUG'), 'root["LogLevel"] is an invalid value' |
| 57 | + |
| 58 | + assert 'LogFileName' in data, 'Missing "LogFileName" in root' |
| 59 | + |
| 60 | + assert 'SlurmBinPath' in data, 'Missing "SlurmBinPath" in root' |
| 61 | + |
| 62 | + assert 'SlurmConf' in data, 'Missing "SlurmConf" in root' |
| 63 | + slurm_conf = data['SlurmConf'] |
| 64 | + assert isinstance(slurm_conf, dict), 'root["SlurmConf"] is not a dict' |
| 65 | + |
| 66 | + assert 'PrivateData' in slurm_conf, 'Missing "PrivateData" in root["SlurmConf"]' |
| 67 | + assert 'ResumeProgram' in slurm_conf, 'Missing "ResumeProgram" in root["SlurmConf"]' |
| 68 | + assert 'SuspendProgram' in slurm_conf, 'Missing "SuspendProgram" in root["SlurmConf"]' |
| 69 | + assert 'ResumeRate' in slurm_conf, 'Missing "ResumeRate" in root["SlurmConf"]' |
| 70 | + assert 'SuspendRate' in slurm_conf, 'Missing "SuspendRate" in root["SlurmConf"]' |
| 71 | + assert 'ResumeTimeout' in slurm_conf, 'Missing "ResumeTimeout" in root["SlurmConf"]' |
| 72 | + assert 'SuspendTime' in slurm_conf, 'Missing "SuspendTime" in root["SlurmConf"]' |
| 73 | + assert 'TreeWidth' in slurm_conf, 'Missing "TreeWidth" in root["SlurmConf"]' |
| 74 | + |
| 75 | + |
| 76 | +# Validate the structure of the partitions.json file content |
| 77 | +# - data: dict loaded from partitions.json |
| 78 | +def validate_partitions(data): |
| 79 | + |
| 80 | + assert 'Partitions' in data, 'Missing "Partitions" in root' |
| 81 | + assert isinstance(data['Partitions'], list), 'root["Partitions"] is not an array' |
| 82 | + |
| 83 | + for i_partition, partition in enumerate(data['Partitions']): |
| 84 | + assert 'PartitionName' in partition, 'Missing "PartitionName" in root["Partitions"][%s]' %i_partition |
| 85 | + assert re.match('^[a-zA-Z0-9_]+$', partition['PartitionName']), 'root["Partitions"][%s]["PartitionName"] does not match ^[a-zA-Z0-9-]+$' %i_partition |
| 86 | + |
| 87 | + assert 'NodeGroups' in partition, 'Missing "NodeGroups" in root["Partitions"][%s]' %i_partition |
| 88 | + assert isinstance(partition['NodeGroups'], list), 'root["Partitions"][%s]["NodeGroups"] is not an array' %i_partition |
| 89 | + |
| 90 | + for i_nodegroup, nodegroup in enumerate(partition['NodeGroups']): |
| 91 | + assert 'NodeGroupName' in nodegroup, 'Missing "NodeGroupName" in root["Partitions"][%s]["NodeGroups"][%s]' %(i_partition, i_nodegroup) |
| 92 | + assert re.match('^[a-zA-Z0-9_]+$', nodegroup['NodeGroupName']), 'root["Partitions"][%s]["NodeGroups"][%s]["NodeGroupName"] does not match ^[a-zA-Z0-9-]+$' %(i_partition, i_nodegroup) |
| 93 | + |
| 94 | + assert 'MaxNodes' in nodegroup, 'Missing "MaxNodes" in root["Partitions"][%s]["NodeGroups"][%s]' %(i_partition, i_nodegroup) |
| 95 | + assert isinstance(nodegroup['MaxNodes'], int), 'root["Partitions"][%s]["NodeGroups"][%s]["MaxNodes"] is not a number' %(i_partition, i_nodegroup) |
| 96 | + |
| 97 | + assert 'Region' in nodegroup, 'Missing "Region" in root["Partitions"][%s]["NodeGroups"][%s]' %(i_partition, i_nodegroup) |
| 98 | + |
| 99 | + assert 'SlurmSpecifications' in nodegroup, 'Missing "SlurmSpecifications" in root["Partitions"][%s]["NodeGroups"][%s]' %(i_partition, i_nodegroup) |
| 100 | + assert isinstance(nodegroup['SlurmSpecifications'], dict), 'root["Partitions"][%s]["NodeGroups"][%s]["SlurmSpecifications"] is not a dict' %(i_partition, i_nodegroup) |
| 101 | + |
| 102 | + assert 'PurchasingOption' in nodegroup, 'Missing "PurchasingOption" in root["Partitions"][%s]["NodeGroups"][%s]' %(i_partition, i_nodegroup) |
| 103 | + assert nodegroup['PurchasingOption'] in ('spot', 'on-demand'), 'root["Partitions"][%s]["NodeGroups"][%s]["PurchasingOption"] must be spot or on-demand' %(i_partition, i_nodegroup) |
| 104 | + |
| 105 | + assert 'LaunchTemplateSpecification' in nodegroup, 'Missing "LaunchTemplateSpecification" in root["Partitions"][%s]["NodeGroups"][%s]' %(i_partition, i_nodegroup) |
| 106 | + assert isinstance(nodegroup['LaunchTemplateSpecification'], dict), 'root["Partitions"][%s]["NodeGroups"][%s]["LaunchTemplateSpecification"] is not a dict' %(i_partition, i_nodegroup) |
| 107 | + |
| 108 | + assert 'LaunchTemplateOverrides' in nodegroup, 'Missing "LaunchTemplateOverrides" in root["Partitions"][%s]["NodeGroups"][%s]' %(i_partition, i_nodegroup) |
| 109 | + assert isinstance(nodegroup['LaunchTemplateOverrides'], list), 'root["Partitions"][%s]["NodeGroups"][%s]["LaunchTemplateOverrides"] is not a dict' %(i_partition, i_nodegroup) |
| 110 | + |
| 111 | + assert 'SubnetIds' in nodegroup, 'Missing "SubnetIds" in root["Partitions"][%s]["NodeGroups"][%s]' %(i_partition, i_nodegroup) |
| 112 | + assert isinstance(nodegroup['SubnetIds'], list), 'root["Partitions"][%s]["NodeGroups"][%s]["SubnetIds"] is not a dict' %(i_partition, i_nodegroup) |
| 113 | + |
| 114 | + |
| 115 | +# Create and return logger, config, and partitions variables |
| 116 | +def get_common(scriptname): |
| 117 | + |
| 118 | + global logger |
| 119 | + global config |
| 120 | + global partitions |
| 121 | + |
| 122 | + # Load configuration parameters from ./config.json and merge with default values |
| 123 | + try: |
| 124 | + config_filename = '%s/config.json' %dir_path |
| 125 | + with open(config_filename, 'r') as f: |
| 126 | + config = json.load(f) |
| 127 | + except Exception as e: |
| 128 | + config = {'JsonLoadError': str(e)} |
| 129 | + |
| 130 | + # Populate default values if unspecified |
| 131 | + if not 'LogFileName' in config: |
| 132 | + config['LogFileName'] = '%s/aws_plugin.log' %dir_path |
| 133 | + if not 'LogLevel' in config: |
| 134 | + config['LogLevel'] = 'DEBUG' |
| 135 | + |
| 136 | + # Make sure that SlurmBinPath ends with a / |
| 137 | + if 'SlurmBinPath' in config and not config['SlurmBinPath'].endswith('/'): |
| 138 | + config['SlurmBinPath'] += '/' |
| 139 | + |
| 140 | + # Create a logger |
| 141 | + logger = get_logger(scriptname, config['LogLevel'], config['LogFileName']) |
| 142 | + logger.debug('Config: %s' %json.dumps(config, indent=4)) |
| 143 | + |
| 144 | + # Validate the structure of config.json |
| 145 | + if 'JsonLoadError' in config: |
| 146 | + logger.critical('Failed to load %s - %s' %(config['LogFileName'], config['JsonLoadError'])) |
| 147 | + sys.exit(1) |
| 148 | + try: |
| 149 | + validate_config(config) |
| 150 | + except Exception as e: |
| 151 | + logger.critical('File config.json is invalid - %s' %e) |
| 152 | + sys.exit(1) |
| 153 | + |
| 154 | + # Load partitions details from ./partitions.json |
| 155 | + partitions_filename = '%s/partitions.json' %dir_path |
| 156 | + try: |
| 157 | + with open(partitions_filename, 'r') as f: |
| 158 | + partitions_json = json.load(f) |
| 159 | + except Exception as e: |
| 160 | + logger.critical('Failed to load %s - %s' %(partitions_filename, e)) |
| 161 | + sys.exit(1) |
| 162 | + |
| 163 | + # Validate the structure of partitions.json |
| 164 | + try: |
| 165 | + validate_partitions(partitions_json) |
| 166 | + except Exception as e: |
| 167 | + logger.critical('File partition.json is invalid - %s' %e) |
| 168 | + sys.exit(1) |
| 169 | + finally: |
| 170 | + partitions = partitions_json['Partitions'] |
| 171 | + logger.debug('Partitions: %s' %json.dumps(partitions_json, indent=4)) |
| 172 | + |
| 173 | + return logger, config, partitions |
| 174 | + |
| 175 | + |
| 176 | +# Return the name of a node [partition_name]-[nodegroup_name][id] |
| 177 | +# - partition: can either be a string, or a dict with dict['PartitionName'] = partition_name |
| 178 | +# - nodegroup: can either be a string, or a dict with dict['NodeGroupName'] = nodegroup_name |
| 179 | +# - id: optional id |
| 180 | +def get_node_name(partition, nodegroup, node_id=''): |
| 181 | + |
| 182 | + if isinstance(partition, dict): |
| 183 | + partition_name = partition['PartitionName'] |
| 184 | + else: |
| 185 | + partition_name = partition |
| 186 | + |
| 187 | + if isinstance(nodegroup, dict): |
| 188 | + nodegroup_name = nodegroup['NodeGroupName'] |
| 189 | + else: |
| 190 | + nodegroup_name = nodegroup |
| 191 | + |
| 192 | + return '%s-%s%s' %(partition_name, nodegroup_name, node_id) |
| 193 | + |
| 194 | + |
| 195 | +# Return the name of a node [partition_name]-[nodegroup_name][id] |
| 196 | +# - partition: can either be a string, or a dict with dict['PartitionName'] = partition_name |
| 197 | +# - nodegroup: can either be a string, or a dict with dict['NodeGroupName'] = nodegroup_name |
| 198 | +# - nb_nodes: optional number of nodes |
| 199 | +def get_node_range(partition, nodegroup, nb_nodes=None): |
| 200 | + |
| 201 | + if nb_nodes is None: |
| 202 | + nb_nodes = nodegroup['MaxNodes'] |
| 203 | + |
| 204 | + if nb_nodes > 1: |
| 205 | + return '%s[0-%s]' %(get_node_name(partition, nodegroup), nb_nodes-1) |
| 206 | + else: |
| 207 | + return '%s0' %(get_node_name(partition, nodegroup)) |
| 208 | + |
| 209 | + |
| 210 | +# Run scontrol and return output |
| 211 | +# - command: name of the command such as scontrol |
| 212 | +# - arguments: array |
| 213 | +def run_scommand(command, arguments): |
| 214 | + |
| 215 | + scommand_path = '%s%s' %(config['SlurmBinPath'], command) |
| 216 | + cmd = [scommand_path] + arguments |
| 217 | + logger.debug('Command %s: %s' %(command, ' '.join(cmd))) |
| 218 | + proc = subprocess.Popen(cmd, stdout=subprocess.PIPE) |
| 219 | + lines = proc.communicate()[0].splitlines() |
| 220 | + return [line.decode() for line in lines] |
| 221 | + |
| 222 | + |
| 223 | +# Use 'scontrol show hostnames' to expand the hostlist and return a list of node names |
| 224 | +# - hostlist: argument passed to SuspendProgram or ResumeProgram |
| 225 | +def expand_hostlist(hostlist): |
| 226 | + |
| 227 | + try: |
| 228 | + arguments = ['show', 'hostnames', hostlist] |
| 229 | + return run_scommand('scontrol', arguments) |
| 230 | + except Exception as e: |
| 231 | + logger.critical('Failed to expand hostlist - %s' %e) |
| 232 | + sys.exit(1) |
| 233 | + |
| 234 | + |
| 235 | +# Take a list of node names in input and return a dict with result[partition_name][nodegroup_name] = list of node ids |
| 236 | +def parse_node_names(node_names): |
| 237 | + result = {} |
| 238 | + for node_name in node_names: |
| 239 | + |
| 240 | + # For each node: extract partition name, node group name and node id |
| 241 | + pattern = '^([a-zA-Z0-9_]+)-([a-zA-Z0-9_]+)([0-9]+)$' |
| 242 | + match = re.match(pattern, node_name) |
| 243 | + if match: |
| 244 | + partition_name, nodegroup_name, node_id = match.groups() |
| 245 | + |
| 246 | + # Add to result |
| 247 | + if not partition_name in result: |
| 248 | + result[partition_name] = {} |
| 249 | + if not nodegroup_name in result[partition_name]: |
| 250 | + result[partition_name][nodegroup_name] = [] |
| 251 | + result[partition_name][nodegroup_name].append(node_id) |
| 252 | + |
| 253 | + return result |
| 254 | + |
| 255 | + |
| 256 | +# Return a pointer in partitions to a specific partition and node group |
| 257 | +def get_partition_nodegroup(partition_name, nodegroup_name): |
| 258 | + |
| 259 | + for partition in partitions: |
| 260 | + if partition['PartitionName'] == partition_name: |
| 261 | + for nodegroup in partition['NodeGroups']: |
| 262 | + if nodegroup['NodeGroupName'] == nodegroup_name: |
| 263 | + return nodegroup |
| 264 | + |
| 265 | + # Return None if it does not exist |
| 266 | + return None |
| 267 | + |
| 268 | + |
| 269 | +# Use 'scontrol update node' to update nodes |
| 270 | +def update_node(node_name, parameters): |
| 271 | + |
| 272 | + parameters_split = parameters.split(' ') |
| 273 | + arguments = ['update', 'nodename=%s' %node_name] + parameters_split |
| 274 | + run_scommand('scontrol', arguments) |
| 275 | + |
| 276 | + |
| 277 | +# Call sinfo and return node status for a list of nodes |
| 278 | +def get_node_state(hostlist): |
| 279 | + |
| 280 | + try: |
| 281 | + cmd = [scontrol_path, '-n', ','.join(hostlist), '-N', '-o', '"%N %t"'] |
| 282 | + return run_scommand('sinfo', arguments) |
| 283 | + except Exception as e: |
| 284 | + logger.critical('Failed to retrieve node state - %s' %e) |
| 285 | + sys.exit(1) |
| 286 | + |
| 287 | + |
| 288 | +# Return boto3 client |
| 289 | +def get_ec2_client(nodegroup): |
| 290 | + |
| 291 | + if 'ProfileName' in nodegroup: |
| 292 | + try: |
| 293 | + session = boto3.session.Session(region_name=nodegroup['Region'], profile_name=nodegroup['ProfileName']) |
| 294 | + return session.client('ec2') |
| 295 | + except Exception as e: |
| 296 | + logger.critical('Failed to create a EC2 client - %s' %e) |
| 297 | + sys.exit(1) |
| 298 | + else: |
| 299 | + return boto3.client('ec2', region_name=nodegroup['Region']) |
0 commit comments