Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use aws textract for ocr #935

Merged
merged 4 commits into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ cohere = [
"cohere>=4.11.2"
]
minimal = [
"boto3==1.33.3",
"openai==1.45.0",
"langchain==0.2.16",
"langchain-anthropic==0.1.23",
Expand All @@ -108,6 +109,7 @@ minimal = [
]
all = [
"black",
"boto3==1.33.3",
"bumpver",
"pip-tools",
"pytest",
Expand Down
100 changes: 76 additions & 24 deletions src/autolabel/transforms/image.py
Original file line number Diff line number Diff line change
@@ -1,53 +1,91 @@
from typing import Dict, Any, List
"""Extract text from images using OCR."""

from __future__ import annotations

from typing import Any, ClassVar

from autolabel.transforms.schema import TransformType
from autolabel.transforms import BaseTransform
from autolabel.cache import BaseCache
from autolabel.transforms import BaseTransform
from autolabel.transforms.schema import TransformType


class ImageTransform(BaseTransform):
"""This class is used to extract text from images using OCR. The output columns dictionary for this class should include the keys 'content_column' and 'metadata_column'
"""Extract text from images using OCR.

This transform supports the following image formats: PNG, JPEG, TIFF, JPEG 2000, GIF, WebP, BMP, and PNM
The output columns dictionary for this class should include the keys 'content_column'
and 'metadata_column'.

This transform supports the following image formats: PNG, JPEG, TIFF, JPEG 2000, GIF,
WebP, BMP, and PNM.
"""

COLUMN_NAMES = [
COLUMN_NAMES: ClassVar[list[str]] = [
"content_column",
"metadata_column",
]

def __init__(
self,
cache: BaseCache,
output_columns: Dict[str, Any],
output_columns: dict[str, Any],
file_path_column: str,
lang: str = None,
lang: str | None = None,
) -> None:
"""Initialize the ImageTransform.

Args:
cache: Cache instance to use
output_columns: Dictionary mapping output column names
file_path_column: Column containing image file paths
lang: Optional language for OCR

"""
super().__init__(cache, output_columns)
self.file_path_column = file_path_column
self.lang = lang

try:
from PIL import Image
import pytesseract
from PIL import Image

self.Image = Image
self.pytesseract = pytesseract
self.pytesseract.get_tesseract_version()
except ImportError:
raise ImportError(
"pillow and pytesseract are required to use the image transform with ocr. Please install pillow and pytesseract with the following command: pip install pillow pytesseract"
msg = (
"pillow and pytesseract required to use the image transform with ocr"
"Please install pillow and pytesseract with the following command: "
"pip install pillow pytesseract"
)
except EnvironmentError:
raise EnvironmentError(
"The tesseract engine is required to use the image transform with ocr. Please see https://tesseract-ocr.github.io/tessdoc/Installation.html for installation instructions."
raise ImportError(msg) from None
except OSError:
msg = (
"The tesseract engine is required to use the image transform with ocr. "
"Please see https://tesseract-ocr.github.io/tessdoc/Installation.html "
"for installation instructions."
)
raise OSError(msg) from None

@staticmethod
def name() -> str:
"""Get transform name.

Returns:
Transform type name

"""
return TransformType.IMAGE

def get_image_metadata(self, file_path: str):
def get_image_metadata(self, file_path: str) -> dict[str, Any]:
"""Get metadata from image file.

Args:
file_path: Path to image file

Returns:
Dictionary of image metadata

"""
try:
image = self.Image.open(file_path)
metadata = {
Expand All @@ -59,20 +97,22 @@ def get_image_metadata(self, file_path: str):
"exif": image._getexif(), # Exif metadata
}
return metadata
except Exception as e:
return {"error": str(e)}
except Exception as exc:
return {"error": str(exc)}

async def _apply(self, row: Dict[str, Any]) -> Dict[str, Any]:
"""This function transforms an image into text using OCR.
async def _apply(self, row: dict[str, Any]) -> dict[str, Any]:
"""Transform an image into text using OCR.

Args:
row (Dict[str, Any]): The row of data to be transformed.
row: The row of data to be transformed

Returns:
Dict[str, Any]: The dict of output columns.
Dictionary of output columns

"""
content = self.pytesseract.image_to_string(
row[self.file_path_column], lang=self.lang
row[self.file_path_column],
lang=self.lang,
)
metadata = self.get_image_metadata(row[self.file_path_column])
transformed_row = {
Expand All @@ -81,12 +121,24 @@ async def _apply(self, row: Dict[str, Any]) -> Dict[str, Any]:
}
return transformed_row

def params(self) -> Dict[str, Any]:
def params(self) -> dict[str, Any]:
"""Get transform parameters.

Returns:
Dictionary of parameters

"""
return {
"output_columns": self.output_columns,
"file_path_column": self.file_path_column,
"lang": self.lang,
}

def input_columns(self) -> List[str]:
def input_columns(self) -> list[str]:
"""Get required input columns.

Returns:
List of input column names

"""
return [self.file_path_column]
Loading
Loading