1
1
#include " DNSServer.h"
2
2
#include < lwip/def.h>
3
3
#include < Arduino.h>
4
+ #include < memory>
4
5
5
6
#ifdef DEBUG_ESP_PORT
6
7
#define DEBUG_OUTPUT DEBUG_ESP_PORT
7
8
#else
8
9
#define DEBUG_OUTPUT Serial
9
10
#endif
10
11
12
+ #define DNS_HEADER_SIZE sizeof (DNSHeader)
13
+
11
14
DNSServer::DNSServer ()
12
15
{
13
16
_ttl = lwip_htonl (60 );
@@ -50,108 +53,154 @@ void DNSServer::downcaseAndRemoveWwwPrefix(String &domainName)
50
53
domainName.remove (0 , 4 );
51
54
}
52
55
53
- void DNSServer::processNextRequest ( )
56
+ void DNSServer::respondToRequest ( uint8_t *buffer, size_t length )
54
57
{
55
- size_t packetSize = _udp.parsePacket ();
58
+ DNSHeader *dnsHeader;
59
+ uint8_t *query, *start;
60
+ const char *matchString;
61
+ size_t remaining, labelLength, queryLength;
62
+ uint16_t qtype, qclass;
63
+
64
+ dnsHeader = (DNSHeader *)buffer;
56
65
57
- if (packetSize >= sizeof (DNSHeader))
58
- {
59
- uint8_t * buffer = reinterpret_cast <uint8_t *>(malloc (packetSize));
60
- if (buffer == NULL ) return ;
66
+ // Must be a query for us to do anything with it
67
+ if (dnsHeader->QR != DNS_QR_QUERY)
68
+ return ;
61
69
62
- _udp.read (buffer, packetSize);
70
+ // If operation is anything other than query, we don't do it
71
+ if (dnsHeader->OPCode != DNS_OPCODE_QUERY)
72
+ return replyWithError (dnsHeader, DNSReplyCode::NotImplemented);
73
+
74
+ // Only support requests containing single queries - everything else
75
+ // is badly defined
76
+ if (dnsHeader->QDCount != lwip_htons (1 ))
77
+ return replyWithError (dnsHeader, DNSReplyCode::FormError);
78
+
79
+ // We must return a FormError in the case of a non-zero ARCount to
80
+ // be minimally compatible with EDNS resolvers
81
+ if (dnsHeader->ANCount != 0 || dnsHeader->NSCount != 0
82
+ || dnsHeader->ARCount != 0 )
83
+ return replyWithError (dnsHeader, DNSReplyCode::FormError);
84
+
85
+ // Even if we're not going to use the query, we need to parse it
86
+ // so we can check the address type that's being queried
87
+
88
+ query = start = buffer + DNS_HEADER_SIZE;
89
+ remaining = length - DNS_HEADER_SIZE;
90
+ while (remaining != 0 && *start != 0 ) {
91
+ labelLength = *start;
92
+ if (labelLength + 1 > remaining)
93
+ return replyWithError (dnsHeader, DNSReplyCode::FormError);
94
+ remaining -= (labelLength + 1 );
95
+ start += (labelLength + 1 );
96
+ }
63
97
64
- DNSHeader* dnsHeader = reinterpret_cast <DNSHeader*>(buffer);
98
+ // 1 octet labelLength, 2 octet qtype, 2 octet qclass
99
+ if (remaining < 5 )
100
+ return replyWithError (dnsHeader, DNSReplyCode::FormError);
65
101
66
- if (dnsHeader->QR == DNS_QR_QUERY &&
67
- dnsHeader->OPCode == DNS_OPCODE_QUERY &&
68
- requestIncludesOnlyOneQuestion (dnsHeader) &&
69
- (_domainName == " *" || getDomainNameWithoutWwwPrefix (buffer, packetSize) == _domainName)
70
- )
71
- {
72
- replyWithIP (buffer, packetSize);
73
- }
74
- else if (dnsHeader->QR == DNS_QR_QUERY)
75
- {
76
- replyWithCustomCode (buffer, packetSize);
102
+ start += 1 ; // Skip the 0 length label that we found above
103
+
104
+ memcpy (&qtype, start, sizeof (qtype));
105
+ start += 2 ;
106
+ memcpy (&qclass, start, sizeof (qclass));
107
+ start += 2 ;
108
+
109
+ queryLength = start - query;
110
+
111
+ if (qclass != lwip_htons (DNS_QCLASS_ANY)
112
+ && qclass != lwip_htons (DNS_QCLASS_IN))
113
+ return replyWithError (dnsHeader, DNSReplyCode::NonExistentDomain,
114
+ query, queryLength);
115
+
116
+ if (qtype != lwip_htons (DNS_QTYPE_A)
117
+ && qtype != lwip_htons (DNS_QTYPE_ANY))
118
+ return replyWithError (dnsHeader, DNSReplyCode::NonExistentDomain,
119
+ query, queryLength);
120
+
121
+ // If we have no domain name configured, just return an error
122
+ if (_domainName == " " )
123
+ return replyWithError (dnsHeader, _errorReplyCode,
124
+ query, queryLength);
125
+
126
+ // If we're running with a wildcard we can just return a result now
127
+ if (_domainName == " *" )
128
+ return replyWithIP (dnsHeader, query, queryLength);
129
+
130
+ matchString = _domainName.c_str ();
131
+
132
+ start = query;
133
+
134
+ // If there's a leading 'www', skip it
135
+ if (*start == 3 && strncasecmp (" www" , (char *) start + 1 , 3 ) == 0 )
136
+ start += 4 ;
137
+
138
+ while (*start != 0 ) {
139
+ labelLength = *start;
140
+ start += 1 ;
141
+ while (labelLength > 0 ) {
142
+ if (tolower (*start) != *matchString)
143
+ return replyWithError (dnsHeader, _errorReplyCode,
144
+ query, queryLength);
145
+ ++start;
146
+ ++matchString;
147
+ --labelLength;
77
148
}
149
+ if (*start == 0 && *matchString == ' \0 ' )
150
+ return replyWithIP (dnsHeader, query, queryLength);
78
151
79
- free (buffer);
152
+ if (*matchString != ' .' )
153
+ return replyWithError (dnsHeader, _errorReplyCode,
154
+ query, queryLength);
155
+ ++matchString;
80
156
}
81
- }
82
157
83
- bool DNSServer::requestIncludesOnlyOneQuestion (const DNSHeader* dnsHeader)
84
- {
85
- return lwip_ntohs (dnsHeader->QDCount ) == 1 &&
86
- dnsHeader->ANCount == 0 &&
87
- dnsHeader->NSCount == 0 &&
88
- dnsHeader->ARCount == 0 ;
158
+ return replyWithError (dnsHeader, _errorReplyCode,
159
+ query, queryLength);
89
160
}
90
161
91
- String DNSServer::getDomainNameWithoutWwwPrefix ( const uint8_t * buffer, size_t packetSize )
162
+ void DNSServer::processNextRequest ( )
92
163
{
93
- String parsedDomainName;
94
-
95
- const uint8_t * pos = buffer + sizeof (DNSHeader);
96
- const uint8_t * end = buffer + packetSize;
97
-
98
- // to minimize reallocations due to concats below
99
- // we reserve enough space that a median or average domain
100
- // name size cold be easily contained without a reallocation
101
- // - max size would be 253, in 2013, average is 11 and max was 42
102
- //
103
- parsedDomainName.reserve (32 );
104
-
105
- uint8_t labelLength = *pos;
106
-
107
- while (true )
108
- {
109
- if (labelLength == 0 )
110
- {
111
- // no more labels
112
- downcaseAndRemoveWwwPrefix (parsedDomainName);
113
- return parsedDomainName;
114
- }
164
+ size_t currentPacketSize;
115
165
116
- // append next label
117
- for (int i = 0 ; i < labelLength && pos < end; i++)
118
- {
119
- pos++;
120
- parsedDomainName += static_cast <char >(*pos);
121
- }
166
+ currentPacketSize = _udp.parsePacket ();
167
+ if (currentPacketSize == 0 )
168
+ return ;
122
169
123
- if (pos >= end)
124
- {
125
- // malformed packet, return an empty domain name
126
- parsedDomainName = " " ;
127
- return parsedDomainName;
128
- }
129
- else
130
- {
131
- // next label
132
- pos++;
133
- labelLength = *pos;
134
-
135
- // if there is another label, add delimiter
136
- if (labelLength != 0 )
137
- {
138
- parsedDomainName += " ." ;
139
- }
140
- }
141
- }
170
+ // The DNS RFC requires that DNS packets be less than 512 bytes in size,
171
+ // so just discard them if they are larger
172
+ if (currentPacketSize > MAX_DNS_PACKETSIZE)
173
+ return ;
174
+
175
+ // If the packet size is smaller than the DNS header, then someone is
176
+ // messing with us
177
+ if (currentPacketSize < DNS_HEADER_SIZE)
178
+ return ;
179
+
180
+ std::unique_ptr<uint8_t []> buffer (new (std::nothrow) uint8_t [currentPacketSize]);
181
+
182
+ if (buffer == NULL )
183
+ return ;
184
+
185
+ _udp.read (buffer.get (), currentPacketSize);
186
+ respondToRequest (buffer.get (), currentPacketSize);
142
187
}
143
188
144
- void DNSServer::replyWithIP (uint8_t * buffer, size_t packetSize)
189
+ void DNSServer::replyWithIP (DNSHeader *dnsHeader,
190
+ unsigned char * query,
191
+ size_t queryLength)
145
192
{
146
- DNSHeader* dnsHeader = reinterpret_cast <DNSHeader*>(buffer);
147
-
148
193
dnsHeader->QR = DNS_QR_RESPONSE;
149
- dnsHeader->ANCount = dnsHeader->QDCount ;
150
- dnsHeader->QDCount = dnsHeader->QDCount ;
151
- // dnsHeader->RA = 1;
194
+ dnsHeader->QDCount = lwip_htons (1 );
195
+ dnsHeader->ANCount = lwip_htons (1 );
196
+ dnsHeader->NSCount = 0 ;
197
+ dnsHeader->ARCount = 0 ;
198
+
199
+ // _dnsHeader->RA = 1;
152
200
153
201
_udp.beginPacket (_udp.remoteIP (), _udp.remotePort ());
154
- _udp.write (buffer, packetSize);
202
+ _udp.write ((unsigned char *) dnsHeader, sizeof (DNSHeader));
203
+ _udp.write (query, queryLength);
155
204
156
205
_udp.write ((uint8_t )192 ); // answer name is a pointer
157
206
_udp.write ((uint8_t )12 ); // pointer to offset at 0x00c
@@ -169,27 +218,32 @@ void DNSServer::replyWithIP(uint8_t* buffer, size_t packetSize)
169
218
_udp.write ((uint8_t )4 );
170
219
_udp.write (_resolvedIP, sizeof (_resolvedIP));
171
220
_udp.endPacket ();
172
-
173
- #ifdef DEBUG_ESP_DNS
174
- DEBUG_OUTPUT.printf (" DNS responds: %s for %s\n " ,
175
- IPAddress (_resolvedIP).toString ().c_str (), getDomainNameWithoutWwwPrefix (buffer, packetSize).c_str () );
176
- #endif
177
221
}
178
222
179
- void DNSServer::replyWithCustomCode (uint8_t * buffer, size_t packetSize)
223
+ void DNSServer::replyWithError (DNSHeader *dnsHeader,
224
+ DNSReplyCode rcode,
225
+ unsigned char *query,
226
+ size_t queryLength)
180
227
{
181
- if (packetSize < sizeof (DNSHeader))
182
- {
183
- return ;
184
- }
185
-
186
- DNSHeader* dnsHeader = reinterpret_cast <DNSHeader*>(buffer);
187
-
188
228
dnsHeader->QR = DNS_QR_RESPONSE;
189
- dnsHeader->RCode = (unsigned char )_errorReplyCode;
190
- dnsHeader->QDCount = 0 ;
229
+ dnsHeader->RCode = (unsigned char ) rcode;
230
+ if (query)
231
+ dnsHeader->QDCount = lwip_htons (1 );
232
+ else
233
+ dnsHeader->QDCount = 0 ;
234
+ dnsHeader->ANCount = 0 ;
235
+ dnsHeader->NSCount = 0 ;
236
+ dnsHeader->ARCount = 0 ;
191
237
192
238
_udp.beginPacket (_udp.remoteIP (), _udp.remotePort ());
193
- _udp.write (buffer, sizeof (DNSHeader));
239
+ _udp.write ((unsigned char *)dnsHeader, sizeof (DNSHeader));
240
+ if (query != NULL )
241
+ _udp.write (query, queryLength);
194
242
_udp.endPacket ();
195
243
}
244
+
245
+ void DNSServer::replyWithError (DNSHeader *dnsHeader,
246
+ DNSReplyCode rcode)
247
+ {
248
+ replyWithError (dnsHeader, rcode, NULL , 0 );
249
+ }
0 commit comments