diff options
| -rw-r--r-- | 01-knapsack/ks_dp.c | 282 | 
1 files changed, 154 insertions, 128 deletions
| diff --git a/01-knapsack/ks_dp.c b/01-knapsack/ks_dp.c index 25e92cb..1ba27b9 100644 --- a/01-knapsack/ks_dp.c +++ b/01-knapsack/ks_dp.c @@ -1,29 +1,142 @@  #include <stdio.h>  #include <stdlib.h> -#include <limits.h> +#include <string.h> +#include <stdint.h> -int main(int argc, char** argv, char** env) +static int debug = 0; + +typedef uint32_t word_t; +enum { WORD_BITS = sizeof(word_t) * 8 }; + +static inline int bindex(int b) { return b / WORD_BITS; } +static inline int boffset(int b) { return b % WORD_BITS; } + +static inline void set_bit(word_t *bits, int b)  { -   int k;         // capacity -   int n;         // items # -   int *values;   // items values -   int *weights;  // items weights -   int *solution; // solution -   int *matrix;   // solving matrix +   bits[bindex(b)] |= (1 << (boffset(b))); +} + +static inline int get_bit(word_t *bits, int b) +{ +   return ((bits[bindex(b)] & (1 << (boffset(b)))) >> (boffset(b))); +} + +typedef struct _SolverData { +     int v; +     word_t *bits; +} SolverData; + +typedef struct _Solver +{ +   int k;               // capacity +   int n;               // items # +   int *values;         // item values +   int *weights;        // item weights +   int nw;              // # bytes per bits array +   size_t sd_sz;        // solver data size +   SolverData *data;    // solver data +} Solver; + +static void solve(Solver* solver) +{ +   int n, k, nw; +   size_t sd_sz; +   int *values, *weights;     int i, j; +   int v, w; +   char *data_base; +   SolverData *c, *p; + +   sd_sz = solver->sd_sz; +   nw = solver->nw; +   n = solver->n; +   k = solver->k; +   values = solver->values; +   weights = solver->weights; +   data_base = (char *) solver->data; + +   /* SOLVE */ +   for (i = 0; i < n; i++) +     { +        v = values[i]; +        w = weights[i]; +        c = (SolverData *) (data_base + (sd_sz * k)); +        p = (SolverData *) (data_base + (sd_sz * (k - w))); + +        for (j = k; j > 0; j--) +          { +             if (j < w) +               break; +             if ((j >= w) && (c->v < (v + p->v))) +               { +                  c->v = v + p->v; +                  memcpy(c->bits, p->bits, (nw * sizeof(word_t))); +                  set_bit(c->bits, i); +               } +             c = (SolverData *) (((char *) c) - sd_sz); +             p = (SolverData *) (((char *) p) - sd_sz); +          } + +        if (debug) +          { +             printf("i=% 4d : ", i); +             for (j = 0; j <= k; j++) +               { +                  c = (SolverData *) (data_base + (sd_sz * j)); +                  printf("% 4d ", c->v); +               } +             printf("\n"); +          } +     } +} + +static void print(Solver* solver) +{ +   int b; +   int i; +   int v, w; +   SolverData *sol; + +   v = 0; +   w = 0; +   sol = (SolverData *) (((char *) solver->data) + (solver->sd_sz * solver->k)); + +   printf("%d %d\n", sol->v, 1); +   for (i = 0; i < solver->n; i++) +     { +        b = get_bit(sol->bits, i); +        printf("%d ", b); +        if (b) +          { +             v += solver->values[i]; +             w += solver->weights[i]; +          } +     } +   printf("\n"); + +   if (v != sol->v) +     fprintf(stderr, "ERROR: value %d != %d\n", v, sol->v); +   if (w > solver->k) +     fprintf(stderr, "ERROR: weight %d > %d\n", w, solver->k); +} + +int main(int argc, char** argv, char** env) +{ +   FILE *fp; +   Solver solver;    // solver + +   int i;     int *vp, *wp; -   int v, w, tmp; +   SolverData *data;     if(argc < 2)       { -        fprintf(stderr,"input file missing"); +        fprintf(stderr,"input file missing\n");          return EXIT_FAILURE;       } -   FILE *fp; - -   /* printf("%s read %s\n", argv[0], argv[1]); */ +   if (debug) printf("%s read %s\n", argv[0], argv[1]);     fp = fopen(argv[1], "r");     if (fp == NULL)       { @@ -32,151 +145,64 @@ int main(int argc, char** argv, char** env)       }     /* read k and n */ -   if (fscanf(fp, "%d %d\n", &n, &k) != 2) +   if (fscanf(fp, "%d %d\n", &solver.n, &solver.k) != 2)       {          fprintf(stderr, "ERROR: read first line\n");          return EXIT_FAILURE;       } -   /* printf("k:%d n:%d\n", k, n); */ +   if (debug) printf("  K:%d N:%d\n", solver.k, solver.n);     /* allocate */ -   values = calloc(n, sizeof(int)); -   if (!values) +   solver.values = calloc(solver.n, sizeof(int)); +   if (!solver.values)       {          fprintf(stderr, "ERROR: values calloc\n");          return EXIT_FAILURE;       } -   vp = values; +   vp = solver.values; -   weights = calloc(n, sizeof(int)); -   if (!weights) +   solver.weights = calloc(solver.n, sizeof(int)); +   if (!solver.weights)       { -        free(values); +        free(solver.values);          fprintf(stderr, "ERROR: weights calloc\n");          return EXIT_FAILURE;       } -   wp = weights; +   wp = solver.weights; -   solution = calloc(n, sizeof(int)); -   if (!solution) +   solver.nw = (solver.n / WORD_BITS + 1); +   solver.sd_sz = sizeof(SolverData) + (solver.nw * sizeof(word_t)); +   solver.data= calloc((solver.k + 1), solver.sd_sz); +   if (!solver.data)       { -        free(values); -        free(weights); -        fprintf(stderr, "ERROR: solution calloc\n"); +        free(solver.values); +        free(solver.weights); +        fprintf(stderr, "ERROR: solver calloc\n");          return EXIT_FAILURE;       } - -   matrix = calloc(k * n, sizeof(int)); -   if (!matrix) +   for (i = 0; i <= solver.k; i++)       { -        free(values); -        free(weights); -        free(solution); -        fprintf(stderr, "ERROR: matrix calloc\n"); -        return EXIT_FAILURE; +        data = (SolverData *) (((char *)  solver.data) + solver.sd_sz * i); +        data->bits = (word_t *) (((char *) data) + sizeof(SolverData));       }     /* read items */ -   for (i = 0; i < n; i++) +   for (i = 0; i < solver.n; i++)       fscanf(fp, "%d %d\n", vp++, wp++);     fclose(fp); -   /* for (i = 0; i < n; i++) */ -   /*   printf("%d -> v:%d w:%d\n", i, values[i], weights[i]); */ - -   /* SOLVE matrix is [i][j] */ -#define CURRENT   ((i * k) + j) -#define LEFT      (((i - 1) * k) + j) -#define UPLEFT    (((i - 1) * k) + (j - w)) -#define LAST      ((k * n) - 1) -   /* first item */ -   i = 0; -   v = values[i]; -   w = weights[i]; -   for (j = 0; j < k; j++)  // capacities -     { -        if (w <= (j + 1)) -          matrix[CURRENT] = v; -     } -   /* following items */ -   for (i = 1; i < n; i++)       // items -     { -        v = values[i]; -        w = weights[i]; -        for (j = 0; j < k; j++)  // capacities -          { -             // item weight to much -             if (w > j + 1) -               { -                  matrix[CURRENT] = matrix[LEFT]; -               } -             else -               { -                  /* do not go upper than capacity 0 when moving up left */ -                  tmp = v + (((j - w) < 0) ? 0 : matrix[UPLEFT]); -                  if ( tmp > matrix[LEFT]) -                    matrix[CURRENT] = tmp; -                  else -                    matrix[CURRENT] = matrix[LEFT]; -               } -          } -     } - -   /* printf("matrix\n"); */ -   /* for (j = 0; j < k; j++) */ -   /*   { */ -   /*      printf(" %d : ", (j + 1)); */ -   /*      for (i = 0; i < n; i++) */ -   /*        { */ -   /*           printf(" %d", matrix[CURRENT]); */ -   /*        } */ -   /*      printf("\n"); */ -   /*   } */ - -   j = k - 1; -   for (i = n - 1; i >= 0; i--) -     { -        w = weights[i]; -        tmp = matrix[CURRENT]; -        if( (tmp != 0) && ((j - w) >= -1) && ( (i == 0) || (tmp != matrix[LEFT]))) -          { -             solution[i] = 1; -             j -= w; -          } -        else -             solution[i] = 0; -     } - -   v = 0; -   w = 0; -   printf ("%d %d\n", matrix[LAST], 0); -   for (i = 0; i < n; i++) -     { -        printf("%d ", solution[i]); -        if (solution[i]) { -             v += values[i]; -             w += weights[i]; -        } -     } -   printf("\n"); - -   if (v != matrix[LAST]) +   if (debug)       { -        fprintf(stderr, "ERROR: wrong sum of values %d != %d\n", -                v, matrix[LAST]); -        return EXIT_FAILURE; +        printf("     index    value   weight\n"); +        for (i = 0; i < solver.n; i++) +          printf("  % 8d % 8d % 8d\n", i, solver.values[i], solver.weights[i]);       } -   if (w > k) -     { -        fprintf(stderr, "ERROR: wrong sum of weights %d > %d\n", w, k); -        return EXIT_FAILURE; -     } +   solve(&solver); +   print(&solver); -   free(values); -   free(weights); -   free(solution); -   free(matrix); +   free(solver.values); +   free(solver.weights);     return EXIT_SUCCESS;  } | 
