Module ktrain.vision.caption.core
Expand source code
from ...torch_base import TorchBase
from ... import imports as I
class ImageCaptioner(TorchBase):
"""
interface to Image Captioner
"""
def __init__(self, model_name='ydshieh/vit-gpt2-coco-en', device=None):
"""
```
ImageCaptioner constructor
Args:
model_name(str): name of image captioning model
device(str): device to use (e.g., 'cuda', 'cpu')
```
"""
if not I.PIL_INSTALLED:
raise Exception('PIL is not installed. Please install with: pip install pillow>=9.0.1')
super().__init__(device=device, quantize=False, min_transformers_version='4.12.3')
self.model_name = model_name
from transformers import AutoTokenizer, VisionEncoderDecoderModel, ViTFeatureExtractor
self.model = VisionEncoderDecoderModel.from_pretrained(self.model_name).to(self.torch_device)
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
self.extractor = ViTFeatureExtractor.from_pretrained(self.model_name)
def caption(self, images):
"""
```
Performs image captioning
This method supports a single image or a list of images. If the input is an image, the return
type is a string. If text is a list, a list of strings is returned
Args:
images: image|list
Returns:
list of captions
```
"""
# Convert single element to list
values = [images] if not isinstance(images, list) else images
# Open images if file strings
values = [I.Image.open(image) if isinstance(image, str) else image for image in values]
# Feature extraction
pixels = self.extractor(images=values, return_tensors="pt").pixel_values
pixels = pixels.to(self.torch_device)
# Run model
import torch
with torch.no_grad():
outputs = self.model.generate(pixels, max_length=16, num_beams=4, return_dict_in_generate=True).sequences
# Tokenize outputs into text results
captions = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
captions = [caption.strip() for caption in captions]
# Return single element if single element passed in
return captions[0] if not isinstance(images, list) else captions
Classes
class ImageCaptioner (model_name='ydshieh/vit-gpt2-coco-en', device=None)
-
interface to Image Captioner
ImageCaptioner constructor Args: model_name(str): name of image captioning model device(str): device to use (e.g., 'cuda', 'cpu')
Expand source code
class ImageCaptioner(TorchBase): """ interface to Image Captioner """ def __init__(self, model_name='ydshieh/vit-gpt2-coco-en', device=None): """ ``` ImageCaptioner constructor Args: model_name(str): name of image captioning model device(str): device to use (e.g., 'cuda', 'cpu') ``` """ if not I.PIL_INSTALLED: raise Exception('PIL is not installed. Please install with: pip install pillow>=9.0.1') super().__init__(device=device, quantize=False, min_transformers_version='4.12.3') self.model_name = model_name from transformers import AutoTokenizer, VisionEncoderDecoderModel, ViTFeatureExtractor self.model = VisionEncoderDecoderModel.from_pretrained(self.model_name).to(self.torch_device) self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) self.extractor = ViTFeatureExtractor.from_pretrained(self.model_name) def caption(self, images): """ ``` Performs image captioning This method supports a single image or a list of images. If the input is an image, the return type is a string. If text is a list, a list of strings is returned Args: images: image|list Returns: list of captions ``` """ # Convert single element to list values = [images] if not isinstance(images, list) else images # Open images if file strings values = [I.Image.open(image) if isinstance(image, str) else image for image in values] # Feature extraction pixels = self.extractor(images=values, return_tensors="pt").pixel_values pixels = pixels.to(self.torch_device) # Run model import torch with torch.no_grad(): outputs = self.model.generate(pixels, max_length=16, num_beams=4, return_dict_in_generate=True).sequences # Tokenize outputs into text results captions = self.tokenizer.batch_decode(outputs, skip_special_tokens=True) captions = [caption.strip() for caption in captions] # Return single element if single element passed in return captions[0] if not isinstance(images, list) else captions
Ancestors
Methods
def caption(self, images)
-
Performs image captioning This method supports a single image or a list of images. If the input is an image, the return type is a string. If text is a list, a list of strings is returned Args: images: image|list Returns: list of captions
Expand source code
def caption(self, images): """ ``` Performs image captioning This method supports a single image or a list of images. If the input is an image, the return type is a string. If text is a list, a list of strings is returned Args: images: image|list Returns: list of captions ``` """ # Convert single element to list values = [images] if not isinstance(images, list) else images # Open images if file strings values = [I.Image.open(image) if isinstance(image, str) else image for image in values] # Feature extraction pixels = self.extractor(images=values, return_tensors="pt").pixel_values pixels = pixels.to(self.torch_device) # Run model import torch with torch.no_grad(): outputs = self.model.generate(pixels, max_length=16, num_beams=4, return_dict_in_generate=True).sequences # Tokenize outputs into text results captions = self.tokenizer.batch_decode(outputs, skip_special_tokens=True) captions = [caption.strip() for caption in captions] # Return single element if single element passed in return captions[0] if not isinstance(images, list) else captions
Inherited members