Skip to content

Commit

Permalink
Setup fourier cfg and fixed small path issues in utils and train.py
Browse files Browse the repository at this point in the history
  • Loading branch information
AndReGeist committed Mar 4, 2024
1 parent d0b2be0 commit 9cf7dcc
Show file tree
Hide file tree
Showing 7 changed files with 124 additions and 124 deletions.
2 changes: 2 additions & 0 deletions hitchhiking_rotations/cfgs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,5 @@
from .cfg_cube_image_to_pose import get_cfg_cube_image_to_pose
from .cfg_pcd_to_pose import get_cfg_pcd_to_pose
from .cfg_pose_to_cube_image import get_cfg_pose_to_cube_image
from .cfg_pose_to_fourier import get_cfg_pose_to_fourier

44 changes: 24 additions & 20 deletions hitchhiking_rotations/cfgs/cfg_pose_to_fourier.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,57 +4,61 @@
# See LICENSE file in the project root for details.
#
def get_cfg_pose_to_fourier(device, nb, nf):
shared_trainer_cfg = {
cfg = {
"_target_": "hitchhiking_rotations.utils.Trainer",
"lr": 0.001,
"lr": 0.01,
"optimizer": "SGD",
"logger": "${logger}",
"verbose": "${verbose}",
"device": device,
"preprocess_target": "${u:passthrough}",
"postprocess_pred_loss": "${u:passthrough}",
"postprocess_pred_logging": "${u:passthrough}",
"loss": "${u:l2}",
}

return {
"verbose": False,
"batch_size": 32,
"epochs": 5,
"batch_size": 64,
"epochs": 10,
"training_data": {
"_target_": "hitchhiking_rotations.datasets.PoseToFourierDataset",
"mode": "train",
"dataset_size": 800,
"device": device,
"nb": nb,
"nf": nf,
"device": device,
},
"test_data": {
"_target_": "hitchhiking_rotations.datasets.PoseToFourierDataset",
"mode": "test",
"dataset_size": 1000,
"device": device,
"nb": nb,
"nf": nf,
"device": device,
},
"val_data": {
"_target_": "hitchhiking_rotations.datasets.PoseToFourierDataset",
"mode": "val",
"dataset_size": 800,
"device": device,
"nb": nb,
"nf": nf,
"device": device,
},
"model9": {"_target_": "hitchhiking_rotations.models.MLP", "input_dim": 12288, "output_dim": 9},
# Maybe here we have to also change the logger - but the l2 metric may do it
"model9": {"_target_": "hitchhiking_rotations.models.MLP", "input_dim": 9, "output_dim": 1},
"model6": {"_target_": "hitchhiking_rotations.models.MLP", "input_dim": 6, "output_dim": 1},
"model4": {"_target_": "hitchhiking_rotations.models.MLP", "input_dim": 4, "output_dim": 1},
"model3": {"_target_": "hitchhiking_rotations.models.MLP", "input_dim": 3, "output_dim": 1},
"logger": {
"_target_": "hitchhiking_rotations.utils.OrientationLogger",
"metrics": ["l2"],
},
"trainers": {
"r9_l1": {
**shared_trainer_cfg,
**{
"preprocess_input": "${u:flatten}",
"preprocess_target": "${u:passthrough}",
"postprocess_pred_loss": "${u:procrustes_to_rotmat}",
"postprocess_pred_logging": "${u:procrustes_to_rotmat}",
"loss": "${u:l1}",
"model": "${model9}",
},
}
"r9_l2": {**cfg, **{"preprocess_input": "${u:flatten}", "model": "${model9}"}},
"r6_l2": {**cfg, **{"preprocess_input": "${u:rotmat_to_gramschmidt_f}", "model": "${model6}"}},
"quat_c_l2": {**cfg, **{"preprocess_input": "${u:rotmat_to_quaternion_canonical}", "model": "${model4}"}},
"quat_rf_l2": {**cfg, **{"preprocess_input": "${u:rotmat_to_quaternion_rand_flip}", "model": "${model4}"}},
"euler_l2": {**cfg, **{"preprocess_input": "${u:rotmat_to_euler}", "model": "${model3}"}},
"rotvec_l2": {**cfg, **{"preprocess_input": "${u:rotmat_to_rotvec}", "model": "${model3}"}},
},
}
2 changes: 1 addition & 1 deletion hitchhiking_rotations/datasets/cube_data_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def __del__(self):


if __name__ == "__main__":
dg = DataGenerator(64, 64)
dg = CubeDataGenerator(64, 64)
img = dg.render_img(np.array([0, 0, 0, 1]))

i1 = Image.fromarray(img)
Expand Down
91 changes: 0 additions & 91 deletions hitchhiking_rotations/datasets/exp4_fourier_data.py

This file was deleted.

97 changes: 89 additions & 8 deletions hitchhiking_rotations/datasets/fourier_dataset.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,95 @@
#
# Copyright (c) 2024, MPI-IS, Jonas Frey, Rene Geist, Mikel Zhobro.
# All rights reserved. Licensed under the MIT license.
# See LICENSE file in the project root for details.
#
import os
from os.path import join

import numpy as np
from scipy.spatial.transform import Rotation
import torch
from torch.utils.data import Dataset
import matplotlib.pyplot as plt
import roma

from hitchhiking_rotations import HITCHHIKING_ROOT_DIR
from hitchhiking_rotations.utils import save_pickle, load_pickle

class PoseToFourierDataset(Dataset):
def __init__(self, mode, nb, nf, device):
pass
"""
Loads data from fourier dataset
"""
def __init__(self, mode, dataset_size, device, nb, nf):

path = join(HITCHHIKING_ROOT_DIR, "assets", "datasets", "fourier_dataset",
f"fourier_dataset_{mode}_nb{nb}_nf{nf}.pkl")

if os.path.exists(path):
dic = load_pickle(path)
quats, features = dic["quats"], dic["features"]
print(f"Loaded fourier_dataset_{mode}_nb{nb}_nf{nf}.pkl: {path}")
else:
quats, features = create_data(N_points=dataset_size, nb=nb, seed=nf)
dic = {"quats": quats, "features": features}
save_pickle(dic, path)
print(f"Saved fourier_dataset_{mode}_nb{nb}_nf{nf}.pkl: {path}")

self.features = torch.from_numpy(features).to(device)
self.quats = torch.from_numpy(quats).to(device)

def __len__(self):
return len(self.features)

def __getitem__(self, idx):
return x, y
return roma.unitquat_to_rotmat(self.quats[idx]).type(torch.float32), self.features[idx]

class random_fourier_function():

def __init__(self, n_basis, seed, A0=0., L=1.):
np.random.seed(seed)
self.L = L
self.n_basis = n_basis
self.A0 = A0
self.A = np.random.normal(size=n_basis)
self.B = np.random.normal(size=n_basis)
self.matrix = np.random.normal(size=(1, 9))

def __call__(self, x):
fFs = self.A0 / 2
for k in range(len(self.A)):
fFs = (fFs + self.A[k] * np.cos((k + 1) * np.pi * np.matmul(self.matrix, x) / self.L) +
self.B[k] * np.sin((k + 1) * np.pi * np.matmul(self.matrix, x) / self.L))
return fFs

def create_data(N_points, nb, seed):
"""
Create data from fourier series.
Args:
N_points: Number of random rotations to generate
nb: Number of fourier basis that form the target function
seed: Used to randomly initialize fourier function coefficients
Returns:
rots: Random rotations
features: Target function evaluated at rots
"""
np.random.seed(seed)
rots = Rotation.random(N_points)
inputs = rots.as_matrix().reshape(N_points, -1)
four_func = random_fourier_function(nb, seed)
features = np.apply_along_axis(four_func, 1, inputs)
return rots.as_quat().astype(np.float32), features.astype(np.float32)

def plot_fourier_data(rotations, features):
import pandas as pd
import seaborn as sns

data = np.c_[rotations, features]
df = pd.DataFrame(data)
sns.set(style="ticks")

g = sns.PairGrid(df, diag_sharey=True)

g.map_upper(sns.scatterplot, s=15)
g.map_lower(sns.kdeplot)
g.map_diag(sns.kdeplot, lw=2)
g.set(xlim=(-1.2, 1.2), ylim=(-1.2, 1.2))
plt.show()

if __name__ == "__main__":
create_data(N_points=100, nb=2, seed=5)
1 change: 1 addition & 0 deletions hitchhiking_rotations/utils/loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def save_pickle(cfg, path: str):
cfg (dict): Configuration
path (str): File path
"""
os.makedirs(os.path.dirname(path), exist_ok=True)
with open(path, "wb") as file:
pickle.dump(cfg, file, protocol=pickle.HIGHEST_PROTOCOL)

Expand Down
11 changes: 7 additions & 4 deletions scripts/train.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from hitchhiking_rotations import HITCHHIKING_ROOT_DIR
from hitchhiking_rotations.utils import save_pickle
from hitchhiking_rotations.cfgs import get_cfg_pcd_to_pose, get_cfg_cube_image_to_pose, get_cfg_pose_to_cube_image
from hitchhiking_rotations.cfgs import (get_cfg_pcd_to_pose, get_cfg_cube_image_to_pose, get_cfg_pose_to_cube_image,
get_cfg_pose_to_fourier)

import numpy as np
import argparse
Expand All @@ -14,7 +15,7 @@

parser = argparse.ArgumentParser()

fourier_choices = ["pose_to_fourier_{idx}" for idx in range(1, 8)]
fourier_choices = [f"pose_to_fourier_{idx}" for idx in range(1, 7)]

parser.add_argument(
"--experiment",
Expand All @@ -23,7 +24,9 @@
default="pose_to_cube_image",
help="Experiment Configuration",
)
parser.add_argument("--seed", type=int, default=0, help="number of seeds")
parser.add_argument("--seed", type=int, default=0,
help="Random seed used during training, " +
"for pose_to_fourier the seed is used to select the target function.")
args = parser.parse_args()

s = args.seed
Expand All @@ -39,7 +42,7 @@
cfg_exp = get_cfg_pose_to_cube_image(device)

elif args.experiment.find("pose_to_fourier") != -1:
cfg_exp = get_cfg_pose_to_fourier(device, nf=seed, nb=int(arg.experiment.split("_")[-1]))
cfg_exp = get_cfg_pose_to_fourier(device, nf=s, nb=int(args.experiment.split("_")[-1]))

elif args.experiment == "pcd_to_pose":
cfg_exp = get_cfg_pcd_to_pose(device)
Expand Down

0 comments on commit 9cf7dcc

Please sign in to comment.