DartでDNSクエリを投げてみる

ゴールデンウィーク中に書いているコードで、DNSクエリを一から書いてみるかということで雑に実装したコードを残しておきます。
ちゃんとした仕様通りの実装をしているわけではないので解説はしませんが、UDP通信したい人や別のものを参考にする場合には利用できるかもしれません。

実行する場合には dart main.dart zuki.dev A のように指定すると動作します。

import 'dart:convert';
import 'dart:async';
import 'dart:io';
import 'dart:typed_data';

class RecordType {
  final int flag;
  final Function decode;

  RecordType._(this.flag, this.decode);

  static RecordType fromFlag(int flag) {
    return [
      A,
      AAAA,
      NS,
      TXT,
    ].firstWhere((type) => type.flag == flag);
  }

  static RecordType fromName(String name) {
    switch (name) {
      case 'A':
        return A;
      case 'AAAA':
        return AAAA;
      case 'NS':
        return NS;
      case 'TXT':
        return TXT;
      default:
        throw Exception('Unknown record type: ${name}');
    }
  }

  static final A = RecordType._(
    0x01,
    (ByteData data, int offset) {
      final octets = List.generate(4, (index) => data.getUint8(offset + index));
      final address = InternetAddress.fromRawAddress(Uint8List.fromList(octets));
      return address.address;
    },
  );
  static final AAAA = RecordType._(
    0x1c,
    (ByteData data, int offset) {
      final octets = List.generate(16, (index) => data.getUint8(offset + index));
      final address = InternetAddress.fromRawAddress(Uint8List.fromList(octets));
      return address.address;
    },
  );
  static final NS = RecordType._(
    0x02,
    (ByteData data, int offset) {
      final name = _decodeDomainName(data, offset);
      return name;
    },
  );
  static final TXT = RecordType._(
    0x10,
    (ByteData data, int offset) {
      final length = data.getUint8(offset);
      final text = utf8.decode(data.buffer.asUint8List(offset + 1, length));
      return text;
    },
  );

  static String _decodeDomainName(ByteData data, int offset) {
    final labels = <String>[];
    int current = offset;

    while (true) {
      if (current >= data.lengthInBytes) {
        break;
      }
      final length = data.getUint8(current);
      if (length == 0) {
        break;
      }

      if ((length & 0xc0) == 0xc0) {
        final pointer = ((length & 0x3f) << 8) + data.getUint8(current + 1);
        labels.addAll(_decodeDomainName(data, pointer).split('.'));
        current += 2;
        break;
      }

      final label = utf8.decode(data.buffer.asUint8List(current + 1, length));
      labels.add(label);
      current += length + 1;
    }

    return labels.join('.');
  }
}

enum ClassCode {
  IN(0x01),
  ;

  final int flag;

  const ClassCode(this.flag);
}

class DNS {
  List<String> nameServers = const ['1.1.1.1', '1.0.0.1'];
  int port = 53;

  DNS({
    this.nameServers = const ['1.1.1.1', '1.0.0.1'],
    port = 53,
  });

  Future<void> request(String name, RecordType type) async {
    final socket = await RawDatagramSocket.bind(InternetAddress.anyIPv4, 0);
    final query = _buildQuery(name, type, ClassCode.IN);

    try {
      socket.send(query, InternetAddress(nameServers[0]), port);
      print('Sent datagram: ${name} to ${nameServers[0]}:${port}');
      socket.listen((RawSocketEvent event) {
        if (event == RawSocketEvent.read) {
          final datagram = socket.receive();
          if (datagram != null) {
            final message = _parseResponse(datagram.data);
            print('Received message:\n${message}');
            socket.close();
          }
        }
      });
    } catch (e) {
      print('Error: $e');
      socket.close();
    }
  }

  Uint8List _buildQuery(String name, RecordType type, ClassCode classCode) {
    final List<int> query = [];

    query.addAll([
      0x12, 0x34, // Transaction ID
      0x01, 0x00, // Flags: Standard query
      0x00, 0x01, // Questions: 1
      0x00, 0x00, // Answer RRs: 0
      0x00, 0x00, // Authority RRs: 0
      0x00, 0x00, // Additional RRs: 0
    ]);

    final List<String> labels = name.split('.');
    for (final label in labels) {
      query.add(label.length);
      query.addAll(utf8.encode(label));
    }
    query.add(0x00); // Null terminator

    query.addAll([
      0x00, type.flag, // Record Type
      0x00, classCode.flag, // Class Code
    ]);

    return Uint8List.fromList(query);
  }

  String _parseResponse(Uint8List response) {
    final ByteData data = ByteData.view(response.buffer);
    int offset = 0;

    final transactionId = data.getUint16(offset);
    offset += 2;
    final flags = data.getUint16(offset);
    offset += 2;
    final questions = data.getUint16(offset);
    offset += 2;
    final answerRRs = data.getUint16(offset);
    offset += 2;
    final authorityRRs = data.getUint16(offset);
    offset += 2;
    final additionalRRs = data.getUint16(offset);
    offset += 2;

    for (int i = 0; i < questions; i++) {
      offset = _skipDomainName(data, offset);
      offset += 4; // Skip Type and Class
    }

    var records = <String>[];
    print('Answer RRs: ${answerRRs}');
    for (int i = 0; i < answerRRs; i++) {
      offset = _skipDomainName(data, offset);
      final type = data.getUint16(offset);
      offset += 2;
      final classCode = data.getUint16(offset);
      offset += 2;
      final ttl = data.getUint32(offset);
      offset += 4;
      final dataLength = data.getUint16(offset);
      offset += 2;

      RecordType? recordType;
      try {
        recordType = RecordType.fromFlag(type);
      } catch (e) {
        print('Error: $e');
        break;
      }

      if (recordType == null) {
        throw Exception('Unknown record type: ${type}');
      }
      final record = recordType.decode(data, offset);
      records.add(record);
      offset += dataLength;
    }

    return records.join('\n');
  }

  int _skipDomainName(ByteData data, int offset) {
    while(true) {
      final length = data.getUint8(offset);
      offset += 1;

      if (length == 0) {
        return offset;
      }
      if ((length & 0xc0) == 0xc0) {
        return offset + 1;
      }
      offset += length;
    }
  }
}

Future<void> main(List<String> args) async {
  final name = args[0];
  final type = RecordType.fromName(args[1]);
  final dns = DNS();
  await dns.request(name, type);
}