Skip to content

Commit

Permalink
Merge pull request #3 from nside/get-improvements
Browse files Browse the repository at this point in the history
Get improvements
  • Loading branch information
nside authored Jul 27, 2023
2 parents 1eec4b8 + e1c7975 commit 15a3be0
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 7 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Byte-compiled / optimized / DLL files
*.swp
__pycache__/
*.py[cod]
*$py.class
Expand Down
14 changes: 9 additions & 5 deletions sqlite2rest/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,22 @@ def get_primary_key(self, table_name):
return column[1] # The 2nd item in the tuple is the column name
return None

def get_records(self, table_name):
self.cursor.execute(f"SELECT * FROM {table_name};")
def get_records(self, table_name, page, per_page):
offset = (page - 1) * per_page
self.cursor.execute(f"SELECT * FROM {table_name} LIMIT ? OFFSET ?;", (per_page, offset))
col_names = [description[0] for description in self.cursor.description]
records = [dict(zip(col_names, record)) for record in self.cursor.fetchall()]
return records

def get_record(self, table_name, key):
primary_key = self.get_primary_key(table_name)
self.cursor.execute(f"SELECT * FROM {table_name} WHERE {primary_key} = ?;", (key,))
col_names = [description[0] for description in self.cursor.description]
record = dict(zip(col_names, self.cursor.fetchone()))
return record
row = self.cursor.fetchone()
if row is None:
return None
col_names = [column[0] for column in self.cursor.description]
return dict(zip(col_names, row))


def create_record(self, table_name, data):
columns = ', '.join(data.keys())
Expand Down
17 changes: 15 additions & 2 deletions sqlite2rest/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,24 @@
def setup_routes(app, tables, get_database):
def create_get_records_fn(table_name):
def get_records():
app.logger.info(f'Getting records for table {table_name}')
records = get_database().get_records(table_name)
page = request.args.get('page', default=1, type=int)
per_page = request.args.get('per_page', default=10, type=int)
app.logger.info(f'Getting records for table {table_name}, page {page}, {per_page} per page')
records = get_database().get_records(table_name, page, per_page)
return jsonify(records), 200, {'Content-Type': 'application/json'}
get_records.__name__ = f'get_records_{table_name}'
return get_records

def create_get_record_fn(table_name):
def get_record(id):
app.logger.info(f'Getting record with id {id} from table {table_name}')
record = get_database().get_record(table_name, id)
if record is None:
return jsonify({'message': 'Record not found.'}), 404, {'Content-Type': 'application/json'}
return jsonify(record), 200, {'Content-Type': 'application/json'}
get_record.__name__ = f'get_record_{table_name}'
return get_record

def create_create_record_fn(table_name):
def create_record():
data = request.get_json()
Expand Down Expand Up @@ -37,6 +49,7 @@ def delete_record(id):
return delete_record

for table_name in tables:
app.add_url_rule(f'/{table_name}/<id>', 'get_record_'+table_name, create_get_record_fn(table_name), methods=['GET'])
app.add_url_rule(f'/{table_name}', 'get_records_'+table_name, create_get_records_fn(table_name), methods=['GET'])
app.add_url_rule(f'/{table_name}', 'create_record_'+table_name, create_create_record_fn(table_name), methods=['POST'])
app.add_url_rule(f'/{table_name}/<id>', 'update_record_'+table_name, create_update_record_fn(table_name), methods=['PUT'])
Expand Down
16 changes: 16 additions & 0 deletions tests/test_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,19 @@ def test_delete(self):
response = self.client.delete('/Artist/3')
self.assertEqual(response.status_code, 200)
self.assertEqual(json.loads(response.data), {'message': 'Record deleted.'})

def test_get_single_record(self):
# First, create a record to get
self.client.post('/Artist', json={'ArtistId': 4, 'Name': 'Test Artist'})

# Then, get the record
response = self.client.get('/Artist/4')
self.assertEqual(response.status_code, 200)
artist = json.loads(response.data)
self.assertEqual(artist, {'ArtistId': 4, 'Name': 'Test Artist'})

def test_get_single_record_not_found(self):
# Try to get a record that does not exist
response = self.client.get('/Artist/999')
self.assertEqual(response.status_code, 404)
self.assertEqual(json.loads(response.data), {'message': 'Record not found.'})

0 comments on commit 15a3be0

Please sign in to comment.