Skip to content

Commit

Permalink
Merge branch 'dev' into hallvictoria/df-unit-testing
Browse files Browse the repository at this point in the history
  • Loading branch information
hallvictoria authored Oct 15, 2024
2 parents 1426076 + 7683200 commit 8fbbaf1
Show file tree
Hide file tree
Showing 11 changed files with 495 additions and 12 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/label.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name: "Pull Request Labeler"
on:
pull_request:
pull_request_target:
paths:
- '**/__init__.py'
jobs:
Expand Down
6 changes: 5 additions & 1 deletion azure/functions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from ._queue import QueueMessage
from ._servicebus import ServiceBusMessage
from ._sql import SqlRow, SqlRowList
from ._mysql import MySqlRow, MySqlRowList

# Import binding implementations to register them
from . import blob # NoQA
Expand All @@ -37,6 +38,7 @@
from . import durable_functions # NoQA
from . import sql # NoQA
from . import warmup # NoQA
from . import mysql # NoQA


__all__ = (
Expand Down Expand Up @@ -67,6 +69,8 @@
'SqlRowList',
'TimerRequest',
'WarmUpContext',
'MySqlRow',
'MySqlRowList',

# Middlewares
'WsgiMiddleware',
Expand Down Expand Up @@ -98,4 +102,4 @@
'BlobSource'
)

__version__ = '1.21.0b3'
__version__ = '1.22.0b3'
71 changes: 71 additions & 0 deletions azure/functions/_mysql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import abc
import collections
import json


class BaseMySqlRow(abc.ABC):

@classmethod
@abc.abstractmethod
def from_json(cls, json_data: str) -> 'BaseMySqlRow':
raise NotImplementedError

@classmethod
@abc.abstractmethod
def from_dict(cls, dct: dict) -> 'BaseMySqlRow':
raise NotImplementedError

@abc.abstractmethod
def __getitem__(self, key):
raise NotImplementedError

@abc.abstractmethod
def __setitem__(self, key, value):
raise NotImplementedError

@abc.abstractmethod
def to_json(self) -> str:
raise NotImplementedError


class BaseMySqlRowList(abc.ABC):
pass


class MySqlRow(BaseMySqlRow, collections.UserDict):
"""A MySql Row.
MySqlRow objects are ''UserDict'' subclasses and behave like dicts.
"""

@classmethod
def from_json(cls, json_data: str) -> 'BaseMySqlRow':
"""Create a MySqlRow from a JSON string."""
return cls.from_dict(json.loads(json_data))

@classmethod
def from_dict(cls, dct: dict) -> 'BaseMySqlRow':
"""Create a MySqlRow from a dict object"""
return cls({k: v for k, v in dct.items()})

def to_json(self) -> str:
"""Return the JSON representation of the MySqlRow"""
return json.dumps(dict(self))

def __getitem__(self, key):
return collections.UserDict.__getitem__(self, key)

def __setitem__(self, key, value):
return collections.UserDict.__setitem__(self, key, value)

def __repr__(self) -> str:
return (
f'<MySqlRow at 0x{id(self):0x}>'
)


class MySqlRowList(BaseMySqlRowList, collections.UserList):
"A ''UserList'' subclass containing a list of :class:'~MySqlRow' objects"
pass
5 changes: 4 additions & 1 deletion azure/functions/decorators/blob.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@ def __init__(self,
**kwargs):
self.path = path
self.connection = connection
self.source = source
if isinstance(source, BlobSource):
self.source = source.value
else:
self.source = source # type: ignore
super().__init__(name=name, data_type=data_type)

@staticmethod
Expand Down
26 changes: 22 additions & 4 deletions azure/functions/decorators/function_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@ def __str__(self):
return self.get_function_json()

def __call__(self, *args, **kwargs):
"""This would allow the Function object to be directly callable and runnable
directly using the interpreter locally.
"""This would allow the Function object to be directly callable
and runnable directly using the interpreter locally.
Example:
@app.route(route="http_trigger")
Expand Down Expand Up @@ -342,8 +342,8 @@ def decorator():
return wrap

