Skip to content

Commit

Permalink
update to make token generation continuous
Browse files Browse the repository at this point in the history
  • Loading branch information
reubenson committed Dec 28, 2023
1 parent f9ebf22 commit b36be6d
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 16 deletions.
38 changes: 26 additions & 12 deletions midi-save/lambda_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
from datetime import datetime
import asyncio
import os
import json

loop = asyncio.get_event_loop()
from_path = '/tmp/midi-sequence.mid'
# from_path = '/tmp/midi-sequence.mid'

def get_file_size(file_path):
return os.path.getsize(file_path)
Expand All @@ -17,10 +18,10 @@ def open_file(file_path):
content = file.read()
return content

def upload_to_s3():
def upload_to_s3(from_path, to_path):
bucket_name = 'midi-archive'
current_date = datetime.now().date()
to_path = f'neural-net/{current_date}_sequence.mid'
# current_date = datetime.now().date()
# to_path = f'neural-net/{current_date}_sequence.mid'

# Upload the file
s3_client = boto3.client('s3')
Expand All @@ -37,13 +38,13 @@ def load_tokenizer():
TOKENIZER_PARAMS = {
"pitch_range": (21, 109),
"beat_res": {(0, 4): 8, (4, 12): 4},
"num_velocities": 2,
"special_tokens": ["PAD", "BOS", "EOS", "MASK"],
"num_velocities": 4,
# "special_tokens": ["PAD", "BOS", "EOS", "MASK"],
"use_chords": True,
"use_rests": True,
"use_tempos": True,
"use_time_signatures": False,
"use_programs": False,
# "use_programs": False,
"num_tempos": 4, # nb of tempo bins
"tempo_range": (40, 250), # (min, max),
"one_token_stream": True,
Expand All @@ -54,17 +55,24 @@ def load_tokenizer():

tokenizer = Structured(config)
return tokenizer

def save_tokens_to_json(tokens, filepath):
with open(filepath, "w") as f:
json.dump(tokens, f)

def save_tokens_to_midi(tokens):
def save_tokens_to_midi(tokens, filepath):
tokenizer = load_tokenizer()
result = tokenizer(tokens)
result.dump('/tmp/midi-sequence.mid')
result.dump(filepath)

def lambda_handler(event, context):
midi_filepath = '/tmp/midi-sequence.mid'
json_filepath = '/tmp/midi-sequence.json'
tokens = event["body"]
save_tokens_to_midi(tokens)
save_tokens_to_midi(tokens, midi_filepath)
save_tokens_to_json(tokens, json_filepath)

size = get_file_size(from_path)
size = get_file_size(midi_filepath)
print(f'The size of the file is {size} bytes')
if size < 50:
# didn't generate a viable MIDI file
Expand All @@ -75,7 +83,13 @@ def lambda_handler(event, context):
'body': 'could not generate a viable MIDI file'
}

upload_to_s3()
current_date = datetime.now().date()
# this version will get over-written daily
upload_to_s3(midi_filepath, 'neural-net/model-prediction.mid')
# this one will be archived
upload_to_s3(midi_filepath, f'neural-net/{current_date}_sequence.mid')
# this sequence will be used to prompt subsequent token generation
upload_to_s3(json_filepath, 'neural-net/token_sequence.json')

return {
'statusCode': 200,
Expand Down
42 changes: 38 additions & 4 deletions neural-net/lambda_function.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,53 @@
import onnxruntime
import numpy as np
import boto3
from botocore.exceptions import ClientError
import logging
import json

vocab_size = 352 # hard-coding for now, but should update to import preconfigured tokenizer
vocab_size = 354 # hard-coding for now, but won't change often

def get_previous_tokens():
token_filepath = '/tmp/previous_token_sequence.json'
download_from_s3('neural-net/token_sequence.json', token_filepath)
with open(token_filepath, 'r') as f:
prev_tokens = json.load(f)

return prev_tokens

def download_from_s3(file_key, save_path):
bucket_name = 'midi-archive'
# current_date = datetime.now().date()
# to_path = f'neural-net/{current_date}_sequence.mid'

# Upload the file
s3_client = boto3.client('s3')
try:
s3_client.download_file(bucket_name, file_key, save_path)
except ClientError as e:
logging.error(e)
return False
return True

def softmax(x):
return(np.exp(x)/np.exp(x).sum())

def generate_tokens(session):
input_shape = session.get_inputs()[0].shape
batch_size, block_size = input_shape
num_tokens = 64 * block_size
num_tokens = 8192 # translates roughly to 3-4min
# randint should be replaced with some actual MIDI data so it's not gibberish to start
context = np.random.randint(vocab_size, size=block_size)
prev_tokens = get_previous_tokens()
print(f'prev_tokens {prev_tokens}')

context = prev_tokens[-block_size:]
# context2 = np.random.randint(vocab_size, size=block_size)
# print(f'context1 {context1}')
# print(f'context2 {context2}')
# return
# context = seed_sequence
outputs = []
for _ in range(num_tokens):
for i in range(num_tokens):
logits = session.run(None, {'input': [context]})[0]
last_logit = logits[0, -1, :] # grab last timestep
# 0.9999 factor is a hacky solution to probability adding up to > 1 due to float math
Expand All @@ -24,6 +57,7 @@ def generate_tokens(session):
outputs.append(token)
context = np.append(context, token)
context = context[-block_size:]
print(f'token i: {i}')
return outputs

def lambda_handler(event, context):
Expand Down

0 comments on commit b36be6d

Please sign in to comment.