Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature: add model path argument #18

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def config():
parser.add_argument("--model", default="clip-flant5-xxl", type=str)
parser.add_argument("--question", default=None, type=str)
parser.add_argument("--answer", default=None, type=str)
parser.add_argument("--model_path", default=None, type=str)
return parser.parse_args()


Expand All @@ -24,7 +25,7 @@ def main():
if not os.path.exists(args.root_dir):
os.makedirs(args.root_dir)

score_func = t2v_metrics.get_score_model(model=args.model, device=args.device, cache_dir=args.cache_dir)
score_func = t2v_metrics.get_score_model(model=args.model, model_path=args.model_path, device=args.device, cache_dir=args.cache_dir)

kwargs = {}
if args.question is not None:
Expand Down
4 changes: 2 additions & 2 deletions t2v_metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
def list_all_models():
return list_all_vqascore_models() + list_all_clipscore_models() + list_all_itmscore_models()

def get_score_model(model='clip-flant5-xxl', device='cuda', cache_dir=HF_CACHE_DIR, **kwargs):
def get_score_model(model='clip-flant5-xxl', model_path=None, device='cuda', cache_dir=HF_CACHE_DIR, **kwargs):
if model in list_all_vqascore_models():
return VQAScore(model, device=device, cache_dir=cache_dir, **kwargs)
return VQAScore(model, model_path=model_path, device=device, cache_dir=cache_dir, **kwargs)
elif model in list_all_clipscore_models():
return CLIPScore(model, device=device, cache_dir=cache_dir, **kwargs)
elif model in list_all_itmscore_models():
Expand Down
8 changes: 4 additions & 4 deletions t2v_metrics/models/vqascore_models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@
def list_all_vqascore_models():
return [model for models in ALL_VQA_MODELS for model in models]

def get_vqascore_model(model_name, device='cuda', cache_dir=HF_CACHE_DIR, **kwargs):
def get_vqascore_model(model_name, model_path=None, device='cuda', cache_dir=HF_CACHE_DIR, **kwargs):
assert model_name in list_all_vqascore_models()
if model_name in CLIP_T5_MODELS:
return CLIPT5Model(model_name, device=device, cache_dir=cache_dir, **kwargs)
return CLIPT5Model(model_name, model_path=model_path, device=device, cache_dir=cache_dir, **kwargs)
elif model_name in LLAVA_MODELS:
return LLaVAModel(model_name, device=device, cache_dir=cache_dir, **kwargs)
return LLaVAModel(model_name, model_path=model_path, device=device, cache_dir=cache_dir, **kwargs)
elif model_name in LLAVA16_MODELS:
return LLaVA16Model(model_name, device=device, cache_dir=cache_dir, **kwargs)
return LLaVA16Model(model_name, model_path=model_path, device=device, cache_dir=cache_dir, **kwargs)
elif model_name in InstructBLIP_MODELS:
return InstructBLIPModel(model_name, device=device, cache_dir=cache_dir, **kwargs)
elif model_name in GPT4V_MODELS:
Expand Down
7 changes: 4 additions & 3 deletions t2v_metrics/models/vqascore_models/clip_t5_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,15 +154,16 @@ def format_answer(answer, conversation_style='plain'):
},
}



class CLIPT5Model(VQAScoreModel):
"""A wrapper for the CLIP-FlanT5 or CLIP-T5 models"""
def __init__(self,
model_name='clip-flant5-xxl',
device='cuda',
cache_dir=HF_CACHE_DIR):
cache_dir=HF_CACHE_DIR,
model_path=None):
assert model_name in CLIP_T5_MODELS
if model_path is not None:
CLIP_T5_MODELS[model_name]['model']['path'] = model_path
super().__init__(model_name=model_name,
device=device,
cache_dir=cache_dir)
Expand Down
5 changes: 4 additions & 1 deletion t2v_metrics/models/vqascore_models/llava16_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,11 @@ class LLaVA16Model(VQAScoreModel):
def __init__(self,
model_name='llava-v1.6-13b',
device='cuda',
cache_dir=HF_CACHE_DIR):
cache_dir=HF_CACHE_DIR,
model_path=None):
assert model_name in LLAVA16_MODELS
if model_path is not None:
LLAVA16_MODELS[model_name]['model']['path'] = model_path
super().__init__(model_name=model_name,
device=device,
cache_dir=cache_dir)
Expand Down
5 changes: 4 additions & 1 deletion t2v_metrics/models/vqascore_models/llava_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,11 @@ class LLaVAModel(VQAScoreModel):
def __init__(self,
model_name='llava-v1.5-13b',
device='cuda',
cache_dir=HF_CACHE_DIR):
cache_dir=HF_CACHE_DIR,
model_path=None):
assert model_name in LLAVA_MODELS
if model_path is not None:
LLAVA_MODELS[model_name]['model']['path'] = model_path
super().__init__(model_name=model_name,
device=device,
cache_dir=cache_dir)
Expand Down
4 changes: 3 additions & 1 deletion t2v_metrics/score.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,20 @@ class Score(nn.Module):
def __init__(self,
model: str,
device: str='cuda',
model_path: str=None,
cache_dir: str=HF_CACHE_DIR,
**kwargs):
"""Initialize the ScoreModel
"""
super().__init__()
assert model in self.list_all_models()
self.device = device
self.model = self.prepare_scoremodel(model, device, cache_dir, **kwargs)
self.model = self.prepare_scoremodel(model, model_path, device, cache_dir, **kwargs)

@abstractmethod
def prepare_scoremodel(self,
model: str,
model_path: str,
device: str,
cache_dir: str,
**kwargs):
Expand Down
2 changes: 2 additions & 0 deletions t2v_metrics/vqascore.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@
class VQAScore(Score):
def prepare_scoremodel(self,
model='clip-flant5-xxl',
model_path=None,
device='cuda',
cache_dir=HF_CACHE_DIR,
**kwargs):
return get_vqascore_model(
model,
model_path=model_path,
device=device,
cache_dir=cache_dir,
**kwargs
Expand Down