summaryrefslogtreecommitdiff
path: root/szilagyi/_dataset/__init__.py
blob: c91abc384da78e5d04abc6825bc3febd213ca3b1 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import csv
import os
import re


ROOT = os.path.dirname(os.path.abspath(__file__))


class Vector(complex):
	def __getitem__(self, index):
		if index == 0:
			return self.real
		elif index == 1:
			return self.imag
		else:
			raise IndexError


def load():
	def _read(iterable):
		for x, y in iterable:
			yield Vector(float(x), float(y))

	def _load(filename):
		with open(filename) as fd:
			reader = csv.reader(fd)
			return list(_read(reader))

	def _files(directory):
		for filename in os.listdir(directory):
			match = re.match(r"SWI_(-?\d+)\.csv", filename)
			if match:
				yield int(match.group(1)), os.path.join(directory, filename)

	return [(x, _load(y)) for x, y in sorted(_files(ROOT), key=lambda x: x[0])]