Gowalla Preprocess: Example on how to use fast kmeans
Use Kmeans to cut slots on geo information
Load the gowalla data, with geo location
import numpy as np
import pandas as pd
df = pd.read_csv("/data/gowalla_admin.csv")
Shuffle the data with df.sample(frac=1.)
.
Retrieve the geo information to numpy array
df = df.sample(frac=1.)
geo_array = df[["latitude","longitude"]].as_matrix()
Use The kmeans_core.
Set k=200, put in the array, set batch size to 800,000
from ray.kmean_torch import kmeans_core
km = kmeans_core(200,geo_array,batch_size=8e5)
Run the calculation.
Each epoch takes 3 seconds, on a NVIDIA GTX 1070 GPU
km.run()
π₯[epoch:0 iter:8]π₯ π₯k:200 π₯distance:4665.021: 100%|ββββββββββ| 9/9 [00:03<00:00, 2.91it/s]
π₯[epoch:1 iter:8]π₯ π₯k:200 π₯distance:4668.177: 100%|ββββββββββ| 9/9 [00:03<00:00, 2.96it/s]
π₯[epoch:2 iter:8]π₯ π₯k:200 π₯distance:5084.151: 100%|ββββββββββ| 9/9 [00:03<00:00, 2.95it/s]
π₯[epoch:3 iter:8]π₯ π₯k:200 π₯distance:5084.281: 100%|ββββββββββ| 9/9 [00:03<00:00, 2.94it/s]
π₯[epoch:4 iter:8]π₯ π₯k:200 π₯distance:5084.367: 100%|ββββββββββ| 9/9 [00:03<00:00, 2.95it/s]
π₯[epoch:5 iter:8]π₯ π₯k:200 π₯distance:5084.366: 100%|ββββββββββ| 9/9 [00:03<00:00, 2.96it/s]
π₯[epoch:6 iter:8]π₯ π₯k:200 π₯distance:5084.366: 100%|ββββββββββ| 9/9 [00:03<00:00, 2.96it/s]
π₯[epoch:7 iter:8]π₯ π₯k:200 π₯distance:5084.366: 100%|ββββββββββ| 9/9 [00:03<00:00, 2.97it/s]
π₯[epoch:8 iter:8]π₯ π₯k:200 π₯distance:5084.366: 100%|ββββββββββ| 9/9 [00:03<00:00, 2.95it/s]
π₯[epoch:9 iter:8]π₯ π₯k:200 π₯distance:5084.366: 100%|ββββββββββ| 9/9 [00:03<00:00, 2.96it/s]
π₯[epoch:10 iter:8]π₯ π₯k:200 π₯distance:5084.366: 100%|ββββββββββ| 9/9 [00:03<00:00, 2.94it/s]
π₯[epoch:11 iter:8]π₯ π₯k:200 π₯distance:5084.365: 100%|ββββββββββ| 9/9 [00:03<00:00, 2.95it/s]
π₯[epoch:12 iter:8]π₯ π₯k:200 π₯distance:5084.366: 100%|ββββββββββ| 9/9 [00:03<00:00, 2.90it/s]
0%| | 0/9 [00:00<?, ?it/s]
Centroids is not shifting anymore
100%|ββββββββββ| 9/9 [00:03<00:00, 2.97it/s]
tensor([ 74, 102, 63, ..., 158, 108, 172], device='cuda:0')
from ray.kmean_torch import kmeans_core
For comparison, 1 epoch took 38 seconds to run on cpu
km = kmeans_core(200,geo_array,batch_size=8e5,epochs=1)
km.run()
π₯[epoch:0 iter:8]π₯ π₯k:200 π₯distance:4665.269: 100%|ββββββββββ| 9/9 [00:38<00:00, 4.23s/it]
100%|ββββββββββ| 9/9 [00:36<00:00, 4.07s/it]
tensor([ 171, 152, 27, ..., 117, 14, 46])
idx_array = km.idx.cpu().numpy()
cen = km.cent.cpu().numpy()
np.save("/data/centroids.npy",cen)
Mapping It Back to DataFrame
df["idx"]=idx_array
cen_lati = cen[:,:1].tolist()
cen_longi = cen[:,1:].tolist()
centroids = list("lati:%s longi%s"%(la[0],lo[0]) for la,lo in zip(cen_lati,cen_longi))
df["cent"] = df["idx"].apply(lambda x:centroids[x])
df.sample(10)
user | time | latitude | longitude | loc_id | name | admin1 | admin2 | cc | idx | cent | |
---|---|---|---|---|---|---|---|---|---|---|---|
3355801 | 46101 | 2010-07-29T16:08:11Z | 52.513544 | 13.319991 | 127070 | Hansaviertel | Berlin | NaN | DE | 90 | lati:51.475975036621094 longi13.54339599609375 |
4753318 | 105757 | 2010-05-22T12:46:45Z | 55.602003 | 13.024511 | 1153250 | Malmoe | Skane | Malmo | SE | 74 | lati:55.60551452636719 longi13.091133117675781 |
2008736 | 17272 | 2010-03-12T10:19:50Z | 50.809053 | 8.810520 | 679546 | Marburg an der Lahn | Hesse | Regierungsbezirk Giessen | DE | 53 | lati:49.7291374206543 longi8.50769329071045 |
3944538 | 67736 | 2010-04-06T18:43:10Z | 39.958952 | -86.009593 | 171442 | Fishers | Indiana | Hamilton County | US | 118 | lati:39.82094955444336 longi-86.19813537597656 |
1504032 | 10388 | 2009-11-13T22:40:32Z | 28.518201 | -81.344147 | 50220 | Conway | Florida | Orange County | US | 164 | lati:28.502212524414062 longi-81.45722198486328 |
5586680 | 129560 | 2010-06-30T11:40:20Z | 59.319182 | 18.076099 | 1359960 | Stockholm | Stockholm | Stockholms Kommun | SE | 50 | lati:59.325557708740234 longi18.073034286499023 |
82823 | 282 | 2010-10-16T14:11:36Z | 32.786529 | -96.794764 | 50853 | Dallas | Texas | Dallas County | US | 28 | lati:32.839134216308594 longi-96.77941131591797 |
2903954 | 36019 | 2010-05-10T01:02:01Z | 40.588843 | -111.940675 | 27137 | West Jordan | Utah | Salt Lake County | US | 192 | lati:40.87124252319336 longi-111.69757080078125 |
2191735 | 19338 | 2010-06-14T15:47:05Z | 25.334211 | 51.467593 | 854133 | Ar Rayyan | Baladiyat ar Rayyan | NaN | QA | 189 | lati:25.092052459716797 longi48.042118072509766 |
1766141 | 13081 | 2010-07-06T17:36:14Z | 56.350999 | 13.733063 | 95606 | Bjarnum | Skane | Hassleholms Kommun | SE | 13 | lati:56.30799865722656 longi14.415539741516113 |
df.to_csv("/data/gowalla_admin_slotted.csv")