Skip to content

Commit

Permalink
refactored runtime/runtime_test to use the driver
Browse files Browse the repository at this point in the history
  • Loading branch information
Hugobros3 committed Oct 22, 2023
1 parent eae9ed0 commit 81df278
Show file tree
Hide file tree
Showing 7 changed files with 64 additions and 66 deletions.
1 change: 1 addition & 0 deletions include/shady/driver.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ void destroy_driver_config(DriverConfig*);

void parse_driver_arguments(DriverConfig* args, int* pargc, char** argv);

ShadyErrorCodes driver_load_source_files(DriverConfig* args, Module* mod);
ShadyErrorCodes driver_compile(DriverConfig* args, Module* mod);

#endif
2 changes: 2 additions & 0 deletions include/shady/runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ Device* get_an_device(Runtime*);
const char* get_device_name(Device*);

typedef struct CompilerConfig_ CompilerConfig;
typedef struct Module_ Module;

Program* new_program_from_module(Runtime*, const CompilerConfig*, Module*);
Program* load_program(Runtime*, const CompilerConfig*, const char* program_src);
Program* load_program_from_disk(Runtime*, const CompilerConfig*, const char* path);

Expand Down
16 changes: 16 additions & 0 deletions src/driver/driver.c
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,22 @@ ShadyErrorCodes parse_file_from_filename(const char* filename, Module* mod) {
return err;
}

ShadyErrorCodes driver_load_source_files(DriverConfig* args, Module* mod) {
if (entries_count_list(args->input_filenames) == 0) {
error_print("Missing input file. See --help for proper usage");
return MissingInputArg;
}

size_t num_source_files = entries_count_list(args->input_filenames);
for (size_t i = 0; i < num_source_files; i++) {
int err = parse_file_from_filename(read_list(const char*, args->input_filenames)[i], mod);
if (err)
return err;
}

return NoError;
}

