|
1 | 1 | import logging
|
| 2 | +import os |
| 3 | +import re |
2 | 4 | import subprocess
|
3 | 5 | import tempfile
|
4 | 6 | from contextlib import contextmanager
|
| 7 | +from pathlib import Path |
5 | 8 | from urllib.parse import quote, urlunparse, urlparse
|
6 | 9 |
|
| 10 | +import boto3 |
| 11 | +from botocore.config import Config as Boto3Config |
7 | 12 | from django import forms
|
8 | 13 | from django.conf import settings
|
9 | 14 | from django.utils.translation import gettext as _
|
@@ -115,3 +120,70 @@ def fetch(self):
|
115 | 120 | yield local_path.name
|
116 | 121 |
|
117 | 122 | local_path.cleanup()
|
| 123 | + |
| 124 | + |
| 125 | +@register_backend(DataSourceTypeChoices.AMAZON_S3) |
| 126 | +class S3Backend(DataBackend): |
| 127 | + parameters = { |
| 128 | + 'aws_access_key_id': forms.CharField( |
| 129 | + label=_('AWS access key ID'), |
| 130 | + widget=forms.TextInput(attrs={'class': 'form-control'}) |
| 131 | + ), |
| 132 | + 'aws_secret_access_key': forms.CharField( |
| 133 | + label=_('AWS secret access key'), |
| 134 | + widget=forms.TextInput(attrs={'class': 'form-control'}) |
| 135 | + ), |
| 136 | + } |
| 137 | + |
| 138 | + REGION_REGEX = r's3\.([a-z0-9-]+)\.amazonaws\.com' |
| 139 | + |
| 140 | + @contextmanager |
| 141 | + def fetch(self): |
| 142 | + local_path = tempfile.TemporaryDirectory() |
| 143 | + |
| 144 | + # Build the S3 configuration |
| 145 | + s3_config = Boto3Config( |
| 146 | + proxies=settings.HTTP_PROXIES, |
| 147 | + ) |
| 148 | + |
| 149 | + # Initialize the S3 resource and bucket |
| 150 | + aws_access_key_id = self.params.get('aws_access_key_id') |
| 151 | + aws_secret_access_key = self.params.get('aws_secret_access_key') |
| 152 | + s3 = boto3.resource( |
| 153 | + 's3', |
| 154 | + region_name=self._region_name, |
| 155 | + aws_access_key_id=aws_access_key_id, |
| 156 | + aws_secret_access_key=aws_secret_access_key, |
| 157 | + config=s3_config |
| 158 | + ) |
| 159 | + bucket = s3.Bucket(self._bucket_name) |
| 160 | + |
| 161 | + # Download all files within the specified path |
| 162 | + for obj in bucket.objects.filter(Prefix=self._remote_path): |
| 163 | + local_filename = os.path.join(local_path.name, obj.key) |
| 164 | + # Build local path |
| 165 | + Path(os.path.dirname(local_filename)).mkdir(parents=True, exist_ok=True) |
| 166 | + bucket.download_file(obj.key, local_filename) |
| 167 | + |
| 168 | + yield local_path.name |
| 169 | + |
| 170 | + local_path.cleanup() |
| 171 | + |
| 172 | + @property |
| 173 | + def _region_name(self): |
| 174 | + domain = urlparse(self.url).netloc |
| 175 | + if m := re.match(self.REGION_REGEX, domain): |
| 176 | + return m.group(1) |
| 177 | + return None |
| 178 | + |
| 179 | + @property |
| 180 | + def _bucket_name(self): |
| 181 | + url_path = urlparse(self.url).path.lstrip('/') |
| 182 | + return url_path.split('/')[0] |
| 183 | + |
| 184 | + @property |
| 185 | + def _remote_path(self): |
| 186 | + url_path = urlparse(self.url).path.lstrip('/') |
| 187 | + if '/' in url_path: |
| 188 | + return url_path.split('/', 1)[1] |
| 189 | + return '' |
0 commit comments