From 18074c08ca3e9ab01f385d6794cfcb6b5fff2c73 Mon Sep 17 00:00:00 2001 From: JustGlowing Date: Thu, 22 Aug 2024 13:45:13 +0100 Subject: [PATCH] fixed points added --- minisom.py | 49 +++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 47 insertions(+), 2 deletions(-) diff --git a/minisom.py b/minisom.py index c85a8b3..5a685c0 100644 --- a/minisom.py +++ b/minisom.py @@ -448,7 +448,8 @@ def pca_weights_init(self, data): c2*pc[pc_order[1]] def train(self, data, num_iteration, - random_order=False, verbose=False, use_epochs=False): + random_order=False, verbose=False, + use_epochs=False, fixed_points=None): """Trains the SOM. Parameters @@ -473,6 +474,12 @@ def train(self, data, num_iteration, If True the SOM will be trained for num_iteration epochs. In one epoch the weights are updated len(data) times and the learning rate is constat throughout a single epoch. + + fixed_points : dict (default=None) + A dictionary k : (c_1, c_2), that will force the + training algorithm to use the neuron with coordinates + (c_1, c_2) as winner for the samples k instead of + the best matching unit. """ self._check_iteration_number(num_iteration) self._check_input_len(data) @@ -488,9 +495,31 @@ def get_decay_rate(iteration_index, data_len): else: def get_decay_rate(iteration_index, data_len): return int(iteration_index) + + if fixed_points: + for k in fixed_points.keys(): + if not isinstance(k, int): + raise TypeError(f'fixed points indexes must ' + + 'be integers.') + if k >= len(data) or k < 0: + raise ValueError(f'an index of a fixed point ' + + 'cannot be grater than len(data)' + + ' or less than 0.') + if fixed_points[k][0] >= self._weights.shape[0] or \ + fixed_points[k][1] >= self._weights.shape[1]: + raise ValueError(f'coordinates for fixed point' + + ' are out of boundaries.') + if fixed_points[k][0] < 0 or \ + fixed_points[k][1] < 0: + raise ValueError(f'coordinates cannot be negative.') + else: + fixed_points = {} + for t, iteration in enumerate(iterations): decay_rate = get_decay_rate(t, len(data)) - self.update(data[iteration], self.winner(data[iteration]), + self.update(data[iteration], + fixed_points.get(iteration, + self.winner(data[iteration])), decay_rate, num_iteration) if verbose: print('\n quantization error:', self.quantization_error(data)) @@ -910,6 +939,22 @@ def test_train_use_epochs(self): som.train(data, 10, use_epochs=True) assert q1 > som.quantization_error(data) + def test_train_fixed_points(self): + som = MiniSom(5, 5, 2, sigma=1.0, learning_rate=0.5, random_seed=1) + data = array([[4, 2], [3, 1]]) + som.train(data, 10, fixed_points={0: (0, 0)}) + with self.assertRaises(ValueError): + som.train(data, 10, fixed_points={0: (5, 0)}) + with self.assertRaises(ValueError): + som.train(data, 10, fixed_points={2: (0, 0)}) + with self.assertRaises(ValueError): + som.train(data, 10, fixed_points={0: (-1, 0)}) + with self.assertRaises(ValueError): + som.train(data, 10, fixed_points={-1: (0, 0)}) + with self.assertRaises(TypeError): + som.train(data, 10, fixed_points={'oops': (0, 0)}) + + def test_use_epochs_variables(self): len_data = 100000 num_epochs = 100