ShadyErrorCodes driver_compile(DriverConfig* args, Module* mod) {
info_print("Parsed program successfully: \n");
log_module(INFO, &args->config, mod);
Expand Down
12 changes: 1 addition & 11 deletions src/driver/slim.c
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,7 @@ int main(int argc, char** argv) {
IrArena* arena = new_ir_arena(default_arena_config());
Module* mod = new_module(arena, "my_module"); // TODO name module after first filename, or perhaps the last one

if (entries_count_list(args.input_filenames) == 0) {
error_print("Missing input file. See --help for proper usage");
exit(MissingInputArg);
}

size_t num_source_files = entries_count_list(args.input_filenames);
for (size_t i = 0; i < num_source_files; i++) {
int err = parse_file_from_filename(read_list(const char*, args.input_filenames)[i], mod);
if (err)
return err;
}
driver_load_source_files(&args, mod);

driver_compile(&args, mod);
info_print("Done\n");
Expand Down
1 change: 1 addition & 0 deletions src/runtime/runtime_private.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ struct Device_ {
struct Program_ {
Runtime* runtime;
const CompilerConfig* base_config;
/// owns the module, may be NULL if module is owned by someone else
IrArena* arena;
Module* module;
};
Expand Down
49 changes: 27 additions & 22 deletions src/runtime/runtime_program.c
Original file line number Diff line number Diff line change
Expand Up @@ -9,44 +9,49 @@
#include <assert.h>
#include <string.h>

static Program* load_program_internal(Runtime* runtime, const CompilerConfig* base_config, const char* program_src, const char* program_path) {
Program* new_program_from_module(Runtime* runtime, const CompilerConfig* base_config, Module* mod) {
Program* program = calloc(1, sizeof(Program));
program->runtime = runtime;
program->base_config = base_config;

CompilerConfig config = default_compiler_config();
ArenaConfig arena_config = default_arena_config();
program->arena = new_ir_arena(arena_config);
CHECK(program->arena != NULL, return false);

program->module = new_module(program->arena, "my_module");

if (!program_src) {
assert(program_path);
bool ok = read_file(program_path, NULL, &program_src);
assert(ok);
} else {
assert(!program_path);
}
CHECK(parse_file(SrcShadyIR, strlen(program_src), program_src, program->module) == CompilationNoError, return false);
if (program_path)
free(program_src);
program->arena = NULL;
program->module = mod;

// TODO split the compilation pipeline into generic and non-generic parts
append_list(Program*, runtime->programs, program);
return program;
}

Program* load_program(Runtime* runtime, const CompilerConfig* base_config, const char* program_src) {
return load_program_internal(runtime, base_config, program_src, NULL);
IrArena* arena = new_ir_arena(default_arena_config());
Module* module = new_module(arena, "my_module");

int err = parse_file(SrcShadyIR, strlen(program_src), program_src, module);
if (err != NoError) {
return NULL;
}

Program* program = new_program_from_module(runtime, base_config, module);
program->arena = arena;
return program;
}

Program* load_program_from_disk(Runtime* runtime, const CompilerConfig* base_config, const char* path) {
return load_program_internal(runtime, base_config, NULL, path);
IrArena* arena = new_ir_arena(default_arena_config());
Module* module = new_module(arena, "my_module");

int err = parse_file_from_filename(path, module);
if (err != NoError) {
return NULL;
}

Program* program = new_program_from_module(runtime, base_config, module);
program->arena = arena;
return program;
}

void unload_program(Program* program) {
// TODO iterate over the specialized stuff
destroy_ir_arena(program->arena);
if (program->arena) // if the program owns an arena
destroy_ir_arena(program->arena);
free(program);
}
49 changes: 16 additions & 33 deletions src/runtime/runtime_test.c
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

#include "log.h"
#include "portability.h"
#include "list.h"
#include "util.h"

#include <stdlib.h>
Expand All @@ -14,14 +13,14 @@

static const char* default_shader =
"@EntryPoint(\"compute\") @WorkgroupSize(SUBGROUP_SIZE, 1, 1) fn main(uniform i32 a, uniform ptr global i32 b) {\n"
" debug_printf(\"hi %d 0x%lx\\n\", a, reinterpret[u64](b));"
" val rb = reinterpret[u64](b);\n"
" debug_printf(\"hi %d 0x%lx\\n\", a, rb);\n"
" return ();\n"
"}";

typedef struct {
CompilerConfig compiler_config;
DriverConfig driver_config;
RuntimeConfig runtime_config;
struct List* input_filenames;
size_t device;
} Args;

Expand Down Expand Up @@ -59,60 +58,44 @@ static void parse_runtime_arguments(int* pargc, char** argv, Args* args) {
int main(int argc, char* argv[]) {
set_log_level(INFO);
Args args = {
.input_filenames = new_list(const char*),
.compiler_config = default_compiler_config(),
.driver_config = default_driver_config(),
};
args.runtime_config = (RuntimeConfig) {
.use_validation = true,
.dump_spv = true,
};
parse_runtime_arguments(&argc, argv, &args);
parse_common_args(&argc, argv);
parse_compiler_config_args(&args.compiler_config, &argc, argv);
parse_input_files(args.input_filenames, &argc, argv);
parse_compiler_config_args(&args.driver_config.config, &argc, argv);
parse_input_files(args.driver_config.input_filenames, &argc, argv);

info_print("Shady runtime test starting...\n");

Runtime* runtime = initialize_runtime(args.runtime_config);
Device* device = get_device(runtime, args.device);
assert(device);
const char* shader = NULL;

// Read the files
size_t num_source_files = entries_count_list(args.input_filenames);
LARRAY(const char*, read_files, num_source_files);
for (size_t i = 0; i < num_source_files; i++) {
char* input_file_contents;

bool ok = read_file(read_list(const char*, args.input_filenames)[i], NULL, &input_file_contents);
assert(ok);
if (input_file_contents == NULL) {
error_print("file does not exist\n");
exit(InputFileDoesNotExist);
}
read_files[i] = (char*)input_file_contents;
}
destroy_list(args.input_filenames);

// TODO handle multiple input files properly !
assert(num_source_files < 2);
if (num_source_files == 1)
shader = read_files[0];
if (!shader)
shader = default_shader;
IrArena* arena = new_ir_arena(default_arena_config());
Module* module = new_module(arena, "my_module");

int err = driver_load_source_files(&args.driver_config, module);
if (err)
return err;
Program* program = new_program_from_module(runtime, &args.driver_config.config, module);

int32_t stuff[] = { 42, 42, 42, 42 };
Buffer* buffer = allocate_buffer_device(device, sizeof(stuff));
copy_to_buffer(buffer, 0, stuff, sizeof(stuff));
copy_from_buffer(buffer, 0, stuff, sizeof(stuff));

Program* program = load_program(runtime, &args.compiler_config, shader);

int32_t a0 = 42;
uint64_t a1 = get_buffer_device_pointer(buffer);
wait_completion(launch_kernel(program, device, "main", 1, 1, 1, 2, (void*[]) { &a0, &a1 }));

destroy_buffer(buffer);

shutdown_runtime(runtime);
destroy_ir_arena(arena);
destroy_driver_config(&args.driver_config);
return 0;
}

0 comments on commit 81df278

Please sign in to comment.