diff --git a/kMeans.py b/kMeans.py index 83e5e1b..5c1212f 100644 --- a/kMeans.py +++ b/kMeans.py @@ -44,7 +44,8 @@ def updata_centroids(self, data, classes): self.centroids[i] = np.mean(class_i_data, axis=0) def fit(self, data): - self.centroids = random.sample(data, self.k) + sample_id = np.random.choice(data.shape[0], size=self.k) + self.centroids = data[sample_id] # iteration iter = 0 while iter < self.max_iter: