Skip to content

Commit

Permalink
fixed points added
Browse files Browse the repository at this point in the history
  • Loading branch information
JustGlowing committed Aug 22, 2024
1 parent df1ff85 commit 18074c0
Showing 1 changed file with 47 additions and 2 deletions.
49 changes: 47 additions & 2 deletions minisom.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,8 @@ def pca_weights_init(self, data):
c2*pc[pc_order[1]]

def train(self, data, num_iteration,
random_order=False, verbose=False, use_epochs=False):
random_order=False, verbose=False,
use_epochs=False, fixed_points=None):
"""Trains the SOM.
Parameters
Expand All @@ -473,6 +474,12 @@ def train(self, data, num_iteration,
If True the SOM will be trained for num_iteration epochs.
In one epoch the weights are updated len(data) times and
the learning rate is constat throughout a single epoch.
fixed_points : dict (default=None)
A dictionary k : (c_1, c_2), that will force the
training algorithm to use the neuron with coordinates
(c_1, c_2) as winner for the samples k instead of
the best matching unit.
"""
self._check_iteration_number(num_iteration)
self._check_input_len(data)
Expand All @@ -488,9 +495,31 @@ def get_decay_rate(iteration_index, data_len):
else:
def get_decay_rate(iteration_index, data_len):
return int(iteration_index)

if fixed_points:
for k in fixed_points.keys():
if not isinstance(k, int):
raise TypeError(f'fixed points indexes must ' +
'be integers.')
if k >= len(data) or k < 0:
raise ValueError(f'an index of a fixed point ' +
'cannot be grater than len(data)' +
' or less than 0.')
if fixed_points[k][0] >= self._weights.shape[0] or \
fixed_points[k][1] >= self._weights.shape[1]:
raise ValueError(f'coordinates for fixed point' +
' are out of boundaries.')
if fixed_points[k][0] < 0 or \
fixed_points[k][1] < 0:
raise ValueError(f'coordinates cannot be negative.')
else:
fixed_points = {}

for t, iteration in enumerate(iterations):
decay_rate = get_decay_rate(t, len(data))
self.update(data[iteration], self.winner(data[iteration]),
self.update(data[iteration],
fixed_points.get(iteration,
self.winner(data[iteration])),
decay_rate, num_iteration)
if verbose:
print('\n quantization error:', self.quantization_error(data))
Expand Down Expand Up @@ -910,6 +939,22 @@ def test_train_use_epochs(self):
som.train(data, 10, use_epochs=True)
assert q1 > som.quantization_error(data)

def test_train_fixed_points(self):
som = MiniSom(5, 5, 2, sigma=1.0, learning_rate=0.5, random_seed=1)
data = array([[4, 2], [3, 1]])
som.train(data, 10, fixed_points={0: (0, 0)})
with self.assertRaises(ValueError):
som.train(data, 10, fixed_points={0: (5, 0)})
with self.assertRaises(ValueError):
som.train(data, 10, fixed_points={2: (0, 0)})
with self.assertRaises(ValueError):
som.train(data, 10, fixed_points={0: (-1, 0)})
with self.assertRaises(ValueError):
som.train(data, 10, fixed_points={-1: (0, 0)})
with self.assertRaises(TypeError):
som.train(data, 10, fixed_points={'oops': (0, 0)})


def test_use_epochs_variables(self):
len_data = 100000
num_epochs = 100
Expand Down

0 comments on commit 18074c0

Please sign in to comment.