def _get_durable_blueprint(self):
"""Attempt to import the Durable Functions SDK from which DF decorators are
implemented.
"""Attempt to import the Durable Functions SDK from which DF
decorators are implemented.
"""

try:
Expand Down Expand Up @@ -3276,6 +3276,8 @@ def assistant_query_input(self,
arg_name: str,
id: str,
timestamp_utc: str,
chat_storage_connection_setting: Optional[str] = "AzureWebJobsStorage", # noqa: E501
collection_name: Optional[str] = "ChatState", # noqa: E501
data_type: Optional[
Union[DataType, str]] = None,
**kwargs) \
Expand All @@ -3288,6 +3290,11 @@ def assistant_query_input(self,
:param timestamp_utc: the timestamp of the earliest message in the chat
history to fetch. The timestamp should be in ISO 8601 format - for
example, 2023-08-01T00:00:00Z.
:param chat_storage_connection_setting: The configuration section name
for the table settings for assistant chat storage. The default value is
"AzureWebJobsStorage".
:param collection_name: The table collection name for assistant chat
storage. The default value is "ChatState".
:param id: The ID of the Assistant to query.
:param data_type: Defines how Functions runtime should treat the
parameter value
Expand All @@ -3305,6 +3312,8 @@ def decorator():
name=arg_name,
id=id,
timestamp_utc=timestamp_utc,
chat_storage_connection_setting=chat_storage_connection_setting, # noqa: E501
collection_name=collection_name,
data_type=parse_singular_param_to_enum(data_type,
DataType),
**kwargs))
Expand All @@ -3318,6 +3327,8 @@ def assistant_post_input(self, arg_name: str,
id: str,
user_message: str,
model: Optional[str] = None,
chat_storage_connection_setting: Optional[str] = "AzureWebJobsStorage", # noqa: E501
collection_name: Optional[str] = "ChatState", # noqa: E501
data_type: Optional[
Union[DataType, str]] = None,
**kwargs) \
Expand All @@ -3331,6 +3342,11 @@ def assistant_post_input(self, arg_name: str,
:param user_message: The user message that user has entered for
assistant to respond to.
:param model: The OpenAI chat model to use.
:param chat_storage_connection_setting: The configuration section name
for the table settings for assistant chat storage. The default value is
"AzureWebJobsStorage".
:param collection_name: The table collection name for assistant chat
storage. The default value is "ChatState".
:param data_type: Defines how Functions runtime should treat the
parameter value
:param kwargs: Keyword arguments for specifying additional binding
Expand All @@ -3348,6 +3364,8 @@ def decorator():
id=id,
user_message=user_message,
model=model,
chat_storage_connection_setting=chat_storage_connection_setting, # noqa: E501
collection_name=collection_name,
data_type=parse_singular_param_to_enum(data_type,
DataType),
**kwargs))
Expand Down
8 changes: 8 additions & 0 deletions azure/functions/decorators/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,14 @@ def __init__(self,
name: str,
id: str,
timestamp_utc: str,
chat_storage_connection_setting: Optional[str] = "AzureWebJobsStorage", # noqa: E501
collection_name: Optional[str] = "ChatState",
data_type: Optional[DataType] = None,
**kwargs):
self.id = id
self.timestamp_utc = timestamp_utc
self.chat_storage_connection_setting = chat_storage_connection_setting
self.collection_name = collection_name
super().__init__(name=name, data_type=data_type)


Expand Down Expand Up @@ -165,12 +169,16 @@ def __init__(self, name: str,
id: str,
user_message: str,
model: Optional[str] = None,
chat_storage_connection_setting: Optional[str] = "AzureWebJobsStorage", # noqa: E501
collection_name: Optional[str] = "ChatState",
data_type: Optional[DataType] = None,
**kwargs):
self.name = name
self.id = id
self.user_message = user_message
self.model = model
self.chat_storage_connection_setting = chat_storage_connection_setting
self.collection_name = collection_name
super().__init__(name=name, data_type=data_type)


Expand Down
78 changes: 78 additions & 0 deletions azure/functions/mysql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import collections.abc
import json
import typing

from azure.functions import _mysql as mysql

from . import meta


class MySqlConverter(meta.InConverter, meta.OutConverter,
binding='mysql'):

@classmethod
def check_input_type_annotation(cls, pytype: type) -> bool:
return issubclass(pytype, mysql.BaseMySqlRowList)

@classmethod
def check_output_type_annotation(cls, pytype: type) -> bool:
return issubclass(pytype, (mysql.BaseMySqlRowList, mysql.BaseMySqlRow))

@classmethod
def decode(cls,
data: meta.Datum,
*,
trigger_metadata) -> typing.Optional[mysql.MySqlRowList]:
if data is None or data.type is None:
return None

data_type = data.type

if data_type in ['string', 'json']:
body = data.value

elif data_type == 'bytes':
body = data.value.decode('utf-8')

else:
raise NotImplementedError(
f'Unsupported payload type: {data_type}')

rows = json.loads(body)
if not isinstance(rows, list):
rows = [rows]

return mysql.MySqlRowList(
(None if row is None else mysql.MySqlRow.from_dict(row))
for row in rows)

@classmethod
def encode(cls, obj: typing.Any, *,
expected_type: typing.Optional[type]) -> meta.Datum:
if isinstance(obj, mysql.MySqlRow):
data = mysql.MySqlRowList([obj])

elif isinstance(obj, mysql.MySqlRowList):
data = obj

elif isinstance(obj, collections.abc.Iterable):
data = mysql.MySqlRowList()

for row in obj:
if not isinstance(row, mysql.MySqlRow):
raise NotImplementedError(
f'Unsupported list type: {type(obj)}, \
lists must contain MySqlRow objects')
else:
data.append(row)

else:
raise NotImplementedError(f'Unsupported type: {type(obj)}')

return meta.Datum(
type='json',
value=json.dumps([dict(d) for d in data])
)
8 changes: 4 additions & 4 deletions tests/decorators/test_blob.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,15 @@ def test_blob_trigger_creation_with_default_specified_source(self):
"name": "req",
"dataType": DataType.UNDEFINED,
"path": "dummy_path",
'source': BlobSource.LOGS_AND_CONTAINER_SCAN,
'source': 'LogsAndContainerScan',
"connection": "dummy_connection"
})

def test_blob_trigger_creation_with_source_as_string(self):
trigger = BlobTrigger(name="req",
path="dummy_path",
connection="dummy_connection",
source=BlobSource.EVENT_GRID,
source="EventGrid",
data_type=DataType.UNDEFINED,
dummy_field="dummy")

Expand All @@ -62,7 +62,7 @@ def test_blob_trigger_creation_with_source_as_string(self):
"name": "req",
"dataType": DataType.UNDEFINED,
"path": "dummy_path",
'source': BlobSource.EVENT_GRID,
'source': 'EventGrid',
"connection": "dummy_connection"
})

Expand All @@ -82,7 +82,7 @@ def test_blob_trigger_creation_with_source_as_enum(self):
"name": "req",
"dataType": DataType.UNDEFINED,
"path": "dummy_path",
'source': BlobSource.EVENT_GRID,
'source': 'EventGrid',
"connection": "dummy_connection"
})

Expand Down
2 changes: 1 addition & 1 deletion tests/decorators/test_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -1628,7 +1628,7 @@ def test_blob_input_binding():
"type": BLOB_TRIGGER,
"name": "req",
"path": "dummy_path",
"source": BlobSource.EVENT_GRID,
"source": 'EventGrid',
"connection": "dummy_conn"
})

Expand Down
8 changes: 8 additions & 0 deletions tests/decorators/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ def test_text_completion_input_valid_creation(self):
def test_assistant_query_input_valid_creation(self):
input = AssistantQueryInput(name="test",
timestamp_utc="timestamp_utc",
chat_storage_connection_setting="AzureWebJobsStorage", # noqa: E501
collection_name="ChatState",
data_type=DataType.UNDEFINED,
id="test_id",
type="assistantQueryInput",
Expand All @@ -66,6 +68,8 @@ def test_assistant_query_input_valid_creation(self):
self.assertEqual(input.get_dict_repr(),
{"name": "test",
"timestampUtc": "timestamp_utc",
"chatStorageConnectionSetting": "AzureWebJobsStorage", # noqa: E501
"collectionName": "ChatState",
"dataType": DataType.UNDEFINED,
"direction": BindingDirection.IN,
"type": "assistantQuery",
Expand Down Expand Up @@ -111,6 +115,8 @@ def test_assistant_post_input_valid_creation(self):
input = AssistantPostInput(name="test",
id="test_id",
model="test_model",
chat_storage_connection_setting="AzureWebJobsStorage", # noqa: E501
collection_name="ChatState",
user_message="test_message",
data_type=DataType.UNDEFINED,
dummy_field="dummy")
Expand All @@ -120,6 +126,8 @@ def test_assistant_post_input_valid_creation(self):
{"name": "test",
"id": "test_id",
"model": "test_model",
"chatStorageConnectionSetting": "AzureWebJobsStorage", # noqa: E501
"collectionName": "ChatState",
"userMessage": "test_message",
"dataType": DataType.UNDEFINED,
"direction": BindingDirection.IN,
Expand Down
Loading

0 comments on commit 8fbbaf1

Please sign in to comment.